Commit 421ecad6 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/develop' into ck-gsg

parents 3d0426e9 7cf05301
...@@ -100,10 +100,10 @@ struct find_static_2in_broadcasts ...@@ -100,10 +100,10 @@ struct find_static_2in_broadcasts
} // namespace } // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. * Makes all the shapes in the dynamic_dimension range. Probably won't work for `if`
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those * and `loop` instructions, depending on how the submodules for those
* work. Inserts select_module instruction to the top. Replaces return, bypassing other * work. Inserts select_module instruction to the top. Replaces return, bypassing other
* instructions. * instructions. Skips if the dynamic parameter outputs to a select_module operator.
*/ */
void split_single_dyn_dim::apply(module_pass_manager& mpm) const void split_single_dyn_dim::apply(module_pass_manager& mpm) const
{ {
...@@ -111,7 +111,13 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -111,7 +111,13 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
auto param_names = mm->get_parameter_names(); auto param_names = mm->get_parameter_names();
auto param_shapes = mm->get_parameter_shapes(); auto param_shapes = mm->get_parameter_shapes();
optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes); optional<dynamic_dimensions_check> dd_check = has_one_dyn_dim(param_shapes);
if(dd_check.has_value()) auto any_sm_next = [&](auto ddc) {
auto p_outputs = mm->get_parameter(ddc->dyn_param_str)->outputs();
return std::any_of(p_outputs.cbegin(), p_outputs.cend(), [](auto ins) {
return ins->name() == "select_module";
});
};
if(dd_check.has_value() and not any_sm_next(dd_check))
{ {
const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str); const auto& dyn_param = mm->get_parameter(dd_check->dyn_param_str);
auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str); auto dyn_param_shape = mm->get_parameter_shape(dd_check->dyn_param_str);
......
...@@ -33,11 +33,7 @@ if(NOT TARGET MIOpen) ...@@ -33,11 +33,7 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
if(BUILD_DEV) set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipClang APIs")
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif()
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/module.hpp> #include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp> #include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp> #include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/rewrite_quantization.hpp>
#include <migraphx/cpp_generator.hpp> #include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp> #include <migraphx/instruction.hpp>
...@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers) ...@@ -171,7 +172,8 @@ std::string make_transformer_args(std::vector<std::string> transformers)
void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name) void generate_pointwise(cpp_generator& gg, const module& pm, const std::string& name)
{ {
module m = pm; module m = pm;
run_passes(m, {eliminate_common_subexpression{}, dead_code_elimination{}}); run_passes(m,
{rewrite_quantization{}, eliminate_common_subexpression{}, dead_code_elimination{}});
cpp_generator g; cpp_generator g;
g.fmap([](const std::string& fname) { return "migraphx::" + fname; }); g.fmap([](const std::string& fname) { return "migraphx::" + fname; });
g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})"); g.add_point_op("where", "${function:where}(${0}, ${1}, ${2})");
...@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -280,6 +282,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not input->get_shape().broadcasted(); not input->get_shape().broadcasted();
}); });
auto inner_names = names; auto inner_names = names;
for(auto input : ins->inputs())
{
if(input->name() != "@param")
continue;
if(contains(tensors, input))
continue;
inner_names[input] += "[out_idx]";
}
for(auto input : tensors) for(auto input : tensors)
inner_names[input] += "_lambda_param"; inner_names[input] += "_lambda_param";
auto call_function = auto call_function =
...@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -308,6 +318,8 @@ std::string generate_reduce(const module& m, const std::string& name)
}); });
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name); f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r"); f.add_generic_param("r");
f.add_generic_param("out_idx");
f.unused_param("out_idx");
g.create_function(f); g.create_function(f);
return g.str(); return g.str();
} }
......
...@@ -120,12 +120,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -120,12 +120,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(not std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.ndim(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
}); });
} }
...@@ -133,12 +131,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -133,12 +131,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.ndim(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
template <class F> template <class F>
......
...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable(gpu-driver add_executable(gpu-driver
${GPU_DRIVER_SRCS} ${GPU_DRIVER_SRCS}
) )
rocm_clang_tidy_check(gpu-driver)
target_include_directories(gpu-driver PRIVATE include) target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu) target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
...@@ -44,7 +44,7 @@ struct auto_register_action ...@@ -44,7 +44,7 @@ struct auto_register_action
template <class T> template <class T>
static void apply() static void apply()
{ {
auto name = get_type_name<T>(); const auto& name = get_type_name<T>();
register_action(name.substr(name.rfind("::") + 2), register_action(name.substr(name.rfind("::") + 2),
[](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); }); [](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); });
} }
......
...@@ -38,6 +38,27 @@ namespace gpu { ...@@ -38,6 +38,27 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR); MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_MLIR);
bool mlir_enabled()
{
#ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{});
if(mlir_enabled)
{
return true;
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
return false;
}
#else
return false;
#endif
}
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
struct mlir_op struct mlir_op
...@@ -58,8 +79,41 @@ struct mlir_op ...@@ -58,8 +79,41 @@ struct mlir_op
MIGRAPHX_THROW("should have one submodule."); MIGRAPHX_THROW("should have one submodule.");
if(inputs.size() < 2) if(inputs.size() < 2)
MIGRAPHX_THROW("should have at least two inputs."); MIGRAPHX_THROW("should have at least two inputs.");
auto n = inputs.size();
return op.compute_shape({inputs[n - 2], inputs[n - 1]}); module_ref mod = mods[0];
auto type = mod->get_output_shapes().front().type();
std::unordered_map<instruction_ref, shape> ins_shapes;
size_t param_cnt = 0;
std::vector<std::string> names = mod->get_parameter_names();
std::sort(names.begin(), names.end());
for(std::string param_name : names)
{
ins_shapes[mod->get_parameter(param_name)] = inputs[param_cnt++];
}
for(auto ins : iterator_for(*mod))
{
if(ins->name() == "@param")
{
continue;
}
if(ins->name() == "@literal")
{
ins_shapes[ins] = ins->get_shape();
continue;
}
if(ins->name() == "@return")
{
return ins_shapes[ins->inputs().at(0)].with_type(type);
}
std::vector<shape> input_shapes;
input_shapes.resize(ins->inputs().size());
std::transform(ins->inputs().begin(),
ins->inputs().end(),
input_shapes.begin(),
[&](auto in) { return ins_shapes[in]; });
ins_shapes[ins] = ins->get_operator().compute_shape(input_shapes);
}
MIGRAPHX_THROW("No return found in the submodule");
} }
}; };
MIGRAPHX_REGISTER_OP(mlir_op); MIGRAPHX_REGISTER_OP(mlir_op);
...@@ -68,7 +122,7 @@ namespace { ...@@ -68,7 +122,7 @@ namespace {
MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins) MIGRAPHX_PRED_MATCHER(is_mlir_conv, instruction_ref ins)
{ {
if(ins->name() != "convolution") if(ins->name() != "convolution" and ins->name() != "quant_convolution")
return false; return false;
value v = ins->get_operator().to_value(); value v = ins->get_operator().to_value();
auto group = v.at("group").to<int>(); auto group = v.at("group").to<int>();
...@@ -89,6 +143,53 @@ struct find_mlir_op ...@@ -89,6 +143,53 @@ struct find_mlir_op
return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x"))); return match::name("pointwise")(match::any_of[match::inputs()](dot_or_conv.bind("x")));
} }
std::unordered_map<instruction_ref, instruction_ref>
create_param_map_with_literals(module_ref mm, const module* pm, const shape& shape) const
{
std::unordered_map<instruction_ref, instruction_ref> ins_map;
for(auto ins : iterator_for(*pm))
{
if(ins->name() != "@literal")
{
continue;
}
literal r = ins->get_literal();
instruction_ref literal = mm->add_literal(r);
instruction_ref mbcast = mm->add_instruction(
make_op("multibroadcast", {{"out_lens", shape.lens()}}), literal);
ins_map[ins] = mbcast;
}
return ins_map;
}
std::tuple<instruction_ref, std::vector<instruction_ref>>
fuse_input_ops_and_gemm_based_op(module_ref mm, instruction_ref gemm_based_op) const
{
std::vector<instruction_ref> top_inputs;
std::vector<instruction_ref> imm_inputs;
size_t input_cnt = 0;
for(instruction_ref input : gemm_based_op->inputs())
{
std::vector<operation> op_stream;
while(contains({"slice", "transpose", "contiguous", "reshape"}, input->name()))
{
op_stream.push_back(input->get_operator());
input = input->inputs().at(0);
}
top_inputs.push_back(input);
instruction_ref prev_input =
mm->add_parameter("y" + std::to_string(input_cnt++), input->get_shape());
for(const auto& op : reverse(op_stream))
{
prev_input = mm->add_instruction(op, {prev_input});
}
imm_inputs.push_back(prev_input);
}
instruction_ref new_gemm_based_op =
mm->add_instruction(gemm_based_op->get_operator(), imm_inputs);
return {new_gemm_based_op, top_inputs};
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
...@@ -98,33 +199,42 @@ struct find_mlir_op ...@@ -98,33 +199,42 @@ struct find_mlir_op
auto names = pm->get_parameter_names(); auto names = pm->get_parameter_names();
// Whitelist pointwise operators // Whitelist pointwise operators
if(std::any_of(pm->begin(), pm->end(), [](const auto& i) { if(std::any_of(pm->begin(), pm->end(), [](const auto& i) {
return not contains( return not contains({"@literal",
{"@literal", "@param", "@return", "convolution", "dot", "add", "relu"}, "@param",
i.name()); "@return",
"convolution",
"quant_convolution",
"dot",
"add",
"relu",
"dequantizelinear",
"quantizelinear",
"mul"},
i.name());
})) }))
return; return;
// Only fuse with fp32/fp16 // Only fuse with fp32/fp16/int8/int32
if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) { if(std::any_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return not contains({shape::type_t::float_type, shape::type_t::half_type}, return not contains({shape::type_t::float_type,
shape::type_t::half_type,
shape::type_t::int8_type,
shape::type_t::int32_type},
i->get_shape().type()); i->get_shape().type());
})) }))
return; return;
std::sort(names.begin(), names.end()); std::sort(names.begin(), names.end());
module_ref mm = mpm.create_module("mlir_" + pm->name()); module_ref mm = mpm.create_module("mlir_" + pm->name());
mm->set_bypass(); mm->set_bypass();
std::unordered_map<instruction_ref, instruction_ref> param_map; std::unordered_map<instruction_ref, instruction_ref> param_map =
auto x = mm->add_parameter("x" + std::to_string(names.size()), create_param_map_with_literals(mm, pm, gemm_based_op->get_shape());
gemm_based_op->inputs().at(0)->get_shape()); auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(mm, gemm_based_op);
auto w = mm->add_parameter("x" + std::to_string(names.size() + 1),
gemm_based_op->inputs().at(1)->get_shape());
auto conv = mm->add_instruction(gemm_based_op->get_operator(), {x, w});
std::transform(names.begin(), std::transform(names.begin(),
names.end(), names.end(),
ins->inputs().begin(), ins->inputs().begin(),
std::inserter(param_map, param_map.end()), std::inserter(param_map, param_map.end()),
[&](auto name, auto input) { [&, &anchor_op = anchor_op](auto name, auto input) {
if(input == x_ins) if(input == x_ins)
return std::make_pair(pm->get_parameter(name), conv); return std::make_pair(pm->get_parameter(name), anchor_op);
return std::make_pair(pm->get_parameter(name), return std::make_pair(pm->get_parameter(name),
mm->add_parameter(name, input->get_shape())); mm->add_parameter(name, input->get_shape()));
}); });
...@@ -135,7 +245,7 @@ struct find_mlir_op ...@@ -135,7 +245,7 @@ struct find_mlir_op
ins->inputs().end(), ins->inputs().end(),
std::back_inserter(inputs), std::back_inserter(inputs),
[&](auto input) { return input != gemm_based_op; }); [&](auto input) { return input != gemm_based_op; });
inputs.insert(inputs.end(), gemm_based_op->inputs().begin(), gemm_based_op->inputs().end()); inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
mpm.get_module().replace_instruction( mpm.get_module().replace_instruction(
ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm}); ins, mlir_op{gemm_based_op->get_operator()}, inputs, {mm});
} }
...@@ -148,17 +258,7 @@ struct find_mlir_op ...@@ -148,17 +258,7 @@ struct find_mlir_op
void fuse_mlir::apply(module_pass_manager& mpm) const void fuse_mlir::apply(module_pass_manager& mpm) const
{ {
#ifdef MIGRAPHX_MLIR #ifdef MIGRAPHX_MLIR
const bool mlir_enabled = enabled(MIGRAPHX_ENABLE_MLIR{}); match::find_matches(mpm, find_mlir_op{});
if(mlir_enabled)
{
match::find_matches(mpm, find_mlir_op{});
}
else
{
std::cerr << "WARNING: MIGraphX built with MLIR but it is not enabled. Please set the env "
"var MIGRAPHX_ENABLE_MLIR to use MLIR kernel generator."
<< std::endl;
}
#else #else
(void)mpm; (void)mpm;
#endif #endif
......
...@@ -157,13 +157,8 @@ void gemm_impl(context& ctx, ...@@ -157,13 +157,8 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
......
...@@ -34,6 +34,8 @@ struct module_pass_manager; ...@@ -34,6 +34,8 @@ struct module_pass_manager;
namespace gpu { namespace gpu {
bool mlir_enabled();
struct fuse_mlir struct fuse_mlir
{ {
context* ctx = nullptr; context* ctx = nullptr;
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/dyn_output.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -112,7 +113,7 @@ struct hip_copy_to_gpu ...@@ -112,7 +113,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -121,6 +122,10 @@ struct hip_copy_to_gpu ...@@ -121,6 +122,10 @@ struct hip_copy_to_gpu
if(args.size() == 1) if(args.size() == 1)
return input; return input;
argument result = args[1].share(); argument result = args[1].share();
if(result.get_shape().dynamic())
{
result = result.reshape(args[0].get_shape());
}
gpu_copy(ctx, input, result); gpu_copy(ctx, input, result);
// Associate the input since it was registered with hip // Associate the input since it was registered with hip
return {result.get_shape(), [input, result]() mutable { return result.data(); }}; return {result.get_shape(), [input, result]() mutable { return result.data(); }};
...@@ -138,19 +143,24 @@ struct hip_copy_from_gpu ...@@ -138,19 +143,24 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const dyn_output& dyn_out, const std::vector<argument>& args) const
{ {
if(args.size() == 1) if(args.size() == 1)
{ {
argument result = allocate_gpu(output_shape, true); argument result = allocate_gpu(dyn_out.computed_shape, true);
gpu_copy(ctx, args[0], result); gpu_copy(ctx, args[0], result);
return result; return result;
} }
copy_from_gpu(ctx, args[0], args[1]); argument input = args[0].share();
if(input.get_shape().dynamic())
{
input = input.reshape(args[1].get_shape());
}
copy_from_gpu(ctx, input, args[1]);
return args[1]; return args[1];
} }
std::ptrdiff_t output_alias(const std::vector<shape>& args) const std::ptrdiff_t output_alias(const std::vector<shape>& args) const
......
...@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs) ...@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template <class... Ts> template <class... Ts>
__device__ void println(Ts... xs) __device__ void println(Ts... xs)
{ {
print_each(&coutln, xs...); print_each(&cout, xs..., '\n');
} }
template <class... Ts> template <class... Ts>
__device__ void println_once(Ts... xs) __device__ void println_once(Ts... xs)
{ {
print_each_once(&coutln, xs...); print_each_once(&cout, xs..., '\n');
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -570,7 +570,7 @@ template <class Algo, class Reduced, class Output, class F> ...@@ -570,7 +570,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f) __device__ void fused_reduce(Output output, F f)
{ {
Algo::template run<Reduced>([&](auto out_idx, auto r) { Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r); auto result = f(r, out_idx);
if constexpr(reduce::is_inner_storage<decltype(result)>{}) if constexpr(reduce::is_inner_storage<decltype(result)>{})
{ {
r.inner([&](auto& y, auto x) { y = x; })(output, result); r.inner([&](auto& y, auto x) { y = x; })(output, result);
......
...@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i) ...@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
return vec<T, N>{x}; return vec<T, N>{x};
else else
{ {
MIGRAPHX_ASSERT((i + N) < vec_size<T>()); MIGRAPHX_ASSERT((i + N) <= vec_size<T>());
vec<vec_type<T>, N> result = {0}; vec<vec_type<T>, N> result = {0};
for(int j = 0; j < N; j++) for(int j = 0; j < N; j++)
{ {
......
...@@ -197,10 +197,14 @@ struct mlir_program ...@@ -197,10 +197,14 @@ struct mlir_program
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if(as.is_integral()) else if(as.is_integral())
{ {
if(as.is_signed()) // Note: rocMLIR use signless integer type for tensors types. This
result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8); // will translate to signed implementation for current supported
else // operations.
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8); if(as.is_unsigned())
{
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
}
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
} }
else else
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum())); MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
...@@ -320,7 +324,8 @@ struct mlir_program ...@@ -320,7 +324,8 @@ struct mlir_program
std::string, std::string,
value, value,
std::vector<value>, std::vector<value>,
MlirType>; MlirType,
MlirAttribute>;
using named_attribute_t = std::pair<std::string_view, attribute_t>; using named_attribute_t = std::pair<std::string_view, attribute_t>;
MlirNamedAttribute name_attribute(const named_attribute_t& na) const MlirNamedAttribute name_attribute(const named_attribute_t& na) const
...@@ -477,13 +482,17 @@ struct mlir_program ...@@ -477,13 +482,17 @@ struct mlir_program
{ {
if(ins->name() == "@return") if(ins->name() == "@return")
return "func.return"; return "func.return";
if(ins->name() == "@literal")
{
return "tosa.const";
}
return "migraphx." + ins->name(); return "migraphx." + ins->name();
} }
static value get_operator_value(const operation& op) static value get_operator_value(const operation& op)
{ {
auto v = op.to_value(); auto v = op.to_value();
if(op.name() == "convolution") if(op.name() == "convolution" or op.name() == "quant_convolution")
{ {
// Adjust symetrical padding // Adjust symetrical padding
if(v.at("padding").size() == v.at("stride").size()) if(v.at("padding").size() == v.at("stride").size())
...@@ -528,11 +537,24 @@ struct mlir_program ...@@ -528,11 +537,24 @@ struct mlir_program
{ {
if(ins->name() == "@param") if(ins->name() == "@param")
continue; continue;
if(ins->name() == "contiguous")
{
ins_map[ins] = ins_map[ins->inputs().at(0)];
continue;
}
auto name = get_name(ins); auto name = get_name(ins);
auto ops = create_operation_state(name); auto ops = create_operation_state(name);
ops.add_attribute_value(get_operator_value(ins->get_operator())); ops.add_attribute_value(get_operator_value(ins->get_operator()));
if(ins->name() != "@return") if(ins->name() != "@return")
ops.add_results({get_shape(ins)}); ops.add_results({get_shape(ins)});
if(ins->name() == "@literal")
{
literal r = ins->get_literal();
MlirType tensor_type = make_tensor(ins->get_shape());
MlirAttribute mlir_value_attr =
mlirDenseElementsAttrRawBufferGet(tensor_type, r.get_shape().bytes(), r.data());
ops.add_attributes({{"value", mlir_value_attr}});
}
if(ins->name() == "convolution" or ins->name() == "dot") if(ins->name() == "convolution" or ins->name() == "dot")
{ {
pp = pp =
...@@ -735,12 +757,13 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct ...@@ -735,12 +757,13 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
{ {
adjust_param_shapes(m, inputs); adjust_param_shapes(m, inputs);
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace)
std::cout << m << std::endl;
// set mutex while llvm thread support is disabled. // set mutex while llvm thread support is disabled.
static std::mutex g_mlirc_mutex; // NOLINT static std::mutex g_mlirc_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_mlirc_mutex); const std::lock_guard<std::mutex> lock(g_mlirc_mutex);
if(trace)
std::cout << m << std::endl;
mlir_program mp; mlir_program mp;
mp.find_target(); mp.find_target();
mp.parse(m); mp.parse(m);
......
...@@ -55,24 +55,15 @@ const std::unordered_set<std::string>& get_rocblas_fp32_archs() ...@@ -55,24 +55,15 @@ const std::unordered_set<std::string>& get_rocblas_fp32_archs()
bool get_compute_fp32_flag() bool get_compute_fp32_flag()
{ {
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front()); const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name)) return contains(get_rocblas_fp32_archs(), device_name);
compute_fp32 = true;
#endif
return compute_fp32;
} }
bool get_int8_x4_format(context& ctx) bool get_int8_x4_format(context& ctx)
{ {
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); return flag == rocblas_gemm_flags_pack_int8x4;
#endif
return int8_x4_format;
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -74,7 +74,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -74,7 +74,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass struct id_pass
...@@ -104,12 +103,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -104,12 +103,12 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
// clang-format off // clang-format off
return return
{ {
enable_pass(options.split_single_dyn_dim, split_single_dyn_dim{}), split_single_dyn_dim{},
enable_pass(options.split_single_dyn_dim, dead_code_elimination{}), dead_code_elimination{},
normalize_ops{}, normalize_ops{},
dead_code_elimination{}, dead_code_elimination{},
simplify_qdq{}, simplify_qdq{},
rewrite_quantization{}, enable_pass(not mlir_enabled(), rewrite_quantization{}),
dead_code_elimination{}, dead_code_elimination{},
eliminate_data_type{unsupported_types, shape::type_t::float_type}, eliminate_data_type{unsupported_types, shape::type_t::float_type},
simplify_reshapes{}, simplify_reshapes{},
...@@ -133,11 +132,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -133,11 +132,11 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
fuse_ck_gemm_softmax_gemm{&ctx}, fuse_ck_gemm_softmax_gemm{&ctx},
dead_code_elimination{}, dead_code_elimination{},
optimize_module{}, optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}), fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
fuse_mlir{&ctx}, enable_pass(mlir_enabled(), fuse_mlir{&ctx}),
dead_code_elimination{}, dead_code_elimination{},
fuse_ck{&ctx}, fuse_ck{&ctx},
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers() ...@@ -46,6 +46,7 @@ std::vector<std::string> get_op_parsers()
op_parser_map().end(), op_parser_map().end(),
std::back_inserter(result), std::back_inserter(result),
[&](auto&& p) { return p.first; }); [&](auto&& p) { return p.first; });
std::sort(result.begin(), result.end());
return result; return result;
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
* THE SOFTWARE. * THE SOFTWARE.
*/ */
#include <migraphx/tf/tf_parser.hpp> #include <migraphx/tf/tf_parser.hpp>
#include <migraphx/tf/op_parser.hpp>
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <unordered_map> #include <unordered_map>
...@@ -62,5 +63,7 @@ program parse_tf(const std::string& name, const tf_options& options) ...@@ -62,5 +63,7 @@ program parse_tf(const std::string& name, const tf_options& options)
return std::move(parser.prog); return std::move(parser.prog);
} }
std::vector<std::string> get_tf_operators() { return tf::get_op_parsers(); }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
...@@ -48,6 +48,7 @@ add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR}) ...@@ -48,6 +48,7 @@ add_api_test(assign test_assign.cpp ${TEST_ONNX_DIR})
add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR}) add_api_test(compile_options test_compile_options.cpp ${TEST_ONNX_DIR})
add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR}) add_api_test(lookup test_lookup.cpp ${TEST_ONNX_DIR})
add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR}) add_api_test(module_construct test_module_construct.cpp ${TEST_ONNX_DIR})
add_api_test(dynamic_shape test_dynamic_shape.cpp ${TEST_ONNX_DIR})
add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR}) add_api_test(ref test_cpu.cpp ${TEST_ONNX_DIR})
add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR}) add_api_test(save_load test_save_load.cpp ${TEST_ONNX_DIR})
add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR}) add_api_test(op test_op_construct.cpp ${TEST_ONNX_DIR})
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp>
#include "test.hpp"
TEST_CASE(create_dynamic_dimensions)
{
migraphx::dynamic_dimension dd0{1, 4};
EXPECT(not dd0.is_fixed());
migraphx::dynamic_dimension dd1{4, 4};
EXPECT(dd1.is_fixed());
migraphx::optimals opts{1, 2, 4};
migraphx::dynamic_dimension dd2{1, 4, opts};
migraphx::dynamic_dimensions dyn_dims0{dd0, dd1, dd2};
CHECK(bool{dyn_dims0[0] == dd0});
CHECK(bool{dyn_dims0[1] == dd1});
CHECK(bool{dyn_dims0[2] == dd2});
CHECK(bool{dyn_dims0[2] != dd0});
EXPECT(dyn_dims0.size() == 3);
}
TEST_CASE(create_dynamic_shape)
{
migraphx::dynamic_dimensions dyn_dims(migraphx::dynamic_dimension{1, 4},
migraphx::dynamic_dimension{78, 92},
migraphx::dynamic_dimension{1, 4, {1, 4}});
migraphx::shape dyn_shape{migraphx_shape_float_type, dyn_dims};
CHECK(bool{dyn_shape.dynamic()});
CHECK(bool{dyn_shape.dyn_dims()[0] == migraphx::dynamic_dimension{1, 4}});
migraphx::shape static_shape{migraphx_shape_float_type, {3, 8}};
EXPECT(not static_shape.dynamic());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment