Unverified Commit bf0a4713 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improve applicable batched gemms (#1214)

* Improve applicable batched gemms for bert
parent 150d6d20
......@@ -16,11 +16,9 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1];
if(astride == bstride * blen)
{
if(astride == bstride * blen or alen == 1)
new_lens.push_back(alen * blen);
}
}
if(new_lens.size() != shapes.size())
return false;
std::size_t i = 0;
......@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return true;
}
void reduce_dim1(std::vector<shape>& shapes)
{
if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().size() < 2 or s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
}
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{
while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1;
}
void reduce_dim_all(std::vector<shape>& shapes)
......@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std::size_t n = 0;
while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
}
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
......
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
}
void blas_shape(const shape& s)
{
if(s.lens().size() < 2)
return;
if(std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { return i == 1; }))
MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
if(s.lens().size() < 3)
return;
shape batch_shape{s.type(),
{s.lens().begin(), s.lens().end() - 2},
{s.strides().begin(), s.strides().end() - 2}};
auto batch_shapes = reduce_dims({batch_shape});
if(batch_shapes.front().lens().size() != 1)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
}
template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs)
{
......@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
return f(xs..., nullptr, nullptr);
}
static bool is_transposed(const shape& s)
{
if(not s.transposed())
return false;
return s.strides().back() != 1;
}
static rocblas_int get_batch_stride(const argument& a)
{
return a.get_shape().strides()[a.get_shape().strides().size() - 3];
}
template <class T>
void gemm_impl(context& ctx,
const shape& output_shape,
......@@ -45,8 +74,8 @@ void gemm_impl(context& ctx,
bool int8_x4_format,
bool compute_fp32)
{
bool transa = args[0].get_shape().transposed();
bool transb = args[1].get_shape().transposed();
bool transa = is_transposed(args[0].get_shape());
bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2;
......@@ -142,6 +171,9 @@ void gemm_impl(context& ctx,
}
else
{
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none,
......@@ -153,20 +185,20 @@ void gemm_impl(context& ctx,
to_pointer(args.at(1)),
arg_type,
ldb,
k * n,
b_stride,
to_pointer(args.at(0)),
arg_type,
lda,
m * k,
a_stride,
beta_v,
to_pointer(args[2]),
output_type,
ldc,
m * n,
c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type,
ldc,
m * n,
c_stride,
num_matrices,
compute_type,
rocblas_gemm_algo_standard,
......
......@@ -18,6 +18,8 @@ namespace gpu {
struct context;
void blas_shape(const shape& s);
template <class Op>
struct rocblas_gemm
{
......@@ -50,13 +52,14 @@ struct rocblas_gemm
std::vector<shape> in_shapes(inputs);
in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted();
batch_not_transposed(inputs[0].strides());
batch_not_transposed(inputs[1].strides());
blas_shape(inputs[0]);
blas_shape(inputs[1]);
// if gemm and add are fused
if(not float_equal(beta, 0))
if(in_shapes.size() > 2)
{
auto cmat_shape = in_shapes.back();
in_shapes.pop_back();
blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes);
if(cmat_shape.lens() != op_out_shape.lens())
{
......@@ -71,6 +74,7 @@ struct rocblas_gemm
to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type()));
}
return op_out_shape;
}
return op.compute_shape(in_shapes);
......@@ -96,28 +100,6 @@ struct rocblas_gemm
return args.back();
}
void batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
{
MIGRAPHX_THROW("GPU_GEMM: matrix size and batch size {" + to_string_range(strides) +
"} are transposed!");
}
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("GPU_GEMM: batch size {" + to_string_range(strides) +
"} is transposed!");
}
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
......
......@@ -109,6 +109,29 @@ TEST_CASE(transposed1)
EXPECT(eshapes == rshapes);
}
TEST_CASE(non_packed_empty1)
{
std::vector<migraphx::shape> ishapes = {make_shape({1, 12}, {589824, 64})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(non_packed_empty2)
{
std::vector<migraphx::shape> ishapes = {make_shape({12, 1}, {64, 589824})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(single_dim)
{
std::vector<migraphx::shape> ishapes = {make_shape({1}, {1})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(ishapes == rshapes);
}
TEST_CASE(empty)
{
auto rshapes = migraphx::reduce_dims({});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment