"vscode:/vscode.git/clone" did not exist on "3c2c5869ad719d41d87f6aca8a71e683ebcadc76"
Commit 3d03e158 authored by Paul's avatar Paul
Browse files

IMprove implicit casting and add a method to extract a tensor view directly

parent 9a24414f
......@@ -59,6 +59,7 @@ struct raw_data : raw_data_base
s.visit_type([&](auto as) { v(make_view(s, as.from(buffer))); });
}
/// Returns true if the raw data is only one element
bool single() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
......@@ -86,17 +87,31 @@ struct raw_data : raw_data_base
template <class T>
operator T()
{
assert(self->single());
return self->template at<T>();
}
template <class T>
operator T*()
{
// TODO: Check type
return reinterpret_cast<T*>(self->data());
using type = std::remove_cv_t<T>;
assert((std::is_void<T>{} or std::is_same<char, type>{} or std::is_same<unsigned char, type>{} or self->get_shape().type() == rtg::shape::get_type<T>{}));
return reinterpret_cast<type*>(self->data());
}
};
auto_cast get() const { return {static_cast<const Derived*>(this)}; }
/// Implicit conversion of raw data pointer
auto_cast implicit() const { return {static_cast<const Derived*>(this)}; }
/// Get a tensor_view to the data
template<class T>
tensor_view<T> get() const
{
auto&& s = static_cast<const Derived&>(*this).get_shape();
auto&& buffer = static_cast<const Derived&>(*this).data();
if(s.type() != rtg::shape::get_type<T>{})
RTG_THROW("Incorrect data type for raw data");
return make_view(s, reinterpret_cast<T*>(buffer));
}
};
template <class T,
......
......@@ -115,31 +115,31 @@ struct miopen_convolution
float alpha = 1, beta = 0;
int algo_count;
miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].get(),
miopenFindConvolutionForwardAlgorithm(args[0].implicit(),
x_desc.get(),
args[1].get(),
args[1].implicit(),
w_desc.get(),
args[2].get(),
args[2].implicit(),
cd.get(),
y_desc.get(),
args[3].get(),
args[3].implicit(),
1,
&algo_count,
&perf,
nullptr,
0,
false);
miopenConvolutionForward(args[0].get(),
miopenConvolutionForward(args[0].implicit(),
&alpha,
x_desc.get(),
args[1].get(),
args[1].implicit(),
w_desc.get(),
args[2].get(),
args[2].implicit(),
cd.get(),
perf.fwd_algo,
&beta,
y_desc.get(),
args[3].get(),
args[3].implicit(),
nullptr,
0);
return args[3];
......@@ -161,14 +161,14 @@ struct miopen_relu
float alpha = 1, beta = 0;
auto x_desc = make_tensor(args[1].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(args[0].get(),
miopenActivationForward(args[0].implicit(),
ad.get(),
&alpha,
x_desc.get(),
args[1].get(),
args[1].implicit(),
&beta,
y_desc.get(),
args[2].get());
args[2].implicit());
return args[2];
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment