Commit d46b6972 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

create_gemm_args() work in progress. This version builds only if I comment...

create_gemm_args() work in progress.  This version builds only if I comment out all but 1 return from create_gemm_args().
parent 7c7e7ba8
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
...@@ -108,6 +108,17 @@ static rocblas_int get_batch_stride(const argument& a) ...@@ -108,6 +108,17 @@ static rocblas_int get_batch_stride(const argument& a)
return a.get_shape().strides()[a.get_shape().strides().size() - 3]; return a.get_shape().strides()[a.get_shape().strides().size() - 3];
} }
/**
* The rocblas API calls we may be interested in. Each one takes a slightly different
* argument list, generated by create_gemm_args().
*/
enum ROCBLAS_CALL
{
ROCBLAS_GEMM_EX,
ROCBLAS_GEMM_STRIDED_BATCHED_EX,
ROCBLAS_GEMM_EX_GET_SOLUTIONS,
};
template <class T> template <class T>
void gemm_impl(context& ctx, void gemm_impl(context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -117,69 +128,8 @@ void gemm_impl(context& ctx, ...@@ -117,69 +128,8 @@ void gemm_impl(context& ctx,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32) bool compute_fp32)
{ {
const bool is_3inputs = (args.size() == 4); output_shape.visit_type([&](auto as) { // TODO: not needed?
if(not is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
}
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0)) if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
...@@ -187,74 +137,24 @@ void gemm_impl(context& ctx, ...@@ -187,74 +137,24 @@ void gemm_impl(context& ctx,
// If the batch dimension of B is broadcasted, then we can // If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex // multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex. // instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
// the rocblas_gemm API handles inputs and output matrices as // the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm. // A and args[0] as B in calling the rocblas_gemm.
rocblas_invoke(&rocblas_gemm_ex, // auto to_invoke =
ctx.get_stream().get_rocblas(), create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_EX, output_shape, args,
transb ? rocblas_operation_transpose : rocblas_operation_none, alpha, beta, int8_x4_format, compute_fp32);
transa ? rocblas_operation_transpose : rocblas_operation_none, // rocblas_invoke(&rocblas_gemm_ex,
n, // to_invoke);
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
else else
{ {
auto a_stride = get_batch_stride(args[0]); // auto to_invoke =
auto b_stride = get_batch_stride(args[1]); create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_STRIDED_BATCHED_EX,
auto c_stride = get_batch_stride(args[2]); output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride; // rocblas_invoke(&rocblas_gemm_strided_batched_ex,
rocblas_invoke(&rocblas_gemm_strided_batched_ex, // to_invoke);
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(args.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
} }
}); });
} }
...@@ -268,7 +168,6 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig ...@@ -268,7 +168,6 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
return args; return args;
} }
// from perf.cpp // from perf.cpp
using milliseconds = std::chrono::duration<double, std::milli>; using milliseconds = std::chrono::duration<double, std::milli>;
std::pair<double, double> std::pair<double, double>
...@@ -321,20 +220,19 @@ void gemm(context& ctx, ...@@ -321,20 +220,19 @@ void gemm(context& ctx,
} }
/** /**
* Create a list of the arguments needed for rocBLAS GEMM calls, from * Create a list of the arguments needed for rocBLAS GEMM calls, from
* a set of MigraphX arguments. * a set of MigraphX arguments.
*/ */
template <class T> template <class T>
auto create_gemm_args auto create_gemm_args(context& ctx,
(context& ctx, ROCBLAS_CALL rocblas_call,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
T alpha, T alpha,
T beta, T beta,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32) bool compute_fp32)
{ {
const bool is_3inputs = (inputs.size() == 4); const bool is_3inputs = (inputs.size() == 4);
if(not is_3inputs) if(not is_3inputs)
...@@ -396,88 +294,102 @@ auto create_gemm_args ...@@ -396,88 +294,102 @@ auto create_gemm_args
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); }; auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(inputs[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format) if(inputs[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{ {
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!"); MIGRAPHX_THROW("create_gemm_args: k size of int8 type input must be multiple of 4!");
} }
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(inputs[1]) == 0))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
return pack (
// the rocblas_gemm API handles inputs and output matrices as switch(rocblas_call){
// column-major format. When doing a C = A * B, we actually do case ROCBLAS_GEMM_EX:
// C^T = (B^T) * (A^T). That is the reason we input inputs[1] as {
// A and inputs[0] as B in calling the rocblas_gemm. m *= num_matrices;
// rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(), return pack(
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none, // the rocblas_gemm API handles inputs and output matrices as
n, // column-major format. When doing a C = A * B, we actually do
m, // C^T = (B^T) * (A^T). That is the reason we input inputs[1] as
k, // A and inputs[0] as B in calling the rocblas_gemm.
alpha_v, // rocblas_invoke(&rocblas_gemm_ex,
to_pointer(inputs.at(1)), ctx.get_stream().get_rocblas(),
arg_type, transb ? rocblas_operation_transpose : rocblas_operation_none,
ldb, transa ? rocblas_operation_transpose : rocblas_operation_none,
to_pointer(inputs.at(0)), n,
arg_type, m,
lda, k,
beta_v, alpha_v,
to_pointer(inputs[2]), to_pointer(inputs.at(1)),
output_type, arg_type,
ldc, ldb,
is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]), to_pointer(inputs.at(0)),
output_type, arg_type,
ldd, lda,
compute_type, beta_v,
rocblas_gemm_algo_standard, to_pointer(inputs[2]),
0, output_type,
flag); ldc,
} is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
else output_type,
{ ldd,
auto a_stride = get_batch_stride(inputs[0]); compute_type,
auto b_stride = get_batch_stride(inputs[1]); rocblas_gemm_algo_standard,
auto c_stride = get_batch_stride(inputs[2]); 0,
auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride; flag);
return pack ( }
// rocblas_invoke(&rocblas_gemm_strided_batched_ex, // case ROCBLAS_GEMM_STRIDED_BATCHED_EX:
ctx.get_stream().get_rocblas(), // {
transb ? rocblas_operation_transpose : rocblas_operation_none, // auto a_stride = get_batch_stride(inputs[0]);
transa ? rocblas_operation_transpose : rocblas_operation_none, // auto b_stride = get_batch_stride(inputs[1]);
n, // auto c_stride = get_batch_stride(inputs[2]);
m, // auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride;
k, // return pack(
alpha_v, // // rocblas_invoke( &rocblas_gemm_strided_batched_ex,
to_pointer(inputs.at(1)), // ctx.get_stream().get_rocblas(),
arg_type, // transb ? rocblas_operation_transpose : rocblas_operation_none,
ldb, // transa ? rocblas_operation_transpose : rocblas_operation_none,
b_stride, // n,
to_pointer(inputs.at(0)), // m,
arg_type, // k,
lda, // alpha_v,
a_stride, // to_pointer(inputs.at(1)),
beta_v, // arg_type,
to_pointer(inputs[2]), // ldb,
output_type, // b_stride,
ldc, // to_pointer(inputs.at(0)),
c_stride, // arg_type,
is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]), // lda,
output_type, // a_stride,
ldd, // beta_v,
d_stride, // to_pointer(inputs[2]),
num_matrices, // output_type,
compute_type, // ldc,
rocblas_gemm_algo_standard, // c_stride,
0, // is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
flag); // output_type,
} // ldd,
}); // d_stride,
// num_matrices,
// compute_type,
// rocblas_gemm_algo_standard,
// 0,
// flag);
// }
// case ROCBLAS_GEMM_EX_GET_SOLUTIONS:
// // the original macro in rocBLAS-internal/rocBLAS/clients/samples/example_user_driven_tuning.cpp is
// // Note different order of m, n, k
// // #define GEMM_EX_ARGS \
// // handle, transa, transb, m, n, k, &alpha, da, type, lda, db, type, ldb, &beta, dc, type, ldc,
// // \
// // dc, type, ldc, type, rocblas_gemm_algo_solution_index
// #define GEMM_EX_ARGS \
// handle, transa, transb, m, n, k, alpha_v, da, type, lda, db, type, ldb, beta_v, dc, type, ldc, \
// dc, type, ldc, type, rocblas_gemm_algo_solution_index
// return pack(ctx.get_stream().get_rocblas());
// default:
// MIGRAPHX_THROW ("create_gemm_args(): rocBLAS command not supported");
}});
} }
} // namespace gpu } // namespace gpu
......
/* /*
* The MIT License (MIT) * The MIT License (MIT)
* *
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal * of this software and associated documentation files (the "Software"), to deal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment