Commit 1974671d authored by turneram's avatar turneram
Browse files

Formatting

parent fe9a42f1
...@@ -48,15 +48,15 @@ struct transposectx ...@@ -48,15 +48,15 @@ struct transposectx
int num_heads = lens.at(1); int num_heads = lens.at(1);
int sequence_length = lens.at(2); int sequence_length = lens.at(2);
int head_size = lens.back(); int head_size = lens.back();
const int NH = num_heads * head_size; const int NH = num_heads * head_size;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
//const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS; // const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
const int out_offset = n * head_size + s * NH + b * NHS; const int out_offset = n * head_size + s * NH + b * NHS;
const int j = idx.back(); const int j = idx.back();
output[out_offset + j] = input[i]; output[out_offset + j] = input[i];
}); });
}); });
......
...@@ -33,7 +33,7 @@ struct transposeqkv ...@@ -33,7 +33,7 @@ struct transposeqkv
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
// Input: BxSxKxNxH // Input: BxSxKxNxH
// Output: KxBxNxSxH // Output: KxBxNxSxH
// K is the number of identical matrix // K is the number of identical matrix
...@@ -53,12 +53,13 @@ struct transposeqkv ...@@ -53,12 +53,13 @@ struct transposeqkv
const int num_heads = lens[3]; const int num_heads = lens[3];
const int sequence_length = lens[1]; const int sequence_length = lens[1];
const int batch_size = lens[0]; const int batch_size = lens[0];
const int H = lens.back(); const int H = lens.back();
const int NH = num_heads * H; const int NH = num_heads * H;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size; const int out_offset =
s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
output[out_offset + j] = input[i]; output[out_offset + j] = input[i];
}); });
......
...@@ -79,14 +79,11 @@ struct parse_attention : op_parser<parse_attention> ...@@ -79,14 +79,11 @@ struct parse_attention : op_parser<parse_attention>
auto ones = auto ones =
info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec}); info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec});
bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias); bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias);
auto gemm_1 = info.add_instruction( auto gemm_1 = info.add_instruction(migraphx::make_op("dot"), bias, ones);
migraphx::make_op("dot"),
bias,
ones);
gemm_1 = gemm_1 =
info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1); info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);
/// Use row-major => results(N, M) = 1 * input x weights + 1 x B /// Use row-major => results(N, M) = 1 * input x weights + 1 x B
auto input_sq = info.add_instruction( auto input_sq = info.add_instruction(
migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}), migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input); input);
...@@ -99,8 +96,7 @@ struct parse_attention : op_parser<parse_attention> ...@@ -99,8 +96,7 @@ struct parse_attention : op_parser<parse_attention>
migraphx::make_op("reshape", migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}), {{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms); add_gemms);
auto transqkv = info.add_instruction( auto transqkv = info.add_instruction(migraphx::make_op("transposeqkv"), add_gemms);
migraphx::make_op("transposeqkv"), add_gemms);
// transqkv has shape 3xBxNxSxH // transqkv has shape 3xBxNxSxH
// => Q, K, V: each has size BxNxSxH // => Q, K, V: each has size BxNxSxH
...@@ -155,7 +151,7 @@ struct parse_attention : op_parser<parse_attention> ...@@ -155,7 +151,7 @@ struct parse_attention : op_parser<parse_attention>
// Inference mask is all 1s => masking can be skipped // Inference mask is all 1s => masking can be skipped
auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3); auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3);
// compute P*V // compute P*V
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t); auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H // result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H
......
...@@ -40,7 +40,8 @@ struct transposectx_compiler : compiler<transposectx_compiler> ...@@ -40,7 +40,8 @@ struct transposectx_compiler : compiler<transposectx_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back()); options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back());
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "transposectx_kernel"; options.kernel_name = "transposectx_kernel";
...@@ -78,7 +79,8 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler> ...@@ -78,7 +79,8 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back()); options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back());
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "transposeqkv_kernel"; options.kernel_name = "transposeqkv_kernel";
......
...@@ -11,12 +11,12 @@ __device__ void transposectx(const T& input_t, const U& output_t) ...@@ -11,12 +11,12 @@ __device__ void transposectx(const T& input_t, const U& output_t)
{ {
// Input: BxNxSxH // Input: BxNxSxH
// Output: BxSxNxH // Output: BxSxNxH
auto index = make_index(); auto index = make_index();
auto input_shape = input_t.get_shape(); auto input_shape = input_t.get_shape();
auto lens = input_shape.lens; auto lens = input_shape.lens;
const int num_heads = lens[1]; const int num_heads = lens[1];
const int sequence_length = lens[2]; const int sequence_length = lens[2];
int head_size = lens[3]; int head_size = lens[3];
auto idx = input_shape.multi(index.global); auto idx = input_shape.multi(index.global);
...@@ -24,11 +24,11 @@ __device__ void transposectx(const T& input_t, const U& output_t) ...@@ -24,11 +24,11 @@ __device__ void transposectx(const T& input_t, const U& output_t)
int s = idx[2]; int s = idx[2];
int b = idx[0]; int b = idx[0];
const int NH = num_heads * head_size; const int NH = num_heads * head_size;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
const int out_offset = n * head_size + s * NH + b * NHS; const int out_offset = n * head_size + s * NH + b * NHS;
if (index.local < 1024) if(index.local < 1024)
output_t[out_offset + idx[3]] = input_t[index.global]; output_t[out_offset + idx[3]] = input_t[index.global];
} }
......
...@@ -13,9 +13,9 @@ __device__ void transposeqkv(const T& input_t, const U& output_t) ...@@ -13,9 +13,9 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
// Output: KxBxNxSxH // Output: KxBxNxSxH
// K is the number of identical matrix // K is the number of identical matrix
auto index = make_index(); auto index = make_index();
auto input_shape = input_t.get_shape(); auto input_shape = input_t.get_shape();
auto lens = input_shape.lens; auto lens = input_shape.lens;
auto idx = input_shape.multi(index.global); auto idx = input_shape.multi(index.global);
...@@ -23,14 +23,14 @@ __device__ void transposeqkv(const T& input_t, const U& output_t) ...@@ -23,14 +23,14 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
const int s = idx[1]; const int s = idx[1];
const int m = idx[2]; const int m = idx[2];
const int n = idx[3]; const int n = idx[3];
//const int j = idx[4]; // const int j = idx[4];
const int num_heads = lens[3]; const int num_heads = lens[3];
const int sequence_length = lens[1]; const int sequence_length = lens[1];
const int batch_size = lens[0]; const int batch_size = lens[0];
const int H = lens[4]; const int H = lens[4];
const int NH = num_heads * H; const int NH = num_heads * H;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size; const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
......
...@@ -671,12 +671,12 @@ TEST_CASE(bert_transpose_ops_test) ...@@ -671,12 +671,12 @@ TEST_CASE(bert_transpose_ops_test)
{ {
// transposeQKV // transposeQKV
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
const int k = 3, b = 1, n = 2, s = 2, h = 1; const int k = 3, b = 1, n = 2, s = 2, h = 1;
migraphx::shape sh{migraphx::shape::float_type, {b, s, k, n, h}}; migraphx::shape sh{migraphx::shape::float_type, {b, s, k, n, h}};
std::vector<float> data(b * s * k * n * h); std::vector<float> data(b * s * k * n * h);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data}); auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposeqkv"), l1); mm->add_instruction(migraphx::make_op("transposeqkv"), l1);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -686,14 +686,15 @@ TEST_CASE(bert_transpose_ops_test) ...@@ -686,14 +686,15 @@ TEST_CASE(bert_transpose_ops_test)
migraphx::program p2; migraphx::program p2;
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data}); auto l2 = mm2->add_literal(migraphx::literal{sh, data});
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), l2); mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}),
l2);
p2.compile(migraphx::ref::target{}); p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back(); auto result2 = p2.eval({}).back();
std::vector<float> result_vector2(k * b * n * s * h); std::vector<float> result_vector2(k * b * n * s * h);
result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); });
for (auto& i : result_vector2) for(auto& i : result_vector2)
std::cout << i << ", "; std::cout << i << ", ";
std::cout << std::endl; std::cout << std::endl;
...@@ -706,7 +707,7 @@ TEST_CASE(bert_transpose_ops_test) ...@@ -706,7 +707,7 @@ TEST_CASE(bert_transpose_ops_test)
migraphx::shape s{migraphx::shape::float_type, {2, 2, 2, 2}}; migraphx::shape s{migraphx::shape::float_type, {2, 2, 2, 2}};
std::vector<float> data(2 * 2 * 2 * 2); std::vector<float> data(2 * 2 * 2 * 2);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{s, data}); auto l1 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1); mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
...@@ -719,18 +720,20 @@ TEST_CASE(bert_transpose_ops_test) ...@@ -719,18 +720,20 @@ TEST_CASE(bert_transpose_ops_test)
{ {
// transposeCtx // transposeCtx
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
const int b = 2, n = 2, s = 3, h = 4; const int b = 2, n = 2, s = 3, h = 4;
migraphx::shape sh{migraphx::shape::float_type, {b, n, s, h}}; migraphx::shape sh{migraphx::shape::float_type, {b, n, s, h}};
std::vector<float> data(b * n * s * h); std::vector<float> data(b * n * s * h);
std::iota(data.begin(), data.end(), 0); std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data}); auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1); mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{}); p.compile(migraphx::ref::target{});
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> result_vector(b * n * s * h); std::vector<float> result_vector(b * n * s * h);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 36, 37, 38, 39, 28, 29, 30, 31, 40, 41, 42, 43, 32, 33, 34, 35, 44, 45, 46, 47}; std::vector<float> gold{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19,
8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 36, 37, 38, 39,
28, 29, 30, 31, 40, 41, 42, 43, 32, 33, 34, 35, 44, 45, 46, 47};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment