Commit ed53d437 authored by Paul's avatar Paul
Browse files

Format

parent 6f7ee0b7
...@@ -1075,8 +1075,10 @@ struct find_contiguous_tranpose_gemm ...@@ -1075,8 +1075,10 @@ struct find_contiguous_tranpose_gemm
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::contiguous")(match::arg(0)(match::name("transpose")(match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))).bind("transpose")) return match::name("gpu::contiguous")(match::arg(0)(
); match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
} }
template <class Vector> template <class Vector>
...@@ -1099,14 +1101,15 @@ struct find_contiguous_tranpose_gemm ...@@ -1099,14 +1101,15 @@ struct find_contiguous_tranpose_gemm
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm); auto iperm = invert_permutation(perm);
if (perm.size() < 3) if(perm.size() < 3)
return; return;
if (not is_swapped(perm, perm.size() - 3, perm.size() - 2)) if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return; return;
auto lens = gemm->get_shape().lens(); auto lens = gemm->get_shape().lens();
if (lens.size() > 3 and not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; })) if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return; return;
auto gemmv = gemm->get_operator().to_value(); auto gemmv = gemm->get_operator().to_value();
...@@ -1114,7 +1117,8 @@ struct find_contiguous_tranpose_gemm ...@@ -1114,7 +1117,8 @@ struct find_contiguous_tranpose_gemm
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)}; auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}})); auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose = m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc); auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs(); auto inputs = gemm->inputs();
inputs.back() = alloc_transpose; inputs.back() = alloc_transpose;
......
...@@ -70,14 +70,14 @@ void blas_shape(const shape& s) ...@@ -70,14 +70,14 @@ void blas_shape(const shape& s)
shape transpose_batch(const shape& s, unsigned trans_batch) shape transpose_batch(const shape& s, unsigned trans_batch)
{ {
if (trans_batch == 0) if(trans_batch == 0)
return s; return s;
if (s.lens().size() < 3) if(s.lens().size() < 3)
return s; return s;
auto batch = s.lens().size() - 3; auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size()); std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch+trans_batch]); std::swap(perm[batch], perm[batch + trans_batch]);
return reorder_shape(s, perm); return reorder_shape(s, perm);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment