Commit 3d03e158 authored by Paul's avatar Paul
Browse files

IMprove implicit casting and add a method to extract a tensor view directly

parent 9a24414f
...@@ -59,6 +59,7 @@ struct raw_data : raw_data_base ...@@ -59,6 +59,7 @@ struct raw_data : raw_data_base
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); }); s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
} }
/// Returns true if the raw data is only one element
bool single() const bool single() const
{ {
auto&& s = static_cast<const Derived&>(*this).get_shape(); auto&& s = static_cast<const Derived&>(*this).get_shape();
...@@ -86,17 +87,31 @@ struct raw_data : raw_data_base ...@@ -86,17 +87,31 @@ struct raw_data : raw_data_base
template <class T> template <class T>
operator T() operator T()
{ {
assert(self->single());
return self->template at<T>(); return self->template at<T>();
} }
template <class T> template <class T>
operator T*() operator T*()
{ {
// TODO: Check type using type = std::remove_cv_t<T>;
return reinterpret_cast<T*>(self->data()); assert((std::is_void<T>{} or std::is_same<char, type>{} or std::is_same<unsigned char, type>{} or self->get_shape().type() == rtg::shape::get_type<T>{}));
return reinterpret_cast<type*>(self->data());
} }
}; };
auto_cast get() const { return {static_cast<const Derived*>(this)}; } /// Implicit conversion of raw data pointer
auto_cast implicit() const { return {static_cast<const Derived*>(this)}; }
/// Get a tensor_view to the data
template<class T>
tensor_view<T> get() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
if(s.type() != rtg::shape::get_type<T>{})
RTG_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer));
}
}; };
template <class T, template <class T,
......
...@@ -115,31 +115,31 @@ struct miopen_convolution ...@@ -115,31 +115,31 @@ struct miopen_convolution
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
int algo_count; int algo_count;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].get(), miopenFindConvolutionForwardAlgorithm(args[0].implicit(),
x_desc.get(), x_desc.get(),
args[1].get(), args[1].implicit(),
w_desc.get(), w_desc.get(),
args[2].get(), args[2].implicit(),
cd.get(), cd.get(),
y_desc.get(), y_desc.get(),
args[3].get(), args[3].implicit(),
1, 1,
&algo_count, &algo_count,
&perf, &perf,
nullptr, nullptr,
0, 0,
false); false);
miopenConvolutionForward(args[0].get(), miopenConvolutionForward(args[0].implicit(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].get(), args[1].implicit(),
w_desc.get(), w_desc.get(),
args[2].get(), args[2].implicit(),
cd.get(), cd.get(),
perf.fwd_algo, perf.fwd_algo,
&beta, &beta,
y_desc.get(), y_desc.get(),
args[3].get(), args[3].implicit(),
nullptr, nullptr,
0); 0);
return args[3]; return args[3];
...@@ -161,14 +161,14 @@ struct miopen_relu ...@@ -161,14 +161,14 @@ struct miopen_relu
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[1].get_shape()); auto x_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
miopenActivationForward(args[0].get(), miopenActivationForward(args[0].implicit(),
ad.get(), ad.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].get(), args[1].implicit(),
&beta, &beta,
y_desc.get(), y_desc.get(),
args[2].get()); args[2].implicit());
return args[2]; return args[2];
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment