Commit d46b6972 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

create_gemm_args() work in progress. This version builds only if I comment...

create_gemm_args() work in progress.  This version builds only if I comment out all but 1 return from create_gemm_args().
parent 7c7e7ba8
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......@@ -108,6 +108,17 @@ static rocblas_int get_batch_stride(const argument& a)
return a.get_shape().strides()[a.get_shape().strides().size() - 3];
}
/**
* The rocblas API calls we may be interested in. Each one takes a slightly different
* argument list, generated by create_gemm_args().
*/
enum ROCBLAS_CALL
{
ROCBLAS_GEMM_EX,
ROCBLAS_GEMM_STRIDED_BATCHED_EX,
ROCBLAS_GEMM_EX_GET_SOLUTIONS,
};
template <class T>
void gemm_impl(context& ctx,
const shape& output_shape,
......@@ -117,69 +128,8 @@ void gemm_impl(context& ctx,
bool int8_x4_format,
bool compute_fp32)
{
const bool is_3inputs = (args.size() == 4);
if(not is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = args[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = args[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = args[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? args[3].get_shape().strides()[dim_0] : ldc;
rocblas_datatype arg_type = get_type(args[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
output_shape.visit_type([&](auto as) { // TODO: not needed?
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = args[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(args[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
}
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(args[1]) == 0))
......@@ -187,74 +137,24 @@ void gemm_impl(context& ctx,
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
to_pointer(args.at(0)),
arg_type,
lda,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
// auto to_invoke =
create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_EX, output_shape, args,
alpha, beta, int8_x4_format, compute_fp32);
// rocblas_invoke(&rocblas_gemm_ex,
// to_invoke);
}
else
{
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
auto d_stride = is_3inputs ? get_batch_stride(args[3]) : c_stride;
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(args.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(args.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
// auto to_invoke =
create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_STRIDED_BATCHED_EX,
output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
// rocblas_invoke(&rocblas_gemm_strided_batched_ex,
// to_invoke);
}
});
}
......@@ -268,7 +168,6 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
return args;
}
// from perf.cpp
using milliseconds = std::chrono::duration<double, std::milli>;
std::pair<double, double>
......@@ -321,14 +220,13 @@ void gemm(context& ctx,
}
/**
* Create a list of the arguments needed for rocBLAS GEMM calls, from
* a set of MigraphX arguments.
*/
*/
template <class T>
auto create_gemm_args
(context& ctx,
auto create_gemm_args(context& ctx,
ROCBLAS_CALL rocblas_call,
const shape& output_shape,
const std::vector<argument>& inputs,
T alpha,
......@@ -396,18 +294,18 @@ auto create_gemm_args
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(inputs[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
MIGRAPHX_THROW("create_gemm_args: k size of int8 type input must be multiple of 4!");
}
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(inputs[1]) == 0))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
switch(rocblas_call){
case ROCBLAS_GEMM_EX:
{
m *= num_matrices;
return pack (
return pack(
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
......@@ -439,45 +337,59 @@ auto create_gemm_args
0,
flag);
}
else
{
auto a_stride = get_batch_stride(inputs[0]);
auto b_stride = get_batch_stride(inputs[1]);
auto c_stride = get_batch_stride(inputs[2]);
auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride;
return pack (
// rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(inputs.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(inputs.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(inputs[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
}
});
// case ROCBLAS_GEMM_STRIDED_BATCHED_EX:
// {
// auto a_stride = get_batch_stride(inputs[0]);
// auto b_stride = get_batch_stride(inputs[1]);
// auto c_stride = get_batch_stride(inputs[2]);
// auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride;
// return pack(
// // rocblas_invoke( &rocblas_gemm_strided_batched_ex,
// ctx.get_stream().get_rocblas(),
// transb ? rocblas_operation_transpose : rocblas_operation_none,
// transa ? rocblas_operation_transpose : rocblas_operation_none,
// n,
// m,
// k,
// alpha_v,
// to_pointer(inputs.at(1)),
// arg_type,
// ldb,
// b_stride,
// to_pointer(inputs.at(0)),
// arg_type,
// lda,
// a_stride,
// beta_v,
// to_pointer(inputs[2]),
// output_type,
// ldc,
// c_stride,
// is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
// output_type,
// ldd,
// d_stride,
// num_matrices,
// compute_type,
// rocblas_gemm_algo_standard,
// 0,
// flag);
// }
// case ROCBLAS_GEMM_EX_GET_SOLUTIONS:
// // the original macro in rocBLAS-internal/rocBLAS/clients/samples/example_user_driven_tuning.cpp is
// // Note different order of m, n, k
// // #define GEMM_EX_ARGS \
// // handle, transa, transb, m, n, k, &alpha, da, type, lda, db, type, ldb, &beta, dc, type, ldc,
// // \
// // dc, type, ldc, type, rocblas_gemm_algo_solution_index
// #define GEMM_EX_ARGS \
// handle, transa, transb, m, n, k, alpha_v, da, type, lda, db, type, ldb, beta_v, dc, type, ldc, \
// dc, type, ldc, type, rocblas_gemm_algo_solution_index
// return pack(ctx.get_stream().get_rocblas());
// default:
// MIGRAPHX_THROW ("create_gemm_args(): rocBLAS command not supported");
}});
}
} // namespace gpu
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment