Commit 59e4f91c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add gpu implementation for quant_gemm

parent 56b65698
...@@ -46,6 +46,7 @@ add_library(migraphx_gpu ...@@ -46,6 +46,7 @@ add_library(migraphx_gpu
target.cpp target.cpp
lowering.cpp lowering.cpp
gemm.cpp gemm.cpp
quant_gemm.cpp
pooling.cpp pooling.cpp
convolution.cpp convolution.cpp
softmax.cpp softmax.cpp
......
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#include <migraphx/gpu/batchnorm.hpp> #include <migraphx/gpu/batchnorm.hpp>
#include <migraphx/gpu/pooling.hpp> #include <migraphx/gpu/pooling.hpp>
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/concat.hpp> #include <migraphx/gpu/concat.hpp>
#include <migraphx/gpu/pad.hpp> #include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/gather.hpp> #include <migraphx/gpu/gather.hpp>
...@@ -95,6 +96,7 @@ struct miopen_apply ...@@ -95,6 +96,7 @@ struct miopen_apply
add_generic_op<hip_min>("min"); add_generic_op<hip_min>("min");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<miopen_quant_gemm, op::quant_dot>("dot");
add_extend_op<miopen_contiguous, op::contiguous>("contiguous"); add_extend_op<miopen_contiguous, op::contiguous>("contiguous");
add_extend_op<hip_concat, op::concat>("concat"); add_extend_op<hip_concat, op::concat>("concat");
add_extend_op<miopen_softmax, op::softmax>("softmax"); add_extend_op<miopen_softmax, op::softmax>("softmax");
......
#include <migraphx/gpu/quant_gemm.hpp> #include <migraphx/gpu/quant_gemm.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
namespace migraphx { namespace migraphx {
...@@ -6,129 +7,15 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -6,129 +7,15 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts> template <class... Ts>
rocblas_status generic_rocblas_scal(shape::as<float>, Ts&&... xs) rocblas_status generic_rocblas_gemm_ex(Ts&&... xs)
{ {
return rocblas_sscal(std::forward<Ts>(xs)...); return rocblas_gemm_ex(std::forward<Ts>(xs)...);
} }
template <class... Ts> template <class... Ts>
rocblas_status generic_rocblas_scal(shape::as<double>, Ts&&... xs) rocblas_status generic_rocblas_batched_gemm_ex(Ts&&... xs)
{ {
return rocblas_dscal(std::forward<Ts>(xs)...); return rocblas_gemm_strided_batched_ex(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_scal(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_SCAL: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<half>, Ts&&... xs)
{
return rocblas_haxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<float>, Ts&&... xs)
{
return rocblas_saxpy(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<double>, Ts&&... xs)
{
return rocblas_daxpy(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_axpy(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_AXPY: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<float>, Ts&&... xs)
{
return rocblas_sdot(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_dot(shape::as<double>, Ts&&... xs)
{
return rocblas_ddot(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_dot(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_DOT: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemv(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemv(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_gemv(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMMV: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<half>, Ts&&... xs)
{
return rocblas_hgemm_strided_batched(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_batched_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_BATCHED_GEMM: type unsupported by rocblas");
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<float>, Ts&&... xs)
{
return rocblas_sgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<double>, Ts&&... xs)
{
return rocblas_dgemm(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<half>, Ts&&... xs)
{
return rocblas_hgemm(std::forward<Ts>(xs)...);
}
template <class T, class... Ts>
rocblas_status generic_rocblas_gemm(shape::as<T>, Ts&&...)
{
MIGRAPHX_THROW("GENERIC_ROCBLAS_GEMM: type unsupported by rocblas");
} }
template <class T> template <class T>
...@@ -164,8 +51,6 @@ rb_type<T>* to_rocblas_type(T* x) ...@@ -164,8 +51,6 @@ rb_type<T>* to_rocblas_type(T* x)
return reinterpret_cast<rb_type<T>*>(x); return reinterpret_cast<rb_type<T>*>(x);
} }
rocblas_half to_rocblas_type(half x) { return reinterpret_cast<const rocblas_half&>(x); }
shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1); std::vector<shape> input_shapes(inputs.begin(), inputs.begin() + inputs.size() - 1);
...@@ -181,14 +66,6 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -181,14 +66,6 @@ argument miopen_quant_gemm::compute(context& ctx,
float beta = 0.0f; float beta = 0.0f;
if(is_3inputs) if(is_3inputs)
{ {
output_shape.visit_type([&](auto as) {
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
hipMemcpyAsync(to_pointer(args[3]),
to_pointer(args[2]),
output_shape.bytes(),
hipMemcpyDeviceToDevice,
ctx.get_stream().get());
});
beta = op.beta; beta = op.beta;
} }
...@@ -209,13 +86,16 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -209,13 +86,16 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
assert(k % 4 == 0);
assert(transa && (lda % 4 == 0));
assert(!transb && (ldb % 4 == 0));
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); };
if(num_matrices == 1) if(num_matrices == 1)
{ {
generic_rocblas_gemm(as, generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
...@@ -223,17 +103,25 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -223,17 +103,25 @@ argument miopen_quant_gemm::compute(context& ctx,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1]), to_pointer(args[1]),
rocblas_datatype_i8_r,
ldb, ldb,
to_pointer(args[0]), to_pointer(args[0]),
rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])), (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
ldc); rocblas_datatype_i32_r,
ldc,
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0, 0, nullptr, nullptr);
} }
else else
{ {
generic_rocblas_batched_gemm( generic_rocblas_batched_gemm_ex(
as,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -242,16 +130,26 @@ argument miopen_quant_gemm::compute(context& ctx, ...@@ -242,16 +130,26 @@ argument miopen_quant_gemm::compute(context& ctx,
k, k,
&alpha_r, &alpha_r,
to_pointer(args[1]), to_pointer(args[1]),
rocblas_datatype_i8_r,
ldb, ldb,
k * n, k * n,
to_pointer(args[0]), to_pointer(args[0]),
rocblas_datatype_i8_r,
lda, lda,
m * k, m * k,
&beta_r, &beta_r,
to_pointer(args[2]),
rocblas_datatype_i32_r,
ldc,
m * n,
(is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])), (is_3inputs ? to_pointer(args[3]) : to_pointer(args[2])),
rocblas_datatype_i32_r,
ldc, ldc,
m * n, m * n,
num_matrices); num_matrices,
rocblas_datatype_i32_r,
rocblas_gemm_algo_standard,
0, 0, nullptr, nullptr);
} }
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment