Commit 7c7e7ba8 authored by Brian Pickrell's avatar Brian Pickrell
Browse files

work in progress. Add function definitions generate_arguments(), time_op(),...

work in progress.  Add function definitions generate_arguments(), time_op(), create_gemm_args().  Compiles, but not functional.
parent d478675c
......@@ -26,6 +26,12 @@
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
// #include <migraphx/config.hpp>
// #include <migraphx/gpu/context.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/time.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
......@@ -253,6 +259,45 @@ void gemm_impl(context& ctx,
});
}
std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsigned long seed = 0)
{
std::vector<argument> args;
std::transform(shapes.begin(), shapes.end(), std::back_inserter(args), [&](auto& s) {
return to_gpu(generate_argument(s, seed++));
});
return args;
}
// from perf.cpp
using milliseconds = std::chrono::duration<double, std::milli>;
std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
{
// TODO: Use std::ref
migraphx::context ctx = ictx;
auto& gctx = any_cast<migraphx::gpu::context>(ctx);
auto output = op.compute_shape(inputs);
// op.finalize(ctx, output, inputs);
auto args = generate_arguments(inputs);
auto run = [&] {
op.compute(ctx, output, args);
ctx.finish();
};
gctx.enable_perf_measurement();
run();
double host_time = 0.0;
double device_time = 0.0;
for(auto i : range(n))
{
(void)i;
host_time += time<milliseconds>(run);
device_time += gctx.get_elapsed_ms();
}
return std::make_pair(host_time / n, device_time / n);
}
void gemm(context& ctx,
const shape& output_shape,
const std::vector<argument>& args,
......@@ -275,6 +320,166 @@ void gemm(context& ctx,
gemm_impl(ctx, output_shape, args, alpha, beta, int8_x4_format, compute_fp32);
}
/**
* Create a list of the arguments needed for rocBLAS GEMM calls, from
* a set of MigraphX arguments.
*/
template <class T>
auto create_gemm_args
(context& ctx,
const shape& output_shape,
const std::vector<argument>& inputs,
T alpha,
T beta,
bool int8_x4_format,
bool compute_fp32)
{
const bool is_3inputs = (inputs.size() == 4);
if(not is_3inputs)
{
beta = 0;
}
bool transa = is_transposed(inputs[0].get_shape());
bool transb = is_transposed(inputs[1].get_shape());
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
rocblas_int lda = inputs[0].get_shape().strides()[transa ? dim_1 : dim_0];
rocblas_int ldb = inputs[1].get_shape().strides()[transb ? dim_1 : dim_0];
rocblas_int ldc = inputs[2].get_shape().strides()[dim_0];
rocblas_int ldd = is_3inputs ? inputs[3].get_shape().strides()[dim_0] : ldc;
rocblas_datatype arg_type = get_type(inputs[0].get_shape().type());
auto output_type = arg_type;
if(output_type == rocblas_datatype_i8_r)
{
output_type = rocblas_datatype_i32_r;
}
auto compute_type = output_type;
if(compute_fp32)
{
if(arg_type == rocblas_datatype_f16_r)
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = inputs[0].get_shape().lens();
auto b_lens = inputs[1].get_shape().lens();
return output_shape.visit_type([&](auto as) {
auto alpha_r = as(alpha);
auto beta_r = as(beta);
// use void pointer to select different data type if using fp32 mode
void* alpha_v = &alpha_r;
void* beta_v = &beta_r;
if(compute_fp32)
{
alpha_v = &alpha;
beta_v = &beta;
}
auto out_lens = output_shape.lens();
rocblas_int m = out_lens[dim_0];
rocblas_int n = out_lens[dim_1];
rocblas_int k = inputs[0].get_shape().lens()[dim_1];
auto to_pointer = [&](auto&& arg) { return as.from(arg.data()); };
if(inputs[0].get_shape().type() == shape::int8_type and (k % 4) != 0 and int8_x4_format)
{
MIGRAPHX_THROW("ROCBLAS_GEMM: k size of int8 type input must be mutlple of 4!");
}
auto num_matrices = std::accumulate(
out_lens.rbegin() + 2, out_lens.rend(), std::size_t{1}, std::multiplies<std::size_t>());
if(num_matrices == 1 or (num_matrices > 1 and get_batch_stride(inputs[1]) == 0))
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m *= num_matrices;
return pack (
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input inputs[1] as
// A and inputs[0] as B in calling the rocblas_gemm.
// rocblas_invoke(&rocblas_gemm_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(inputs.at(1)),
arg_type,
ldb,
to_pointer(inputs.at(0)),
arg_type,
lda,
beta_v,
to_pointer(inputs[2]),
output_type,
ldc,
is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
output_type,
ldd,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
}
else
{
auto a_stride = get_batch_stride(inputs[0]);
auto b_stride = get_batch_stride(inputs[1]);
auto c_stride = get_batch_stride(inputs[2]);
auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride;
return pack (
// rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
transa ? rocblas_operation_transpose : rocblas_operation_none,
n,
m,
k,
alpha_v,
to_pointer(inputs.at(1)),
arg_type,
ldb,
b_stride,
to_pointer(inputs.at(0)),
arg_type,
lda,
a_stride,
beta_v,
to_pointer(inputs[2]),
output_type,
ldc,
c_stride,
is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
output_type,
ldd,
d_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
0,
flag);
}
});
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......@@ -43,6 +43,9 @@ struct context;
void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n = 100);
argument generate_argument(shape s, unsigned long seed = 0);
template <class Op>
struct rocblas_gemm
......
......@@ -47,6 +47,13 @@ void gemm(context& ctx,
bool int8_x4_format,
bool compute_fp32);
template <class T>
auto create_gemm_args(context& ctx,
const std::vector<argument>& inputs);
// The version with just shapes will use null pointers for the buffers
template <class T>
auto create_gemm_args(context& ctx, const std::vector<shape>& inputs);
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment