Commit dc21cca1 authored by Paul's avatar Paul
Browse files

Formatting

parent f20f1990
...@@ -15,14 +15,13 @@ using convolution_descriptor = RTG_MANAGE_PTR(miopenConvolutionDescriptor_t, ...@@ -15,14 +15,13 @@ using convolution_descriptor = RTG_MANAGE_PTR(miopenConvolutionDescriptor_t,
using activation_descriptor = RTG_MANAGE_PTR(miopenActivationDescriptor_t, using activation_descriptor = RTG_MANAGE_PTR(miopenActivationDescriptor_t,
miopenDestroyActivationDescriptor); miopenDestroyActivationDescriptor);
template <class Result, class F, class... Ts>
template<class Result, class F, class... Ts>
Result make_obj(F f, Ts... xs) Result make_obj(F f, Ts... xs)
{ {
typename Result::pointer x = nullptr; typename Result::pointer x = nullptr;
auto status = f(&x, xs...); auto status = f(&x, xs...);
Result r{x}; Result r{x};
if (status != miopenStatusSuccess) if(status != miopenStatusSuccess)
RTG_THROW("MIOpen call failed"); RTG_THROW("MIOpen call failed");
return r; return r;
} }
...@@ -34,8 +33,10 @@ tensor_descriptor make_tensor(const rtg::shape& s) ...@@ -34,8 +33,10 @@ tensor_descriptor make_tensor(const rtg::shape& s)
std::vector<int> lens(s.lens().begin(), s.lens().end()); std::vector<int> lens(s.lens().begin(), s.lens().end());
std::vector<int> strides(s.strides().begin(), s.strides().end()); std::vector<int> strides(s.strides().begin(), s.strides().end());
miopenDataType_t d; miopenDataType_t d;
if(s.type() == shape::float_type) d = miopenFloat; if(s.type() == shape::float_type)
else RTG_THROW("Unsupported type"); d = miopenFloat;
else
RTG_THROW("Unsupported type");
miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data()); miopenSetTensorDescriptor(t.get(), d, s.lens().size(), lens.data(), strides.data());
return t; return t;
} }
...@@ -43,7 +44,14 @@ tensor_descriptor make_tensor(const rtg::shape& s) ...@@ -43,7 +44,14 @@ tensor_descriptor make_tensor(const rtg::shape& s)
convolution_descriptor make_conv(const rtg::convolution& op) convolution_descriptor make_conv(const rtg::convolution& op)
{ {
auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor); auto c = make_obj<convolution_descriptor>(&miopenCreateConvolutionDescriptor);
miopenInitConvolutionDescriptor(c.get(), miopenConvolution, op.padding[0], op.padding[1], op.stride[0], op.stride[1], op.dilation[0], op.dilation[1]); miopenInitConvolutionDescriptor(c.get(),
miopenConvolution,
op.padding[0],
op.padding[1],
op.stride[0],
op.stride[1],
op.dilation[0],
op.dilation[1]);
return c; return c;
} }
...@@ -66,8 +74,33 @@ struct miopen_convolution ...@@ -66,8 +74,33 @@ struct miopen_convolution
int algo_count; int algo_count;
miopenConvAlgoPerf_t perf; miopenConvAlgoPerf_t perf;
miopenFindConvolutionForwardAlgorithm(args[0].data(), x_desc.get(), args[1].data(), w_desc, args[2].data(), cd.get(), y_desc, args[4].data(), 1, &algo_count, &perf, args[3].data(), args[3].get_shape().bytes(), false); miopenFindConvolutionForwardAlgorithm(args[0].data(),
miopenConvolutionForward(args[0].data(), &alpha, x_desc, args[1].data(), w_desc, args[2].data(), cd.get(), perf.fwd_algo, &beta, y_desc, args[4].data(), args[3].data(), args[3].get_shape().bytes()); x_desc.get(),
args[1].data(),
w_desc,
args[2].data(),
cd.get(),
y_desc,
args[4].data(),
1,
&algo_count,
&perf,
args[3].data(),
args[3].get_shape().bytes(),
false);
miopenConvolutionForward(args[0].data(),
&alpha,
x_desc,
args[1].data(),
w_desc,
args[2].data(),
cd.get(),
perf.fwd_algo,
&beta,
y_desc,
args[4].data(),
args[3].data(),
args[3].get_shape().bytes());
return result; return result;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment