Commit 352e6668 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Brian/Umang changes, still doesn't build

parent d46b6972
Pipeline #666 failed with stages
in 0 seconds
......@@ -108,16 +108,6 @@ static rocblas_int get_batch_stride(const argument& a)
return a.get_shape().strides()[a.get_shape().strides().size() - 3];
}
/**
* The rocblas API calls we may be interested in. Each one takes a slightly different
* argument list, generated by create_gemm_args().
*/
enum ROCBLAS_CALL
{
ROCBLAS_GEMM_EX,
ROCBLAS_GEMM_STRIDED_BATCHED_EX,
ROCBLAS_GEMM_EX_GET_SOLUTIONS,
};
template <class T>
void gemm_impl(context& ctx,
......@@ -129,6 +119,7 @@ void gemm_impl(context& ctx,
bool compute_fp32)
{
output_shape.visit_type([&](auto as) { // TODO: not needed?
(void)as;
auto out_lens = output_shape.lens();
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
......@@ -142,7 +133,7 @@ void gemm_impl(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
// auto to_invoke =
auto to_invoke =
create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_EX, output_shape, args,
alpha, beta, int8_x4_format, compute_fp32);
// rocblas_invoke(&rocblas_gemm_ex,
......@@ -150,7 +141,7 @@ void gemm_impl(context& ctx,
}
else
{
// auto to_invoke =
auto to_invoke =
create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_STRIDED_BATCHED_EX,
output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
// rocblas_invoke(&rocblas_gemm_strided_batched_ex,
......@@ -225,7 +216,7 @@ void gemm(context& ctx,
* a set of MigraphX arguments.
*/
template <class T>
auto create_gemm_args(context& ctx,
static auto create_gemm_args(context& ctx,
ROCBLAS_CALL rocblas_call,
const shape& output_shape,
const std::vector<argument>& inputs,
......@@ -273,25 +264,27 @@ auto create_gemm_args(context& ctx,
auto a_lens = inputs[0].get_shape().lens();
auto b_lens = inputs[1].get_shape().lens();
return output_shape.visit_type([&](auto as) {
void * alpha_v = nullptr;
void* beta_v = nullptr;
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
alpha_v = &alpha_r;
beta_v = &beta_r;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
});
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = inputs[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
auto to_pointer = [&](auto&& arg) { return reinterpret_cast<T*>(arg.data()); };
if(inputs[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("create_gemm_args: k size of int8 type input must be multiple of 4!");
......@@ -302,7 +295,7 @@ auto create_gemm_args(context& ctx,
switch(rocblas_call){
case ROCBLAS_GEMM_EX:
{
{
m *= num_matrices;
return pack(
......@@ -337,46 +330,49 @@ auto create_gemm_args(context& ctx,
0,
flag);
}
// case ROCBLAS_GEMM_STRIDED_BATCHED_EX:
// {
// auto a_stride = get_batch_stride(inputs[0]);
// auto b_stride = get_batch_stride(inputs[1]);
// auto c_stride = get_batch_stride(inputs[2]);
// auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride;
// return pack(
// // rocblas_invoke( &rocblas_gemm_strided_batched_ex,
// ctx.get_stream().get_rocblas(),
// transb ? rocblas_operation_transpose : rocblas_operation_none,
// transa ? rocblas_operation_transpose : rocblas_operation_none,
// n,
// m,
// k,
// alpha_v,
// to_pointer(inputs.at(1)),
// arg_type,
// ldb,
// b_stride,
// to_pointer(inputs.at(0)),
// arg_type,
// lda,
// a_stride,
// beta_v,
// to_pointer(inputs[2]),
// output_type,
// ldc,
// c_stride,
// is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
// output_type,
// ldd,
// d_stride,
// num_matrices,
// compute_type,
// rocblas_gemm_algo_standard,
// 0,
// flag);
// }
case ROCBLAS_GEMM_STRIDED_BATCHED_EX:
default:
{
auto a_stride = get_batch_stride(inputs[0]);
auto b_stride = get_batch_stride(inputs[1]);
auto c_stride = get_batch_stride(inputs[2]);
auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride;
return pack(
// rocblas_invoke( &rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(inputs.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(inputs.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(inputs[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
}
// case ROCBLAS_GEMM_EX_GET_SOLUTIONS:
// default:
// // the original macro in rocBLAS-internal/rocBLAS/clients/samples/example_user_driven_tuning.cpp is
// // Note different order of m, n, k
// // #define GEMM_EX_ARGS \
......@@ -386,10 +382,15 @@ auto create_gemm_args(context& ctx,
// #define GEMM_EX_ARGS \
// handle, transa, transb, m, n, k, alpha_v, da, type, lda, db, type, ldb, beta_v, dc, type, ldc, \
// dc, type, ldc, type, rocblas_gemm_algo_solution_index
// return pack(ctx.get_stream().get_rocblas());
// return pack(ctx.get_stream().get_rocblas());
// Get number of solutions
// rocblas_int size;
// CHECK_ROCBLAS_ERROR(
// rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, NULL, &size));
} // end switch
// default:
// MIGRAPHX_THROW ("create_gemm_args(): rocBLAS command not supported");
}});
}
} // namespace gpu
......
......@@ -31,6 +31,16 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
/**
* The rocblas API calls we may be interested in. Each one takes a slightly different
* argument list, generated by create_gemm_args().
*/
enum ROCBLAS_CALL
{
ROCBLAS_GEMM_EX,
ROCBLAS_GEMM_STRIDED_BATCHED_EX,
ROCBLAS_GEMM_EX_GET_SOLUTIONS,
};
void gemm(context& ctx,
const shape& output_shape,
......@@ -47,13 +57,22 @@ void gemm(context& ctx,
bool int8_x4_format,
bool compute_fp32);
template <class T>
auto create_gemm_args(context& ctx,
const std::vector<argument>& inputs);
// template <class T>
// auto create_gemm_args(context& ctx,
// const std::vector<argument>& inputs);
// The version with just shapes will use null pointers for the buffers
template <class T>
auto create_gemm_args(context& ctx, const std::vector<shape>& inputs);
template <class T>
static auto create_gemm_args(context& ctx,
ROCBLAS_CALL rocblas_call,
const shape& output_shape,
const std::vector<argument>& inputs,
T alpha,
T beta,
bool int8_x4_format,
bool compute_fp32);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment