Commit fe9a42f1 authored by turneram's avatar turneram
Browse files

Fix transpose kernels

parent 96663815
...@@ -20,15 +20,6 @@ namespace op { ...@@ -20,15 +20,6 @@ namespace op {
struct transposectx struct transposectx
{ {
int head_size = 64;
bool reversed_bs = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.head_size, "head_size"), f(self.reversed_bs, "reversed_bs"));
}
std::string name() const { return "transposectx"; } std::string name() const { return "transposectx"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -45,10 +36,28 @@ struct transposectx ...@@ -45,10 +36,28 @@ struct transposectx
// Input: BxNxSxH // Input: BxNxSxH
// Output: BxSxNxH // Output: BxSxNxH
argument result{output_shape}; argument result{output_shape};
auto in_s = args.front().get_shape();
auto lens = in_s.lens();
visit_all(result, args.front())([&](auto output, const auto input) { visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
// TODO: calculate in_offet and out_offset auto idx = in_s.multi(i);
output[i] = input[i];
int n = idx.at(1);
int s = idx.at(2);
int b = idx.front();
int num_heads = lens.at(1);
int sequence_length = lens.at(2);
int head_size = lens.back();
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
//const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
const int out_offset = n * head_size + s * NH + b * NHS;
const int j = idx.back();
output[out_offset + j] = input[i];
}); });
}); });
......
...@@ -20,15 +20,6 @@ namespace op { ...@@ -20,15 +20,6 @@ namespace op {
struct transposeqkv struct transposeqkv
{ {
int head_size = 64;
bool reversed_bs = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.head_size, "head_size"), f(self.reversed_bs, "reversed_bs"));
}
std::string name() const { return "transposeqkv"; } std::string name() const { return "transposeqkv"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
...@@ -42,14 +33,34 @@ struct transposeqkv ...@@ -42,14 +33,34 @@ struct transposeqkv
argument compute(const shape& output_shape, std::vector<argument> args) const argument compute(const shape& output_shape, std::vector<argument> args) const
{ {
// Input: BxSxKxNxH or SxBxKxNxH // Input: BxSxKxNxH
// Output: KxBxNxSxH // Output: KxBxNxSxH
// K is the number of identical matrix // K is the number of identical matrix
auto in_s = args.front().get_shape();
auto lens = in_s.lens();
argument result{output_shape}; argument result{output_shape};
visit_all(result, args.front())([&](auto output, const auto input) { visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) { par_for(output_shape.elements(), [&](auto i) {
// TODO: calculate in_offet and out_offset auto idx = in_s.multi(i);
output[i] = input[i];
const int b = idx.front();
const int s = idx.at(1);
const int m = idx.at(2);
const int n = idx.at(3);
const int j = idx.back();
const int num_heads = lens[3];
const int sequence_length = lens[1];
const int batch_size = lens[0];
const int H = lens.back();
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
output[out_offset + j] = input[i];
}); });
}); });
......
...@@ -82,31 +82,28 @@ struct parse_attention : op_parser<parse_attention> ...@@ -82,31 +82,28 @@ struct parse_attention : op_parser<parse_attention>
auto gemm_1 = info.add_instruction( auto gemm_1 = info.add_instruction(
migraphx::make_op("dot"), migraphx::make_op("dot"),
bias, bias,
ones /* info.make_contiguous(mb_bias), info.make_contiguous(ones) */); ones);
gemm_1 = gemm_1 =
info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1); info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);
/// ORT: Gemm, note that ROCM assumes col-major, so result(N, M) = 1 * weights x input + 1 x /// Use row-major => results(N, M) = 1 * input x weights + 1 x B
/// B. Assume row-major => results(N, M) = 1 * input x weights + 1 x B ?
auto input_sq = info.add_instruction( auto input_sq = info.add_instruction(
migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}), migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input); input);
auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_sq, weights); auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_sq, weights);
auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2); auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2);
// LaunchAttentionKernel: // LaunchTransQkv
// LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH // input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms = info.add_instruction( add_gemms = info.add_instruction(
migraphx::make_op("reshape", migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}), {{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms); add_gemms);
std::vector<std::size_t> qkv_perm{2, 0, 3, 1, 4};
auto transqkv = info.add_instruction( auto transqkv = info.add_instruction(
migraphx::make_op("transposeqkv", {{"head_size", head_size}}), add_gemms); migraphx::make_op("transposeqkv"), add_gemms);
// now scratch3 has Q, K, V: each has size BxNxSxH // transqkv has shape 3xBxNxSxH
// => transqkv has shape 3xBxNxSxH // => Q, K, V: each has size BxNxSxH
auto batches = batch_size * num_heads; auto batches = batch_size * num_heads;
auto size_per_batch = sequence_length * head_size; auto size_per_batch = sequence_length * head_size;
auto total_size = batches * size_per_batch; auto total_size = batches * size_per_batch;
...@@ -158,12 +155,11 @@ struct parse_attention : op_parser<parse_attention> ...@@ -158,12 +155,11 @@ struct parse_attention : op_parser<parse_attention>
// Inference mask is all 1s => masking can be skipped // Inference mask is all 1s => masking can be skipped
auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3); auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3);
// compute P*V (as V*P), and store in scratch3: BxNxSxH // compute P*V
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t); auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// scratch3 is BxNxSxH, transpose to output BxSxNxH // result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H
gemm4 = info.add_instruction(migraphx::make_op("transposectx", {{"head_size", head_size}}), gemm4 = info.add_instruction(migraphx::make_op("transposectx"), gemm4);
gemm4);
gemm4 = info.add_instruction( gemm4 = info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}), make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
info.make_contiguous(gemm4)); info.make_contiguous(gemm4));
......
...@@ -19,17 +19,13 @@ namespace gpu { ...@@ -19,17 +19,13 @@ namespace gpu {
static const char* const transposectx_kernel = R"__migraphx__( static const char* const transposectx_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposectx.hpp> #include <migraphx/kernels/transposectx.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
namespace migraphx { namespace migraphx {
extern "C" { extern "C" {
__global__ void transposectx_kernel(void* input_p, void* output_p) __global__ void transposectx_kernel(void* input_p, void* output_p)
{ {
make_tensors()(input_p, output_p)([](auto input, auto output) { make_tensors()(input_p, output_p)([](auto input, auto output) {
auto settings = make_transposectx_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{HEAD_SIZE}), MIGRAPHX_MAKE_CONSTANT(bool{REVERSED_BS})); transposectx(input, output);
transposectx(input, output, settings);
}); });
} }
...@@ -44,21 +40,11 @@ struct transposectx_compiler : compiler<transposectx_compiler> ...@@ -44,21 +40,11 @@ struct transposectx_compiler : compiler<transposectx_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 64); options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back());
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "transposectx_kernel"; options.kernel_name = "transposectx_kernel";
// head_size
assert(v.contains("head_size"));
auto head_size = v.at("head_size").to<int>();
options.params += " -DHEAD_SIZE=" + std::to_string(head_size);
// reversed_bs
assert(v.contains("reversed_bs"));
auto reversed_bs = v.at("reversed_bs").to<bool>();
options.params += " -DREVERSED_BS=" + std::to_string(reversed_bs);
return compile_hip_code_object(transposectx_kernel, options); return compile_hip_code_object(transposectx_kernel, options);
} }
...@@ -71,17 +57,13 @@ struct transposectx_compiler : compiler<transposectx_compiler> ...@@ -71,17 +57,13 @@ struct transposectx_compiler : compiler<transposectx_compiler>
static const char* const transposeqkv_kernel = R"__migraphx__( static const char* const transposeqkv_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp> #include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposeqkv.hpp> #include <migraphx/kernels/transposeqkv.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp> #include <args.hpp>
namespace migraphx { namespace migraphx {
extern "C" { extern "C" {
__global__ void transposeqkv_kernel(void* input_p, void* output_p) __global__ void transposeqkv_kernel(void* input_p, void* output_p)
{ {
make_tensors()(input_p, output_p)([](auto input, auto output) { make_tensors()(input_p, output_p)([](auto input, auto output) {
auto settings = make_transposeqkv_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{HEAD_SIZE}), MIGRAPHX_MAKE_CONSTANT(bool{REVERSED_BS})); transposeqkv(input, output);
transposeqkv(input, output, settings);
}); });
} }
...@@ -96,21 +78,11 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler> ...@@ -96,21 +78,11 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 64); options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back());
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "transposeqkv_kernel"; options.kernel_name = "transposeqkv_kernel";
// head_size
assert(v.contains("head_size"));
auto head_size = v.at("head_size").to<int>();
options.params += " -DHEAD_SIZE=" + std::to_string(head_size);
// reversed_bs
assert(v.contains("reversed_bs"));
auto reversed_bs = v.at("reversed_bs").to<bool>();
options.params += " -DREVERSED_BS=" + std::to_string(reversed_bs);
return compile_hip_code_object(transposeqkv_kernel, options); return compile_hip_code_object(transposeqkv_kernel, options);
} }
......
...@@ -7,55 +7,29 @@ ...@@ -7,55 +7,29 @@
namespace migraphx { namespace migraphx {
template <class T, class U> template <class T, class U>
struct transposectx_settings __device__ void transposectx(const T& input_t, const U& output_t)
{
T head_size{};
U reversed_bs{};
};
template <class... Ts>
constexpr transposectx_settings<Ts...> make_transposectx_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class Settings>
__device__ void transposectx(const T& input_t, const U& output_t, Settings st)
{ {
// Input: BxNxSxH // Input: BxNxSxH
// Output: BxSxNxH // Output: BxSxNxH
auto index = make_index();
auto input_shape = input_t.get_shape();
auto lens = input_shape.lens;
const int num_heads = lens[1];
const int sequence_length = lens[2];
int head_size = lens[3];
auto head_size = st.head_size; auto idx = input_shape.multi(index.global);
auto reversed_bs = st.reversed_bs;
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int num_heads = blockDim.y; int n = idx[1];
int sequence_length = gridDim.x; int s = idx[2];
int b = idx[0];
const int NH = num_heads * head_size; const int NH = num_heads * head_size;
const int NHS = NH * sequence_length; const int NHS = NH * sequence_length;
const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS; const int out_offset = n * head_size + s * NH + b * NHS;
int out_offset = 0; if (index.local < 1024)
if(reversed_bs) output_t[out_offset + idx[3]] = input_t[index.global];
{
const int batch_size = gridDim.y;
const int BNH = NH * batch_size;
out_offset = n * head_size + b * NH + s * BNH;
}
else
{
out_offset = n * head_size + s * NH + b * NHS;
}
const int i = threadIdx.x;
if(i < head_size)
{
output_t[out_offset + i] = input_t[in_offset + i];
}
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -7,57 +7,36 @@ ...@@ -7,57 +7,36 @@
namespace migraphx { namespace migraphx {
template <class T, class U> template <class T, class U>
struct transposeqkv_settings __device__ void transposeqkv(const T& input_t, const U& output_t)
{
T head_size{};
U reversed_bs{};
};
template <class... Ts>
constexpr transposeqkv_settings<Ts...> make_transposeqkv_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class Settings>
__device__ void transposeqkv(const T& input_t, const U& output_t, Settings st)
{ {
// Input: BxSxKxNxH or SxBxKxNxH // Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH // Output: KxBxNxSxH
// K is the number of identical matrix // K is the number of identical matrix
auto H = st.head_size; auto index = make_index();
auto reversed_bs = st.reversed_bs; auto input_shape = input_t.get_shape();
auto lens = input_shape.lens;
int n = threadIdx.y; auto idx = input_shape.multi(index.global);
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int num_heads = blockDim.y; const int b = idx[0];
const int s = idx[1];
const int m = idx[2];
const int n = idx[3];
//const int j = idx[4];
const int sequence_length = gridDim.x; const int num_heads = lens[3];
const int batch_size = gridDim.y; const int sequence_length = lens[1];
const int chunk_num = gridDim.z; const int batch_size = lens[0];
const int NH = num_heads * H; const int H = lens[4];
const int NHS = NH * sequence_length; const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = 0;
if(reversed_bs)
{
const int BNH = NH * batch_size;
in_offset = n * H + (m + b * chunk_num) * NH + s * BNH * chunk_num;
}
else
{
in_offset = n * H + (m + s * chunk_num) * NH + b * NHS * chunk_num;
}
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size; const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int i = threadIdx.x; if(index.global < input_shape.elements())
if(i < H)
{ {
output_t[out_offset + i] = input_t[in_offset + i]; output_t[out_offset + idx[4]] = input_t[index.global];
} }
} }
......
...@@ -666,6 +666,76 @@ TEST_CASE(batch_norm_inference_test) ...@@ -666,6 +666,76 @@ TEST_CASE(batch_norm_inference_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(bert_transpose_ops_test)
{
{
// transposeQKV
migraphx::program p;
auto* mm = p.get_main_module();
const int k = 3, b = 1, n = 2, s = 2, h = 1;
migraphx::shape sh{migraphx::shape::float_type, {b, s, k, n, h}};
std::vector<float> data(b * s * k * n * h);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposeqkv"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(k * b * n * s * h);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 4, 5, 8, 9, 2, 3, 6, 7, 10, 11};
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
std::vector<float> result_vector2(k * b * n * s * h);
result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); });
for (auto& i : result_vector2)
std::cout << i << ", ";
std::cout << std::endl;
EXPECT(migraphx::verify_range(result_vector, result_vector2));
}
{
// transposeCtx
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 2, 2, 2}};
std::vector<float> data(2 * 2 * 2 * 2);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{s, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(2 * 2 * 2 * 2);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 4, 5, 2, 3, 6, 7, 8, 9, 12, 13, 10, 11, 14, 15};
EXPECT(migraphx::verify_range(result_vector, gold));
}
{
// transposeCtx
migraphx::program p;
auto* mm = p.get_main_module();
const int b = 2, n = 2, s = 3, h = 4;
migraphx::shape sh{migraphx::shape::float_type, {b, n, s, h}};
std::vector<float> data(b * n * s * h);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(b * n * s * h);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 36, 37, 38, 39, 28, 29, 30, 31, 40, 41, 42, 43, 32, 33, 34, 35, 44, 45, 46, 47};
EXPECT(migraphx::verify_range(result_vector, gold));
}
}
TEST_CASE(broadcast_test) TEST_CASE(broadcast_test)
{ {
migraphx::program p; migraphx::program p;
......
...@@ -11,8 +11,8 @@ struct test_transposectx : verify_program<test_transposectx> ...@@ -11,8 +11,8 @@ struct test_transposectx : verify_program<test_transposectx>
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = auto x =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 12, 128, 64}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {16, 12, 384, 64}});
mm->add_instruction(migraphx::make_op("transposectx", {{"head_size", 64}}), x); mm->add_instruction(migraphx::make_op("transposectx"), x);
p.debug_print(); p.debug_print();
return p; return p;
} }
......
...@@ -11,7 +11,7 @@ struct test_transposeqkv : verify_program<test_transposeqkv> ...@@ -11,7 +11,7 @@ struct test_transposeqkv : verify_program<test_transposeqkv>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = auto x =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 12, 64}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 12, 64}});
mm->add_instruction(migraphx::make_op("transposeqkv", {{"head_size", 64}}), x); mm->add_instruction(migraphx::make_op("transposeqkv"), x);
p.debug_print(); p.debug_print();
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment