"vscode:/vscode.git/clone" did not exist on "6e051e01e9a9e55a5ce36eb0c740396630c0bbd7"
Commit 4e028da0 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

more changes on the gemm for cppcheck fix.

parent 908ed025
......@@ -13,8 +13,6 @@ struct context;
struct miopen_quant_gemm
{
op::quant_dot op;
argument arg_a;
argument arg_b;
template <class Self, class F>
static auto reflect(Self& self, F f)
......
......@@ -180,14 +180,14 @@ struct miopen_apply
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
auto inputs = ins->inputs();
auto in_shapes = to_shapes(inputs);
auto arg_a = allocate_gpu(in_shapes[0]);
auto arg_b = allocate_gpu(in_shapes[1]);
auto quant_dot = miopen_quant_gemm{op, arg_a, arg_b};
auto pack_a = insert_allocation(ins, in_shapes[0], "pack_a");
auto pack_b = insert_allocation(ins, in_shapes[1], "pack_b");
auto output = insert_allocation(ins, ins->get_shape());
inputs.push_back(pack_a);
inputs.push_back(pack_b);
inputs.push_back(output);
return prog->replace_instruction(ins, quant_dot, inputs);
return prog->replace_instruction(ins, miopen_quant_gemm{op}, inputs);
});
}
......
......@@ -54,11 +54,11 @@ rb_type<T>* to_rocblas_type(T* x)
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std::vector<shape> input_shapes(inputs);
input_shapes.pop_back();
check_shapes{input_shapes}.not_broadcasted();
std::vector<shape> in_shapes(inputs);
in_shapes.erase(in_shapes.begin() + in_shapes.size() - 3, in_shapes.end());
check_shapes{in_shapes}.not_broadcasted();
return op.compute_shape(input_shapes);
return op.compute_shape(in_shapes);
}
argument miopen_quant_gemm::compute(context& ctx,
......@@ -70,23 +70,24 @@ argument miopen_quant_gemm::compute(context& ctx,
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
auto arg_num = args.size();
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldc = args[arg_num - 1].get_shape().strides()[dim_0];
if(!transb)
{
device::pack_a(ctx.get_stream().get(), arg_b, args[1]);
device::pack_a(ctx.get_stream().get(), args[arg_num - 2], args[1]);
}
// need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API
if(transa)
{
device::pack_b(ctx.get_stream().get(), arg_a, args[0]);
device::pack_b(ctx.get_stream().get(), args[arg_num - 3], args[0]);
}
bool is_3inputs = (args.size() == 4);
bool is_3inputs = (arg_num == 6);
int32_t beta = 0;
if(is_3inputs)
{
......@@ -120,17 +121,17 @@ argument miopen_quant_gemm::compute(context& ctx,
m,
k,
&alpha_r,
(!transb) ? to_pointer(arg_b) : to_pointer(args.at(1)),
(!transb) ? to_pointer(args[arg_num - 2]) : to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
transa ? to_pointer(arg_a) : to_pointer(args.at(0)),
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
&beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
is_3inputs ? to_pointer(args.at(3)) : to_pointer(args[2]),
to_pointer(args[arg_num - 1]),
rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
......@@ -150,11 +151,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m,
k,
&alpha_r,
(!transb) ? to_pointer(arg_b) : to_pointer(args.at(1)),
(!transb) ? to_pointer(args[arg_num - 2]) : to_pointer(args.at(1)),
rocblas_datatype_i8_r,
ldb,
k * n,
transa ? to_pointer(arg_a) : to_pointer(args.at(0)),
transa ? to_pointer(args[arg_num - 3]) : to_pointer(args.at(0)),
rocblas_datatype_i8_r,
lda,
m * k,
......@@ -163,7 +164,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r,
ldc,
m * n,
is_3inputs ? to_pointer(args.at(3)) : to_pointer(args[2]),
to_pointer(args[arg_num - 1]),
rocblas_datatype_i32_r,
ldc,
m * n,
......@@ -177,7 +178,7 @@ argument miopen_quant_gemm::compute(context& ctx,
}
});
return is_3inputs ? args.at(3) : args[2];
return args[arg_num - 1];
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment