Commit b5d93611 authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/bert-attention-no-transpose-ops' into bert-perf2

parents 27dc554d 38163d54
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_attention : op_parser<parse_attention>
{
std::vector<op_desc> operators() const { return {{"Attention"}}; }
instruction_ref parse(const op_desc& /* opd */,
const onnx_parser& /* parser */,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
auto input = args[0];
auto weights = args[1];
auto bias = args[2];
// mask_index = args[3];
// Raw attention mask is 2d (BxS) and all 1s for BERT-base and BERT-large inference
// BERT-base default is 12, BERT-large default is 16
std::size_t num_heads = 12;
if(contains(info.attributes, "num_heads"))
num_heads = info.attributes.at("num_heads").i();
// input shape: (batch_size, sequence_length, input_hidden_size)
auto input_lens = input->get_shape().lens();
auto batch_size = input_lens.at(0);
auto sequence_length = input_lens.at(1);
// bias shape= (3 * hidden_size)
auto bias_lens = bias->get_shape().lens();
auto hidden_size = bias_lens.at(0) / 3;
auto head_size = hidden_size / num_heads;
// Use GEMM for fully connection.
auto m = batch_size * sequence_length;
auto n = bias_lens.front();
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
auto bias_type = bias->get_shape().type();
std::vector<float> ones_vec(m, 1);
std::vector<std::size_t> ones_lens{1, m};
auto ones =
info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec});
bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias);
auto gemm_1 = info.add_instruction(migraphx::make_op("dot"), bias, ones);
gemm_1 =
info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);
/// results(N, M) = 1 * input x weights + 1 x B
auto input_rs = info.add_instruction(
migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input);
auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_rs, weights);
auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2);
// LaunchTransQkv: BxSx3xNxH => 3xBxNxSxH
add_gemms = info.add_instruction(
migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms);
auto transqkv = info.add_instruction(
migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), add_gemms);
// Q, K, V: each has size BxNxSxH
auto q_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv);
auto k_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transqkv);
auto v_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transqkv);
q_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), q_t);
k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t);
v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t);
// compute Q*K' scaled by 1/sqrt(H)
// Q: BxNxSxH, K (present_k): BxNxSxH => Q*K': BxNxSxS
const float alpha = 1.f / sqrt(static_cast<float>(head_size));
// K{B,N,S,H} -> K'{B,N,H,S}
k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t);
auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t);
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}),
info.add_literal(
migraphx::literal{migraphx::shape{gemm3->get_shape().type()}, {alpha}}));
gemm3 =
info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit));
// Inference mask is all 1s => masking can be skipped
// P = softmax result: BxNxSxS
auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3);
// compute P*V: (BxNxSxS) x (BxNxSxH) => BxNxSxH
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxHiddenSize
// transposeCtx
gemm4 = info.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4);
return info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
info.make_contiguous(gemm4));
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_layernorm : op_parser<parse_layernorm>
{
std::vector<op_desc> operators() const { return {{"LayerNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
// un-fuse layernorm op so migraphx can handle fusion instead
auto x = args.front();
auto x_type = x->get_shape().type();
auto weights = args.at(1);
auto bias = args.at(2);
float epsilon = 1e-12f;
int64_t axis = -1;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
}
auto epsilon_lit = info.add_literal(literal{shape{x_type, {1}}, {epsilon}});
auto exponent = info.add_literal(literal{shape{x_type, {1}}, {2.0}});
auto dims = x->get_shape().lens();
auto mean = info.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), x);
auto mean_mbcast =
info.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", dims}}), mean);
auto sub = info.add_instruction(migraphx::make_op("sub"), x, mean_mbcast);
auto exponent_mbcast = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", dims}}), exponent);
auto pow = info.add_instruction(migraphx::make_op("pow"), sub, exponent_mbcast);
auto var = info.add_instruction(migraphx::make_op("reduce_mean", {{"axes", {axis}}}), pow);
auto add_epsilon = info.add_broadcastable_binary_op("add", var, epsilon_lit);
auto sqrt = info.add_instruction(migraphx::make_op("sqrt"), add_epsilon);
auto div = info.add_broadcastable_binary_op("div", sub, sqrt);
auto mul = info.add_broadcastable_binary_op("mul", div, weights);
return info.add_broadcastable_binary_op("add", mul, bias);
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
attention_test:ö
T
input
weights
bias
mask_indexresult Attention_0" Attention*
num_heads  attention_testZ
input


€
€Z
weights

€
€Z
bias

€Z
mask_index


€b
result


€
€B
\ No newline at end of file
...@@ -210,6 +210,27 @@ def atanh_test(): ...@@ -210,6 +210,27 @@ def atanh_test():
return ([node], [x], [y]) return ([node], [x], [y])
@onnx_test
def attention_test():
input = helper.make_tensor_value_info('input', TensorProto.FLOAT,
[2, 384, 768])
weights = helper.make_tensor_value_info('weights', TensorProto.FLOAT,
[768, 2304])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [2304])
mask_index = helper.make_tensor_value_info('mask_index', TensorProto.INT64,
[2, 384])
result = helper.make_tensor_value_info('result', TensorProto.FLOAT,
[2, 384, 768])
node = helper.make_node('Attention',
inputs=['input', 'weights', 'bias', 'mask_index'],
outputs=['result'],
num_heads=12,
name="Attention_0")
return ([node], [input, weights, bias, mask_index], [result])
@onnx_test @onnx_test
def averagepool_1d_test(): def averagepool_1d_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5])
...@@ -2831,6 +2852,23 @@ def layernorm_test(): ...@@ -2831,6 +2852,23 @@ def layernorm_test():
bias_add], [x, scale, bias], [y], [pow_tensor, epsilon_tensor]) bias_add], [x, scale, bias], [y], [pow_tensor, epsilon_tensor])
@onnx_test
def layernorm_op_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 2, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [3])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [3])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[1, 2, 3])
node = onnx.helper.make_node('LayerNormalization',
inputs=['x', 'w', 'b'],
outputs=["output"],
epsilon=1e-5,
axis=-1)
return ([node], [x, w, b], [output])
@onnx_test @onnx_test
def leaky_relu_test(): def leaky_relu_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
layernorm_op_test:
N
x
w
boutput"LayerNormalization*
axis*
epsilon'7layernorm_op_testZ
x



Z
w

Z
b

b
output



B
\ No newline at end of file
...@@ -32,6 +32,32 @@ ...@@ -32,6 +32,32 @@
#include <migraphx/onnx.hpp> #include <migraphx/onnx.hpp>
#include "test.hpp" #include "test.hpp"
TEST_CASE(attention_test)
{
auto p = migraphx::parse_onnx("attention_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape s_i{migraphx::shape::float_type, {2, 384, 768}};
migraphx::shape s_w{migraphx::shape::float_type, {768, 2304}};
migraphx::shape s_b{migraphx::shape::float_type, {2304}};
migraphx::shape s_m{migraphx::shape::int64_type, {2, 384}};
std::vector<float> input_v(2 * 384 * 768, 1);
std::vector<float> weights_v(768 * 2304, 1);
std::vector<float> bias_v(2304, 1);
std::vector<int64_t> mask_index_v(2 * 384, 1);
migraphx::parameter_map pp;
pp["input"] = migraphx::argument(s_i, input_v.data());
pp["weights"] = migraphx::argument(s_w, weights_v.data());
pp["bias"] = migraphx::argument(s_b, bias_v.data());
pp["mask_index"] = migraphx::argument(s_m, mask_index_v.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(2 * 384 * 768, 769);
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(averagepool_notset_test) TEST_CASE(averagepool_notset_test)
{ {
auto p = migraphx::parse_onnx("averagepool_notset_test.onnx"); auto p = migraphx::parse_onnx("averagepool_notset_test.onnx");
...@@ -469,6 +495,31 @@ TEST_CASE(instance_norm_3d_test) ...@@ -469,6 +495,31 @@ TEST_CASE(instance_norm_3d_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(layernorm_op_test)
{
migraphx::program p = migraphx::parse_onnx("layernorm_op_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sx{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape swb{migraphx::shape::float_type, {3}};
std::vector<float> x_vec{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
std::vector<float> w_vec{1.0, 1.0, 1.0};
std::vector<float> b_vec{0.0, 0.0, 0.0};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(sx, x_vec.data());
pp["w"] = migraphx::argument(swb, w_vec.data());
pp["b"] = migraphx::argument(swb, b_vec.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector(6);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1.22474f, 0.0f, 1.22474f, -1.22474f, 0.0f, 1.22474f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(lessorequal_test) TEST_CASE(lessorequal_test)
{ {
migraphx::program p = migraphx::parse_onnx("lessorequal_test.onnx"); migraphx::program p = migraphx::parse_onnx("lessorequal_test.onnx");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment