"docs/_static/vscode:/vscode.git/clone" did not exist on "0997c33310ed5e496c79a2b3c659cacc0a2aeba2"
Commit ed53d437 authored by Paul's avatar Paul
Browse files

Format

parent 6f7ee0b7
...@@ -1075,8 +1075,10 @@ struct find_contiguous_tranpose_gemm ...@@ -1075,8 +1075,10 @@ struct find_contiguous_tranpose_gemm
{ {
auto matcher() const auto matcher() const
{ {
return match::name("gpu::contiguous")(match::arg(0)(match::name("transpose")(match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm"))).bind("transpose")) return match::name("gpu::contiguous")(match::arg(0)(
); match::name("transpose")(
match::arg(0)(match::name("gpu::gemm")(match::used_once()).bind("gemm")))
.bind("transpose")));
} }
template <class Vector> template <class Vector>
...@@ -1092,33 +1094,35 @@ struct find_contiguous_tranpose_gemm ...@@ -1092,33 +1094,35 @@ struct find_contiguous_tranpose_gemm
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto gemm = r.instructions["gemm"]; auto gemm = r.instructions["gemm"];
auto alloc = gemm->inputs().back(); auto alloc = gemm->inputs().back();
auto transpose = r.instructions["transpose"]; auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm); auto iperm = invert_permutation(perm);
if (perm.size() < 3) if(perm.size() < 3)
return; return;
if (not is_swapped(perm, perm.size() - 3, perm.size() - 2)) if(not is_swapped(perm, perm.size() - 3, perm.size() - 2))
return; return;
auto lens = gemm->get_shape().lens(); auto lens = gemm->get_shape().lens();
if (lens.size() > 3 and not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; })) if(lens.size() > 3 and
not std::all_of(lens.begin(), lens.end() - 3, [](auto i) { return i == 1; }))
return; return;
auto gemmv = gemm->get_operator().to_value(); auto gemmv = gemm->get_operator().to_value();
gemmv["trans_batch"] = 1; gemmv["trans_batch"] = 1;
auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)}; auto s = shape{alloc->get_shape().type(), reorder_dims(alloc->get_shape().lens(), iperm)};
auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}})); auto new_alloc = m.insert_instruction(gemm, make_op("allocate", {{"shape", to_value(s)}}));
auto alloc_transpose = m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc); auto alloc_transpose =
m.insert_instruction(gemm, make_op("transpose", {{"permutation", perm}}), new_alloc);
auto inputs = gemm->inputs(); auto inputs = gemm->inputs();
inputs.back() = alloc_transpose; inputs.back() = alloc_transpose;
auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs); auto new_gemm = m.insert_instruction(gemm, make_op("gpu::gemm", gemmv), inputs);
auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm); auto gemm_transpoe = m.insert_instruction(gemm, transpose->get_operator(), new_gemm);
m.replace_instruction(ins, gemm_transpoe); m.replace_instruction(ins, gemm_transpoe);
......
...@@ -70,14 +70,14 @@ void blas_shape(const shape& s) ...@@ -70,14 +70,14 @@ void blas_shape(const shape& s)
shape transpose_batch(const shape& s, unsigned trans_batch) shape transpose_batch(const shape& s, unsigned trans_batch)
{ {
if (trans_batch == 0) if(trans_batch == 0)
return s; return s;
if (s.lens().size() < 3) if(s.lens().size() < 3)
return s; return s;
auto batch = s.lens().size() - 3; auto batch = s.lens().size() - 3;
std::vector<int64_t> perm(s.lens().size()); std::vector<int64_t> perm(s.lens().size());
std::iota(perm.begin(), perm.end(), 0); std::iota(perm.begin(), perm.end(), 0);
std::swap(perm[batch], perm[batch+trans_batch]); std::swap(perm[batch], perm[batch + trans_batch]);
return reorder_shape(s, perm); return reorder_shape(s, perm);
} }
......
...@@ -48,10 +48,10 @@ template <class Op> ...@@ -48,10 +48,10 @@ template <class Op>
struct rocblas_gemm struct rocblas_gemm
{ {
Op op; Op op;
float alpha = 1; float alpha = 1;
float beta = 0; float beta = 0;
bool int8_x4_format = true; bool int8_x4_format = true;
bool compute_fp32 = false; bool compute_fp32 = false;
unsigned trans_batch = 0; unsigned trans_batch = 0;
template <class Self, class F> template <class Self, class F>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment