"docs/en_US/networkmorphismTuner.md" did not exist on "efa479b0a91ff4c2b8a0f4cbb58c4f48a3747d47"
Commit 1974671d authored by turneram's avatar turneram
Browse files

Formatting

parent fe9a42f1
...@@ -52,7 +52,7 @@ struct transposectx ...@@ -52,7 +52,7 @@ struct transposectx
const int NH = num_heads * head_size; const int NH = num_heads * head_size;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
//const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS; // const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
const int out_offset = n * head_size + s * NH + b * NHS; const int out_offset = n * head_size + s * NH + b * NHS;
......
...@@ -58,7 +58,8 @@ struct transposeqkv ...@@ -58,7 +58,8 @@ struct transposeqkv
const int NH = num_heads * H; const int NH = num_heads * H;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size; const int out_offset =
s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
output[out_offset + j] = input[i]; output[out_offset + j] = input[i];
}); });
......
...@@ -79,10 +79,7 @@ struct parse_attention : op_parser<parse_attention> ...@@ -79,10 +79,7 @@ struct parse_attention : op_parser<parse_attention>
auto ones = auto ones =
info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec}); info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec});
bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias); bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias);
auto gemm_1 = info.add_instruction( auto gemm_1 = info.add_instruction(migraphx::make_op("dot"), bias, ones);
migraphx::make_op("dot"),
bias,
ones);
gemm_1 = gemm_1 =
info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1); info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);
...@@ -99,8 +96,7 @@ struct parse_attention : op_parser<parse_attention> ...@@ -99,8 +96,7 @@ struct parse_attention : op_parser<parse_attention>
migraphx::make_op("reshape", migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}), {{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms); add_gemms);
auto transqkv = info.add_instruction( auto transqkv = info.add_instruction(migraphx::make_op("transposeqkv"), add_gemms);
migraphx::make_op("transposeqkv"), add_gemms);
// transqkv has shape 3xBxNxSxH // transqkv has shape 3xBxNxSxH
// => Q, K, V: each has size BxNxSxH // => Q, K, V: each has size BxNxSxH
......
...@@ -40,7 +40,8 @@ struct transposectx_compiler : compiler<transposectx_compiler> ...@@ -40,7 +40,8 @@ struct transposectx_compiler : compiler<transposectx_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back()); options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back());
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "transposectx_kernel"; options.kernel_name = "transposectx_kernel";
...@@ -78,7 +79,8 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler> ...@@ -78,7 +79,8 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back()); options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back());
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "transposeqkv_kernel"; options.kernel_name = "transposeqkv_kernel";
......
...@@ -28,7 +28,7 @@ __device__ void transposectx(const T& input_t, const U& output_t) ...@@ -28,7 +28,7 @@ __device__ void transposectx(const T& input_t, const U& output_t)
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
const int out_offset = n * head_size + s * NH + b * NHS; const int out_offset = n * head_size + s * NH + b * NHS;
if (index.local < 1024) if(index.local < 1024)
output_t[out_offset + idx[3]] = input_t[index.global]; output_t[out_offset + idx[3]] = input_t[index.global];
} }
......
...@@ -23,7 +23,7 @@ __device__ void transposeqkv(const T& input_t, const U& output_t) ...@@ -23,7 +23,7 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
const int s = idx[1]; const int s = idx[1];
const int m = idx[2]; const int m = idx[2];
const int n = idx[3]; const int n = idx[3];
//const int j = idx[4]; // const int j = idx[4];
const int num_heads = lens[3]; const int num_heads = lens[3];
const int sequence_length = lens[1]; const int sequence_length = lens[1];
......
...@@ -687,13 +687,14 @@ TEST_CASE(bert_transpose_ops_test) ...@@ -687,13 +687,14 @@ TEST_CASE(bert_transpose_ops_test)
migraphx::program p2; migraphx::program p2;
auto* mm2 = p2.get_main_module(); auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data}); auto l2 = mm2->add_literal(migraphx::literal{sh, data});
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), l2); mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}),
l2);
p2.compile(migraphx::ref::target{}); p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back(); auto result2 = p2.eval({}).back();
std::vector<float> result_vector2(k * b * n * s * h); std::vector<float> result_vector2(k * b * n * s * h);
result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); }); result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); });
for (auto& i : result_vector2) for(auto& i : result_vector2)
std::cout << i << ", "; std::cout << i << ", ";
std::cout << std::endl; std::cout << std::endl;
...@@ -730,7 +731,9 @@ TEST_CASE(bert_transpose_ops_test) ...@@ -730,7 +731,9 @@ TEST_CASE(bert_transpose_ops_test)
auto result = p.eval({}).back(); auto result = p.eval({}).back();
std::vector<float> result_vector(b * n * s * h); std::vector<float> result_vector(b * n * s * h);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 36, 37, 38, 39, 28, 29, 30, 31, 40, 41, 42, 43, 32, 33, 34, 35, 44, 45, 46, 47}; std::vector<float> gold{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19,
8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 36, 37, 38, 39,
28, 29, 30, 31, 40, 41, 42, 43, 32, 33, 34, 35, 44, 45, 46, 47};
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment