Commit cd96c1c8 authored by turneram's avatar turneram
Browse files

Formatting

parent 37351ed6
...@@ -65,7 +65,8 @@ struct parse_attention : op_parser<parse_attention> ...@@ -65,7 +65,8 @@ struct parse_attention : op_parser<parse_attention>
migraphx::make_op("reshape", migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}), {{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms); add_gemms);
auto transqkv = info.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), add_gemms); auto transqkv = info.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), add_gemms);
// Q, K, V: each has size BxNxSxH // Q, K, V: each has size BxNxSxH
auto q_t = info.add_instruction( auto q_t = info.add_instruction(
...@@ -99,7 +100,8 @@ struct parse_attention : op_parser<parse_attention> ...@@ -99,7 +100,8 @@ struct parse_attention : op_parser<parse_attention>
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t); auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxHiddenSize // result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxHiddenSize
gemm4 = info.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4); gemm4 = info.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4);
return info.add_instruction( return info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}), make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
gemm4); gemm4);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment