Commit ed53d437 authored by Paul's avatar Paul
Browse files

Format

parent 6f7ee0b7
......@@ -1075,8 +1075,10 @@ struct find_contiguous_tranpose_gemm
{
auto matcher() const
{
return match::name("gpu::contiguous")(match::arg(0)(match::name("transpose")(match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))).bind("transpose"))
);
return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
}
template <class Vector>
......@@ -1092,33 +1094,35 @@ struct find_contiguous_tranpose_gemm
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto ins = r.result;
auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm);
if (perm.size() < 3)
if(perm.size() < 3)
return;
if (not is_swapped(perm, perm.size() - 3, perm.size() - 2))
if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return;
auto lens = gemm->get_shape().lens();
if (lens.size() > 3 and not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return;
auto gemmv = gemm->get_operator().to_value();
auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose = m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto inputs = gemm->inputs();
inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe);
......
......@@ -70,14 +70,14 @@ void blas_shape(const shape& s)
shape transpose_batch(const shape& s, unsigned trans_batch)
{
if (trans_batch == 0)
if(trans_batch == 0)
return s;
if (s.lens().size() < 3)
if(s.lens().size() < 3)
return s;
auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch+trans_batch]);
std::swap(perm[batch], perm[batch + trans_batch]);
return reorder_shape(s, perm);
}
......
......@@ -48,10 +48,10 @@ template <class Op>
struct rocblas_gemm
{
Op op;
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false;
float alpha = 1;
float beta = 0;
bool int8_x4_format = true;
bool compute_fp32 = false;
unsigned trans_batch = 0;
template <class Self, class F>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment