"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "63db86bdb8e78c13ecf5b4b89b963894c43f21bc"
Commit 6f7ee0b7 authored by Paul's avatar Paul
Browse files

Fuse batch transposed gemm

parent 9cb9bc09
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
...@@ -1070,6 +1071,60 @@ struct find_gemm_pointwise ...@@ -1070,6 +1071,60 @@ struct find_gemm_pointwise
} }
}; };
struct find_contiguous_tranpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(match::name("transpose")(match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))).bind("transpose"))
);
}
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
if (perm.size() < 3)
return;
if (not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;
auto lens = gemm->get_shape().lens();
if (lens.size() > 3 and not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;
auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose = m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1144,6 +1199,7 @@ void fuse_ops::apply(module& m) const ...@@ -1144,6 +1199,7 @@ void fuse_ops::apply(module& m) const
find_triadd_layernorm{}, find_triadd_layernorm{},
find_gemm_add{}, find_gemm_add{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,6 +68,19 @@ void blas_shape(const shape& s) ...@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
} }
shape transpose_batch(const shape& s, unsigned trans_batch)
{
if (trans_batch == 0)
return s;
if (s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch+trans_batch]);
return reorder_shape(s, perm);
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
......
...@@ -42,6 +42,7 @@ namespace gpu { ...@@ -42,6 +42,7 @@ namespace gpu {
struct context; struct context;
void blas_shape(const shape& s); void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
...@@ -51,6 +52,7 @@ struct rocblas_gemm ...@@ -51,6 +52,7 @@ struct rocblas_gemm
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -58,7 +60,9 @@ struct rocblas_gemm ...@@ -58,7 +60,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f), return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"), pack(f(self.alpha, "alpha"),
f(self.beta, "beta"), f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format"))); f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
} }
std::string name() const std::string name() const
...@@ -98,10 +102,10 @@ struct rocblas_gemm ...@@ -98,10 +102,10 @@ struct rocblas_gemm
to_string(cmat_shape.type()) + to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type())); ", it must be: " + to_string(op_out_shape.type()));
} }
return op_out_shape; return transpose_batch(op_out_shape, trans_batch);
} }
return op.compute_shape(in_shapes); return transpose_batch(op.compute_shape(in_shapes), trans_batch);
} }
argument argument
......
...@@ -132,8 +132,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -132,8 +132,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
...@@ -142,6 +140,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -142,6 +140,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment