Commit ba7a370a authored by turneram's avatar turneram
Browse files

Formatting

parent eea36256
...@@ -87,8 +87,7 @@ struct layernorm ...@@ -87,8 +87,7 @@ struct layernorm
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i) for(std::size_t i = 0; i < norm_size; ++i)
{ {
output[offset + i] = output[offset + i] = (data[offset + i] - mean) / mean_square;
(data[offset + i] - mean) / mean_square;
/* if(args.size() == 3) /* if(args.size() == 3)
output[offset + i] = output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i]; (data[offset + i] - mean) / mean_square * weights[i] + bias[i];
......
...@@ -25,12 +25,12 @@ struct parse_attention : op_parser<parse_attention> ...@@ -25,12 +25,12 @@ struct parse_attention : op_parser<parse_attention>
instruction_ref extra_add_qk; instruction_ref extra_add_qk;
bool is_past = false; bool is_past = false;
bool is_extra_add_qk = false; bool is_extra_add_qk = false;
if (args.size() > 4) if(args.size() > 4)
{ {
past = args[4]; past = args[4];
is_past = true; is_past = true;
} }
if (args.size() == 6) if(args.size() == 6)
{ {
is_extra_add_qk = true; is_extra_add_qk = true;
extra_add_qk = args[5]; extra_add_qk = args[5];
...@@ -53,14 +53,14 @@ struct parse_attention : op_parser<parse_attention> ...@@ -53,14 +53,14 @@ struct parse_attention : op_parser<parse_attention>
auto head_size = hidden_size / num_heads; auto head_size = hidden_size / num_heads;
int past_sequence_length = 0; int past_sequence_length = 0;
// GetPresent // GetPresent
// Input and output shapes: // Input and output shapes:
// past : (2, batch_size, num_heads, past_sequence_length, head_size) // past : (2, batch_size, num_heads, past_sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size) // present : (2, batch_size, num_heads, past_sequence_length + sequence_length,
// head_size)
std::vector<std::size_t> present_lens{2, batch_size, num_heads, sequence_length, head_size}; std::vector<std::size_t> present_lens{2, batch_size, num_heads, sequence_length, head_size};
if (is_past) if(is_past)
{ {
auto past_lens = past->get_shape().lens(); auto past_lens = past->get_shape().lens();
past_sequence_length = past_lens.at(3); past_sequence_length = past_lens.at(3);
...@@ -76,24 +76,34 @@ struct parse_attention : op_parser<parse_attention> ...@@ -76,24 +76,34 @@ struct parse_attention : op_parser<parse_attention>
auto bias_type = bias->get_shape().type(); auto bias_type = bias->get_shape().type();
std::vector<float> ones_vec(m, 1); std::vector<float> ones_vec(m, 1);
std::vector<std::size_t> ones_lens{1, m}; std::vector<std::size_t> ones_lens{1, m};
auto ones = info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec}); auto ones =
info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec});
bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias); bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias);
auto gemm_1 = info.add_instruction(migraphx::make_op("dot"), bias, ones/* info.make_contiguous(mb_bias), info.make_contiguous(ones) */); auto gemm_1 = info.add_instruction(
gemm_1 = info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1); migraphx::make_op("dot"),
bias,
ones /* info.make_contiguous(mb_bias), info.make_contiguous(ones) */);
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x B. gemm_1 =
/// Assume row-major => results(N, M) = 1 * input x weights + 1 x B ? info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);
auto input_sq = info.add_instruction(migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}), input);
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x
/// B. Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
auto input_sq = info.add_instruction(
migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input);
auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_sq, weights); auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_sq, weights);
auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2); auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2);
// LaunchAttentionKernel: // LaunchAttentionKernel:
// LaunchTransQkv // LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH // input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms = info.add_instruction(migraphx::make_op("reshape", {{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}), add_gemms); add_gemms = info.add_instruction(
migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms);
std::vector<std::size_t> qkv_perm{2, 0, 3, 1, 4}; std::vector<std::size_t> qkv_perm{2, 0, 3, 1, 4};
auto transqkv = info.add_instruction(migraphx::make_op("transpose", {{"permutation", qkv_perm}}), add_gemms); auto transqkv = info.add_instruction(
migraphx::make_op("transpose", {{"permutation", qkv_perm}}), add_gemms);
// now scratch3 has Q, K, V: each has size BxNxSxH // now scratch3 has Q, K, V: each has size BxNxSxH
// => transqkv has shape 3xBxNxSxH // => transqkv has shape 3xBxNxSxH
...@@ -101,20 +111,25 @@ struct parse_attention : op_parser<parse_attention> ...@@ -101,20 +111,25 @@ struct parse_attention : op_parser<parse_attention>
auto size_per_batch = sequence_length * head_size; auto size_per_batch = sequence_length * head_size;
auto total_size = batches * size_per_batch; auto total_size = batches * size_per_batch;
auto q_t = info.add_instruction(migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv); auto q_t = info.add_instruction(
auto k_t = info.add_instruction(migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transqkv); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv);
auto v_t = info.add_instruction(migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transqkv); auto k_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transqkv);
auto v_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transqkv);
q_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), q_t); q_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), q_t);
k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t); k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t);
v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t); v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t);
if (is_past) if(is_past)
{ {
k_t = info.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), past, k_t); k_t = info.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), past, k_t);
v_t = info.add_instruction(migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), k_t); v_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), k_t);
} }
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length. // Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
// sequence length.
auto mask_index_lens = mask_index->get_shape().lens(); auto mask_index_lens = mask_index->get_shape().lens();
bool use_raw_attention_mask = mask_index_lens.size() >= 2; bool use_raw_attention_mask = mask_index_lens.size() >= 2;
...@@ -127,27 +142,23 @@ struct parse_attention : op_parser<parse_attention> ...@@ -127,27 +142,23 @@ struct parse_attention : op_parser<parse_attention>
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation. // For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
const float alpha = use_raw_attention_mask ? 1.0 : rsqrt_head_size; const float alpha = use_raw_attention_mask ? 1.0 : rsqrt_head_size;
// K{B,N,S,H} -> K'{B,N,H,S} // K{B,N,S,H} -> K'{B,N,H,S}
k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t); k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t);
auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t); auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t);
if (is_extra_add_qk) if(is_extra_add_qk)
gemm3 = info.add_instruction(make_op("add"), gemm3, extra_add_qk); gemm3 = info.add_instruction(make_op("add"), gemm3, extra_add_qk);
auto alpha_lit = info.add_instruction( auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}), migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{gemm3->get_shape().type()}, {alpha}})); info.add_literal(
gemm3 = info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit)); migraphx::literal{migraphx::shape{gemm3->get_shape().type()}, {alpha}}));
gemm3 =
info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit));
// apply softmax and store result P to scratch2: BxNxSxS* // apply softmax and store result P to scratch2: BxNxSxS*
std::vector<float> mask(batch_size*num_heads*sequence_length*all_sequence_length, 0); std::vector<float> mask(batch_size * num_heads * sequence_length * all_sequence_length, 0);
if (false and mask_index_lens.size() >= 2) if(false and mask_index_lens.size() >= 2) {}
else if(false and mask_index_lens.size() == 1)
{ {
}
else if (false and mask_index_lens.size() == 1)
{
} }
// else => no mask // else => no mask
auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3); auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3);
...@@ -156,8 +167,11 @@ struct parse_attention : op_parser<parse_attention> ...@@ -156,8 +167,11 @@ struct parse_attention : op_parser<parse_attention>
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t); auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// scratch3 is BxNxSxH, transpose to output BxSxNxH // scratch3 is BxNxSxH, transpose to output BxSxNxH
gemm4 = info.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4); gemm4 = info.add_instruction(
gemm4 = info.add_instruction(make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}), info.make_contiguous(gemm4)); migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4);
gemm4 = info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
info.make_contiguous(gemm4));
return gemm4; return gemm4;
} }
}; };
......
...@@ -27,9 +27,10 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -27,9 +27,10 @@ struct parse_layernorm : op_parser<parse_layernorm>
epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>(); epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
} }
auto layernorm = info.add_instruction(make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front()); auto layernorm = info.add_instruction(
make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front());
if (args.size() == 3) if(args.size() == 3)
{ {
layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1)); layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1));
layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2)); layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2));
......
...@@ -12,7 +12,8 @@ namespace device { ...@@ -12,7 +12,8 @@ namespace device {
void layernorm(hipStream_t stream, const argument& result, const argument& arg1); void layernorm(hipStream_t stream, const argument& result, const argument& arg1);
//void layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3, const int64_t axis); // void layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument&
// arg2, const argument& arg3, const int64_t axis);
void triadd_layernorm(hipStream_t stream, void triadd_layernorm(hipStream_t stream,
const argument& result, const argument& result,
......
...@@ -19,7 +19,8 @@ argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<ar ...@@ -19,7 +19,8 @@ argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<ar
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name()); auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2], tuned_axis); device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2],
tuned_axis);
} }
else */ else */
std::cout << "calling device::ln" << std::endl; std::cout << "calling device::ln" << std::endl;
......
...@@ -389,7 +389,7 @@ struct miopen_apply ...@@ -389,7 +389,7 @@ struct miopen_apply
apply_map.emplace(op_name, [=](instruction_ref ins) { apply_map.emplace(op_name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
if (op_name == "layernorm") if(op_name == "layernorm")
{ {
std::cout << "layernorm op" << std::endl; std::cout << "layernorm op" << std::endl;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment