Commit 59386637 authored by Paul's avatar Paul
Browse files

Merge branch 'develop' into simplify-more-reshapes

parents 6690765c ed6542ee
......@@ -62,6 +62,7 @@ namespace py = pybind11;
PYBIND11_MODULE(__VA_ARGS__) \
MIGRAPHX_POP_WARNING
#define MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM(x, t) .value(#x, migraphx::shape::type_t::x)
namespace migraphx {
migraphx::value to_value(py::kwargs kwargs);
......@@ -235,7 +236,8 @@ migraphx::shape to_shape(const py::buffer_info& info)
MIGRAPHX_PYBIND11_MODULE(migraphx, m)
{
py::class_<migraphx::shape>(m, "shape")
py::class_<migraphx::shape> shape_cls(m, "shape");
shape_cls
.def(py::init([](py::kwargs kwargs) {
auto v = migraphx::to_value(kwargs);
auto t = migraphx::shape::parse_type(v.get("type", "float"));
......@@ -261,6 +263,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.def("__ne__", std::not_equal_to<migraphx::shape>{})
.def("__repr__", [](const migraphx::shape& s) { return migraphx::to_string(s); });
py::enum_<migraphx::shape::type_t>(shape_cls, "type_t")
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_PYTHON_GENERATE_SHAPE_ENUM);
py::class_<migraphx::argument>(m, "argument", py::buffer_protocol())
.def_buffer([](migraphx::argument& x) -> py::buffer_info { return to_buffer_info(x); })
.def(py::init([](py::buffer b) {
......@@ -282,7 +287,9 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py::class_<migraphx::target>(m, "target");
py::class_<migraphx::instruction_ref>(m, "instruction_ref");
py::class_<migraphx::instruction_ref>(m, "instruction_ref")
.def("shape", [](migraphx::instruction_ref i) { return i->get_shape(); })
.def("op", [](migraphx::instruction_ref i) { return i->get_operator(); });
py::class_<migraphx::module, std::unique_ptr<migraphx::module, py::nodelete>>(m, "module")
.def("print", [](const migraphx::module& mm) { std::cout << mm << std::endl; })
......
......@@ -361,29 +361,26 @@ std::size_t shape::index(std::size_t i) const
}
}
std::vector<std::size_t> shape::multi(std::size_t i) const
std::vector<std::size_t> shape::multi(std::size_t idx) const
{
assert(this->standard());
assert(idx < elements());
std::vector<std::size_t> indices(lens().size());
multi_copy(i, indices.data(), indices.data() + lens().size());
multi_copy(idx, indices.data(), indices.data() + lens().size());
return indices;
}
void shape::multi_copy(std::size_t i, std::size_t* start, const std::size_t* end) const
void shape::multi_copy(std::size_t idx, std::size_t* start, const std::size_t* end) const
{
assert(this->standard());
size_t tidx = idx;
(void)end;
assert(idx < elements());
assert(lens().size() <= (end - start));
std::transform(strides().begin(),
strides().end(),
lens().begin(),
start,
[&](std::size_t stride, std::size_t len) {
assert(len > 0 and stride > 0);
return (i / stride) % len;
});
for(size_t ii = lens().size() - 1; ii > 0; ii--)
{
*(start + ii) = tidx % lens()[ii];
tidx = tidx / lens()[ii];
}
*start = tidx;
}
bool shape::packed() const
......
......@@ -28,6 +28,7 @@
#include <migraphx/functional.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/matcher.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -67,6 +68,37 @@ has_one_dyn_dim(const std::unordered_map<std::string, shape>& param_shapes)
dds_it->max};
}
namespace {
struct find_static_2in_broadcasts
{
// Convert 2 input static shape broadcast/multibroadcast into 1 input version.
// Some compiler passes (ex. simplify_algebra) only support the 1 input versions
// of the broadcasting operators.
auto matcher() const
{
return match::broadcast(match::nargs(2),
match::arg(0)(match::static_shape()),
match::arg(1)(match::static_shape()));
}
void apply(module& m, const match::matcher_result& mr) const
{
auto ins = mr.result;
auto out_lens = ins->get_shape().lens();
auto broadcast_op = ins->get_operator();
if(broadcast_op.name() == "broadcast")
{
broadcast_op.from_value({{"out_lens", out_lens}});
}
else
{
broadcast_op.from_value({{"out_lens", out_lens}, {"out_dyn_dims", {}}});
}
m.replace_instruction(ins, broadcast_op, ins->inputs().at(0));
}
};
} // namespace
/**
* Makes all the shapes in the dynamic_dimension range.
* Probably won't work for `if` and `loop` instructions, depending on how the submodules for those
......@@ -97,6 +129,7 @@ void split_single_dyn_dim::apply(module_pass_manager& mpm) const
dd_check->dyn_param_str, migraphx::shape{dyn_param_shape.type(), static_lens});
auto outputs = submod->add_instructions(mm, map_ins);
submod->add_return({outputs});
match::find_matches(*submod, find_static_2in_broadcasts{});
submodules.push_back(submod);
}
// redirect to select_module operator and return
......
......@@ -33,7 +33,7 @@ if(NOT TARGET MIOpen)
message(SEND_ERROR "Cant find miopen")
endif()
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipRTC APIs")
set(MIGRAPHX_USE_HIPRTC OFF CACHE BOOL "Use hipClang APIs")
include(Embed)
file(GLOB KERNEL_FILES ${CONFIGURE_DEPENDS}
......
......@@ -280,6 +280,14 @@ std::string generate_reduce(const module& m, const std::string& name)
not input->get_shape().broadcasted();
});
auto inner_names = names;
for(auto input : ins->inputs())
{
if(input->name() != "@param")
continue;
if(contains(tensors, input))
continue;
inner_names[input] += "[out_idx]";
}
for(auto input : tensors)
inner_names[input] += "_lambda_param";
auto call_function =
......@@ -308,6 +316,8 @@ std::string generate_reduce(const module& m, const std::string& name)
});
f.set_attributes({"__device__", "__attribute__((const))"}).set_generic_types(m).set_name(name);
f.add_generic_param("r");
f.add_generic_param("out_idx");
f.unused_param("out_idx");
g.create_function(f);
return g.str();
}
......
......@@ -120,12 +120,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
if(not std::all_of(
types.begin(), types.end(), [&](migraphx::shape::type_t t) { return t == s.type(); }))
MIGRAPHX_THROW("Types must be the same");
std::initializer_list<index_int> ranks = {
static_cast<index_int>(get_shape(xs).lens().size())...};
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) {
visit_tensor_size(s.ndim(), [&](auto ndim) {
s.visit_type(hip_visitor([&](auto as) { v(f(xs, ndim, as)...); }));
});
}
......@@ -133,12 +131,10 @@ void hip_visit_all_impl(const shape& s, F f, V&& v, Ts&&... xs)
template <class V, class F, class... Ts>
void hip_visit_views_impl(const shape& s, F f, V&& v, Ts&&... xs)
{
std::initializer_list<index_int> ranks = {
static_cast<index_int>(get_shape(xs).lens().size())...};
if(not std::all_of(
ranks.begin(), ranks.end(), [&](index_int r) { return r == s.lens().size(); }))
std::initializer_list<index_int> ranks = {static_cast<index_int>(get_shape(xs).ndim())...};
if(not std::all_of(ranks.begin(), ranks.end(), [&](index_int r) { return r == s.ndim(); }))
MIGRAPHX_THROW("Ranks must be the same");
visit_tensor_size(s.lens().size(), [&](auto ndim) { v(f(xs, ndim)...); });
visit_tensor_size(s.ndim(), [&](auto ndim) { v(f(xs, ndim)...); });
}
template <class F>
......
......@@ -26,5 +26,6 @@ file(GLOB GPU_DRIVER_SRCS ${CONFIGURE_DEPENDS} ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp
add_executable(gpu-driver
${GPU_DRIVER_SRCS}
)
rocm_clang_tidy_check(gpu-driver)
target_include_directories(gpu-driver PRIVATE include)
target_link_libraries(gpu-driver PRIVATE migraphx_gpu)
......@@ -44,7 +44,7 @@ struct auto_register_action
template <class T>
static void apply()
{
auto name = get_type_name<T>();
const auto& name = get_type_name<T>();
register_action(name.substr(name.rfind("::") + 2),
[](auto&&... xs) { T::apply(std::forward<decltype(xs)>(xs)...); });
}
......
......@@ -140,13 +140,8 @@ void gemm_impl(context& ctx,
compute_type = rocblas_datatype_f32_r;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag =
int8_x4_format ? rocblas_gemm_flags_pack_int8x4 : rocblas_gemm_flags_none;
#else
(void)int8_x4_format;
int flag = 0;
#endif
auto a_lens = args[0].get_shape().lens();
auto b_lens = args[1].get_shape().lens();
......
......@@ -29,6 +29,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/dyn_output.hpp>
#include <utility>
namespace migraphx {
......@@ -112,7 +113,7 @@ struct hip_copy_to_gpu
std::string name() const { return "hip::copy_to_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1, 2).same_type();
check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0);
}
argument compute(context& ctx, const shape&, const std::vector<argument>& args) const
......@@ -121,6 +122,10 @@ struct hip_copy_to_gpu
if(args.size() == 1)
return input;
argument result = args[1].share();
if(result.get_shape().dynamic())
{
result = result.reshape(args[0].get_shape());
}
gpu_copy(ctx, input, result);
// Associate the input since it was registered with hip
return {result.get_shape(), [input, result]() mutable { return result.data(); }};
......@@ -138,19 +143,24 @@ struct hip_copy_from_gpu
std::string name() const { return "hip::copy_from_gpu"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1, 2).same_type();
check_shapes{inputs, *this, true}.has(1, 2).same_type();
return inputs.at(0);
}
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const
compute(context& ctx, const dyn_output& dyn_out, const std::vector<argument>& args) const
{
if(args.size() == 1)
{
argument result = allocate_gpu(output_shape, true);
argument result = allocate_gpu(dyn_out.computed_shape, true);
gpu_copy(ctx, args[0], result);
return result;
}
copy_from_gpu(ctx, args[0], args[1]);
argument input = args[0].share();
if(input.get_shape().dynamic())
{
input = input.reshape(args[1].get_shape());
}
copy_from_gpu(ctx, input, args[1]);
return args[1];
}
std::ptrdiff_t output_alias(const std::vector<shape>& args) const
......
......@@ -244,13 +244,13 @@ __device__ void print_once(Ts... xs)
template <class... Ts>
__device__ void println(Ts... xs)
{
print_each(&coutln, xs...);
print_each(&cout, xs..., '\n');
}
template <class... Ts>
__device__ void println_once(Ts... xs)
{
print_each_once(&coutln, xs...);
print_each_once(&cout, xs..., '\n');
}
} // namespace migraphx
......
......@@ -570,7 +570,7 @@ template <class Algo, class Reduced, class Output, class F>
__device__ void fused_reduce(Output output, F f)
{
Algo::template run<Reduced>([&](auto out_idx, auto r) {
auto result = f(r);
auto result = f(r, out_idx);
if constexpr(reduce::is_inner_storage<decltype(result)>{})
{
r.inner([&](auto& y, auto x) { y = x; })(output, result);
......
......@@ -135,7 +135,7 @@ constexpr vec<vec_type<T>, N> vec_packed_at(T x, I i)
return vec<T, N>{x};
else
{
MIGRAPHX_ASSERT((i + N) < vec_size<T>());
MIGRAPHX_ASSERT((i + N) <= vec_size<T>());
vec<vec_type<T>, N> result = {0};
for(int j = 0; j < N; j++)
{
......
......@@ -197,10 +197,14 @@ struct mlir_program
result = mlirF64TypeGet(ctx.get());
else if(as.is_integral())
{
if(as.is_signed())
result = mlirIntegerTypeSignedGet(ctx.get(), as.size() * 8);
else
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
// Note: rocMLIR use signless integer type for tensors types. This
// will translate to signed implementation for current supported
// operations.
if(as.is_unsigned())
{
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
}
result = mlirIntegerTypeGet(ctx.get(), as.size() * 8);
}
else
MIGRAPHX_THROW("Unsupported type: " + std::to_string(as.type_enum()));
......@@ -483,7 +487,7 @@ struct mlir_program
static value get_operator_value(const operation& op)
{
auto v = op.to_value();
if(op.name() == "convolution")
if(op.name() == "convolution" or op.name() == "quant_convolution")
{
// Adjust symetrical padding
if(v.at("padding").size() == v.at("stride").size())
......
......@@ -55,24 +55,15 @@ const std::unordered_set<std::string>& get_rocblas_fp32_archs()
bool get_compute_fp32_flag()
{
bool compute_fp32 = false;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
const auto device_name = trim(split_string(get_device_name(), ':').front());
if(contains(get_rocblas_fp32_archs(), device_name))
compute_fp32 = true;
#endif
return compute_fp32;
return contains(get_rocblas_fp32_archs(), device_name);
}
bool get_int8_x4_format(context& ctx)
{
bool int8_x4_format = true;
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags flag;
rocblas_query_int8_layout_flag(ctx.get_stream().get_rocblas(), &flag);
int8_x4_format = (flag == rocblas_gemm_flags_pack_int8x4);
#endif
return int8_x4_format;
return flag == rocblas_gemm_flags_pack_int8x4;
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
......
......@@ -72,7 +72,6 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_SCHEDULE_PASS)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_POINTWISE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REDUCE_FUSION)
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_NHWC)
struct id_pass
......@@ -129,7 +128,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{},
auto_contiguous{},
optimize_module{},
enable_pass(not enabled(MIGRAPHX_DISABLE_POINTWISE_FUSION{}), fuse_pointwise{}),
fuse_pointwise{},
dead_code_elimination{},
enable_pass(not enabled(MIGRAPHX_DISABLE_REDUCE_FUSION{}), fuse_reduce{}),
dead_code_elimination{},
......
......@@ -110,7 +110,7 @@ function(add_test_executable TEST_NAME)
add_test_command(${TEST_NAME} ${TEST_COMMAND})
add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx migraphx_onnx)
target_link_libraries(${TEST_NAME} migraphx migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include)
endfunction(add_test_executable)
......@@ -163,7 +163,7 @@ foreach(ONNX_TEST ${ONNX_TESTS})
set(TEST_NAME test_${BASE_NAME})
add_executable(${TEST_NAME} ${ONNX_TEST})
rocm_clang_tidy_check(${TEST_NAME})
target_link_libraries(${TEST_NAME} migraphx_onnx)
target_link_libraries(${TEST_NAME} migraphx_onnx migraphx_ref)
target_include_directories(${TEST_NAME} PUBLIC include)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}> WORKING_DIRECTORY ${TEST_ONNX_DIR})
add_dependencies(tests ${TEST_NAME})
......
......@@ -25,7 +25,7 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
set(NAME test_api_${TEST_NAME})
add_executable(${NAME} EXCLUDE_FROM_ALL ${TEST_SRC})
rocm_clang_tidy_check(${NAME})
target_link_libraries(${NAME} migraphx_c migraphx)
target_link_libraries(${NAME} migraphx_c migraphx migraphx_all_targets)
target_include_directories(${NAME} PUBLIC ../include)
add_test(NAME ${NAME} COMMAND $<TARGET_FILE:${NAME}> WORKING_DIRECTORY ${TEST_DIR})
add_dependencies(tests ${NAME})
......@@ -59,7 +59,7 @@ if(MIGRAPHX_ENABLE_GPU)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
find_package(hip)
add_api_test(gpu test_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_gpu hip::host)
target_link_libraries(test_api_gpu)
add_api_test(custom_op_gpu test_custom_op_gpu.cpp ${TEST_ONNX_DIR})
target_link_libraries(test_api_custom_op_gpu hip::host)
target_link_libraries(test_api_custom_op_gpu)
endif()
......@@ -329,4 +329,36 @@ TEST_CASE(all_scalar_input)
EXPECT(p1 == p2);
}
TEST_CASE(no_input)
{
migraphx::program p;
{
auto* mm = p.get_main_module();
migraphx::shape g_shape{migraphx::shape::int64_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
auto out = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
mm->add_return({out});
}
run_pass(p);
// This should NOT create a pointwise module if there are no inputs here.
migraphx::program p2;
{
auto* mm = p2.get_main_module();
migraphx::shape g_shape{migraphx::shape::int64_type, {1}, {0}};
migraphx::shape s_indices{migraphx::shape::int32_type, {3}};
std::vector<int> indices{3, 800, 800};
auto a0 = mm->add_literal(migraphx::literal{s_indices, indices});
auto a1 = mm->add_literal(migraphx::literal{g_shape, {1}});
int axis = 0;
auto out = mm->add_instruction(migraphx::make_op("gather", {{"axis", axis}}), a0, a1);
mm->add_return({out});
}
EXPECT(p == p2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -213,4 +213,37 @@ module {
EXPECT(verify_mlir(m));
}
TEST_CASE(conv_int8_dequantize_quantize)
{
const std::string mlir_output = R"__migraphx__(
module {
func.func @main(%arg0: tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor<1x2x2x2xf32>, %arg3: tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.quant_convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xi8>, tensor<2x8x3x3xi8>) -> tensor<1x2x2x2xi32>
%1 = migraphx.dequantizelinear(%0, %arg2, %arg3) : (tensor<1x2x2x2xi32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.quantizelinear(%1, %arg2, %arg3) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor<1x2x2x2xi32>
return %2 : tensor<1x2x2x2xi32>
}
}
)__migraphx__";
migraphx::module m;
auto x = m.add_parameter("x", {migraphx::shape::int8_type, {1, 8, 4, 4}});
auto w = m.add_parameter("w", {migraphx::shape::int8_type, {2, 8, 3, 3}});
auto conv = m.add_instruction(migraphx::make_op("quant_convolution"), x, w);
migraphx::shape ss{migraphx::shape::float_type, {1, 2, 2, 2}};
migraphx::shape sz{migraphx::shape::int32_type, {1, 2, 2, 2}};
auto input2 = m.add_parameter("x_scale", ss);
auto input3 = m.add_parameter("x_zero_point", sz);
auto dequant = m.add_instruction(migraphx::make_op("dequantizelinear"), conv, input2, input3);
auto r = m.add_instruction(migraphx::make_op("quantizelinear"), dequant, input2, input3);
m.add_return({r});
auto s = migraphx::gpu::dump_mlir(m);
// Skip test if MLIR is not enabled
if(s.empty())
return;
CHECK(encode(s) == encode(mlir_output));
EXPECT(verify_mlir(m));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment