Commit f02c2856 authored by umangyadav's avatar umangyadav
Browse files

use workspace size from get_solution()

parent ed7973d1
......@@ -119,7 +119,6 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Convolution: find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
......@@ -146,8 +145,10 @@ shape miopen_convolution::find(context& ctx, const shape& output_shape, std::vec
MIGRAPHX_THROW("MIOpen Convolution: get solution failed");
solution_id = solutions.front().solution_id;
algo = solutions.front().algorithm;
workspace_size = solutions.front().workspace_size;
return shape{shape::int8_type, {perf.memory}};
return shape{shape::int8_type, {workspace_size}};
}
void miopen_convolution::finalize(context& ctx,
......
......@@ -119,7 +119,6 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Deconvolution: find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
......@@ -146,8 +145,10 @@ shape miopen_deconvolution::find(context& ctx, const shape& output_shape, std::v
MIGRAPHX_THROW("MIOpen Deconvolution: get solution failed");
solution_id = solutions.front().solution_id;
algo = solutions.front().algorithm;
workspace_size = solutions.front().workspace_size;
return shape{shape::int8_type, {perf.memory}};
return shape{shape::int8_type, {workspace_size}};
}
void miopen_deconvolution::finalize(context& ctx,
......
......@@ -210,7 +210,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto conv = any_cast<miopen_convolution>(ins->get_operator());
if(conv.op.group > 1)
return false;
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionFwdAlgoWinograd)
if(wei.lens()[1] > 512 and conv.algo != miopenConvolutionAlgoWinograd)
return false;
// Do not fuse non-symmetric input
......@@ -220,7 +220,7 @@ MIGRAPHX_PRED_MATCHER(fusable_conv, instruction_ref ins)
auto op = conv.op;
// Dont fuse winograd for non-3x3s since there is no fused windograd for those configs
if(conv.algo == miopenConvolutionFwdAlgoWinograd and wei.lens()[2] != 3 and
if(conv.algo == miopenConvolutionAlgoWinograd and wei.lens()[2] != 3 and
wei.lens()[3] != 3 and contains({{1, 1}}, op.stride))
return false;
return contains({{0, 0, 0, 0}, {1, 1, 1, 1}, {2, 2, 2, 2}}, op.padding) and
......
......@@ -38,7 +38,7 @@ struct miopen_convolution
{
op::convolution op;
shared<convolution_descriptor> cd = nullptr;
miopenConvFwdAlgorithm_t algo{};
miopenConvAlgorithm_t algo{};
uint64_t solution_id = 0;
template <class Self, class F>
......
......@@ -38,7 +38,7 @@ struct miopen_deconvolution
{
op::deconvolution op;
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
miopenConvAlgorithm_t algo{};
uint64_t solution_id = 0;
template <class Self, class F>
......
......@@ -40,7 +40,7 @@ struct miopen_quant_convolution
op::quant_convolution op;
bool int8_x4_format = false;
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
miopenConvAlgorithm_t algo{};
uint64_t solution_id = 0;
template <class Self, class F>
......
......@@ -43,26 +43,23 @@ argument miopen_quant_convolution::compute(context& ctx,
auto w_desc = make_tensor(args[1].get_shape(), int8_x4_format);
auto y_desc = make_tensor(output_shape);
float alpha = 1;
float beta = 0;
if(solution_id == 0)
MIGRAPHX_THROW("MIOpen Convolution: invalid solution ID");
auto status = miopenConvolutionForward(ctx.get_stream().get_miopen(),
&alpha,
x_desc.get(),
args[0].implicit(),
auto status = miopenConvolutionForwardImmediate(ctx.get_stream().get_miopen(),
w_desc.get(),
args[1].implicit(),
x_desc.get(),
args[0].implicit(),
cd.get(),
algo,
&beta,
y_desc.get(),
args[3].implicit(),
args[2].implicit(),
args[2].get_shape().bytes());
args[2].get_shape().bytes(),
solution_id);
if(status != miopenStatusSuccess)
{
MIGRAPHX_THROW("QUANT_CONVOLUTION: run convolution forward failed");
}
return args[3];
}
......@@ -115,7 +112,6 @@ shape miopen_quant_convolution::find(context& ctx,
false);
if(status != miopenStatusSuccess)
MIGRAPHX_THROW("MIOpen Quant Convolution: find convolution failed");
algo = perf.fwd_algo;
size_t solution_count;
......@@ -143,7 +139,10 @@ shape miopen_quant_convolution::find(context& ctx,
solution_id = solutions.front().solution_id;
return shape{shape::int8_type, {perf.memory}};
algo = solutions.front().algorithm;
workspace_size = solutions.front().workspace_size;
return shape{shape::int8_type, {workspace_size}};
}
void miopen_quant_convolution::finalize(context& ctx,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment