Commit 0ccee797 authored by turneram's avatar turneram
Browse files

Add transposectx and transposeqkv

parent f0ff480d
......@@ -185,6 +185,8 @@ register_migraphx_ops(
tan
topk
transpose
transposectx
transposeqkv
unary_not
undefined
unknown
......
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSECTX_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSECTX_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct transposectx
{
int head_size = 64;
bool reversed_bs = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.head_size, "head_size"), f(self.reversed_bs, "reversed_bs"));
}
std::string name() const { return "transposectx"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.front().lens();
std::vector<std::size_t> out_lens{lens[0], lens[2], lens[1], lens[3]};
return {inputs.front().type(), out_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
// Input: BxNxSxH
// Output: BxSxNxH
argument result{output_shape};
visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) {
// TODO: calculate in_offet and out_offset
output[i] = input[i];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSEQKV_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSEQKV_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct transposeqkv
{
int head_size = 64;
bool reversed_bs = false;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.head_size, "head_size"), f(self.reversed_bs, "reversed_bs"));
}
std::string name() const { return "transposeqkv"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.front().lens();
std::vector<std::size_t> out_lens{lens[2], lens[0], lens[3], lens[1], lens[4]};
return {inputs.front().type(), out_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
argument result{output_shape};
visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) {
// TODO: calculate in_offet and out_offset
output[i] = input[i];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -109,6 +109,8 @@
#include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/transposectx.hpp>
#include <migraphx/op/transposeqkv.hpp>
#include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp>
......
......@@ -103,7 +103,7 @@ struct parse_attention : op_parser<parse_attention>
add_gemms);
std::vector<std::size_t> qkv_perm{2, 0, 3, 1, 4};
auto transqkv = info.add_instruction(
migraphx::make_op("transpose", {{"permutation", qkv_perm}}), add_gemms);
migraphx::make_op("transposeqkv", {{"head_size", head_size}}), add_gemms);
// now scratch3 has Q, K, V: each has size BxNxSxH
// => transqkv has shape 3xBxNxSxH
......@@ -163,7 +163,7 @@ struct parse_attention : op_parser<parse_attention>
// scratch3 is BxNxSxH, transpose to output BxSxNxH
gemm4 = info.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1, 3}}}), gemm4);
migraphx::make_op("transposectx", {{"head_size", head_size}}), gemm4);
gemm4 = info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
info.make_contiguous(gemm4));
......
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const transposectx_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposectx.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposectx_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
auto settings = make_transposectx_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{HEAD_SIZE}), MIGRAPHX_MAKE_CONSTANT(bool{REVERSED_BS}));
transposectx(input, output, settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct transposectx_compiler : compiler<transposectx_compiler>
{
std::vector<std::string> names() const { return {"transposectx"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 64);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "transposectx_kernel";
// head_size
assert(v.contains("head_size"));
auto head_size = v.at("head_size").to<int>();
options.params += " -DHEAD_SIZE=" + std::to_string(head_size);
// reversed_bs
assert(v.contains("reversed_bs"));
auto reversed_bs = v.at("reversed_bs").to<bool>();
options.params += " -DREVERSED_BS=" + std::to_string(reversed_bs);
return compile_hip_code_object(transposectx_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
static const char* const transposeqkv_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposeqkv.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposeqkv_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
auto settings = make_transposeqkv_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{HEAD_SIZE}), MIGRAPHX_MAKE_CONSTANT(bool{REVERSED_BS}));
transposeqkv(input, output, settings);
});
}
}
} // namespace migraphx
)__migraphx__";
struct transposeqkv_compiler : compiler<transposeqkv_compiler>
{
std::vector<std::string> names() const { return {"transposeqkv"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()), 64);
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "transposeqkv_kernel";
// head_size
assert(v.contains("head_size"));
auto head_size = v.at("head_size").to<int>();
options.params += " -DHEAD_SIZE=" + std::to_string(head_size);
// reversed_bs
assert(v.contains("reversed_bs"));
auto reversed_bs = v.at("reversed_bs").to<bool>();
options.params += " -DREVERSED_BS=" + std::to_string(reversed_bs);
return compile_hip_code_object(transposeqkv_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_KERNELS_TRANSPOSECTX_HPP
#define MIGRAPHX_GUARD_KERNELS_TRANSPOSECTX_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T, class U>
struct transposectx_settings
{
T head_size{};
U reversed_bs{};
};
template <class... Ts>
constexpr transposectx_settings<Ts...> make_transposectx_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class Settings>
__device__ void transposectx(const T& input_t, const U& output_t, Settings st)
{
// Input: BxNxSxH
// Output: BxSxNxH
auto head_size = st.head_size;
auto reversed_bs = st.reversed_bs;
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int num_heads = blockDim.y;
int sequence_length = gridDim.x;
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
int out_offset = 0;
if (reversed_bs) {
const int batch_size = gridDim.y;
const int BNH = NH * batch_size;
out_offset = n * head_size + b * NH + s * BNH;
} else {
out_offset = n * head_size + s * NH + b * NHS;
}
const int i = threadIdx.x;
if (i < head_size) {
output_t[out_offset + i] = input_t[in_offset + i];
}
}
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_TRANSPOSEQKV_HPP
#define MIGRAPHX_GUARD_KERNELS_TRANSPOSEQKV_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T, class U>
struct transposeqkv_settings
{
T head_size{};
U reversed_bs{};
};
template <class... Ts>
constexpr transposeqkv_settings<Ts...> make_transposeqkv_settings(Ts... xs)
{
return {xs...};
}
template <class T, class U, class Settings>
__device__ void transposeqkv(const T& input_t, const U& output_t, Settings st)
{
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto H = st.head_size;
auto reversed_bs = st.reversed_bs;
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int chunk_num = gridDim.z;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = 0;
if (reversed_bs) {
const int BNH = NH * batch_size;
in_offset = n * H + (m + b * chunk_num) * NH + s * BNH * chunk_num;
} else {
in_offset = n * H + (m + s * chunk_num) * NH + b * NHS * chunk_num;
}
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int i = threadIdx.x;
if (i < H) {
output_t[out_offset + i] = input_t[in_offset + i];
}
}
} // namespace migraphx
#endif
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_transposectx : verify_program<test_transposectx>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 12, 128, 64}});
mm->add_instruction(migraphx::make_op("transposectx", {{"head_size", 64}}), x);
p.debug_print();
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_transposeqkv : verify_program<test_transposeqkv>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 1, 3, 12, 64}});
mm->add_instruction(migraphx::make_op("transposeqkv", {{"head_size", 64}}), x);
p.debug_print();
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment