Unverified Commit 3becd974 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Reduce types generated for hip kernels (#814)



* Remove unused data types

* Formatting

* Reduce types generated for hip kernels

* Formatting

* Fix onnx tests

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 40bab788
...@@ -12,35 +12,44 @@ ...@@ -12,35 +12,44 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace { namespace {
struct alpha_beta
{
float alpha = 0.0;
float beta = 0.0;
};
alpha_beta get_alpha_beta(const operation& op)
{
auto v = op.to_value();
return {v.at("alpha").to<float>(), v.at("beta").to<float>()};
}
struct find_dot_add struct find_dot_add
{ {
auto matcher() const { return match::name("dot")(match::nargs(3)); } auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator()); auto dot = get_alpha_beta(ins->get_operator());
if(not float_equal(dot.beta, 1) and
not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type()))
return;
auto a_ins = ins->inputs()[0]; auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1]; auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1)) if(not float_equal(dot.alpha, 1))
{ {
auto alpha = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.alpha}}); auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction( auto alpha_broadcast = p.insert_instruction(
ins, ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}), make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha); alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast); a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
} }
auto dot_ins = p.insert_instruction(ins, make_op("dot", {{"beta", 0}}), a_ins, b_ins); auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
auto c_ins = ins->inputs()[2]; auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1)) if(not float_equal(dot.beta, 1))
{ {
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}}); auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction( auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta); ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast); c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
...@@ -51,24 +60,24 @@ struct find_dot_add ...@@ -51,24 +60,24 @@ struct find_dot_add
struct find_dot_alpha struct find_dot_alpha
{ {
auto matcher() const { return match::name("dot")(match::nargs(2)); } auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator()); auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0]; auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1]; auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1)) if(not float_equal(dot.alpha, 1))
{ {
auto alpha = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.alpha}}); auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction( auto alpha_broadcast = p.insert_instruction(
ins, ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}), make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha); alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast); a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
} }
p.replace_instruction(ins, make_op("dot", {{"beta", 0}}), a_ins, b_ins); p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
} }
}; };
......
...@@ -13,6 +13,8 @@ void eliminate_data_type::apply(module& m) const ...@@ -13,6 +13,8 @@ void eliminate_data_type::apply(module& m) const
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
continue; continue;
if(ins->name() == "convert")
continue;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0) if(types.count(i->get_shape().type()) == 0)
......
...@@ -7,9 +7,33 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -7,9 +7,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
namespace device { namespace device {
void contiguous_nonstandard(hipStream_t stream, const argument& result, const argument& arg)
{
shape s{result.get_shape().type(), result.get_shape().lens()};
visit_all(result, arg)([&](auto output_v, auto input_v) {
hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) {
mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; });
});
});
}
void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg)
{
index_int nelements = result.get_shape().elements();
visit_all(result, arg)([&](auto output_v, auto input_v) {
const auto* input = device_cast(input_v.data());
auto* output = device_cast(output_v.data());
gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = input[i]; });
});
}
void contiguous(hipStream_t stream, const argument& result, const argument& arg) void contiguous(hipStream_t stream, const argument& result, const argument& arg)
{ {
nary(stream, result, arg)([](auto x) __device__ { return x; }); if(result.get_shape() == arg.get_shape() and result.get_shape().packed())
contiguous_packed(stream, result, arg);
else
contiguous_nonstandard(stream, result, arg);
} }
} // namespace device } // namespace device
......
...@@ -51,6 +51,50 @@ auto get_shape(const T& x) -> decltype(x.get_shape()) ...@@ -51,6 +51,50 @@ auto get_shape(const T& x) -> decltype(x.get_shape())
return x.get_shape(); return x.get_shape();
} }
template <class T>
struct is_hip_type : std::false_type
{
};
template <>
struct is_hip_type<float> : std::true_type
{
};
template <>
struct is_hip_type<half> : std::true_type
{
};
template <>
struct is_hip_type<bool> : std::true_type
{
};
template <>
struct is_hip_type<std::int8_t> : std::true_type
{
};
template <>
struct is_hip_type<std::uint8_t> : std::true_type
{
};
template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v)
{
v(as);
}
template <class T, class V, MIGRAPHX_REQUIRES(not is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T, V&&)
{
MIGRAPHX_THROW(std::string("Unsupported data type on GPU: ") + __PRETTY_FUNCTION__);
}
template <class V>
auto hip_visitor(V v)
{
return [=](auto as) { hip_visitor_invoke(as, v); };
}
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
...@@ -62,8 +106,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -62,8 +106,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
static_cast<index_int>(get_shape(xs).lens().size())...}; static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); })) if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), visit_tensor_size(s.lens().size(), [&](auto ndim) {
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); }); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
});
} }
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp> #include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp> #include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp> #include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp> #include <migraphx/eliminate_pad.hpp>
#include <migraphx/memory_coloring.hpp> #include <migraphx/memory_coloring.hpp>
...@@ -43,12 +44,19 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) ...@@ -43,12 +44,19 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{ {
auto& ctx = any_cast<context>(gctx); auto& ctx = any_cast<context>(gctx);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type);
// clang-format off // clang-format off
return return
{ {
normalize_ops{}, normalize_ops{},
decompose{}, decompose{},
dead_code_elimination{}, dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{}, simplify_reshapes{},
eliminate_identity{}, eliminate_identity{},
eliminate_pad{}, eliminate_pad{},
......
...@@ -128,8 +128,21 @@ TEST_CASE(dot_add_beta_int) ...@@ -128,8 +128,21 @@ TEST_CASE(dot_add_beta_int)
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z); m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot); m1.add_instruction(migraphx::make_op("identity"), dot);
} }
migraphx::module m2 = m1;
run_pass(m1); run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
......
...@@ -14,7 +14,7 @@ const std::string write_2s = R"__migraphx__( ...@@ -14,7 +14,7 @@ const std::string write_2s = R"__migraphx__(
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
extern "C" { extern "C" {
__global__ void write(int* data) __global__ void write(int8_t* data)
{ {
int num = threadIdx.x + blockDim.x * blockIdx.x; int num = threadIdx.x + blockDim.x * blockIdx.x;
data[num] = 2; data[num] = 2;
...@@ -31,7 +31,7 @@ const std::string add_2s_binary = R"__migraphx__( ...@@ -31,7 +31,7 @@ const std::string add_2s_binary = R"__migraphx__(
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
extern "C" { extern "C" {
__global__ void add_2(std::int32_t* x, std::int32_t* y) __global__ void add_2(std::int8_t* x, std::int8_t* y)
{ {
int num = threadIdx.x + blockDim.x * blockIdx.x; int num = threadIdx.x + blockDim.x * blockIdx.x;
y[num] = x[num] + 2; y[num] = x[num] + 2;
...@@ -89,14 +89,14 @@ TEST_CASE(simple_compile_hip) ...@@ -89,14 +89,14 @@ TEST_CASE(simple_compile_hip)
{make_src_file("main.cpp", write_2s)}, "", get_device_name()); {make_src_file("main.cpp", write_2s)}, "", get_device_name());
EXPECT(binaries.size() == 1); EXPECT(binaries.size() == 1);
migraphx::argument input{{migraphx::shape::int32_type, {5}}}; migraphx::argument input{{migraphx::shape::int8_type, {5}}};
auto ginput = migraphx::gpu::to_gpu(input); auto ginput = migraphx::gpu::to_gpu(input);
migraphx::gpu::kernel k{binaries.front(), "write"}; migraphx::gpu::kernel k{binaries.front(), "write"};
k.launch(nullptr, input.get_shape().elements(), 1024)(ginput.cast<int>()); k.launch(nullptr, input.get_shape().elements(), 1024)(ginput.cast<std::int8_t>());
auto output = migraphx::gpu::from_gpu(ginput); auto output = migraphx::gpu::from_gpu(ginput);
EXPECT(output != input); EXPECT(output != input);
auto data = output.get<int>(); auto data = output.get<std::int8_t>();
EXPECT(migraphx::all_of(data, [](auto x) { return x == 2; })); EXPECT(migraphx::all_of(data, [](auto x) { return x == 2; }));
} }
...@@ -106,7 +106,7 @@ TEST_CASE(code_object_hip) ...@@ -106,7 +106,7 @@ TEST_CASE(code_object_hip)
{make_src_file("main.cpp", add_2s_binary)}, "", get_device_name()); {make_src_file("main.cpp", add_2s_binary)}, "", get_device_name());
EXPECT(binaries.size() == 1); EXPECT(binaries.size() == 1);
migraphx::shape input{migraphx::shape::int32_type, {5}}; migraphx::shape input{migraphx::shape::int8_type, {5}};
std::vector<migraphx::shape> expected_inputs = {input, input}; std::vector<migraphx::shape> expected_inputs = {input, input};
auto co = migraphx::make_op("gpu::code_object", auto co = migraphx::make_op("gpu::code_object",
......
...@@ -37,7 +37,8 @@ inline void compile_check(migraphx::program& p, const migraphx::target& t, bool ...@@ -37,7 +37,8 @@ inline void compile_check(migraphx::program& p, const migraphx::target& t, bool
auto shapes = p.get_output_shapes(); auto shapes = p.get_output_shapes();
std::stringstream ss; std::stringstream ss;
migraphx::compile_options options; migraphx::compile_options options;
options.trace = migraphx::tracer{ss}; if(show_trace)
options.trace = migraphx::tracer{std::cout};
p.compile(t, options); p.compile(t, options);
if(shapes.size() != p.get_output_shapes().size()) if(shapes.size() != p.get_output_shapes().size())
{ {
...@@ -55,11 +56,6 @@ inline void compile_check(migraphx::program& p, const migraphx::target& t, bool ...@@ -55,11 +56,6 @@ inline void compile_check(migraphx::program& p, const migraphx::target& t, bool
throw std::runtime_error("Compiling program with " + name + " alters its shape"); throw std::runtime_error("Compiling program with " + name + " alters its shape");
} }
} }
if(show_trace)
{
std::cout << ss.str() << std::endl;
}
} }
target_info run_verify::get_target_info(const std::string& name) const target_info run_verify::get_target_info(const std::string& name) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment