Commit 3ea9fe4c authored by turneram's avatar turneram
Browse files

Add attention, layernorm op, transposectx, and transposeqkv

parent 4a312201
...@@ -117,6 +117,7 @@ register_migraphx_ops( ...@@ -117,6 +117,7 @@ register_migraphx_ops(
if_op if_op
im2col im2col
isnan isnan
layernorm
leaky_relu leaky_relu
less less
load load
...@@ -184,6 +185,8 @@ register_migraphx_ops( ...@@ -184,6 +185,8 @@ register_migraphx_ops(
tan tan
topk topk
transpose transpose
transposectx
transposeqkv
unary_not unary_not
undefined undefined
unknown unknown
......
#ifndef MIGRAPHX_GUARD_OPERATORS_LAYERNORMALIZATION_HPP
#define MIGRAPHX_GUARD_OPERATORS_LAYERNORMALIZATION_HPP
#include <array>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/par_for.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct layernorm
{
float epsilon = 1e-3;
int64_t axis = -1;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.epsilon, "epsilon"), f(self.axis, "axis"));
}
value attributes() const
{
value normalize;
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}
std::string name() const { return "layernorm"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
if(inputs.size() == 2)
{
if(inputs.at(1).lens().front() != inputs.front().lens().at(axis))
MIGRAPHX_THROW("LAYERNORM: weights have wrong shape");
}
if(inputs.size() == 3)
{
if(inputs.at(2).lens().front() != inputs.front().lens().at(axis))
MIGRAPHX_THROW("LAYERNORM: bias has wrong shape");
}
return inputs.front();
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
argument result{output_shape};
auto x_lens = args.front().get_shape().lens();
auto norm_count = std::accumulate(
x_lens.begin(), x_lens.begin() + axis, std::size_t{1}, std::multiplies<std::size_t>());
auto norm_size = std::accumulate(
x_lens.begin() + axis, x_lens.end(), std::size_t{1}, std::multiplies<std::size_t>());
if(args.size() == 3)
{
visit_all(result, args[0], args[1], args[2])(
[&](auto output, auto data, auto weights, auto bias) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
double mean_square = 0;
for(std::size_t i = 0; i < norm_size; ++i)
{
mean += data[offset + i];
mean_square += data[offset + i] * data[offset + i];
}
mean /= norm_size;
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
if(args.size() == 3)
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i] + bias[i];
else
output[offset + i] =
(data[offset + i] - mean) / mean_square * weights[i];
}
});
});
}
else
{
visit_all(result, args[0])([&](auto output, auto data) {
par_for(norm_count, [&](auto idx) {
auto offset = idx * norm_size;
double mean = 0;
double mean_square = 0;
for(std::size_t i = 0; i < norm_size; ++i)
{
mean += data[offset + i];
mean_square += data[offset + i] * data[offset + i];
}
mean /= norm_size;
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
for(std::size_t i = 0; i < norm_size; ++i)
{
output[offset + i] = (data[offset + i] - mean) / mean_square;
// scale and bias handled by pointwise ops
}
});
});
}
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSECTX_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSECTX_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct transposectx
{
std::string name() const { return "transposectx"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.front().lens();
std::vector<std::size_t> out_lens{lens[0], lens[2], lens[1], lens[3]};
return {inputs.front().type(), out_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
// Input: BxNxSxH
// Output: BxSxNxH
argument result{output_shape};
auto in_s = args.front().get_shape();
auto lens = in_s.lens();
visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = in_s.multi(i);
int n = idx.at(1);
int s = idx.at(2);
int b = idx.front();
int num_heads = lens.at(1);
int sequence_length = lens.at(2);
int head_size = lens.back();
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
const int out_offset = n * head_size + s * NH + b * NHS;
const int j = idx.back();
output[out_offset + j] = input[i];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_OPERATORS_TRANSPOSEQKV_HPP
#define MIGRAPHX_GUARD_OPERATORS_TRANSPOSEQKV_HPP
#include <array>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/value.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <cmath>
#include <utility>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {
struct transposeqkv
{
std::string name() const { return "transposeqkv"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
auto lens = inputs.front().lens();
std::vector<std::size_t> out_lens{lens[2], lens[0], lens[3], lens[1], lens[4]};
return {inputs.front().type(), out_lens};
}
argument compute(const shape& output_shape, std::vector<argument> args) const
{
// Input: BxSxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto in_s = args.front().get_shape();
auto lens = in_s.lens();
argument result{output_shape};
visit_all(result, args.front())([&](auto output, const auto input) {
par_for(output_shape.elements(), [&](auto i) {
auto idx = in_s.multi(i);
const int b = idx.front();
const int s = idx.at(1);
const int m = idx.at(2);
const int n = idx.at(3);
const int j = idx.back();
const int num_heads = lens[3];
const int sequence_length = lens[1];
const int batch_size = lens[0];
const int H = lens.back();
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int out_offset =
s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
output[out_offset + j] = input[i];
});
});
return result;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/im2col.hpp> #include <migraphx/op/im2col.hpp>
#include <migraphx/op/isnan.hpp> #include <migraphx/op/isnan.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/leaky_relu.hpp> #include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/less.hpp> #include <migraphx/op/less.hpp>
#include <migraphx/op/load.hpp> #include <migraphx/op/load.hpp>
...@@ -108,6 +109,8 @@ ...@@ -108,6 +109,8 @@
#include <migraphx/op/tan.hpp> #include <migraphx/op/tan.hpp>
#include <migraphx/op/topk.hpp> #include <migraphx/op/topk.hpp>
#include <migraphx/op/transpose.hpp> #include <migraphx/op/transpose.hpp>
#include <migraphx/op/transposectx.hpp>
#include <migraphx/op/transposeqkv.hpp>
#include <migraphx/op/unary.hpp> #include <migraphx/op/unary.hpp>
#include <migraphx/op/unary_not.hpp> #include <migraphx/op/unary_not.hpp>
#include <migraphx/op/undefined.hpp> #include <migraphx/op/undefined.hpp>
......
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_attention : op_parser<parse_attention>
{
std::vector<op_desc> operators() const { return {{"Attention"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
auto input = args[0];
auto weights = args[1];
auto bias = args[2];
auto mask_index = args[3];
instruction_ref past;
instruction_ref extra_add_qk;
bool is_past = false;
bool is_extra_add_qk = false;
if(args.size() > 4)
{
past = args[4];
is_past = true;
}
if(args.size() == 6)
{
is_extra_add_qk = true;
extra_add_qk = args[5];
}
// ORT default is 12
std::size_t num_heads = 12;
if(contains(info.attributes, "num_heads"))
num_heads = info.attributes.at("num_heads").i();
// input shape: (batch_size, sequence_length, input_hidden_size)
auto input_lens = input->get_shape().lens();
auto batch_size = input_lens.at(0);
auto sequence_length = input_lens.at(1);
auto input_hidden_size = input_lens.at(2);
// bias shape: (3 * hidden_size)
auto bias_lens = bias->get_shape().lens();
auto hidden_size = bias_lens.at(0) / 3;
auto head_size = hidden_size / num_heads;
int past_sequence_length = 0;
// GetPresent
// Input and output shapes:
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
// present : (2, batch_size, num_heads, past_sequence_length + sequence_length,
// head_size)
std::vector<std::size_t> present_lens{2, batch_size, num_heads, sequence_length, head_size};
if(is_past)
{
auto past_lens = past->get_shape().lens();
past_sequence_length = past_lens.at(3);
present_lens[3] += past_lens[3];
}
// Use GEMM for fully connection.
auto m = batch_size * sequence_length;
auto n = bias_lens.front();
auto k = input_hidden_size;
// Bias shape is (N), broadcast using B(N, M) = 1 * bias(N, 1) x ones(1, M) + 0 * B.
auto bias_type = bias->get_shape().type();
std::vector<float> ones_vec(m, 1);
std::vector<std::size_t> ones_lens{1, m};
auto ones =
info.add_literal(migraphx::literal{migraphx::shape{bias_type, ones_lens}, ones_vec});
bias = info.add_instruction(migraphx::make_op("reshape", {{"dims", {n, 1}}}), bias);
auto gemm_1 = info.add_instruction(migraphx::make_op("dot"), bias, ones);
gemm_1 =
info.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), gemm_1);
/// Use row-major => results(N, M) = 1 * input x weights + 1 x B
auto input_sq = info.add_instruction(
migraphx::make_op("reshape", {{"dims", {batch_size * sequence_length, hidden_size}}}),
input);
auto gemm_2 = info.add_instruction(migraphx::make_op("dot"), input_sq, weights);
auto add_gemms = info.add_instruction(migraphx::make_op("add"), gemm_1, gemm_2);
// LaunchTransQkv
// input should be BxSx3xNxH => scratch3: 3xBxNxSxH
add_gemms = info.add_instruction(
migraphx::make_op("reshape",
{{"dims", {batch_size, sequence_length, 3, num_heads, head_size}}}),
add_gemms);
auto transqkv = info.add_instruction(migraphx::make_op("transposeqkv"), add_gemms);
// transqkv has shape 3xBxNxSxH
// => Q, K, V: each has size BxNxSxH
auto batches = batch_size * num_heads;
auto size_per_batch = sequence_length * head_size;
auto total_size = batches * size_per_batch;
auto q_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), transqkv);
auto k_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), transqkv);
auto v_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {3}}}), transqkv);
q_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), q_t);
k_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), k_t);
v_t = info.add_instruction(make_op("squeeze", {{"axes", {0}}}), v_t);
if(is_past)
{
k_t = info.add_instruction(migraphx::make_op("concat", {{"axis", 3}}), past, k_t);
v_t = info.add_instruction(
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {3}}}), k_t);
}
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max
// sequence length.
auto mask_index_lens = mask_index->get_shape().lens();
bool use_raw_attention_mask = mask_index_lens.size() >= 2;
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
const int all_sequence_length = past_sequence_length + sequence_length;
const int temp_matrix_size = sequence_length * all_sequence_length;
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
const float alpha = use_raw_attention_mask ? 1.0 : rsqrt_head_size;
// K{B,N,S,H} -> K'{B,N,H,S}
k_t = info.add_instruction(make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), k_t);
auto gemm3 = info.add_instruction(migraphx::make_op("dot"), q_t, k_t);
if(is_extra_add_qk)
gemm3 = info.add_instruction(make_op("add"), gemm3, extra_add_qk);
auto alpha_lit = info.add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", gemm3->get_shape().lens()}}),
info.add_literal(
migraphx::literal{migraphx::shape{gemm3->get_shape().type()}, {alpha}}));
gemm3 =
info.add_instruction(migraphx::make_op("mul"), gemm3, info.make_contiguous(alpha_lit));
// apply softmax and store result P to scratch2: BxNxSxS*
// Inference mask is all 1s => masking can be skipped
auto softmax = info.add_instruction(migraphx::make_op("softmax", {{"axis", 3}}), gemm3);
// compute P*V
auto gemm4 = info.add_instruction(migraphx::make_op("dot"), softmax, v_t);
// result is BxNxSxH, transpose to BxSxNxH and reshape to BxSxN*H
gemm4 = info.add_instruction(migraphx::make_op("transposectx"), gemm4);
gemm4 = info.add_instruction(
make_op("reshape", {{"dims", {batch_size, sequence_length, num_heads * head_size}}}),
info.make_contiguous(gemm4));
return gemm4;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/instruction.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {
struct parse_layernorm : op_parser<parse_layernorm>
{
std::vector<op_desc> operators() const { return {{"LayerNormalization"}}; }
instruction_ref parse(const op_desc& /*opd*/,
const onnx_parser& parser,
onnx_parser::node_info info,
const std::vector<instruction_ref>& args) const
{
float epsilon = 1e-3f;
int64_t axis = -1;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "axis"))
{
epsilon = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
}
auto layernorm = info.add_instruction(make_op("layernorm", {{"epsilon", epsilon}, {"axis", axis}}), args.front());
if (args.size() == 3)
{
layernorm = info.add_broadcastable_binary_op("mul", layernorm, args.at(1));
layernorm = info.add_broadcastable_binary_op("add", layernorm, args.at(2));
}
return layernorm;
}
};
} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -148,6 +148,7 @@ add_library(migraphx_gpu ...@@ -148,6 +148,7 @@ add_library(migraphx_gpu
int8_conv_pack.cpp int8_conv_pack.cpp
int8_gemm_pack.cpp int8_gemm_pack.cpp
kernel.cpp kernel.cpp
layernorm.cpp
lowering.cpp lowering.cpp
logsoftmax.cpp logsoftmax.cpp
loop.cpp loop.cpp
...@@ -204,6 +205,7 @@ register_migraphx_gpu_ops(hip_ ...@@ -204,6 +205,7 @@ register_migraphx_gpu_ops(hip_
floor floor
gather gather
greater greater
layernorm
less less
log log
logsoftmax logsoftmax
......
#ifndef MIGRAPHX_GUARD_RTGLIB_LAYERNORM_HPP
#define MIGRAPHX_GUARD_RTGLIB_LAYERNORM_HPP
#include <migraphx/op/layernorm.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/argument.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
struct context;
struct hip_layernorm
{
op::layernorm op;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return migraphx::reflect(self.op, f);
}
std::string name() const { return "gpu::layernorm"; }
shape compute_shape(std::vector<shape> inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
void finalize(context&, const shape&, const std::vector<shape>&);
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
static const char* const transposectx_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposectx.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposectx_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transposectx(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
struct transposectx_compiler : compiler<transposectx_compiler>
{
std::vector<std::string> names() const { return {"transposectx"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back());
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "transposectx_kernel";
return compile_hip_code_object(transposectx_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
static const char* const transposeqkv_kernel = R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/transposeqkv.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void transposeqkv_kernel(void* input_p, void* output_p)
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transposeqkv(input, output);
});
}
}
} // namespace migraphx
)__migraphx__";
struct transposeqkv_compiler : compiler<transposeqkv_compiler>
{
std::vector<std::string> names() const { return {"transposeqkv"}; }
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{
hip_compile_options options;
options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back());
options.output = inputs.back();
options.inputs = inputs;
options.kernel_name = "transposeqkv_kernel";
return compile_hip_code_object(transposeqkv_kernel, options);
}
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
{
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
}
};
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_KERNELS_TRANSPOSECTX_HPP
#define MIGRAPHX_GUARD_KERNELS_TRANSPOSECTX_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T, class U>
__device__ void transposectx(const T& input_t, const U& output_t)
{
// Input: BxNxSxH
// Output: BxSxNxH
auto index = make_index();
auto input_shape = input_t.get_shape();
auto lens = input_shape.lens;
const int num_heads = lens[1];
const int sequence_length = lens[2];
int head_size = lens[3];
auto idx = input_shape.multi(index.global);
int n = idx[1];
int s = idx[2];
int b = idx[0];
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
const int out_offset = n * head_size + s * NH + b * NHS;
if(index.global < input_shape.elements())
output_t[out_offset + idx[3]] = input_t[index.global];
}
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_KERNELS_TRANSPOSEQKV_HPP
#define MIGRAPHX_GUARD_KERNELS_TRANSPOSEQKV_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
namespace migraphx {
template <class T, class U>
__device__ void transposeqkv(const T& input_t, const U& output_t)
{
// Input: BxSxKxNxH or SxBxKxNxH
// Output: KxBxNxSxH
// K is the number of identical matrix
auto index = make_index();
auto input_shape = input_t.get_shape();
auto lens = input_shape.lens;
auto idx = input_shape.multi(index.global);
const int b = idx[0];
const int s = idx[1];
const int m = idx[2];
const int n = idx[3];
const int num_heads = lens[3];
const int sequence_length = lens[1];
const int batch_size = lens[0];
const int H = lens[4];
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
if(index.global < input_shape.elements())
{
output_t[out_offset + idx[4]] = input_t[index.global];
}
}
} // namespace migraphx
#endif
#include <migraphx/gpu/layernorm.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/layernorm.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape hip_layernorm::compute_shape(std::vector<shape> inputs) const
{
inputs.pop_back();
return op.normalize_compute_shape(inputs);
}
argument hip_layernorm::compute(context& ctx, const shape&, const std::vector<argument>& args) const
{
device::layernorm(ctx.get_stream().get(), args.back(), args[0]);
return args.back();
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <migraphx/op/dot.hpp> #include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp> #include <migraphx/op/elu.hpp>
#include <migraphx/op/if_op.hpp> #include <migraphx/op/if_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/leaky_relu.hpp> #include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/lrn.hpp> #include <migraphx/op/lrn.hpp>
#include <migraphx/op/pooling.hpp> #include <migraphx/op/pooling.hpp>
...@@ -29,6 +30,7 @@ ...@@ -29,6 +30,7 @@
#include <migraphx/gpu/gemm.hpp> #include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/greater.hpp> #include <migraphx/gpu/greater.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp> #include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/layernorm.hpp>
#include <migraphx/gpu/leaky_relu.hpp> #include <migraphx/gpu/leaky_relu.hpp>
#include <migraphx/gpu/less.hpp> #include <migraphx/gpu/less.hpp>
#include <migraphx/gpu/logical_and.hpp> #include <migraphx/gpu/logical_and.hpp>
...@@ -139,6 +141,7 @@ struct miopen_apply ...@@ -139,6 +141,7 @@ struct miopen_apply
add_generic_op("exp"); add_generic_op("exp");
add_generic_op("floor"); add_generic_op("floor");
add_generic_op("greater"); add_generic_op("greater");
add_generic_op("layernorm");
add_generic_op("less"); add_generic_op("less");
add_generic_op("log"); add_generic_op("log");
add_generic_op("logical_and"); add_generic_op("logical_and");
......
...@@ -2622,6 +2622,22 @@ def layernorm_test(): ...@@ -2622,6 +2622,22 @@ def layernorm_test():
bias_add], [x, scale, bias], [y], [pow_tensor, epsilon_tensor]) bias_add], [x, scale, bias], [y], [pow_tensor, epsilon_tensor])
@onnx_test
def layernorm_op_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 2, 3])
w = helper.make_tensor_value_info('w', TensorProto.FLOAT, [3])
b = helper.make_tensor_value_info('b', TensorProto.FLOAT, [3])
output = helper.make_tensor_value_info('output', TensorProto.FLOAT,
[1, 2, 3])
node = onnx.helper.make_node('LayerNormalization',
inputs=['x', 'w', 'b'],
outputs=["output"],
epsilon=1e-5)
return ([node], [x, w, b], [output])
@onnx_test @onnx_test
def leaky_relu_test(): def leaky_relu_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
...@@ -446,6 +446,31 @@ TEST_CASE(instance_norm_3d_test) ...@@ -446,6 +446,31 @@ TEST_CASE(instance_norm_3d_test)
EXPECT(migraphx::verify_range(result_vector, gold)); EXPECT(migraphx::verify_range(result_vector, gold));
} }
TEST_CASE(layernorm_op_test)
{
migraphx::program p = migraphx::parse_onnx("layernorm_op_test.onnx");
p.compile(migraphx::ref::target{});
migraphx::shape sx{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape swb{migraphx::shape::float_type, {3}};
std::vector<float> x_vec{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
std::vector<float> w_vec{1.0, 1.0, 1.0};
std::vector<float> b_vec{0.0, 0.0, 0.0};
migraphx::parameter_map pp;
pp["x"] = migraphx::argument(sx, x_vec.data());
pp["w"] = migraphx::argument(swb, w_vec.data());
pp["b"] = migraphx::argument(swb, b_vec.data());
auto result = p.eval(pp).back();
std::vector<float> result_vector(6);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{-1.22474f, 0.0f, 1.22474f, -1.22474f, 0.0f, 1.22474f};
EXPECT(migraphx::verify_range(result_vector, gold));
}
TEST_CASE(lessorequal_test) TEST_CASE(lessorequal_test)
{ {
migraphx::program p = migraphx::parse_onnx("lessorequal_test.onnx"); migraphx::program p = migraphx::parse_onnx("lessorequal_test.onnx");
......
...@@ -2435,6 +2435,50 @@ TEST_CASE(imagescaler_test) ...@@ -2435,6 +2435,50 @@ TEST_CASE(imagescaler_test)
EXPECT(migraphx::verify_range(results_vector, gold)); EXPECT(migraphx::verify_range(results_vector, gold));
} }
TEST_CASE(layernorm_test)
{
{
// with scale and bias
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {1, 2, 3}};
migraphx::shape swb{migraphx::shape::float_type, {3}};
std::vector<float> x_vec{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
auto x = mm->add_literal(migraphx::literal{sx, x_vec});
auto w = mm->add_literal(migraphx::literal{swb, {1.0, 1.0, 1.0}});
auto b = mm->add_literal(migraphx::literal{swb, {0.0, 0.0, 0.0}});
mm->add_instruction(migraphx::make_op("layernorm", {{"epsilon", 1e-5}}), x, w, b);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(1 * 2 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.22474f, 0.0f, 1.22474f, -1.22474f, 0.0f, 1.22474f};
for(auto&& i : results_vector)
std::cout << i << ", ";
std::cout << std::endl;
EXPECT(migraphx::verify_range(results_vector, gold));
}
{
// without scale and bias
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {1, 2, 3}};
std::vector<float> x_vec{1.0, 2.0, 3.0, 4.0, 5.0, 6.0};
auto x = mm->add_literal(migraphx::literal{sx, x_vec});
mm->add_instruction(migraphx::make_op("layernorm", {{"epsilon", 1e-5}}), x);
p.compile(migraphx::ref::target{});
auto result = p.eval({}).back();
std::vector<float> results_vector(1 * 2 * 3);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> gold = {-1.22474f, 0.0f, 1.22474f, -1.22474f, 0.0f, 1.22474f};
for(auto&& i : results_vector)
std::cout << i << ", ";
std::cout << std::endl;
EXPECT(migraphx::verify_range(results_vector, gold));
}
}
TEST_CASE(leaky_relu_test) TEST_CASE(leaky_relu_test)
{ {
migraphx::program p; migraphx::program p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_layernorm : verify_program<test_layernorm>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 384, 768}});
mm->add_instruction(migraphx::make_op("layernorm", {{"axis", -1}}), x);
return p;
}
};
\ No newline at end of file
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_transposectx : verify_program<test_transposectx>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 12, 128, 64}});
mm->add_instruction(migraphx::make_op("transposectx"), x);
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_transposeqkv : verify_program<test_transposeqkv>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 384, 3, 12, 64}});
mm->add_instruction(migraphx::make_op("transposeqkv"), x);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment