Commit fdcb8d92 authored by Paul's avatar Paul
Browse files

Add format flag to convolution

parent 5e5ed37a
...@@ -14,6 +14,7 @@ struct context; ...@@ -14,6 +14,7 @@ struct context;
struct miopen_quant_convolution struct miopen_quant_convolution
{ {
op::quant_convolution op; op::quant_convolution op;
bool int8_x4_format = false;
shared<convolution_descriptor> cd; shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{}; miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr; miopenHandle_t handle = nullptr;
...@@ -22,7 +23,9 @@ struct miopen_quant_convolution ...@@ -22,7 +23,9 @@ struct miopen_quant_convolution
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
{ {
// TODO: Add algo // TODO: Add algo
return op::quant_convolution::reflect(self.op, f); // return op::quant_convolution::reflect(self.op, f);
return pack(f(self.op, "op"),
f(self.int8_x4_format, "int8_x4_format"));
} }
std::string name() const { return "gpu::quant_convolution"; } std::string name() const { return "gpu::quant_convolution"; }
......
...@@ -365,7 +365,7 @@ struct miopen_apply ...@@ -365,7 +365,7 @@ struct miopen_apply
{ {
apply_map.emplace("quant_convolution", [=](instruction_ref ins) { apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator()); auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)}; auto conv = miopen_quant_convolution{op, int8_x4_format, make_conv(op)};
auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs())); auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto args = ins->inputs(); auto args = ins->inputs();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment