Commit 37351ed6 authored by turneram's avatar turneram
Browse files

Remove transpose kernels

parent 48187e79
......@@ -185,8 +185,6 @@ register_migraphx_ops(
tan
topk
transpose
transposectx
transposeqkv
unary_not
undefined
unknown
......
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSECTX_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSECTX_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct transposectx
{
std::string name() const { return "transposectx"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.front().lens();
std::vector<std::size_t> out_lens{lens[0], lens[2], lens[1], lens[3]};
return {inputs.front().type(), out_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
// Input: BxNxSxH
// Output: BxSxNxH
argument result{output_shape};
auto in_s = args.front().get_shape();
auto lens = in_s.lens();
visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = in_s.multi(i);
int n = idx.at(1);
int s = idx.at(2);
int b = idx.front();
int num_heads = lens.at(1);
int sequence_length = lens.at(2);
int head_size = lens.back();
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
const int out_offset = n * head_size + s * NH + b * NHS;
const int j = idx.back();
output[out_offset + j] = input[i];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSEQKV_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSEQKV_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct transposeqkv
{
std::string name() const { return "transposeqkv"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.front().lens();
std::vector<std::size_t> out_lens{lens[2], lens[0], lens[3], lens[1], lens[4]};
return {inputs.front().type(), out_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
// Input: BxSxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto in_s = args.front().get_shape();
auto lens = in_s.lens();
argument result{output_shape};
visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = in_s.multi(i);
const int b = idx.front();
const int s = idx.at(1);
const int m = idx.at(2);
const int n = idx.at(3);
const int j = idx.back();
const int num_heads = lens[3];
const int sequence_length = lens[1];
const int batch_size = lens[0];
const int H = lens.back();
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int out_offset =
s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
output[out_offset + j] = input[i];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -109,8 +109,6 @@
#include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/transposectx.hpp>
#include <migraphx/op/transposeqkv.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
......
......@@ -65,7 +65,7 @@ struct parse_attention : op_parser<parse_attention>
migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms);
auto transqkv = info.add_instruction(migraphx::make_op("transposeqkv"), add_gemms);
auto transqkv = info.add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), add_gemms);
// Q, K, V: each has size BxNxSxH
auto q_t = info.add_instruction(
......@@ -99,7 +99,7 @@ struct parse_attention : op_parser<parse_attention>
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxHiddenSize
gemm4 = info.add_instruction(migraphx::make_op("transposectx"), gemm4);
gemm4 = info.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4);
return info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
gemm4);
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const transposectx_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposectx.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposectx_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transposectx(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
struct transposectx_compiler : compiler<transposectx_compiler>
{
std::vector<std::string> names() const { return {"transposectx"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto h = inputs.front().lens().back();
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), h), h);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "transposectx_kernel";
return compile_hip_code_object(transposectx_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
static const char* const transposeqkv_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposeqkv.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposeqkv_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transposeqkv(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
struct transposeqkv_compiler : compiler<transposeqkv_compiler>
{
std::vector<std::string> names() const { return {"transposeqkv"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
auto h = inputs.front().lens().back();
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements(), h), h);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "transposeqkv_kernel";
return compile_hip_code_object(transposeqkv_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_KERNELS_TRANSPOSECTX_HPP
#define MIGRAPHX_GUARD_KERNELS_TRANSPOSECTX_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T, class U>
__device__ void transposectx(const T& input_t, const U& output_t)
{
// Input: BxNxSxH
// Output: BxSxNxH
auto index = make_index();
auto input_shape = input_t.get_shape();
auto lens = input_shape.lens;
const int num_heads = lens[1];
const int sequence_length = lens[2];
int head_size = lens[3];
auto idx = input_shape.multi(index.global);
int n = idx[1];
int s = idx[2];
int b = idx[0];
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
const int out_offset = n * head_size + s * NH + b * NHS;
if(index.global < input_shape.elements())
output_t[out_offset + idx[3]] = input_t[index.global];
}
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_TRANSPOSEQKV_HPP
#define MIGRAPHX_GUARD_KERNELS_TRANSPOSEQKV_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T, class U>
__device__ void transposeqkv(const T& input_t, const U& output_t)
{
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto index = make_index();
auto input_shape = input_t.get_shape();
auto lens = input_shape.lens;
auto idx = input_shape.multi(index.global);
const int b = idx[0];
const int s = idx[1];
const int m = idx[2];
const int n = idx[3];
const int num_heads = lens[3];
const int sequence_length = lens[1];
const int batch_size = lens[0];
const int H = lens[4];
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
if(index.global < input_shape.elements())
{
output_t[out_offset + idx[4]] = input_t[index.global];
}
}
} // namespace migraphx
#endif
......@@ -666,67 +666,6 @@ TEST_CASE(batch_norm_inference_test)
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(bert_transpose_ops_test)
{
{
// transposeQKV
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int> bsknh{2, 384, 3, 12, 64};
const int elements = std::accumulate(bsknh.begin(), bsknh.end(), 1, std::multiplies<int>());
migraphx::shape sh{migraphx::shape::float_type, bsknh};
std::vector<float> data(elements);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposeqkv"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(elements);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(elements);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
// BSKNH->KBNSH : perm=2,0,3,1,4
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}),
l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
result2.visit([&](auto output) { gold.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vector, gold));
}
{
// transposeCtx
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int> bnsh{2, 12, 384, 64};
const int elements = std::accumulate(bnsh.begin(), bnsh.end(), 1, std::multiplies<int>());
migraphx::shape sh{migraphx::shape::float_type, bnsh};
std::vector<float> data(elements);
std::iota(data.begin(), data.end(), 0);
auto l1 = mm->add_literal(migraphx::literal{sh, data});
mm->add_instruction(migraphx::make_op("transposectx"), l1);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> result_vector(elements);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold(elements);
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
// BNSH->BSNH : perm=0,2,1,3
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
result2.visit([&](auto output) { gold.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(result_vector, gold));
}
}
TEST_CASE(broadcast_test)
{
migraphx::program p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_transposectx : verify_program<test_transposectx>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 12, 128, 64}});
mm->add_instruction(migraphx::make_op("transposectx"), x);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_transposeqkv : verify_program<test_transposeqkv>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter(
"x", migraphx::shape{migraphx::shape::float_type, {2, 384, 3, 12, 64}});
mm->add_instruction(migraphx::make_op("transposeqkv"), x);
return p;
}
};
......@@ -3,7 +3,7 @@
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_layernorm : verify_program<test_layernorm>
struct test_layernorm_op : verify_program<test_layernorm_op>
{
migraphx::program create_program() const
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment