Unverified Commit bf0a4713 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improve applicable batched gemms (#1214)

* Improve applicable batched gemms for bert
parent 150d6d20
...@@ -16,10 +16,8 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -16,10 +16,8 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto bstride = s.strides()[n + 1]; auto bstride = s.strides()[n + 1];
auto blen = s.lens()[n + 1]; auto blen = s.lens()[n + 1];
if(astride == bstride * blen) if(astride == bstride * blen or alen == 1)
{
new_lens.push_back(alen * blen); new_lens.push_back(alen * blen);
}
} }
if(new_lens.size() != shapes.size()) if(new_lens.size() != shapes.size())
return false; return false;
...@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n) ...@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return true; return true;
} }
void reduce_dim1(std::vector<shape>& shapes)
{
if(std::any_of(shapes.begin(), shapes.end(), [&](const auto& s) {
return s.lens().size() < 2 or s.lens().back() != 1;
}))
return;
for(auto& s : shapes)
{
auto lens = s.lens();
auto strides = s.strides();
lens.pop_back();
strides.pop_back();
s = shape{s.type(), lens, strides};
}
}
std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n) std::size_t reduce_dim_all(std::vector<shape>& shapes, std::size_t n)
{ {
while(reduce_dim(shapes, n) and n < shapes.size()) {} while(reduce_dim(shapes, n) and n < shapes.size()) {}
return n + 1; return n + 1;
} }
void reduce_dim_all(std::vector<shape>& shapes) void reduce_dim_all(std::vector<shape>& shapes)
...@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes) ...@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std::size_t n = 0; std::size_t n = 0;
while(n < shapes.front().lens().size() - 1) while(n < shapes.front().lens().size() - 1)
n = reduce_dim_all(shapes, n); n = reduce_dim_all(shapes, n);
reduce_dim1(shapes);
} }
std::vector<std::size_t> base_lens(const std::vector<shape>& shapes) std::vector<std::size_t> base_lens(const std::vector<shape>& shapes)
......
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type) ...@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!"); MIGRAPHX_THROW("ROCBLAS_GEMM: data type not supported!");
} }
void blas_shape(const shape& s)
{
if(s.lens().size() < 2)
return;
if(std::none_of(s.strides().end() - 2, s.strides().end(), [&](auto i) { return i == 1; }))
MIGRAPHX_THROW("GPU_GEMM: needs to have one matrix stride as 1");
if(s.lens().size() < 3)
return;
shape batch_shape{s.type(),
{s.lens().begin(), s.lens().end() - 2},
{s.strides().begin(), s.strides().end() - 2}};
auto batch_shapes = reduce_dims({batch_shape});
if(batch_shapes.front().lens().size() != 1)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
...@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs) ...@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
return f(xs..., nullptr, nullptr); return f(xs..., nullptr, nullptr);
} }
static bool is_transposed(const shape& s)
{
if(not s.transposed())
return false;
return s.strides().back() != 1;
}
static rocblas_int get_batch_stride(const argument& a)
{
return a.get_shape().strides()[a.get_shape().strides().size() - 3];
}
template <class T> template <class T>
void gemm_impl(context& ctx, void gemm_impl(context& ctx,
const shape& output_shape, const shape& output_shape,
...@@ -45,8 +74,8 @@ void gemm_impl(context& ctx, ...@@ -45,8 +74,8 @@ void gemm_impl(context& ctx,
bool int8_x4_format, bool int8_x4_format,
bool compute_fp32) bool compute_fp32)
{ {
bool transa = args[0].get_shape().transposed(); bool transa = is_transposed(args[0].get_shape());
bool transb = args[1].get_shape().transposed(); bool transb = is_transposed(args[1].get_shape());
auto n_dim = output_shape.lens().size(); auto n_dim = output_shape.lens().size();
auto dim_1 = n_dim - 1; auto dim_1 = n_dim - 1;
auto dim_0 = n_dim - 2; auto dim_0 = n_dim - 2;
...@@ -142,6 +171,9 @@ void gemm_impl(context& ctx, ...@@ -142,6 +171,9 @@ void gemm_impl(context& ctx,
} }
else else
{ {
auto a_stride = get_batch_stride(args[0]);
auto b_stride = get_batch_stride(args[1]);
auto c_stride = get_batch_stride(args[2]);
rocblas_invoke(&rocblas_gemm_strided_batched_ex, rocblas_invoke(&rocblas_gemm_strided_batched_ex,
ctx.get_stream().get_rocblas(), ctx.get_stream().get_rocblas(),
transb ? rocblas_operation_transpose : rocblas_operation_none, transb ? rocblas_operation_transpose : rocblas_operation_none,
...@@ -153,20 +185,20 @@ void gemm_impl(context& ctx, ...@@ -153,20 +185,20 @@ void gemm_impl(context& ctx,
to_pointer(args.at(1)), to_pointer(args.at(1)),
arg_type, arg_type,
ldb, ldb,
k * n, b_stride,
to_pointer(args.at(0)), to_pointer(args.at(0)),
arg_type, arg_type,
lda, lda,
m * k, a_stride,
beta_v, beta_v,
to_pointer(args[2]), to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, c_stride,
is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]), is_3inputs ? to_pointer(args[3]) : to_pointer(args[2]),
output_type, output_type,
ldc, ldc,
m * n, c_stride,
num_matrices, num_matrices,
compute_type, compute_type,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
......
...@@ -18,6 +18,8 @@ namespace gpu { ...@@ -18,6 +18,8 @@ namespace gpu {
struct context; struct context;
void blas_shape(const shape& s);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
...@@ -50,13 +52,14 @@ struct rocblas_gemm ...@@ -50,13 +52,14 @@ struct rocblas_gemm
std::vector<shape> in_shapes(inputs); std::vector<shape> in_shapes(inputs);
in_shapes.pop_back(); in_shapes.pop_back();
check_shapes{in_shapes, *this}.not_broadcasted(); check_shapes{in_shapes, *this}.not_broadcasted();
batch_not_transposed(inputs[0].strides()); blas_shape(inputs[0]);
batch_not_transposed(inputs[1].strides()); blas_shape(inputs[1]);
// if gemm and add are fused // if gemm and add are fused
if(not float_equal(beta, 0)) if(in_shapes.size() > 2)
{ {
auto cmat_shape = in_shapes.back(); auto cmat_shape = in_shapes.back();
in_shapes.pop_back(); in_shapes.pop_back();
blas_shape(cmat_shape);
auto op_out_shape = op.compute_shape(in_shapes); auto op_out_shape = op.compute_shape(in_shapes);
if(cmat_shape.lens() != op_out_shape.lens()) if(cmat_shape.lens() != op_out_shape.lens())
{ {
...@@ -71,6 +74,7 @@ struct rocblas_gemm ...@@ -71,6 +74,7 @@ struct rocblas_gemm
to_string(cmat_shape.type()) + to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type())); ", it must be: " + to_string(op_out_shape.type()));
} }
return op_out_shape;
} }
return op.compute_shape(in_shapes); return op.compute_shape(in_shapes);
...@@ -96,28 +100,6 @@ struct rocblas_gemm ...@@ -96,28 +100,6 @@ struct rocblas_gemm
return args.back(); return args.back();
} }
void batch_not_transposed(const std::vector<std::size_t>& strides) const
{
if(strides.size() <= 2)
return;
auto dim_0 = strides.size() - 2;
auto matrix_size = std::max(strides[dim_0], strides[dim_0 + 1]);
std::vector<std::size_t> batch(strides.begin(), strides.begin() + dim_0);
if(std::all_of(batch.begin(), batch.end(), [&](auto i) { return (i < matrix_size); }))
{
MIGRAPHX_THROW("GPU_GEMM: matrix size and batch size {" + to_string_range(strides) +
"} are transposed!");
}
if(std::adjacent_find(batch.begin(), batch.end(), [&](auto i, auto j) {
return (i < j or i < matrix_size or j < matrix_size);
}) != batch.end())
{
MIGRAPHX_THROW("GPU_GEMM: batch size {" + to_string_range(strides) +
"} is transposed!");
}
}
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{ {
return shapes.size() - 1; return shapes.size() - 1;
......
...@@ -109,6 +109,29 @@ TEST_CASE(transposed1) ...@@ -109,6 +109,29 @@ TEST_CASE(transposed1)
EXPECT(eshapes == rshapes); EXPECT(eshapes == rshapes);
} }
TEST_CASE(non_packed_empty1)
{
std::vector<migraphx::shape> ishapes = {make_shape({1, 12}, {589824, 64})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(non_packed_empty2)
{
std::vector<migraphx::shape> ishapes = {make_shape({12, 1}, {64, 589824})};
std::vector<migraphx::shape> eshapes = {make_shape({12}, {64})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(eshapes == rshapes);
}
TEST_CASE(single_dim)
{
std::vector<migraphx::shape> ishapes = {make_shape({1}, {1})};
auto rshapes = migraphx::reduce_dims(ishapes);
EXPECT(ishapes == rshapes);
}
TEST_CASE(empty) TEST_CASE(empty)
{ {
auto rshapes = migraphx::reduce_dims({}); auto rshapes = migraphx::reduce_dims({});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment