Commit e2bbfca1 authored by turneram's avatar turneram
Browse files

Formatting

parent 6e67ccad
...@@ -87,8 +87,7 @@ struct layernorm ...@@ -87,8 +87,7 @@ struct layernorm
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon); mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i) for(std::size_t i = 0; i < norm_size; ++i)
{ {
output[offset + i] = output[offset + i] = (data[offset + i] - mean) / mean_square;
(data[offset + i] - mean) / mean_square;
/* if(args.size() == 3) /* if(args.size() == 3)
output[offset + i] = output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i]; (data[offset + i] - mean) / mean_square * weights[i] + bias[i];
......
...@@ -15,25 +15,25 @@ struct parse_attention : op_parser<parse_attention> ...@@ -15,25 +15,25 @@ struct parse_attention : op_parser<parse_attention>
const onnx_parser& parser, const onnx_parser& parser,
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
auto input = args[0]; auto input = args[0];
auto weights = args[1]; auto weights = args[1];
auto bias = args[2]; auto bias = args[2];
auto mask_index = args[3]; auto mask_index = args[3];
instruction_ref past; instruction_ref past;
instruction_ref extra_add_qk; instruction_ref extra_add_qk;
bool is_past = false; bool is_past = false;
bool is_extra_add_qk = false; bool is_extra_add_qk = false;
if (args.size() > 4) if(args.size() > 4)
{ {
past = args[4]; past = args[4];
is_past = true; is_past = true;
} }
if (args.size() == 6) if(args.size() == 6)
{ {
is_extra_add_qk = true; is_extra_add_qk = true;
extra_add_qk = args[5]; extra_add_qk = args[5];
} }
// ORT default is 12 // ORT default is 12
...@@ -42,112 +42,123 @@ struct parse_attention : op_parser<parse_attention> ...@@ -42,112 +42,123 @@ struct parse_attention : op_parser<parse_attention>
num_heads = info.attributes.at("num_heads").i(); num_heads = info.attributes.at("num_heads").i();
// input shape: (batch_size, sequence_length, input_hidden_size) // input shape: (batch_size, sequence_length, input_hidden_size)
auto input_lens = input->get_shape().lens(); auto input_lens = input->get_shape().lens();
auto batch_size = input_lens.at(0); auto batch_size = input_lens.at(0);
auto sequence_length = input_lens.at(1); auto sequence_length = input_lens.at(1);
auto input_hidden_size = input_lens.at(2); auto input_hidden_size = input_lens.at(2);
// bias shape: (3 * hidden_size) // bias shape: (3 * hidden_size)
auto bias_lens = bias->get_shape().lens(); auto bias_lens = bias->get_shape().lens();
auto hidden_size = bias_lens.at(0) / 3; auto hidden_size = bias_lens.at(0) / 3;
auto head_size = hidden_size / num_heads; auto head_size = hidden_size / num_heads;
int past_sequence_length = 0; int past_sequence_length = 0;
// GetPresent // GetPresent
// Input and output shapes: // Input and output shapes:
// past : (2, batch_size, num_heads, past_sequence_length, head_size) // past : (2, batch_size, num_heads, past_sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size) // present : (2, batch_size, num_heads, past_sequence_length + sequence_length,
// head_size)
std::vector<std::size_t> present_lens{2, batch_size, num_heads, sequence_length, head_size}; std::vector<std::size_t> present_lens{2, batch_size, num_heads, sequence_length, head_size};
if (is_past) if(is_past)
{ {
auto past_lens = past->get_shape().lens(); auto past_lens = past->get_shape().lens();
past_sequence_length = past_lens.at(3); past_sequence_length = past_lens.at(3);
present_lens[3] += past_lens[3]; present_lens[3] += past_lens[3];
} }
// Use GEMM for fully connection. // Use GEMM for fully connection.
auto m = batch_size * sequence_length; auto m = batch_size * sequence_length;
auto n = bias_lens.front(); auto n = bias_lens.front();
auto k = input_hidden_size; auto k = input_hidden_size;
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B. // Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
auto bias_type = bias->get_shape().type(); auto bias_type = bias->get_shape().type();
std::vector<float> ones_vec(m, 1); std::vector<float> ones_vec(m, 1);
std::vector<std::size_t> ones_lens{1, m}; std::vector<std::size_t> ones_lens{1, m};
auto ones = info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec}); auto ones =
bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias); info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec});
auto gemm_1 = info.add_instruction(migraphx::make_op("dot"), bias, ones/* info.make_contiguous(mb_bias), info.make_contiguous(ones) */); bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias);
gemm_1 = info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1); auto gemm_1 = info.add_instruction(
migraphx::make_op("dot"),
bias,
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x B. ones /* info.make_contiguous(mb_bias), info.make_contiguous(ones) */);
/// Assume row-major => results(N, M) = 1 * input x weights + 1 x B ? gemm_1 =
auto input_sq = info.add_instruction(migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}), input); info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);
auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_sq, weights);
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x
/// B. Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
auto input_sq = info.add_instruction(
migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input);
auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_sq, weights);
auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2); auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2);
// LaunchAttentionKernel: // LaunchAttentionKernel:
// LaunchTransQkv // LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH // input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms = info.add_instruction(migraphx::make_op("reshape", {{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}), add_gemms); add_gemms = info.add_instruction(
migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms);
std::vector<std::size_t> qkv_perm{2, 0, 3, 1, 4}; std::vector<std::size_t> qkv_perm{2, 0, 3, 1, 4};
auto transqkv = info.add_instruction(migraphx::make_op("transpose", {{"permutation", qkv_perm}}), add_gemms); auto transqkv = info.add_instruction(
migraphx::make_op("transpose", {{"permutation", qkv_perm}}), add_gemms);
// now scratch3 has Q, K, V: each has size BxNxSxH // now scratch3 has Q, K, V: each has size BxNxSxH
// => transqkv has shape 3xBxNxSxH // => transqkv has shape 3xBxNxSxH
auto batches = batch_size * num_heads; auto batches = batch_size * num_heads;
auto size_per_batch = sequence_length * head_size; auto size_per_batch = sequence_length * head_size;
auto total_size = batches * size_per_batch; auto total_size = batches * size_per_batch;
auto q_t = info.add_instruction(migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv); auto q_t = info.add_instruction(
auto k_t = info.add_instruction(migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transqkv); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv);
auto v_t = info.add_instruction(migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transqkv); auto k_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transqkv);
auto v_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transqkv);
q_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), q_t); q_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), q_t);
k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t); k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t);
v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t); v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t);
if (is_past) if(is_past)
{ {
k_t = info.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), past, k_t); k_t = info.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), past, k_t);
v_t = info.add_instruction(migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), k_t); v_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), k_t);
} }
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length. // Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
auto mask_index_lens = mask_index->get_shape().lens(); // sequence length.
auto mask_index_lens = mask_index->get_shape().lens();
bool use_raw_attention_mask = mask_index_lens.size() >= 2; bool use_raw_attention_mask = mask_index_lens.size() >= 2;
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size)); const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
const int all_sequence_length = past_sequence_length + sequence_length; const int all_sequence_length = past_sequence_length + sequence_length;
const int temp_matrix_size = sequence_length * all_sequence_length; const int temp_matrix_size = sequence_length * all_sequence_length;
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation. // For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
const float alpha = use_raw_attention_mask ? 1.0 : rsqrt_head_size; const float alpha = use_raw_attention_mask ? 1.0 : rsqrt_head_size;
// K{B,N,S,H} -> K'{B,N,H,S} // K{B,N,S,H} -> K'{B,N,H,S}
k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t); k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t);
auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t); auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t);
if (is_extra_add_qk) if(is_extra_add_qk)
gemm3 = info.add_instruction(make_op("add"), gemm3, extra_add_qk); gemm3 = info.add_instruction(make_op("add"), gemm3, extra_add_qk);
auto alpha_lit = info.add_instruction( auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}), migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}),
info.add_literal(migraphx::literal{migraphx::shape{gemm3->get_shape().type()}, {alpha}})); info.add_literal(
gemm3 = info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit)); migraphx::literal{migraphx::shape{gemm3->get_shape().type()}, {alpha}}));
gemm3 =
info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit));
// apply softmax and store result P to scratch2: BxNxSxS* // apply softmax and store result P to scratch2: BxNxSxS*
std::vector<float> mask(batch_size*num_heads*sequence_length*all_sequence_length, 0); std::vector<float> mask(batch_size * num_heads * sequence_length * all_sequence_length, 0);
if (false and mask_index_lens.size() >= 2) if(false and mask_index_lens.size() >= 2) {}
{ else if(false and mask_index_lens.size() == 1)
}
else if (false and mask_index_lens.size() == 1)
{ {
} }
// else => no mask // else => no mask
auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3); auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3);
...@@ -156,8 +167,11 @@ struct parse_attention : op_parser<parse_attention> ...@@ -156,8 +167,11 @@ struct parse_attention : op_parser<parse_attention>
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t); auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// scratch3 is BxNxSxH, transpose to output BxSxNxH // scratch3 is BxNxSxH, transpose to output BxSxNxH
gemm4 = info.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4); gemm4 = info.add_instruction(
gemm4 = info.add_instruction(make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}), info.make_contiguous(gemm4)); migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4);
gemm4 = info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
info.make_contiguous(gemm4));
return gemm4; return gemm4;
} }
}; };
......
...@@ -17,7 +17,7 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -17,7 +17,7 @@ struct parse_layernorm : op_parser<parse_layernorm>
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
float epsilon = 1e-3f; float epsilon = 1e-3f;
int64_t axis = -1; int64_t axis = -1;
if(contains(info.attributes, "epsilon")) if(contains(info.attributes, "epsilon"))
{ {
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>(); epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
...@@ -27,9 +27,10 @@ struct parse_layernorm : op_parser<parse_layernorm> ...@@ -27,9 +27,10 @@ struct parse_layernorm : op_parser<parse_layernorm>
epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>(); epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
} }
auto layernorm = info.add_instruction(make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front()); auto layernorm = info.add_instruction(
make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front());
if (args.size() == 3)
if(args.size() == 3)
{ {
layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1)); layernorm = info.add_instruction(make_op("mul"), layernorm, args.at(1));
layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2)); layernorm = info.add_instruction(make_op("add"), layernorm, args.at(2));
......
...@@ -12,7 +12,8 @@ namespace device { ...@@ -12,7 +12,8 @@ namespace device {
void layernorm(hipStream_t stream, const argument& result, const argument& arg1); void layernorm(hipStream_t stream, const argument& result, const argument& arg1);
//void layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument& arg2, const argument& arg3, const int64_t axis); // void layernorm(hipStream_t stream, const argument& result, const argument& arg1, const argument&
// arg2, const argument& arg3, const int64_t axis);
void triadd_layernorm(hipStream_t stream, void triadd_layernorm(hipStream_t stream,
const argument& result, const argument& result,
......
...@@ -19,12 +19,13 @@ argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<ar ...@@ -19,12 +19,13 @@ argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<ar
{ {
auto n_dim = args.front().get_shape().lens().size(); auto n_dim = args.front().get_shape().lens().size();
auto tuned_axis = tune_axis(n_dim, op.axis, op.name()); auto tuned_axis = tune_axis(n_dim, op.axis, op.name());
device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2], tuned_axis); device::layernorm(ctx.get_stream().get(), args.back(), args[0], args[1], args[2],
tuned_axis);
} }
else */ else */
std::cout << "calling device::ln" << std::endl; std::cout << "calling device::ln" << std::endl;
{ {
device::layernorm(ctx.get_stream().get(), args.back(), args[0]); device::layernorm(ctx.get_stream().get(), args.back(), args[0]);
std::cout << "called device::ln" << std::endl; std::cout << "called device::ln" << std::endl;
} }
......
...@@ -394,7 +394,7 @@ struct miopen_apply ...@@ -394,7 +394,7 @@ struct miopen_apply
apply_map.emplace(op_name, [=](instruction_ref ins) { apply_map.emplace(op_name, [=](instruction_ref ins) {
auto output = insert_allocation(ins, ins->get_shape()); auto output = insert_allocation(ins, ins->get_shape());
std::vector<instruction_ref> refs = ins->inputs(); std::vector<instruction_ref> refs = ins->inputs();
if (op_name == "layernorm") if(op_name == "layernorm")
{ {
std::cout << "layernorm op" << std::endl; std::cout << "layernorm op" << std::endl;
} }
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment