"vscode:/vscode.git/clone" did not exist on "94db0f8a4dfe96cd8b0914c6f7a1a843947d3ca1"
Commit 59386637 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into simplify-more-reshapes

parents 6690765c ed6542ee
...@@ -62,6 +62,7 @@ namespace py = pybind11; ...@@ -62,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \ PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx { namespace migraphx {
migraphx::value to_value(py::kwargs kwargs); migraphx::value to_value(py::kwargs kwargs);
...@@ -235,7 +236,8 @@ migraphx::shape to_shape(const py::buffer_info& info) ...@@ -235,7 +236,8 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m) MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{ {
py::class_<migraphx::shape>(m, "shape") py::class_<migraphx::shape> shape_cls(m, "shape");
shape_cls
.def(py::init([](py::kwargs kwargs) { .def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs); auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float")); auto t = migraphx::shape::parse_type(v.get("type", "float"));
...@@ -261,6 +263,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -261,6 +263,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::shape>{}) .def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); }); .def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol()) py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); }) .def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) { .def(py::init([](py::buffer b) {
...@@ -282,7 +287,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m) ...@@ -282,7 +287,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target"); py::class_<migraphx::target>(m, "target");
py::class_<migraphx::instruction_ref>(m, "instruction_ref"); py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module") py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; }) .def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
......
...@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const ...@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
} }
} }
std::vector<std::size_t> shape::multi(std::size_t i) const std::vector<std::size_t> shape::multi(std::size_t idx) const
{ {
assert(this->standard()); assert(idx < elements());
std::vector<std::size_t> indices(lens().size()); std::vector<std::size_t> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size()); multi_copy(idx, indices.data(), indices.data() + lens().size());
return indices; return indices;
} }
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
{ {
assert(this->standard()); size_t tidx = idx;
(void)end; (void)end;
assert(idx < elements());
assert(lens().size() <= (end - start)); assert(lens().size() <= (end - start));
std::transform(strides().begin(), for(size_t ii = lens().size() - 1; ii > 0; ii--)
strides().end(), {
lens().begin(), *(start + ii) = tidx % lens()[ii];
start, tidx = tidx / lens()[ii];
[&](std::size_t stride, std::size_t len) { }
assert(len > 0 and stride > 0); *start = tidx;
return (i / stride) % len;
});
} }
bool shape::packed() const bool shape::packed() const
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -67,6 +68,37 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes) ...@@ -67,6 +68,37 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max}; dds_it->max};
} }
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/** /**
* Makes all the shapes in the dynamic_dimension range. * Makes all the shapes in the dynamic_dimension range.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those * Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
...@@ -97,6 +129,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const ...@@ -97,6 +129,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens}); dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins); auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs}); submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod); submodules.push_back(submod);
} }
// redirect to select_module operator and return // redirect to select_module operator and return
......
...@@ -33,7 +33,7 @@ if(NOT TARGET MIOpen) ...@@ -33,7 +33,7 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen") message(SEND_ERROR "Cant find miopen")
endif() endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipClang APIs")
include(Embed) include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS} file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
......
...@@ -280,6 +280,14 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -280,6 +280,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not input->get_shape().broadcasted(); not input->get_shape().broadcasted();
}); });
auto inner_names = names; auto inner_names = names;
for(auto input : ins->inputs())
{
if(input->name() != "@param")
continue;
if(contains(tensors, input))
continue;
inner_names[input] += "[out_idx]";
}
for(auto input : tensors) for(auto input : tensors)
inner_names[input] += "_lambda_param"; inner_names[input] += "_lambda_param";
auto call_function = auto call_function =
...@@ -308,6 +316,8 @@ std::string generate_reduce(const module& m, const std::string& name) ...@@ -308,6 +316,8 @@ std::string generate_reduce(const module& m, const std::string& name)
}); });
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name); f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r"); f.add_generic_param("r");
f.add_generic_param("out_idx");
f.unused_param("out_idx");
g.create_function(f); g.create_function(f);
return g.str(); return g.str();
} }
......
...@@ -120,12 +120,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -120,12 +120,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(not std::all_of( if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); })) types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same"); MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { visit_tensor_size(s.ndim(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); })); s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
}); });
} }
...@@ -133,12 +131,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs) ...@@ -133,12 +131,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts> template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs) void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{ {
std::initializer_list<index_int> ranks = { std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
static_cast<index_int>(get_shape(xs).lens().size())...}; if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
MIGRAPHX_THROW("Ranks must be the same"); MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); }); visit_tensor_size(s.ndim(), [&](auto ndim) { v(f(xs, ndim)...); });
} }
template <class F> template <class F>
......
...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp ...@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable(gpu-driver add_executable(gpu-driver
${GPU_DRIVER_SRCS} ${GPU_DRIVER_SRCS}
) )
rocm_clang_tidy_check(gpu-driver)
target_include_directories(gpu-driver PRIVATE include) target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu) target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
...@@ -44,7 +44,7 @@ struct auto_register_action ...@@ -44,7 +44,7 @@ struct auto_register_action
template <class T> template <class T>
static void apply() static void apply()
{ {
auto name = get_type_name<T>(); const auto& name = get_type_name<T>();
register_action(name.substr(name.rfind("::") + 2), register_action(name.substr(name.rfind("::") + 2),
[](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); }); [](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); });
} }
......
...@@ -140,13 +140,8 @@ void gemm_impl(context& ctx, ...@@ -140,13 +140,8 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r; compute_type = rocblas_datatype_f32_r;
} }
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag = rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none; int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens(); auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens(); auto b_lens = args[1].get_shape().lens();
......
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
#include <migraphx/literal.hpp> #include <migraphx/literal.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <migraphx/dyn_output.hpp>
#include <utility> #include <utility>
namespace migraphx { namespace migraphx {
...@@ -112,7 +113,7 @@ struct hip_copy_to_gpu ...@@ -112,7 +113,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; } std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
...@@ -121,6 +122,10 @@ struct hip_copy_to_gpu ...@@ -121,6 +122,10 @@ struct hip_copy_to_gpu
if(args.size() == 1) if(args.size() == 1)
return input; return input;
argument result = args[1].share(); argument result = args[1].share();
if(result.get_shape().dynamic())
{
result = result.reshape(args[0].get_shape());
}
gpu_copy(ctx, input, result); gpu_copy(ctx, input, result);
// Associate the input since it was registered with hip // Associate the input since it was registered with hip
return {result.get_shape(), [input, result]() mutable { return result.data(); }}; return {result.get_shape(), [input, result]() mutable { return result.data(); }};
...@@ -138,19 +143,24 @@ struct hip_copy_from_gpu ...@@ -138,19 +143,24 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; } std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1, 2).same_type(); check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0); return inputs.at(0);
} }
argument argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const compute(context& ctx, const dyn_output& dyn_out, const std::vector<argument>& args) const
{ {
if(args.size() == 1) if(args.size() == 1)
{ {
argument result = allocate_gpu(output_shape, true); argument result = allocate_gpu(dyn_out.computed_shape, true);
gpu_copy(ctx, args[0], result); gpu_copy(ctx, args[0], result);
return result; return result;
} }
copy_from_gpu(ctx, args[0], args[1]); argument input = args[0].share();
if(input.get_shape().dynamic())
{
input = input.reshape(args[1].get_shape());
}
copy_from_gpu(ctx, input, args[1]);
return args[1]; return args[1];
} }
std::ptrdiff_t output_alias(const std::vector<shape>& args) const std::ptrdiff_t output_alias(const std::vector<shape>& args) const
......
...@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs) ...@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template <class... Ts> template <class... Ts>
__device__ void println(Ts... xs) __device__ void println(Ts... xs)
{ {
print_each(&coutln, xs...); print_each(&cout, xs..., '\n');
} }
template <class... Ts> template <class... Ts>
__device__ void println_once(Ts... xs) __device__ void println_once(Ts... xs)
{ {
print_each_once(&coutln, xs...); print_each_once(&cout, xs..., '\n');
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -570,7 +570,7 @@ template <class Algo, class Reduced, class Output, class F> ...@@ -570,7 +570,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f) __device__ void fused_reduce(Output output, F f)
{ {
Algo::template run<Reduced>([&](auto out_idx, auto r) { Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r); auto result = f(r, out_idx);
if constexpr(reduce::is_inner_storage<decltype(result)>{}) if constexpr(reduce::is_inner_storage<decltype(result)>{})
{ {
r.inner([&](auto& y, auto x) { y = x; })(output, result); r.inner([&](auto& y, auto x) { y = x; })(output, result);
......
...@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i) ...@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
return vec<T, N>{x}; return vec<T, N>{x};
else else
{ {
MIGRAPHX_ASSERT((i + N) < vec_size<T>()); MIGRAPHX_ASSERT((i + N) <= vec_size<T>());
vec<vec_type<T>, N> result = {0}; vec<vec_type<T>, N> result = {0};
for(int j = 0; j < N; j++) for(int j = 0; j < N; j++)
{ {
......
...@@ -197,9 +197,13 @@ struct mlir_program ...@@ -197,9 +197,13 @@ struct mlir_program
result = mlirF64TypeGet(ctx.get()); result = mlirF64TypeGet(ctx.get());
else if(as.is_integral()) else if(as.is_integral())
{ {
if(as.is_signed()) // Note: rocMLIR use signless integer type for tensors types. This
result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8); // will translate to signed implementation for current supported
else // operations.
if(as.is_unsigned())
{
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
}
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8); result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
} }
else else
...@@ -483,7 +487,7 @@ struct mlir_program ...@@ -483,7 +487,7 @@ struct mlir_program
static value get_operator_value(const operation& op) static value get_operator_value(const operation& op)
{ {
auto v = op.to_value(); auto v = op.to_value();
if(op.name() == "convolution") if(op.name() == "convolution" or op.name() == "quant_convolution")
{ {
// Adjust symetrical padding // Adjust symetrical padding
if(v.at("padding").size() == v.at("stride").size()) if(v.at("padding").size() == v.at("stride").size())
......
...@@ -55,24 +55,15 @@ const std::unordered_set<std::string>& get_rocblas_fp32_archs() ...@@ -55,24 +55,15 @@ const std::unordered_set<std::string>& get_rocblas_fp32_archs()
bool get_compute_fp32_flag() bool get_compute_fp32_flag()
{ {
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front()); const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name)) return contains(get_rocblas_fp32_archs(), device_name);
compute_fp32 = true;
#endif
return compute_fp32;
} }
bool get_int8_x4_format(context& ctx) bool get_int8_x4_format(context& ctx)
{ {
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag; rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag); rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4); return flag == rocblas_gemm_flags_pack_int8x4;
#endif
return int8_x4_format;
} }
} // namespace gpu } // namespace gpu
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -72,7 +72,6 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -72,7 +72,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC) MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass struct id_pass
...@@ -129,7 +128,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -129,7 +128,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
auto_contiguous{}, auto_contiguous{},
optimize_module{}, optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}), fuse_pointwise{},
dead_code_elimination{}, dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}), enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{}, dead_code_elimination{},
......
...@@ -110,7 +110,7 @@ function(add_test_executable TEST_NAME) ...@@ -110,7 +110,7 @@ function(add_test_executable TEST_NAME)
add_test_command(${TEST_NAME} ${TEST_COMMAND}) add_test_command(${TEST_NAME} ${TEST_COMMAND})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx migraphx_onnx) target_link_libraries(${TEST_NAME} migraphx migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable) endfunction(add_test_executable)
...@@ -163,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS}) ...@@ -163,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set(TEST_NAME test_${BASE_NAME}) set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${ONNX_TEST}) add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME}) rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx) target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include) target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR}) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
......
...@@ -25,7 +25,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR) ...@@ -25,7 +25,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME}) set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC}) add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
rocm_clang_tidy_check(${NAME}) rocm_clang_tidy_check(${NAME})
target_link_libraries(${NAME} migraphx_c migraphx) target_link_libraries(${NAME} migraphx_c migraphx migraphx_all_targets)
target_include_directories(${NAME} PUBLIC ../include) target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR}) add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME}) add_dependencies(tests ${NAME})
...@@ -59,7 +59,7 @@ if(MIGRAPHX_ENABLE_GPU) ...@@ -59,7 +59,7 @@ if(MIGRAPHX_ENABLE_GPU)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm) list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip) find_package(hip)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_gpu hip::host) target_link_libraries(test_api_gpu)
add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR}) add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_custom_op_gpu hip::host) target_link_libraries(test_api_custom_op_gpu)
endif() endif()
...@@ -329,4 +329,36 @@ TEST_CASE(all_scalar_input) ...@@ -329,4 +329,36 @@ TEST_CASE(all_scalar_input)
EXPECT(p1 == p2); EXPECT(p1 == p2);
} }
TEST_CASE(no_input)
{
migraphx::program p;
{
auto* mm = p.get_main_module();
migraphx::shape g_shape{migraphx::shape::int64_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
auto out = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
mm->add_return({out});
}
run_pass(p);
// This should NOT create a pointwise module if there are no inputs here.
migraphx::program p2;
{
auto* mm = p2.get_main_module();
migraphx::shape g_shape{migraphx::shape::int64_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
auto out = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
mm->add_return({out});
}
EXPECT(p == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -213,4 +213,37 @@ module { ...@@ -213,4 +213,37 @@ module {
EXPECT(verify_mlir(m)); EXPECT(verify_mlir(m));
} }
TEST_CASE(conv_int8_dequantize_quantize)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32>
%1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32>
return %2 : tensor<1x2x2x2xi32>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 4, 4}});
auto w = m.add_parameter("w", {migraphx::shape::int8_type, {2, 8, 3, 3}});
auto conv = m.add_instruction(migraphx::make_op("quant_convolution"), x, w);
migraphx::shape ss{migraphx::shape::float_type, {1, 2, 2, 2}};
migraphx::shape sz{migraphx::shape::int32_type, {1, 2, 2, 2}};
auto input2 = m.add_parameter("x_scale", ss);
auto input3 = m.add_parameter("x_zero_point", sz);
auto dequant = m.add_instruction(migraphx::make_op("dequantizelinear"), conv, input2, input3);
auto r = m.add_instruction(migraphx::make_op("quantizelinear"), dequant, input2, input3);
m.add_return({r});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment