Commit cf8ccba4 authored by Paul's avatar Paul
Browse files

Merge branch 'bert-opt2' into bert-opt3

parents 9747cc44 bd70cd8d
...@@ -74,6 +74,9 @@ struct unsqueeze ...@@ -74,6 +74,9 @@ struct unsqueeze
MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar"); MIGRAPHX_THROW("UNSQUEEZE: Input must be a scalar");
} }
if(steps.size() > axes.size())
MIGRAPHX_THROW("UNSQUEEZE: Steps provided with no axis");
std::size_t new_size = old_lens.size() + axes.size(); std::size_t new_size = old_lens.size() + axes.size();
std::vector<std::size_t> new_lens(new_size); std::vector<std::size_t> new_lens(new_size);
......
...@@ -275,7 +275,7 @@ struct find_concat_transpose ...@@ -275,7 +275,7 @@ struct find_concat_transpose
{ {
auto matcher() const auto matcher() const
{ {
return match::name("concat")(match::all_of[match::inputs()](match::transpose_shape())); return match::name("concat")(match::all_of[match::inputs()](match::name("transpose")));
} }
void apply(module& m, const match::matcher_result& mr) const void apply(module& m, const match::matcher_result& mr) const
......
...@@ -48,6 +48,7 @@ ...@@ -48,6 +48,7 @@
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#include <migraphx/array.hpp> #include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp> #include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp> #include <migraphx/op/contiguous.hpp>
...@@ -1063,6 +1064,64 @@ struct find_gemm_pointwise ...@@ -1063,6 +1064,64 @@ struct find_gemm_pointwise
} }
}; };
struct find_contiguous_tranpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}
template <class Vector>
static bool is_swapped(const Vector& perm, std::size_t i, std::size_t j)
{
if(i >= perm.size() or j >= perm.size())
return false;
auto perm2 = perm;
std::iota(perm2.begin(), perm2.end(), 0);
std::swap(perm2[i], perm2[j]);
return perm2 == perm;
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
if(perm.size() < 3)
return;
if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;
auto lens = gemm->get_shape().lens();
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;
auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe);
}
};
struct find_commutative_broadcast struct find_commutative_broadcast
{ {
auto matcher() const auto matcher() const
...@@ -1164,6 +1223,7 @@ void fuse_ops::apply(module& m) const ...@@ -1164,6 +1223,7 @@ void fuse_ops::apply(module& m) const
find_gemm_add{}, find_gemm_add{},
find_layernorm_pointwise{}, find_layernorm_pointwise{},
find_gemm_pointwise{}, find_gemm_pointwise{},
find_contiguous_tranpose_gemm{},
find_commutative_broadcast{}); find_commutative_broadcast{});
match::find_matches(m, find_contiguous{}); match::find_matches(m, find_contiguous{});
} }
......
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#include <rocblas.h> #include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp> #include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp> #include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,6 +68,19 @@ void blas_shape(const shape& s) ...@@ -67,6 +68,19 @@ void blas_shape(const shape& s)
MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible"); MIGRAPHX_THROW("GPU_GEMM: Batch dimension is not collapsible");
} }
shape transpose_batch(const shape& s, unsigned trans_batch)
{
if(trans_batch == 0)
return s;
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch + trans_batch]);
return shape::from_permutation(s.type(), s.lens(), perm);
}
template <class R, class... Ts, class... Us> template <class R, class... Ts, class... Us>
R rocblas_invoke(R (*f)(Ts...), Us... xs) R rocblas_invoke(R (*f)(Ts...), Us... xs)
{ {
......
...@@ -42,15 +42,17 @@ namespace gpu { ...@@ -42,15 +42,17 @@ namespace gpu {
struct context; struct context;
void blas_shape(const shape& s); void blas_shape(const shape& s);
shape transpose_batch(const shape& s, unsigned trans_batch);
template <class Op> template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0;
template <class Self, class F> template <class Self, class F>
static auto reflect(Self& self, F f) static auto reflect(Self& self, F f)
...@@ -58,7 +60,9 @@ struct rocblas_gemm ...@@ -58,7 +60,9 @@ struct rocblas_gemm
return pack_join(migraphx::reflect(self.op, f), return pack_join(migraphx::reflect(self.op, f),
pack(f(self.alpha, "alpha"), pack(f(self.alpha, "alpha"),
f(self.beta, "beta"), f(self.beta, "beta"),
f(self.int8_x4_format, "int8_x4_format"))); f(self.int8_x4_format, "int8_x4_format"),
f(self.compute_fp32, "compute_fp32"),
f(self.trans_batch, "trans_batch")));
} }
std::string name() const std::string name() const
...@@ -98,10 +102,10 @@ struct rocblas_gemm ...@@ -98,10 +102,10 @@ struct rocblas_gemm
to_string(cmat_shape.type()) + to_string(cmat_shape.type()) +
", it must be: " + to_string(op_out_shape.type())); ", it must be: " + to_string(op_out_shape.type()));
} }
return op_out_shape; return transpose_batch(op_out_shape, trans_batch);
} }
return op.compute_shape(in_shapes); return transpose_batch(op.compute_shape(in_shapes), trans_batch);
} }
argument argument
......
...@@ -133,8 +133,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -133,8 +133,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
eliminate_concat{concat_gpu_optimization{}}, eliminate_concat{concat_gpu_optimization{}},
dead_code_elimination{}, dead_code_elimination{},
pack_int8_args{}, pack_int8_args{},
...@@ -143,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -143,6 +141,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_ops{&ctx, options.fast_math}, fuse_ops{&ctx, options.fast_math},
dead_code_elimination{}, dead_code_elimination{},
replace_allocate{gpu_allocation_model{}, options.offload_copy},
dead_code_elimination{},
compile_ops{&ctx}, compile_ops{&ctx},
dead_code_elimination{}, dead_code_elimination{},
write_literals{&ctx}, write_literals{&ctx},
......
...@@ -1551,7 +1551,7 @@ TEST_CASE(test_unsqueeze_step_non_divisable) ...@@ -1551,7 +1551,7 @@ TEST_CASE(test_unsqueeze_step_non_divisable)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_step_non_zero) TEST_CASE(test_unsqueeze_step_zero)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {0}}}), s1); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {0}}}), s1);
...@@ -1563,6 +1563,12 @@ TEST_CASE(test_unsqueeze_step_at_end) ...@@ -1563,6 +1563,12 @@ TEST_CASE(test_unsqueeze_step_at_end)
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {3}}, {"steps", {2}}}), s1); throws_shape(migraphx::make_op("unsqueeze", {{"axes", {3}}, {"steps", {2}}}), s1);
} }
TEST_CASE(test_unsqueeze_mismatch_step_axis)
{
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 12}};
throws_shape(migraphx::make_op("unsqueeze", {{"axes", {2}}, {"steps", {2, 3}}}), s1);
}
TEST_CASE(test_unsqueeze_negative_axis) TEST_CASE(test_unsqueeze_negative_axis)
{ {
migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}}; migraphx::shape s1{migraphx::shape::float_type, {4, 5, 3}};
...@@ -1659,6 +1665,13 @@ TEST_CASE(test_unsqueeze_multiple_axes_4) ...@@ -1659,6 +1665,13 @@ TEST_CASE(test_unsqueeze_multiple_axes_4)
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {5, 4, 2}}}), s1); expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {5, 4, 2}}}), s1);
} }
TEST_CASE(test_unsqueeze_multiple_axes_step)
{
migraphx::shape s1{migraphx::shape::float_type, {3, 4, 10}};
migraphx::shape s2{migraphx::shape::float_type, {3, 4, 2, 5, 1, 1}};
expect_shape(s2, migraphx::make_op("unsqueeze", {{"axes", {2, 4, 5}}, {"steps", {2}}}), s1);
}
TEST_CASE(transpose_shape) TEST_CASE(transpose_shape)
{ {
migraphx::shape input{migraphx::shape::float_type, {2, 2}}; migraphx::shape input{migraphx::shape::float_type, {2, 2}};
......
...@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast) ...@@ -1141,6 +1141,38 @@ TEST_CASE(transpose_contiguous_reshape_binary_broadcast)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(transpose_unsqueeze_concat)
{
migraphx::module m1;
{
auto l0 = m1.add_parameter("0", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt0 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l0);
auto l1 = m1.add_parameter("1", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt1 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l1);
auto l2 = m1.add_parameter("2", migraphx::shape{migraphx::shape::float_type, {1, 2, 1, 1}});
auto lt2 =
m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), l2);
std::vector<migraphx::instruction_ref> args{lt0, lt1, lt2};
std::vector<migraphx::instruction_ref> unsqueezed_args;
int64_t axis = 3;
std::transform(
args.begin(),
args.end(),
std::back_inserter(unsqueezed_args),
[&](migraphx::instruction_ref arg) {
return m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {axis}}}), arg);
});
m1.add_instruction(migraphx::make_op("concat", {{"axis", axis}}), unsqueezed_args);
}
// TODO: This could be simplified to a single transpose after concat
migraphx::module m2 = m1;
run_pass(m1);
EXPECT(m1 == m2);
}
TEST_CASE(transpose_slice) TEST_CASE(transpose_slice)
{ {
migraphx::module m1; migraphx::module m1;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment