Commit 352e6668 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

Brian/Umang changes, still doesn't build

parent d46b6972
Pipeline #666 failed with stages
in 0 seconds
...@@ -108,16 +108,6 @@ static rocblas_int get_batch_stride(const argument& a) ...@@ -108,16 +108,6 @@ static rocblas_int get_batch_stride(const argument& a)
return a.get_shape().strides()[a.get_shape().strides().size() - 3]; return a.get_shape().strides()[a.get_shape().strides().size() - 3];
} }
/**
* The rocblas API calls we may be interested in. Each one takes a slightly different
* argument list, generated by create_gemm_args().
*/
enum ROCBLAS_CALL
{
ROCBLAS_GEMM_EX,
ROCBLAS_GEMM_STRIDED_BATCHED_EX,
ROCBLAS_GEMM_EX_GET_SOLUTIONS,
};
template <class T> template <class T>
void gemm_impl(context& ctx, void gemm_impl(context& ctx,
...@@ -129,6 +119,7 @@ void gemm_impl(context& ctx, ...@@ -129,6 +119,7 @@ void gemm_impl(context& ctx,
bool compute_fp32) bool compute_fp32)
{ {
output_shape.visit_type([&](auto as) { // TODO: not needed? output_shape.visit_type([&](auto as) { // TODO: not needed?
(void)as;
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
auto num_matrices = std::accumulate( auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>()); out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
...@@ -142,7 +133,7 @@ void gemm_impl(context& ctx, ...@@ -142,7 +133,7 @@ void gemm_impl(context& ctx,
// column-major format. When doing a C = A * B, we actually do // column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as // C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm. // A and args[0] as B in calling the rocblas_gemm.
// auto to_invoke = auto to_invoke =
create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_EX, output_shape, args, create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_EX, output_shape, args,
alpha, beta, int8_x4_format, compute_fp32); alpha, beta, int8_x4_format, compute_fp32);
// rocblas_invoke(&rocblas_gemm_ex, // rocblas_invoke(&rocblas_gemm_ex,
...@@ -150,7 +141,7 @@ void gemm_impl(context& ctx, ...@@ -150,7 +141,7 @@ void gemm_impl(context& ctx,
} }
else else
{ {
// auto to_invoke = auto to_invoke =
create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_STRIDED_BATCHED_EX, create_gemm_args(ctx, ROCBLAS_CALL::ROCBLAS_GEMM_STRIDED_BATCHED_EX,
output_shape, args, alpha, beta, int8_x4_format, compute_fp32); output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
// rocblas_invoke(&rocblas_gemm_strided_batched_ex, // rocblas_invoke(&rocblas_gemm_strided_batched_ex,
...@@ -225,7 +216,7 @@ void gemm(context& ctx, ...@@ -225,7 +216,7 @@ void gemm(context& ctx,
* a set of MigraphX arguments. * a set of MigraphX arguments.
*/ */
template <class T> template <class T>
auto create_gemm_args(context& ctx, static auto create_gemm_args(context& ctx,
ROCBLAS_CALL rocblas_call, ROCBLAS_CALL rocblas_call,
const shape& output_shape, const shape& output_shape,
const std::vector<argument>& inputs, const std::vector<argument>& inputs,
...@@ -273,25 +264,27 @@ auto create_gemm_args(context& ctx, ...@@ -273,25 +264,27 @@ auto create_gemm_args(context& ctx,
auto a_lens = inputs[0].get_shape().lens(); auto a_lens = inputs[0].get_shape().lens();
auto b_lens = inputs[1].get_shape().lens(); auto b_lens = inputs[1].get_shape().lens();
return output_shape.visit_type([&](auto as) { void * alpha_v = nullptr;
void* beta_v = nullptr;
output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha); auto alpha_r = as(alpha);
auto beta_r = as(beta); auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode // use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r; alpha_v = &alpha_r;
void* beta_v = &beta_r; beta_v = &beta_r;
if(compute_fp32) if(compute_fp32)
{ {
alpha_v = &alpha; alpha_v = &alpha;
beta_v = &beta; beta_v = &beta;
} }
});
auto out_lens = output_shape.lens(); auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0]; rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1]; rocblas_int n = out_lens[dim_1];
rocblas_int k = inputs[0].get_shape().lens()[dim_1]; rocblas_int k = inputs[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); }; auto to_pointer = [&](auto&& arg) { return reinterpret_cast<T*>(arg.data()); };
if(inputs[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format) if(inputs[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{ {
MIGRAPHX_THROW("create_gemm_args: k size of int8 type input must be multiple of 4!"); MIGRAPHX_THROW("create_gemm_args: k size of int8 type input must be multiple of 4!");
...@@ -302,7 +295,7 @@ auto create_gemm_args(context& ctx, ...@@ -302,7 +295,7 @@ auto create_gemm_args(context& ctx,
switch(rocblas_call){ switch(rocblas_call){
case ROCBLAS_GEMM_EX: case ROCBLAS_GEMM_EX:
{ {
m *= num_matrices; m *= num_matrices;
return pack( return pack(
...@@ -337,46 +330,49 @@ auto create_gemm_args(context& ctx, ...@@ -337,46 +330,49 @@ auto create_gemm_args(context& ctx,
0, 0,
flag); flag);
} }
// case ROCBLAS_GEMM_STRIDED_BATCHED_EX:
// { case ROCBLAS_GEMM_STRIDED_BATCHED_EX:
// auto a_stride = get_batch_stride(inputs[0]); default:
// auto b_stride = get_batch_stride(inputs[1]); {
// auto c_stride = get_batch_stride(inputs[2]); auto a_stride = get_batch_stride(inputs[0]);
// auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride; auto b_stride = get_batch_stride(inputs[1]);
// return pack( auto c_stride = get_batch_stride(inputs[2]);
// // rocblas_invoke( &rocblas_gemm_strided_batched_ex, auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride;
// ctx.get_stream().get_rocblas(), return pack(
// transb ? rocblas_operation_transpose : rocblas_operation_none, // rocblas_invoke( &rocblas_gemm_strided_batched_ex,
// transa ? rocblas_operation_transpose : rocblas_operation_none, ctx.get_stream().get_rocblas(),
// n, transb ? rocblas_operation_transpose : rocblas_operation_none,
// m, transa ? rocblas_operation_transpose : rocblas_operation_none,
// k, n,
// alpha_v, m,
// to_pointer(inputs.at(1)), k,
// arg_type, alpha_v,
// ldb, to_pointer(inputs.at(1)),
// b_stride, arg_type,
// to_pointer(inputs.at(0)), ldb,
// arg_type, b_stride,
// lda, to_pointer(inputs.at(0)),
// a_stride, arg_type,
// beta_v, lda,
// to_pointer(inputs[2]), a_stride,
// output_type, beta_v,
// ldc, to_pointer(inputs[2]),
// c_stride, output_type,
// is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]), ldc,
// output_type, c_stride,
// ldd, is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
// d_stride, output_type,
// num_matrices, ldd,
// compute_type, d_stride,
// rocblas_gemm_algo_standard, num_matrices,
// 0, compute_type,
// flag); rocblas_gemm_algo_standard,
// } 0,
flag);
}
// case ROCBLAS_GEMM_EX_GET_SOLUTIONS: // case ROCBLAS_GEMM_EX_GET_SOLUTIONS:
// default:
// // the original macro in rocBLAS-internal/rocBLAS/clients/samples/example_user_driven_tuning.cpp is // // the original macro in rocBLAS-internal/rocBLAS/clients/samples/example_user_driven_tuning.cpp is
// // Note different order of m, n, k // // Note different order of m, n, k
// // #define GEMM_EX_ARGS \ // // #define GEMM_EX_ARGS \
...@@ -386,10 +382,15 @@ auto create_gemm_args(context& ctx, ...@@ -386,10 +382,15 @@ auto create_gemm_args(context& ctx,
// #define GEMM_EX_ARGS \ // #define GEMM_EX_ARGS \
// handle, transa, transb, m, n, k, alpha_v, da, type, lda, db, type, ldb, beta_v, dc, type, ldc, \ // handle, transa, transb, m, n, k, alpha_v, da, type, lda, db, type, ldb, beta_v, dc, type, ldc, \
// dc, type, ldc, type, rocblas_gemm_algo_solution_index // dc, type, ldc, type, rocblas_gemm_algo_solution_index
// return pack(ctx.get_stream().get_rocblas()); // return pack(ctx.get_stream().get_rocblas());
// Get number of solutions
// rocblas_int size;
// CHECK_ROCBLAS_ERROR(
// rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, NULL, &size));
} // end switch
// default: // default:
// MIGRAPHX_THROW ("create_gemm_args(): rocBLAS command not supported"); // MIGRAPHX_THROW ("create_gemm_args(): rocBLAS command not supported");
}});
} }
} // namespace gpu } // namespace gpu
......
...@@ -31,6 +31,16 @@ ...@@ -31,6 +31,16 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
/**
* The rocblas API calls we may be interested in. Each one takes a slightly different
* argument list, generated by create_gemm_args().
*/
enum ROCBLAS_CALL
{
ROCBLAS_GEMM_EX,
ROCBLAS_GEMM_STRIDED_BATCHED_EX,
ROCBLAS_GEMM_EX_GET_SOLUTIONS,
};
void gemm(context& ctx, void gemm(context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -47,13 +57,22 @@ void gemm(context& ctx, ...@@ -47,13 +57,22 @@ void gemm(context& ctx,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32); bool compute_fp32);
template <class T> // template <class T>
auto create_gemm_args(context& ctx, // auto create_gemm_args(context& ctx,
const std::vector<argument>& inputs); // const std::vector<argument>& inputs);
// The version with just shapes will use null pointers for the buffers // The version with just shapes will use null pointers for the buffers
template <class T> template <class T>
auto create_gemm_args(context& ctx, const std::vector<shape>& inputs); auto create_gemm_args(context& ctx, const std::vector<shape>& inputs);
template <class T>
static auto create_gemm_args(context& ctx,
ROCBLAS_CALL rocblas_call,
const shape& output_shape,
const std::vector<argument>& inputs,
T alpha,
T beta,
bool int8_x4_format,
bool compute_fp32);
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment