Commit f02c2856 authored by umangyadav's avatar umangyadav
Browse files

use workspace size from get_solution()

parent ed7973d1
...@@ -119,7 +119,6 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec ...@@ -119,7 +119,6 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec
false); false);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: find convolution failed"); MIGRAPHX_THROW("MIOpen Convolution: find convolution failed");
algo = perf.fwd_algo;
size_t solution_count; size_t solution_count;
...@@ -146,8 +145,10 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec ...@@ -146,8 +145,10 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec
MIGRAPHX_THROW("MIOpen Convolution: get solution failed"); MIGRAPHX_THROW("MIOpen Convolution: get solution failed");
solution_id = solutions.front().solution_id; solution_id = solutions.front().solution_id;
algo = solutions.front().algorithm;
workspace_size = solutions.front().workspace_size;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {workspace_size}};
} }
void miopen_convolution::finalize(context& ctx, void miopen_convolution::finalize(context& ctx,
......
...@@ -119,7 +119,6 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v ...@@ -119,7 +119,6 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v
false); false);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: find convolution failed"); MIGRAPHX_THROW("MIOpen Deconvolution: find convolution failed");
algo = perf.fwd_algo;
size_t solution_count; size_t solution_count;
...@@ -146,8 +145,10 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v ...@@ -146,8 +145,10 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v
MIGRAPHX_THROW("MIOpen Deconvolution: get solution failed"); MIGRAPHX_THROW("MIOpen Deconvolution: get solution failed");
solution_id = solutions.front().solution_id; solution_id = solutions.front().solution_id;
algo = solutions.front().algorithm;
workspace_size = solutions.front().workspace_size;
return shape{shape::int8_type, {perf.memory}}; return shape{shape::int8_type, {workspace_size}};
} }
void miopen_deconvolution::finalize(context& ctx, void miopen_deconvolution::finalize(context& ctx,
......
...@@ -210,7 +210,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -210,7 +210,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto conv = any_cast<miopen_convolution>(ins->get_operator()); auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1) if(conv.op.group > 1)
return false; return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd) if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionAlgoWinograd)
return false; return false;
// Do not fuse non-symmetric input // Do not fuse non-symmetric input
...@@ -220,7 +220,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins) ...@@ -220,7 +220,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto op = conv.op; auto op = conv.op;
// Dont fuse winograd for non-3x3s since there is no fused windograd for those configs // Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and if(conv.algo == miopenConvolutionAlgoWinograd and wei.lens()[2] != 3 and
wei.lens()[3] != 3 and contains({{1, 1}}, op.stride)) wei.lens()[3] != 3 and contains({{1, 1}}, op.stride))
return false; return false;
return contains({{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}, op.padding) and return contains({{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}, op.padding) and
......
...@@ -38,7 +38,7 @@ struct miopen_convolution ...@@ -38,7 +38,7 @@ struct miopen_convolution
{ {
op::convolution op; op::convolution op;
shared<convolution_descriptor> cd = nullptr; shared<convolution_descriptor> cd = nullptr;
miopenConvFwdAlgorithm_t algo{}; miopenConvAlgorithm_t algo{};
uint64_t solution_id = 0; uint64_t solution_id = 0;
template <class Self, class F> template <class Self, class F>
......
...@@ -38,7 +38,7 @@ struct miopen_deconvolution ...@@ -38,7 +38,7 @@ struct miopen_deconvolution
{ {
op::deconvolution op; op::deconvolution op;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvAlgorithm_t algo{};
uint64_t solution_id = 0; uint64_t solution_id = 0;
template <class Self, class F> template <class Self, class F>
......
...@@ -40,7 +40,7 @@ struct miopen_quant_convolution ...@@ -40,7 +40,7 @@ struct miopen_quant_convolution
op::quant_convolution op; op::quant_convolution op;
bool int8_x4_format = false; bool int8_x4_format = false;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvAlgorithm_t algo{};
uint64_t solution_id = 0; uint64_t solution_id = 0;
template <class Self, class F> template <class Self, class F>
......
...@@ -43,26 +43,23 @@ argument miopen_quant_convolution::compute(context& ctx, ...@@ -43,26 +43,23 @@ argument miopen_quant_convolution::compute(context& ctx,
auto w_desc = make_tensor(args[1].get_shape(), int8_x4_format); auto w_desc = make_tensor(args[1].get_shape(), int8_x4_format);
auto y_desc = make_tensor(output_shape); auto y_desc = make_tensor(output_shape);
float alpha = 1; if(solution_id == 0)
float beta = 0; MIGRAPHX_THROW("MIOpen Convolution: invalid solution ID");
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(), auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(),
&alpha, w_desc.get(),
x_desc.get(), args[1].implicit(),
args[0].implicit(), x_desc.get(),
w_desc.get(), args[0].implicit(),
args[1].implicit(), cd.get(),
cd.get(), y_desc.get(),
algo, args[3].implicit(),
&beta, args[2].implicit(),
y_desc.get(), args[2].get_shape().bytes(),
args[3].implicit(), solution_id);
args[2].implicit(),
args[2].get_shape().bytes());
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed"); MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
return args[3]; return args[3];
} }
...@@ -115,7 +112,6 @@ shape miopen_quant_convolution::find(context& ctx, ...@@ -115,7 +112,6 @@ shape miopen_quant_convolution::find(context& ctx,
false); false);
if(status != miopenStatusSuccess) if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed"); MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed");
algo = perf.fwd_algo;
size_t solution_count; size_t solution_count;
...@@ -143,7 +139,10 @@ shape miopen_quant_convolution::find(context& ctx, ...@@ -143,7 +139,10 @@ shape miopen_quant_convolution::find(context& ctx,
solution_id = solutions.front().solution_id; solution_id = solutions.front().solution_id;
return shape{shape::int8_type, {perf.memory}}; algo = solutions.front().algorithm;
workspace_size = solutions.front().workspace_size;
return shape{shape::int8_type, {workspace_size}};
} }
void miopen_quant_convolution::finalize(context& ctx, void miopen_quant_convolution::finalize(context& ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment