Commit 8e824ed1 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

fix bugs in gpu quant_dot implementation

parent 433f854a
......@@ -61,11 +61,12 @@ struct hip_tensor_descriptor
{
std::copy(s.lens().begin(), s.lens().end(), lens);
std::copy(s.strides().begin(), s.strides().end(), strides);
indices.resize(s.strides().size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](size_t i1, size_t i2) {
return strides[i1] > strides[i2];
std::vector<std::size_t> vec_idx(s.lens().size());
std::iota(vec_idx.begin(), vec_idx.end(), 0);
std::sort(vec_idx.begin(), vec_idx.end(), [&](size_t i, size_t j) {
return strides[i] > strides[j];
});
std::copy(vec_idx.begin(), vec_idx.end(), indices);
}
__device__ __host__ hip_index<NDim> multi(size_t idx) const
......@@ -79,6 +80,7 @@ struct hip_tensor_descriptor
}
return result;
}
__device__ __host__ size_t linear(hip_index<NDim> s) const
{
size_t idx = 0;
......@@ -88,7 +90,7 @@ struct hip_tensor_descriptor
}
size_t lens[NDim] = {};
size_t strides[NDim] = {};
std::vector<size_t> indices{};
size_t indices[NDim] = {};
};
} // namespace device
......
......@@ -55,8 +55,8 @@ void pack_b(hipStream_t stream, const argument& result, const argument& arg)
gs_launch(stream, nelements)([=](auto ii) {
const size_t nb = 4;
auto idx = desc.multi(ii);
std::size_t i_n = idx[1];
std::size_t i_k = idx[0];
std::size_t i_n = idx[dim_0];
std::size_t i_k = idx[dim_1];
std::size_t offset = ii / m_size * m_size;
out_ptr[i_k % nb + (i_n + (i_k / nb) * ldb) * nb + offset] =
in_ptr[i_n + i_k * ldb + offset];
......
......@@ -178,11 +178,26 @@ struct miopen_apply
auto&& op = any_cast<op::quant_dot>(ins->get_operator());
std::vector<instruction_ref> refs = ins->inputs();
// add additional arguments if need packing
if(refs.at(0)->get_shape().transposed())
// add additional arguments if need packing. Since lowering is added
// after auto_contiguous and before eliminate contiguous, the shapes
// of all inputs are standard, so the input shape cannot be transposed.
// To avoid that, we need to check whether this argument is an output
// of contiguous. If true, we should check the shape of the input
// of the contiguous operator.
auto prev_ins = refs.at(0);
if (prev_ins->name() == "gpu::contiguous")
{
auto pack_a = insert_allocation(refs.at(0), refs.at(0)->get_shape());
refs.push_back(pack_a);
auto input = prev_ins->inputs().front();
if (input->get_shape().transposed())
{
auto pack_a = insert_allocation(input, input->get_shape());
// replace one of the inputs of quant_gemm from the output to the
// input of contiguous. Then the contiguous could become dead code
// of prev_ins is its only output
refs.at(0) = input;
instruction::replace_argument(ins, prev_ins, input);
refs.push_back(pack_a);
}
}
if(!refs.at(1)->get_shape().transposed())
......
......@@ -75,6 +75,7 @@ argument miopen_quant_gemm::compute(context& ctx,
const std::vector<argument>& args) const
{
// handling the packing of B MUST be before handling that for A
auto arg_res = args.back();
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
auto n_dim = output_shape.lens().size();
......@@ -82,7 +83,7 @@ argument miopen_quant_gemm::compute(context& ctx,
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldc = arg_res.get_shape().strides()[dim_0];
auto arg_b = args.at(1);
std::size_t pack_arg_num = 0;
......@@ -147,7 +148,7 @@ argument miopen_quant_gemm::compute(context& ctx,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
to_pointer(arg_res),
rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
......@@ -180,7 +181,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r,
ldc,
m * n,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
to_pointer(arg_res),
rocblas_datatype_i32_r,
ldc,
m * n,
......@@ -194,7 +195,7 @@ argument miopen_quant_gemm::compute(context& ctx,
}
});
return (is_3inputs ? args[3] : args[2]);
return arg_res;
}
} // namespace gpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment