Commit 9c172a77 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

put packing args of int8 input to a pass

parent b679e200
......@@ -14,6 +14,7 @@ struct pack_int8_args
{
std::string name() const { return "gpu::pack_int8_args"; }
void apply(program& p) const;
shape pack_int8_shape(const shape& s) const;
};
} // namespace gpu
......
......@@ -36,7 +36,8 @@ struct miopen_quant_convolution
return shapes.size() - 1;
}
shape pack_int8_shape(const shape& s);
private:
shape pack_int8_shape(const shape& s) const;
};
} // namespace gpu
......
......@@ -188,19 +188,11 @@ struct miopen_apply
auto ws = conv.compile(ctx, ins->get_shape(), to_shapes(ins->inputs()));
auto args = ins->inputs();
auto arg_x_vec4 = insert_allocation(ins, conv.pack_int8_shape(args[0]->get_shape()));
auto arg_x_packed =
prog->insert_instruction(ins, miopen_int8_conv_pack{}, {args[0], arg_x_vec4});
auto arg_y_vec4 = insert_allocation(ins, conv.pack_int8_shape(args[1]->get_shape()));
auto arg_y_packed =
prog->insert_instruction(ins, miopen_int8_conv_pack{}, {args[1], arg_y_vec4});
auto workspace = insert_allocation(ins, ws, "workspace");
auto output = insert_allocation(ins, ins->get_shape());
return prog->replace_instruction(
ins, conv, arg_x_packed, arg_y_packed, workspace, output);
ins, conv, args[0], args[1], workspace, output);
});
}
......
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
......@@ -37,10 +38,34 @@ void pack_int8_args::apply(program& p) const
}
else if(ins->name() == "gpu::quant_convolution")
{
auto inputs = ins->inputs();
auto packed_x = p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[0]->get_shape())});
auto output_x = p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[0], packed_x});
instruction::replace_argument(ins, inputs[0], output_x);
auto packed_w = p.insert_instruction(ins, hip_allocate{pack_int8_shape(inputs[1]->get_shape())});
auto output_w = p.insert_instruction(ins, miopen_int8_conv_pack{}, {inputs[1], packed_w});
instruction::replace_argument(ins, inputs[1], output_w);
}
}
}
shape pack_int8_args::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
MIGRAPHX_THROW("PACK_INT8_ARGS: only process int8_type");
}
auto lens = s.lens();
auto strides = s.strides();
lens[1] = (lens[1] + 3) / 4 * 4;
strides[0] = strides[1] * lens[1];
return {s.type(), lens, strides};
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -105,7 +105,7 @@ void miopen_quant_convolution::finalize(context& ctx,
MIGRAPHX_THROW("Workspace has changed during finalization.");
}
shape miopen_quant_convolution::pack_int8_shape(const shape& s)
shape miopen_quant_convolution::pack_int8_shape(const shape& s) const
{
if(s.type() != shape::int8_type)
{
......
......@@ -54,7 +54,7 @@ rb_type<T>* to_rocblas_type(T* x)
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> in_shapes(inputs);
in_shapes.erase(in_shapes.begin() + in_shapes.size() - 3, in_shapes.end());
in_shapes.pop_back();
check_shapes{in_shapes}.not_broadcasted();
return op.compute_shape(in_shapes);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment