Commit fdcb8d92 authored by Paul's avatar Paul
Browse files

Add format flag to convolution

parent 5e5ed37a
......@@ -14,6 +14,7 @@ struct context;
struct miopen_quant_convolution
{
op::quant_convolution op;
bool int8_x4_format = false;
shared<convolution_descriptor> cd;
miopenConvFwdAlgorithm_t algo{};
miopenHandle_t handle = nullptr;
......@@ -22,7 +23,9 @@ struct miopen_quant_convolution
static auto reflect(Self& self, F f)
{
// TODO: Add algo
return op::quant_convolution::reflect(self.op, f);
// return op::quant_convolution::reflect(self.op, f);
return pack(f(self.op, "op"),
f(self.int8_x4_format, "int8_x4_format"));
}
std::string name() const { return "gpu::quant_convolution"; }
......
......@@ -365,7 +365,7 @@ struct miopen_apply
{
apply_map.emplace("quant_convolution", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_convolution>(ins->get_operator());
auto conv = miopen_quant_convolution{op, make_conv(op)};
auto conv = miopen_quant_convolution{op, int8_x4_format, make_conv(op)};
auto ws = conv.compile(get_context(), ins->get_shape(), to_shapes(ins->inputs()));
auto args = ins->inputs();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment