Commit aa95521b authored by Paul's avatar Paul
Browse files

Formatting

parent 5f37fd2d
...@@ -12,8 +12,8 @@ using miopen_handle = RTG_MANAGE_PTR(miopenHandle_t, miopenDestroy); ...@@ -12,8 +12,8 @@ using miopen_handle = RTG_MANAGE_PTR(miopenHandle_t, miopenDestroy);
using tensor_descriptor = RTG_MANAGE_PTR(miopenTensorDescriptor_t, miopenDestroyTensorDescriptor); using tensor_descriptor = RTG_MANAGE_PTR(miopenTensorDescriptor_t, miopenDestroyTensorDescriptor);
using convolution_descriptor = RTG_MANAGE_PTR(miopenConvolutionDescriptor_t, using convolution_descriptor = RTG_MANAGE_PTR(miopenConvolutionDescriptor_t,
miopenDestroyConvolutionDescriptor); miopenDestroyConvolutionDescriptor);
using pooling_descriptor = RTG_MANAGE_PTR(miopenPoolingDescriptor_t, using pooling_descriptor = RTG_MANAGE_PTR(miopenPoolingDescriptor_t,
miopenDestroyPoolingDescriptor); miopenDestroyPoolingDescriptor);
using activation_descriptor = RTG_MANAGE_PTR(miopenActivationDescriptor_t, using activation_descriptor = RTG_MANAGE_PTR(miopenActivationDescriptor_t,
miopenDestroyActivationDescriptor); miopenDestroyActivationDescriptor);
...@@ -60,17 +60,19 @@ inline convolution_descriptor make_conv(const rtg::convolution& op) ...@@ -60,17 +60,19 @@ inline convolution_descriptor make_conv(const rtg::convolution& op)
inline pooling_descriptor make_pooling(const rtg::pooling& op) inline pooling_descriptor make_pooling(const rtg::pooling& op)
{ {
miopenPoolingMode_t mode; miopenPoolingMode_t mode;
if(op.mode == "max") mode = miopenPoolingMax; if(op.mode == "max")
else mode = miopenPoolingAverage; mode = miopenPoolingMax;
else
mode = miopenPoolingAverage;
auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor); auto p = make_obj<pooling_descriptor>(&miopenCreatePoolingDescriptor);
miopenSet2dPoolingDescriptor(p.get(), miopenSet2dPoolingDescriptor(p.get(),
mode, mode,
op.lengths[0], op.lengths[0],
op.lengths[1], op.lengths[1],
op.padding[0], op.padding[0],
op.padding[1], op.padding[1],
op.stride[0], op.stride[0],
op.stride[1]); op.stride[1]);
return p; return p;
} }
......
...@@ -78,16 +78,16 @@ struct miopen_pooling ...@@ -78,16 +78,16 @@ struct miopen_pooling
float alpha = 1, beta = 0; float alpha = 1, beta = 0;
miopenPoolingForward(args[0].implicit(), miopenPoolingForward(args[0].implicit(),
pd.get(), pd.get(),
&alpha, &alpha,
x_desc.get(), x_desc.get(),
args[1].implicit(), args[1].implicit(),
&beta, &beta,
y_desc.get(), y_desc.get(),
args[2].implicit(), args[2].implicit(),
false, false,
nullptr, nullptr,
0); 0);
return args[2]; return args[2];
} }
...@@ -180,11 +180,8 @@ struct miopen_apply ...@@ -180,11 +180,8 @@ struct miopen_apply
auto pd = make_pooling(op); auto pd = make_pooling(op);
auto output = insert_allocation(ins, ins->result); auto output = insert_allocation(ins, ins->result);
prog->replace_instruction(ins, prog->replace_instruction(
miopen_pooling{op, std::move(pd)}, ins, miopen_pooling{op, std::move(pd)}, handle, ins->arguments.at(0), output);
handle,
ins->arguments.at(0),
output);
} }
void apply_activation(instruction_ref ins) void apply_activation(instruction_ref ins)
......
...@@ -93,7 +93,8 @@ struct test_conv_pooling ...@@ -93,7 +93,8 @@ struct test_conv_pooling
} }
}; };
int main() { int main()
verify_program<test_conv_relu>(); {
verify_program<test_conv_pooling>(); verify_program<test_conv_relu>();
verify_program<test_conv_pooling>();
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment