Commit 2c779d0b authored by Shucai Xiao's avatar Shucai Xiao
Browse files

simplify the gemm call for int8

parent 8f9a766f
...@@ -8,51 +8,6 @@ namespace migraphx { ...@@ -8,51 +8,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts>
rocblas_status generic_rocblas_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_ex(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_strided_batched_ex(std::forward<Ts>(xs)...);
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
...@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = as(op.alpha);
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = as(beta);
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
assert(k % 4 == 0); assert(k % 4 == 0);
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
...@@ -119,7 +74,7 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -119,7 +74,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm. // A and args[0] as B in calling the rocblas_gemm.
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(), rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
...@@ -148,7 +103,7 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -148,7 +103,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
} }
else else
{ {
generic_rocblas_batched_gemm_ex( rocblas_gemm_strided_batched_ex(
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment