Commit 69145ea1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

refine the implementation.

parent 32addf31
...@@ -13,6 +13,8 @@ struct context; ...@@ -13,6 +13,8 @@ struct context;
struct miopen_quant_gemm struct miopen_quant_gemm
{ {
op::quant_dot op; op::quant_dot op;
mutable argument arg_a{};
mutable argument arg_b{};
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
......
...@@ -98,6 +98,7 @@ struct miopen_apply ...@@ -98,6 +98,7 @@ struct miopen_apply
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("quant_dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax"); add_extend_op<miopen_softmax, op::softmax>("softmax");
...@@ -109,7 +110,6 @@ struct miopen_apply ...@@ -109,7 +110,6 @@ struct miopen_apply
add_lrn_op(); add_lrn_op();
add_convolution_op(); add_convolution_op();
add_quant_convolution_op(); add_quant_convolution_op();
add_quant_gemm_op();
add_pooling_op(); add_pooling_op();
add_batch_norm_inference_op(); add_batch_norm_inference_op();
} }
...@@ -172,46 +172,6 @@ struct miopen_apply ...@@ -172,46 +172,6 @@ struct miopen_apply
}); });
} }
void add_quant_gemm_op()
{
apply_map.emplace("quant_dot", [=](instruction_ref ins) {
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
std::vector<instruction_ref> refs = ins->inputs();
// add additional arguments if need packing. Since lowering is added
// after auto_contiguous and before eliminate contiguous, the shapes
// of all inputs are standard, so the input shape cannot be transposed.
// To avoid that, we need to check whether this argument is an output
// of contiguous. If true, we should check the shape of the input
// of the contiguous operator.
auto prev_ins = refs.at(0);
if(prev_ins->name() == "gpu::contiguous")
{
auto input = prev_ins->inputs().front();
if(input->get_shape().transposed())
{
auto pack_a = insert_allocation(input, input->get_shape());
// replace one of the inputs of quant_gemm from the output to the
// input of contiguous. Then the contiguous could become dead code
// of prev_ins is its only output
refs.at(0) = input;
instruction::replace_argument(ins, prev_ins, input);
refs.push_back(pack_a);
}
}
if(!refs.at(1)->get_shape().transposed())
{
auto pack_b = insert_allocation(refs.at(1), refs.at(1)->get_shape());
refs.push_back(pack_b);
}
auto output = insert_allocation(ins, ins->get_shape());
refs.push_back(output);
return prog->replace_instruction(ins, miopen_quant_gemm{op}, refs);
});
}
void add_pooling_op() void add_pooling_op()
{ {
apply_map.emplace("pooling", [=](instruction_ref ins) { apply_map.emplace("pooling", [=](instruction_ref ins) {
......
...@@ -56,16 +56,6 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const ...@@ -56,16 +56,6 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> input_shapes(inputs); std::vector<shape> input_shapes(inputs);
input_shapes.pop_back(); input_shapes.pop_back();
if(!inputs.at(1).transposed())
{
input_shapes.pop_back();
}
if(inputs.at(0).transposed())
{
input_shapes.pop_back();
}
check_shapes{input_shapes}.not_broadcasted(); check_shapes{input_shapes}.not_broadcasted();
return op.compute_shape(input_shapes); return op.compute_shape(input_shapes);
} }
...@@ -74,8 +64,6 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -74,8 +64,6 @@ argument miopen_quant_gemm::compute(context& ctx,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& args) const const std::vector<argument>& args) const
{ {
// handling the packing of B MUST be before handling that for A
auto arg_res = args.back();
bool transa = args[0].get_shape().transposed(); bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed(); bool transb = args[1].get_shape().transposed();
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
...@@ -83,28 +71,29 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -83,28 +71,29 @@ argument miopen_quant_gemm::compute(context& ctx,
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = arg_res.get_shape().strides()[dim_0]; rocblas_int ldc = args[2].get_shape().strides()[dim_0];
auto arg_b = args.at(1);
std::size_t pack_arg_num = 0;
if(!transb) if(!transb)
{ {
arg_b = args.at(args.size() - 2); if (arg_b.empty())
++pack_arg_num; {
arg_b = allocate_gpu(args[1].get_shape());
}
device::pack_a(ctx.get_stream().get(), arg_b, args[1]); device::pack_a(ctx.get_stream().get(), arg_b, args[1]);
} }
// need to pack A in this scenario, use the algorithm to pack B in the // need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API // comment of the API
auto arg_a = args.at(0);
if(transa) if(transa)
{ {
arg_a = args.at(args.size() - 2 - pack_arg_num); if (arg_a.empty())
++pack_arg_num; {
arg_a = allocate_gpu(args.at(0).get_shape());
}
device::pack_b(ctx.get_stream().get(), arg_a, args[0]); device::pack_b(ctx.get_stream().get(), arg_a, args[0]);
} }
bool is_3inputs = (args.size() - pack_arg_num == 4); bool is_3inputs = (args.size() == 4);
int8_t beta = 0; int8_t beta = 0;
if(is_3inputs) if(is_3inputs)
{ {
...@@ -138,17 +127,17 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -138,17 +127,17 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(arg_b), (!transb) ? to_pointer(arg_b) : to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
to_pointer(arg_a), transa ? to_pointer(arg_a) : to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
to_pointer(arg_res), is_3inputs ? to_pointer(args.at(3)) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
...@@ -168,11 +157,11 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -168,11 +157,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(arg_b), (!transb) ? to_pointer(arg_b) : to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
k * n, k * n,
to_pointer(arg_a), transa ? to_pointer(arg_a) : to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
m * k, m * k,
...@@ -181,7 +170,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -181,7 +170,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
to_pointer(arg_res), is_3inputs ? to_pointer(args.at(3)) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
...@@ -195,7 +184,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -195,7 +184,7 @@ argument miopen_quant_gemm::compute(context& ctx,
} }
}); });
return arg_res; return is_3inputs ? args.at(3) : args[2];
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment