Commit 4e028da0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more changes on the gemm for cppcheck fix.

parent 908ed025
...@@ -13,8 +13,6 @@ struct context; ...@@ -13,8 +13,6 @@ struct context;
struct miopen_quant_gemm struct miopen_quant_gemm
{ {
op::quant_dot op; op::quant_dot op;
argument arg_a;
argument arg_b;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
......
...@@ -180,14 +180,14 @@ struct miopen_apply ...@@ -180,14 +180,14 @@ struct miopen_apply
auto&& op = any_cast<op::quant_dot>(ins->get_operator()); auto&& op = any_cast<op::quant_dot>(ins->get_operator());
auto inputs = ins->inputs(); auto inputs = ins->inputs();
auto in_shapes = to_shapes(inputs); auto in_shapes = to_shapes(inputs);
auto arg_a = allocate_gpu(in_shapes[0]); auto pack_a = insert_allocation(ins, in_shapes[0], "pack_a");
auto arg_b = allocate_gpu(in_shapes[1]); auto pack_b = insert_allocation(ins, in_shapes[1], "pack_b");
auto quant_dot = miopen_quant_gemm{op, arg_a, arg_b};
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
inputs.push_back(pack_a);
inputs.push_back(pack_b);
inputs.push_back(output); inputs.push_back(output);
return prog->replace_instruction(ins, quant_dot, inputs); return prog->replace_instruction(ins, miopen_quant_gemm{op}, inputs);
}); });
} }
......
...@@ -54,11 +54,11 @@ rb_type<T>* to_rocblas_type(T* x) ...@@ -54,11 +54,11 @@ rb_type<T>* to_rocblas_type(T* x)
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> input_shapes(inputs); std::vector<shape> in_shapes(inputs);
input_shapes.pop_back(); in_shapes.erase(in_shapes.begin() + in_shapes.size() - 3, in_shapes.end());
check_shapes{input_shapes}.not_broadcasted(); check_shapes{in_shapes}.not_broadcasted();
return op.compute_shape(input_shapes); return op.compute_shape(in_shapes);
} }
argument miopen_quant_gemm::compute(context& ctx, argument miopen_quant_gemm::compute(context& ctx,
...@@ -70,23 +70,24 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -70,23 +70,24 @@ argument miopen_quant_gemm::compute(context& ctx,
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1; auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
auto arg_num = args.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0]; rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0]; rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0]; rocblas_int ldc = args[arg_num - 1].get_shape().strides()[dim_0];
if(!transb) if(!transb)
{ {
device::pack_a(ctx.get_stream().get(), arg_b, args[1]); device::pack_a(ctx.get_stream().get(), args[arg_num - 2], args[1]);
} }
// need to pack A in this scenario, use the algorithm to pack B in the // need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API // comment of the API
if(transa) if(transa)
{ {
device::pack_b(ctx.get_stream().get(), arg_a, args[0]); device::pack_b(ctx.get_stream().get(), args[arg_num - 3], args[0]);
} }
bool is_3inputs = (args.size() == 4); bool is_3inputs = (arg_num == 6);
int32_t beta = 0; int32_t beta = 0;
if(is_3inputs) if(is_3inputs)
{ {
...@@ -120,17 +121,17 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -120,17 +121,17 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
(!transb) ? to_pointer(arg_b) : to_pointer(args.at(1)), (!transb) ? to_pointer(args[arg_num - 2]) : to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
transa ? to_pointer(arg_a) : to_pointer(args.at(0)), transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
is_3inputs ? to_pointer(args.at(3)) : to_pointer(args[2]), to_pointer(args[arg_num - 1]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
...@@ -150,11 +151,11 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -150,11 +151,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m, m,
k, k,
&alpha_r, &alpha_r,
(!transb) ? to_pointer(arg_b) : to_pointer(args.at(1)), (!transb) ? to_pointer(args[arg_num - 2]) : to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
k * n, k * n,
transa ? to_pointer(arg_a) : to_pointer(args.at(0)), transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
m * k, m * k,
...@@ -163,7 +164,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -163,7 +164,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
is_3inputs ? to_pointer(args.at(3)) : to_pointer(args[2]), to_pointer(args[arg_num - 1]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
...@@ -177,7 +178,7 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -177,7 +178,7 @@ argument miopen_quant_gemm::compute(context& ctx,
} }
}); });
return is_3inputs ? args.at(3) : args[2]; return args[arg_num - 1];
} }
} // namespace gpu } // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment