Unverified Commit 3becd974 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Reduce types generated for hip kernels (#814)



* Remove unused data types

* Formatting

* Reduce types generated for hip kernels

* Formatting

* Fix onnx tests

* Formatting
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 40bab788
......@@ -12,35 +12,44 @@
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct alpha_beta
{
float alpha = 0.0;
float beta = 0.0;
};
alpha_beta get_alpha_beta(const operation& op)
{
auto v = op.to_value();
return {v.at("alpha").to<float>(), v.at("beta").to<float>()};
}
struct find_dot_add
{
auto matcher() const { return match::name("dot")(match::nargs(3)); }
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator());
if(not float_equal(dot.beta, 1) and
not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type()))
return;
auto ins = r.result;
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.alpha}});
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
auto dot_ins = p.insert_instruction(ins, make_op("dot", {{"beta", 0}}), a_ins, b_ins);
auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}});
auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
......@@ -51,24 +60,24 @@ struct find_dot_add
struct find_dot_alpha
{
auto matcher() const { return match::name("dot")(match::nargs(2)); }
auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator());
auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1))
{
auto alpha = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.alpha}});
auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction(
ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
}
p.replace_instruction(ins, make_op("dot", {{"beta", 0}}), a_ins, b_ins);
p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
}
};
......
......@@ -13,6 +13,8 @@ void eliminate_data_type::apply(module& m) const
{
if(ins->name()[0] == '@')
continue;
if(ins->name() == "convert")
continue;
auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0)
......
......@@ -7,9 +7,33 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void contiguous_nonstandard(hipStream_t stream, const argument& result, const argument& arg)
{
shape s{result.get_shape().type(), result.get_shape().lens()};
visit_all(result, arg)([&](auto output_v, auto input_v) {
hip_visit_views(output_v, input_v, s)([&](auto output, auto input, auto standard_shape) {
mi_gs_launch(stream,
standard_shape)([=](auto idx) __device__ { output[idx] = input[idx]; });
});
});
}
void contiguous_packed(hipStream_t stream, const argument& result, const argument& arg)
{
index_int nelements = result.get_shape().elements();
visit_all(result, arg)([&](auto output_v, auto input_v) {
const auto* input = device_cast(input_v.data());
auto* output = device_cast(output_v.data());
gs_launch(stream, nelements)([=](auto i) __device__ { output[i] = input[i]; });
});
}
void contiguous(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) __device__ { return x; });
if(result.get_shape() == arg.get_shape() and result.get_shape().packed())
contiguous_packed(stream, result, arg);
else
contiguous_nonstandard(stream, result, arg);
}
} // namespace device
......
......@@ -51,6 +51,50 @@ auto get_shape(const T& x) -> decltype(x.get_shape())
return x.get_shape();
}
template <class T>
struct is_hip_type : std::false_type
{
};
template <>
struct is_hip_type<float> : std::true_type
{
};
template <>
struct is_hip_type<half> : std::true_type
{
};
template <>
struct is_hip_type<bool> : std::true_type
{
};
template <>
struct is_hip_type<std::int8_t> : std::true_type
{
};
template <>
struct is_hip_type<std::uint8_t> : std::true_type
{
};
template <class T, class V, MIGRAPHX_REQUIRES(is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T as, V&& v)
{
v(as);
}
template <class T, class V, MIGRAPHX_REQUIRES(not is_hip_type<typename T::type>{})>
void hip_visitor_invoke(T, V&&)
{
MIGRAPHX_THROW(std::string("Unsupported data type on GPU: ") + __PRETTY_FUNCTION__);
}
template <class V>
auto hip_visitor(V v)
{
return [=](auto as) { hip_visitor_invoke(as, v); };
}
template <class V, class F, class... Ts>
void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
......@@ -62,8 +106,9 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
static_cast<index_int>(get_shape(xs).lens().size())...};
if(!std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(),
[&](auto ndim) { s.visit_type([&](auto as) { v(f(xs, ndim, as)...); }); });
visit_tensor_size(s.lens().size(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
});
}
template <class V, class F, class... Ts>
......
......@@ -7,6 +7,7 @@
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/eliminate_concat.hpp>
#include <migraphx/eliminate_contiguous.hpp>
#include <migraphx/eliminate_data_type.hpp>
#include <migraphx/eliminate_identity.hpp>
#include <migraphx/eliminate_pad.hpp>
#include <migraphx/memory_coloring.hpp>
......@@ -43,12 +44,19 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_options& options) const
{
auto& ctx = any_cast<context>(gctx);
std::set<shape::type_t> unsupported_types(shape::types().begin(), shape::types().end());
unsupported_types.erase(shape::type_t::float_type);
unsupported_types.erase(shape::type_t::half_type);
unsupported_types.erase(shape::type_t::bool_type);
unsupported_types.erase(shape::type_t::int8_type);
unsupported_types.erase(shape::type_t::uint8_type);
// clang-format off
return
{
normalize_ops{},
decompose{},
dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{},
eliminate_identity{},
eliminate_pad{},
......
......@@ -128,8 +128,21 @@ TEST_CASE(dot_add_beta_int)
m1.add_instruction(migraphx::make_op("dot", {{"alpha", 1.0}, {"beta", 0.5}}), x, y, z);
m1.add_instruction(migraphx::make_op("identity"), dot);
}
migraphx::module m2 = m1;
run_pass(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto y = m2.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto z = m2.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}});
auto dot = m2.add_instruction(migraphx::make_op("dot", {{"alpha", 1}, {"beta", 0}}), x, y);
auto beta =
m2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int32_type}, {0.5}});
auto beta_broadcast = m2.add_instruction(
migraphx::make_op("multibroadcast", {{"output_lens", {2, 2}}}), beta);
auto mul = m2.add_instruction(migraphx::make_op("mul"), z, beta_broadcast);
auto add = m2.add_instruction(migraphx::make_op("add"), dot, mul);
m2.add_instruction(migraphx::make_op("identity"), add);
}
EXPECT(m1 == m2);
}
......
......@@ -14,7 +14,7 @@ const std::string write_2s = R"__migraphx__(
#include <hip/hip_runtime.h>
extern "C" {
__global__ void write(int* data)
__global__ void write(int8_t* data)
{
int num = threadIdx.x + blockDim.x * blockIdx.x;
data[num] = 2;
......@@ -31,7 +31,7 @@ const std::string add_2s_binary = R"__migraphx__(
#include <hip/hip_runtime.h>
extern "C" {
__global__ void add_2(std::int32_t* x, std::int32_t* y)
__global__ void add_2(std::int8_t* x, std::int8_t* y)
{
int num = threadIdx.x + blockDim.x * blockIdx.x;
y[num] = x[num] + 2;
......@@ -89,14 +89,14 @@ TEST_CASE(simple_compile_hip)
{make_src_file("main.cpp", write_2s)}, "", get_device_name());
EXPECT(binaries.size() == 1);
migraphx::argument input{{migraphx::shape::int32_type, {5}}};
migraphx::argument input{{migraphx::shape::int8_type, {5}}};
auto ginput = migraphx::gpu::to_gpu(input);
migraphx::gpu::kernel k{binaries.front(), "write"};
k.launch(nullptr, input.get_shape().elements(), 1024)(ginput.cast<int>());
k.launch(nullptr, input.get_shape().elements(), 1024)(ginput.cast<std::int8_t>());
auto output = migraphx::gpu::from_gpu(ginput);
EXPECT(output != input);
auto data = output.get<int>();
auto data = output.get<std::int8_t>();
EXPECT(migraphx::all_of(data, [](auto x) { return x == 2; }));
}
......@@ -106,7 +106,7 @@ TEST_CASE(code_object_hip)
{make_src_file("main.cpp", add_2s_binary)}, "", get_device_name());
EXPECT(binaries.size() == 1);
migraphx::shape input{migraphx::shape::int32_type, {5}};
migraphx::shape input{migraphx::shape::int8_type, {5}};
std::vector<migraphx::shape> expected_inputs = {input, input};
auto co = migraphx::make_op("gpu::code_object",
......
......@@ -37,7 +37,8 @@ inline void compile_check(migraphx::program& p, const migraphx::target& t, bool
auto shapes = p.get_output_shapes();
std::stringstream ss;
migraphx::compile_options options;
options.trace = migraphx::tracer{ss};
if(show_trace)
options.trace = migraphx::tracer{std::cout};
p.compile(t, options);
if(shapes.size() != p.get_output_shapes().size())
{
......@@ -55,11 +56,6 @@ inline void compile_check(migraphx::program& p, const migraphx::target& t, bool
throw std::runtime_error("Compiling program with " + name + " alters its shape");
}
}
if(show_trace)
{
std::cout << ss.str() << std::endl;
}
}
target_info run_verify::get_target_info(const std::string& name) const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment