Unverified Commit 2c5d5fee authored by kahmed10's avatar kahmed10 Committed by GitHub
Browse files

Layernorm onnx support (#599)



* fix pad calc

* bert tf passes correctness

* formatting

* add test

* formatting

* remove comment

* add inline

* formatting

* fix order for literal

* formatting

* test no mul_add

* formatting

* debug layernorm

* debug layernorm

* manual merge

* more progress

* formatting

* remove miopen batchnorm

* remove headers

* Fix compile error with no dpp reductions

* fix indices

* formatting

* change matcher

* formatting

* remove binds

* formatting

* disable tf matcher

* formatting

* use fast div

* formatting

* fix matcher

* formatting

* remove comment

* move find_matches

* add assert

* formatting

* fix deepcode issue
Co-authored-by: default avatarPaul <pfultz2@yahoo.com>
Co-authored-by: default avatarShucai Xiao <shucai.xiao@amd.com>
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent d612e976
...@@ -38,6 +38,7 @@ add_library(migraphx_device ...@@ -38,6 +38,7 @@ add_library(migraphx_device
device/gather.cpp device/gather.cpp
device/gelu.cpp device/gelu.cpp
device/int8_gemm_pack.cpp device/int8_gemm_pack.cpp
device/layernorm.cpp
device/log.cpp device/log.cpp
device/logsoftmax.cpp device/logsoftmax.cpp
device/max.cpp device/max.cpp
......
...@@ -242,6 +242,23 @@ constexpr index_int compute_block_size(index_int n, index_int max_block_size) ...@@ -242,6 +242,23 @@ constexpr index_int compute_block_size(index_int n, index_int max_block_size)
return block_size; return block_size;
} }
inline std::vector<index_int> get_reduce_lens(const std::vector<size_t>& input_lens,
const std::vector<size_t>& output_lens)
{
std::vector<index_int> reduce_lens;
std::transform(output_lens.begin(),
output_lens.end(),
input_lens.begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> index_int {
if(x == y)
return 1;
else
return y;
});
return reduce_lens;
}
template <class Op, class T, class Input, class Output> template <class Op, class T, class Input, class Output>
void reduce_multi_impl(hipStream_t stream, void reduce_multi_impl(hipStream_t stream,
const argument& result, const argument& result,
...@@ -309,29 +326,19 @@ void reduce(hipStream_t stream, ...@@ -309,29 +326,19 @@ void reduce(hipStream_t stream,
{ {
auto&& output_shape = result.get_shape(); auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape(); auto&& input_shape = arg.get_shape();
assert(output_shape.lens().size() == input_shape.lens().size()); auto input_lens = input_shape.lens();
auto output_lens = output_shape.lens();
assert(output_lens.size() == input_lens.size());
if(input_shape.standard() and output_shape.standard() and if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and output_lens.back() != input_lens.back() and
std::equal(output_shape.lens().begin(), std::equal(output_lens.begin(), std::prev(output_lens.end()), input_lens.begin()))
std::prev(output_shape.lens().end()),
input_shape.lens().begin()))
{ {
reduce_standard_impl( reduce_standard_impl(
stream, result, arg, op, init, read_input, read_output, input_shape.lens().back()); stream, result, arg, op, init, read_input, read_output, input_lens.back());
} }
else else
{ {
std::vector<index_int> reduce_lens; std::vector<index_int> reduce_lens = get_reduce_lens(input_lens, output_lens);
std::transform(output_shape.lens().begin(),
output_shape.lens().end(),
input_shape.lens().begin(),
std::back_inserter(reduce_lens),
[](auto x, auto y) -> index_int {
if(x == y)
return 1;
else
return y;
});
shape reduce_slice{output_shape.type(), reduce_lens}; shape reduce_slice{output_shape.type(), reduce_lens};
reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice); reduce_multi_impl(stream, result, arg, op, init, read_input, read_output, reduce_slice);
} }
......
#include <migraphx/gpu/device/layernorm.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/pow.hpp>
#include <migraphx/gpu/device/fast_div.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
// m = x - mean(x)
// m / sqrt(mean(m ^ 2) + 1e-12)
void layernorm(hipStream_t stream, const argument& result, const argument& arg1)
{
auto relements = arg1.get_shape().lens().back();
assert(relements <= 1024);
auto nelements = result.get_shape().elements() / relements;
auto input_shape = arg1.get_shape();
auto output_shape = result.get_shape();
auto reduce_output_lens(output_shape.lens());
reduce_output_lens.back() = 1;
std::vector<index_int> reduce_lens = get_reduce_lens(input_shape.lens(), reduce_output_lens);
hip_visit_all(result, arg1)([&](auto output, auto input) {
using value_type = typename decltype(input)::value_type;
const std::size_t max_block_size = 256;
const std::size_t block_size = compute_block_size(relements, max_block_size);
const std::size_t block_size_div = encode_divisor(block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size;
const auto base_idx = out_idx * relements;
value_type x_data[4];
auto x = [&](auto j) -> value_type& {
return x_data[fast_div(j - idx.local, block_size_div)];
};
idx.local_stride(relements,
[&](auto j) __device__ { x(j) = input.data()[base_idx + j]; });
auto m = block_reduce<max_block_size>(
idx, sum{}, 0, relements, [&](auto j) __device__ { return x(j); }) /
relements;
idx.local_stride(relements, [&](auto j) __device__ { x(j) = x(j) - m; });
auto r = block_reduce<max_block_size>(
idx, sum{}, 0, relements, [&](auto j) __device__ { return x(j) * x(j); }) /
relements;
idx.local_stride(relements, [&](auto j) __device__ {
output.data()[base_idx + j] = x(j) * ::rsqrt(r + 1e-12);
});
});
});
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
#include <migraphx/gpu/clip.hpp> #include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/convolution.hpp> #include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/oper.hpp> #include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/add.hpp>
#include <migraphx/gpu/mul.hpp>
#include <migraphx/gpu/device/layernorm.hpp>
#include <migraphx/gpu/device/gelu.hpp> #include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/mul_add.hpp> #include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/add_clip.hpp> #include <migraphx/gpu/device/add_clip.hpp>
...@@ -205,6 +208,10 @@ struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh> ...@@ -205,6 +208,10 @@ struct hip_add_tanh : binary_device<hip_add_tanh, &device::add_tanh>
{ {
}; };
struct hip_layernorm : unary_device<hip_layernorm, &device::layernorm>
{
};
struct hip_gelu : unary_device<hip_gelu, &device::gelu> struct hip_gelu : unary_device<hip_gelu, &device::gelu>
{ {
}; };
...@@ -249,6 +256,50 @@ void move_standard_front(std::vector<instruction_ref>& args) ...@@ -249,6 +256,50 @@ void move_standard_front(std::vector<instruction_ref>& args)
std::swap(*it, args.front()); std::swap(*it, args.front());
} }
struct find_layernorm
{
template <class... Ts>
static auto multibroadcast_op(Ts... xs)
{
return match::name("multibroadcast")(match::arg(0)(xs...));
}
static auto x_minus_mean()
{
return match::name("gpu::sub")(
match::arg(0)(match::any().bind("x")),
match::arg(1)(multibroadcast_op(match::name("gpu::reduce_mean"))));
}
static auto variance()
{
return match::name("gpu::reduce_mean")(match::arg(0)(
match::name("gpu::pow")(match::arg(0)(x_minus_mean()),
match::arg(1)(multibroadcast_op(match::has_value(2.0f))))));
}
static auto layernorm_onnx()
{
return match::name("gpu::div")(
match::arg(0)(x_minus_mean()),
match::arg(1)(multibroadcast_op(
match::name("gpu::sqrt")(match::arg(0)(match::name("gpu::add")(match::either_arg(
0, 1)(variance(), multibroadcast_op(match::has_value(1e-12f)))))))));
}
auto matcher() const { return layernorm_onnx(); }
void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto args = ins->inputs();
p.replace_instruction(ins, hip_layernorm{}, x_ins, args.back());
}
};
struct find_gelu struct find_gelu
{ {
...@@ -658,6 +709,7 @@ void fuse_ops::apply(program& p) const ...@@ -658,6 +709,7 @@ void fuse_ops::apply(program& p) const
run_passes(p, {dead_code_elimination{}}); run_passes(p, {dead_code_elimination{}});
match::find_matches(p, find_triadd{}); match::find_matches(p, find_triadd{});
match::find_matches(p, match::find_matches(p,
find_layernorm{},
find_conv_bias_relu{ctx}, find_conv_bias_relu{ctx},
find_conv_bias{ctx}, find_conv_bias{ctx},
find_add_gelu{}, find_add_gelu{},
......
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_LAYERNORM_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_LAYERNORM_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void layernorm(hipStream_t stream, const argument& result, const argument& arg1);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
...@@ -1028,6 +1028,37 @@ struct test_triadd_tanh : verify_program<test_triadd_tanh> ...@@ -1028,6 +1028,37 @@ struct test_triadd_tanh : verify_program<test_triadd_tanh>
} }
}; };
struct test_layernorm : verify_program<test_layernorm>
{
migraphx::program create_program() const
{
migraphx::program p;
std::vector<size_t> dims{1, 1, 5};
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, dims});
auto scale = p.add_parameter("scale", migraphx::shape{migraphx::shape::float_type, {5}});
auto bias = p.add_parameter("bias", migraphx::shape{migraphx::shape::float_type, {5}});
auto epsilon = p.add_literal(1e-12f);
auto exponent = p.add_literal(migraphx::literal{
migraphx::shape{migraphx::shape::float_type, {1, 1, 5}}, {2, 2, 2, 2, 2}});
auto mean = p.add_instruction(migraphx::op::reduce_mean({2}), x);
auto mean_mbcast = p.add_instruction(migraphx::op::multibroadcast{{dims}}, mean);
auto sub = p.add_instruction(migraphx::op::sub{}, x, mean_mbcast);
auto pow = p.add_instruction(migraphx::op::pow{}, sub, exponent);
auto var = p.add_instruction(migraphx::op::reduce_mean({2}), pow);
auto epsilon_mbcast = p.add_instruction(migraphx::op::multibroadcast{{1, 1, 1}}, epsilon);
auto add_epsilon = p.add_instruction(migraphx::op::add{}, var, epsilon_mbcast);
auto sqrt = p.add_instruction(migraphx::op::sqrt{}, add_epsilon);
auto sqrt_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, sqrt);
auto div = p.add_instruction(migraphx::op::div{}, sub, sqrt_mbcast);
auto scale_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, scale);
auto mul = p.add_instruction(migraphx::op::mul{}, scale_mbcast, div);
auto bias_mbcast = p.add_instruction(migraphx::op::multibroadcast{dims}, bias);
p.add_instruction(migraphx::op::add{}, mul, bias_mbcast);
return p;
}
};
struct test_sigmoid : verify_program<test_sigmoid> struct test_sigmoid : verify_program<test_sigmoid>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
...@@ -1476,6 +1476,69 @@ def instance_norm_val_test(): ...@@ -1476,6 +1476,69 @@ def instance_norm_val_test():
return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor]) return ([node], [], [y], [x_tensor, scale_tensor, bias_tensor])
@onnx_test
def layernorm_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 1, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [1, 1, 5])
scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT, [5])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT, [5])
axes = [2]
pow_2 = np.array([[[2, 2, 2, 2, 2]]])
epsilon = np.array([1e-12])
pow_tensor = helper.make_tensor(name='pow',
data_type=TensorProto.FLOAT,
dims=pow_2.shape,
vals=pow_2.flatten().astype(np.float))
epsilon_tensor = helper.make_tensor(name='epsilon',
data_type=TensorProto.FLOAT,
dims=epsilon.shape,
vals=epsilon.flatten().astype(
np.float))
mean = onnx.helper.make_node('ReduceMean',
inputs=['0'],
outputs=['mean_out'],
axes=axes)
sub_mean = onnx.helper.make_node('Sub',
inputs=['0', 'mean_out'],
outputs=['sub_out'])
sub_pow = onnx.helper.make_node('Pow',
inputs=['sub_out', 'pow'],
outputs=['pow_out'])
var = onnx.helper.make_node('ReduceMean',
inputs=['pow_out'],
outputs=['var_out'],
axes=axes)
add = onnx.helper.make_node('Add',
inputs=['var_out', 'epsilon'],
outputs=['add_out'])
sqrt = onnx.helper.make_node('Sqrt',
inputs=['add_out'],
outputs=['sqrt_out'])
div = onnx.helper.make_node('Div',
inputs=['sub_out', 'sqrt_out'],
outputs=['div_out'])
mul = onnx.helper.make_node('Mul',
inputs=['scale', 'div_out'],
outputs=['mul_out'])
bias_add = onnx.helper.make_node('Add',
inputs=['mul_out', 'bias'],
outputs=['1'])
return ([mean, sub_mean, sub_pow, var, add, sqrt, div, mul,
bias_add], [x, scale, bias], [y], [pow_tensor, epsilon_tensor])
@onnx_test @onnx_test
def leaky_relu_test(): def leaky_relu_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [3])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment