"example/vscode:/vscode.git/clone" did not exist on "cef3d91fb0ac81b671abb5e94e9268d2a80cbc08"
Commit 7757cfd0 authored by turneram's avatar turneram
Browse files

Remove non-inference portions of parse_attention

parent 5a62e9e7
...@@ -11,32 +11,18 @@ struct parse_attention : op_parser<parse_attention> ...@@ -11,32 +11,18 @@ struct parse_attention : op_parser<parse_attention>
{ {
std::vector<op_desc> operators() const { return {{"Attention"}}; } std::vector<op_desc> operators() const { return {{"Attention"}}; }
instruction_ref parse(const op_desc& /*opd*/, instruction_ref parse(const op_desc& /* opd */,
const onnx_parser& parser, const onnx_parser& /* parser */,
onnx_parser::node_info info, onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const const std::vector<instruction_ref>& args) const
{ {
auto input = args[0]; auto input = args[0];
auto weights = args[1]; auto weights = args[1];
auto bias = args[2]; auto bias = args[2];
auto mask_index = args[3]; // mask_index = args[3];
// Raw attention mask is 2d (BxS) and all 1s for BERT-base and BERT-large inference
instruction_ref past;
instruction_ref extra_add_qk; // BERT-base default is 12, BERT-large default is 16
bool is_past = false;
bool is_extra_add_qk = false;
if(args.size() > 4)
{
past = args[4];
is_past = true;
}
if(args.size() == 6)
{
is_extra_add_qk = true;
extra_add_qk = args[5];
}
// ORT default is 12
std::size_t num_heads = 12; std::size_t num_heads = 12;
if(contains(info.attributes, "num_heads")) if(contains(info.attributes, "num_heads"))
num_heads = info.attributes.at("num_heads").i(); num_heads = info.attributes.at("num_heads").i();
...@@ -47,30 +33,14 @@ struct parse_attention : op_parser<parse_attention> ...@@ -47,30 +33,14 @@ struct parse_attention : op_parser<parse_attention>
auto sequence_length = input_lens.at(1); auto sequence_length = input_lens.at(1);
auto input_hidden_size = input_lens.at(2); auto input_hidden_size = input_lens.at(2);
// bias shape: (3 * hidden_size) // bias shape= (3 * hidden_size)
auto bias_lens = bias->get_shape().lens(); auto bias_lens = bias->get_shape().lens();
auto hidden_size = bias_lens.at(0) / 3; auto hidden_size = bias_lens.at(0) / 3;
auto head_size = hidden_size / num_heads; auto head_size = hidden_size / num_heads;
int past_sequence_length = 0;
// GetPresent
// Input and output shapes:
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length,
// head_size)
std::vector<std::size_t> present_lens{2, batch_size, num_heads, sequence_length, head_size};
if(is_past)
{
auto past_lens = past->get_shape().lens();
past_sequence_length = past_lens.at(3);
present_lens[3] += past_lens[3];
}
// Use GEMM for fully connection. // Use GEMM for fully connection.
auto m = batch_size * sequence_length; auto m = batch_size * sequence_length;
auto n = bias_lens.front(); auto n = bias_lens.front();
auto k = input_hidden_size;
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B. // Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
auto bias_type = bias->get_shape().type(); auto bias_type = bias->get_shape().type();
...@@ -83,27 +53,21 @@ struct parse_attention : op_parser<parse_attention> ...@@ -83,27 +53,21 @@ struct parse_attention : op_parser<parse_attention>
gemm_1 = gemm_1 =
info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1); info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);
/// Use row-major => results(N, M) = 1 * input x weights + 1 x B /// results(N, M) = 1 * input x weights + 1 x B
auto input_sq = info.add_instruction( auto input_rs = info.add_instruction(
migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}), migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input); input);
auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_sq, weights); auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_rs, weights);
auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2); auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2);
// LaunchTransQkv // LaunchTransQkv: BxSx3xNxH => 3xBxNxSxH
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms = info.add_instruction( add_gemms = info.add_instruction(
migraphx::make_op("reshape", migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}), {{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms); add_gemms);
auto transqkv = info.add_instruction(migraphx::make_op("transposeqkv"), add_gemms); auto transqkv = info.add_instruction(migraphx::make_op("transposeqkv"), add_gemms);
// transqkv has shape 3xBxNxSxH // Q, K, V: each has size BxNxSxH
// => Q, K, V: each has size BxNxSxH
auto batches = batch_size * num_heads;
auto size_per_batch = sequence_length * head_size;
auto total_size = batches * size_per_batch;
auto q_t = info.add_instruction( auto q_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv); migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv);
auto k_t = info.add_instruction( auto k_t = info.add_instruction(
...@@ -114,32 +78,12 @@ struct parse_attention : op_parser<parse_attention> ...@@ -114,32 +78,12 @@ struct parse_attention : op_parser<parse_attention>
k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t); k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t);
v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t); v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t);
if(is_past) // compute Q*K' scaled by 1/sqrt(H)
{ // Q: BxNxSxH, K (present_k): BxNxSxH => Q*K': BxNxSxS
k_t = info.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), past, k_t); const float alpha = 1.f / sqrt(static_cast<float>(head_size));
v_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), k_t);
}
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
// sequence length.
auto mask_index_lens = mask_index->get_shape().lens();
bool use_raw_attention_mask = mask_index_lens.size() >= 2;
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
const int all_sequence_length = past_sequence_length + sequence_length;
const int temp_matrix_size = sequence_length * all_sequence_length;
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
const float alpha = use_raw_attention_mask ? 1.0 : rsqrt_head_size;
// K{B,N,S,H} -> K'{B,N,H,S} // K{B,N,S,H} -> K'{B,N,H,S}
k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t); k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t);
auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t); auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t);
if(is_extra_add_qk)
gemm3 = info.add_instruction(make_op("add"), gemm3, extra_add_qk);
auto alpha_lit = info.add_instruction( auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}), migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}),
info.add_literal( info.add_literal(
...@@ -147,19 +91,18 @@ struct parse_attention : op_parser<parse_attention> ...@@ -147,19 +91,18 @@ struct parse_attention : op_parser<parse_attention>
gemm3 = gemm3 =
info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit)); info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit));
// apply softmax and store result P to scratch2: BxNxSxS*
// Inference mask is all 1s => masking can be skipped // Inference mask is all 1s => masking can be skipped
// P = softmax result: BxNxSxS
auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3); auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3);
// compute P*V // compute P*V: (BxNxSxS) x (BxNxSxH) => BxNxSxH
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t); auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H // result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxHiddenSize
gemm4 = info.add_instruction(migraphx::make_op("transposectx"), gemm4); gemm4 = info.add_instruction(migraphx::make_op("transposectx"), gemm4);
gemm4 = info.add_instruction( return info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}), make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
info.make_contiguous(gemm4)); gemm4);
return gemm4;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment