Unverified Commit 7534546a authored by mvermeulen's avatar mvermeulen Committed by GitHub
Browse files

Merge pull request #344 from ROCmSoftwarePlatform/support_rocblas_api_change

Support rocblas api change
parents 8f9a766f 92d2b409
...@@ -8,51 +8,6 @@ namespace migraphx { ...@@ -8,51 +8,6 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
template <class... Ts>
rocblas_status generic_rocblas_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_ex(std::forward<Ts>(xs)...);
}
template <class... Ts>
rocblas_status generic_rocblas_batched_gemm_ex(Ts&&... xs)
{
return rocblas_gemm_strided_batched_ex(std::forward<Ts>(xs)...);
}
template <class T>
struct compute_rocblas_type
{
using type = T;
};
template <class T>
struct compute_rocblas_type<const T>
{
using type = const typename compute_rocblas_type<T>::type;
};
template <>
struct compute_rocblas_type<half>
{
using type = rocblas_half;
};
template <class T>
using rb_type = typename compute_rocblas_type<T>::type;
template <class T>
rb_type<T> to_rocblas_type(T x)
{
return reinterpret_cast<const rb_type<T>&>(x);
}
template <class T>
rb_type<T>* to_rocblas_type(T* x)
{
return reinterpret_cast<rb_type<T>*>(x);
}
shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const shape rocblas_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{ {
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
...@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) { output_shape.visit_type([&](auto as) {
auto alpha_r = to_rocblas_type(as(op.alpha)); auto alpha_r = as(op.alpha);
auto beta_r = to_rocblas_type(as(beta)); auto beta_r = as(beta);
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1]; rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return to_rocblas_type(as.from(arg.data())); }; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
assert(k % 4 == 0); assert(k % 4 == 0);
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
...@@ -119,36 +74,36 @@ argument rocblas_quant_gemm::compute(context& ctx, ...@@ -119,36 +74,36 @@ argument rocblas_quant_gemm::compute(context& ctx,
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm. // A and args[0] as B in calling the rocblas_gemm.
generic_rocblas_gemm_ex(ctx.get_stream().get_rocblas(), rocblas_gemm_ex(ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
n, n,
m, m,
k, k,
&alpha_r, &alpha_r,
to_pointer(args.at(1)), to_pointer(args.at(1)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
ldb, ldb,
to_pointer(args.at(0)), to_pointer(args.at(0)),
rocblas_datatype_i8_r, rocblas_datatype_i8_r,
lda, lda,
&beta_r, &beta_r,
to_pointer(args[2]), to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
ldc, ldc,
rocblas_datatype_i32_r, rocblas_datatype_i32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0, 0,
nullptr, nullptr,
nullptr); nullptr);
} }
else else
{ {
generic_rocblas_batched_gemm_ex( rocblas_gemm_strided_batched_ex(
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, transa ? rocblas_operation_transpose : rocblas_operation_none,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment