Commit 151dd91a authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/ck-flash-attn' into gemm-perf

parents 280e76d0 5b2b7489
...@@ -24,6 +24,7 @@ ...@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP #ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP #define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp> #include <migraphx/config.hpp>
#include <migraphx/context.hpp> #include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
......
...@@ -48,10 +48,18 @@ else() ...@@ -48,10 +48,18 @@ else()
set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs") set(MIGRAPHX_USE_HIPRTC ON CACHE BOOL "Use hipRTC APIs")
endif() endif()
include(Embed)
file(GLOB KERNEL_FILES CONFIGURE_DEPENDS file(GLOB KERNEL_FILES CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp) ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/*.hpp)
message(STATUS "KERNEL_FILES: ${KERNEL_FILES}") message(STATUS "KERNEL_FILES: ${KERNEL_FILES}")
if(WIN32)
# TODO: re-enable when CK is ported to Windows
list(REMOVE_ITEM KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck_gemm.hpp
${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/migraphx/kernels/ck.hpp)
endif()
include(Embed)
add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/) add_embed_library(migraphx_kernels ${KERNEL_FILES} RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/kernels/include/)
configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp) configure_file(device/targets.hpp.in include/migraphx/gpu/device/targets.hpp)
......
...@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
if(src.path.extension() != ".cpp") if(src.path.extension() != ".cpp")
continue; continue;
std::cout << std::string(src.content.first, src.len()) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc); auto p = dynamic_loader::path(&compile_hip_src_with_hiprtc);
...@@ -338,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std ...@@ -338,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{ {
if(src.path.extension() != ".cpp") if(src.path.extension() != ".cpp")
continue; continue;
std::cout << std::string(src.content.first, src.len()) << std::endl; std::cout << std::string(src.content) << std::endl;
} }
} }
...@@ -359,9 +359,7 @@ bool hip_has_flags(const std::vector<std::string>& flags) ...@@ -359,9 +359,7 @@ bool hip_has_flags(const std::vector<std::string>& flags)
join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only"; join_strings(flags, " ") + " -x hip -c --offload-arch=gfx900 --cuda-device-only";
std::string src; std::string src;
src_file input; src_file input{"main.cpp", src};
input.path = "main.cpp";
input.content = std::make_pair(src.data(), src.data() + src.size());
try try
{ {
......
...@@ -172,21 +172,17 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option ...@@ -172,21 +172,17 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert(options.inputs.size() == options.virtual_inputs.size() or assert(options.inputs.size() == options.virtual_inputs.size() or
options.virtual_inputs.empty()); options.virtual_inputs.empty());
std::vector<src_file> srcs = options.additional_src_files; std::vector<src_file> srcs = options.additional_src_files;
std::transform(migraphx_kernels().begin(), static auto kernels{::migraphx_kernels()};
migraphx_kernels().end(), std::transform(
std::back_inserter(srcs), kernels.begin(),
[](auto&& p) { kernels.end(),
auto&& name = p.first; std::back_inserter(srcs),
auto&& c = p.second; [](const std::pair<std::string_view, std::string_view>& elem) { return src_file{elem}; });
auto path = name; srcs.emplace_back("main.cpp", content);
return src_file{path, c};
});
srcs.push_back(src_file{fs::path{"main.cpp"},
std::make_pair(content.data(), content.data() + content.size())});
auto args_hpp = auto args_hpp =
generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs); generate_args_hpp(options.virtual_inputs.empty() ? options.inputs : options.virtual_inputs);
srcs.push_back(src_file{fs::path{"args.hpp"}, srcs.emplace_back("args.hpp", args_hpp);
std::make_pair(args_hpp.data(), args_hpp.data() + args_hpp.size())});
options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global); options.params += " -DMIGRAPHX_NGLOBAL=" + std::to_string(options.global);
options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local); options.params += " -DMIGRAPHX_NLOCAL=" + std::to_string(options.local);
options.params += " " + join_strings(compiler_warnings(), " "); options.params += " " + join_strings(compiler_warnings(), " ");
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP #ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP #define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp> #include <migraphx/gpu/device/config.hpp>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -34,9 +34,13 @@ namespace gpu { ...@@ -34,9 +34,13 @@ namespace gpu {
namespace device { namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT #define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets(); const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string(); std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name(); std::string get_device_name();
} // namespace device } // namespace device
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp> #include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp> #include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp> #include <migraphx/ranges.hpp>
#include <array>
#include <iostream> #include <iostream>
#include <cstring> #include <cstring>
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
#include <migraphx/shape.hpp> #include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp" #include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp" #include "ck/host/device_batched_gemm_softmax_gemm.hpp"
...@@ -55,7 +56,7 @@ template <class P> ...@@ -55,7 +56,7 @@ template <class P>
std::string ck_disable_warnings(P p) std::string ck_disable_warnings(P p)
{ {
return interpolate_string(disable_warning_pragma, return interpolate_string(disable_warning_pragma,
{{"content", std::string{p.first, p.second}}}); {{"content", std::string{p.data(), p.size()}}});
} }
static std::unordered_map<std::string, std::string> create_ck_header_strings() static std::unordered_map<std::string, std::string> create_ck_header_strings()
...@@ -64,8 +65,8 @@ static std::unordered_map<std::string, std::string> create_ck_header_strings() ...@@ -64,8 +65,8 @@ static std::unordered_map<std::string, std::string> create_ck_header_strings()
auto ck_headers = ck::host::GetHeaders(); auto ck_headers = ck::host::GetHeaders();
std::transform( std::transform(
ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto&& p) { ck_headers.begin(), ck_headers.end(), std::inserter(result, result.begin()), [&](auto& p) {
return std::make_pair(p.first, ck_disable_warnings(p.second)); return std::pair<std::string, std::string>(p.first, ck_disable_warnings(p.second));
}); });
return result; return result;
} }
...@@ -74,11 +75,10 @@ static std::vector<src_file> create_ck_headers() ...@@ -74,11 +75,10 @@ static std::vector<src_file> create_ck_headers()
{ {
static const auto& header_strings = create_ck_header_strings(); static const auto& header_strings = create_ck_header_strings();
std::vector<src_file> srcs; std::vector<src_file> srcs;
std::transform( std::transform(header_strings.begin(),
header_strings.begin(), header_strings.end(), std::back_inserter(srcs), [&](auto&& p) { header_strings.end(),
return src_file{fs::path{p.first}, std::back_inserter(srcs),
{p.second.data(), p.second.data() + p.second.size()}}; [&](auto& p) { return src_file{p}; });
});
return srcs; return srcs;
} }
......
...@@ -45,10 +45,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS); ...@@ -45,10 +45,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
struct hiprtc_src_file struct hiprtc_src_file
{ {
hiprtc_src_file() = default; hiprtc_src_file() = default;
hiprtc_src_file(const src_file& s) hiprtc_src_file(const src_file& s) : path(s.path.string()), content(s.content) {}
: path(s.path.string()), content(s.content.first, s.content.second)
{
}
std::string path; std::string path;
std::string content; std::string content;
template <class Self, class F> template <class Self, class F>
......
...@@ -135,7 +135,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -135,7 +135,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
const auto& c_shape = inputs.back(); const auto& c_shape = inputs.back();
auto tuning_value = v.get("tuning_value", 0); auto tuning_value = v.get("tuning_value", 34);
auto batch_count = get_batch_count(c_shape); auto batch_count = get_batch_count(c_shape);
auto problem = create_problem(inputs, v); auto problem = create_problem(inputs, v);
......
...@@ -320,7 +320,10 @@ struct mlir_program ...@@ -320,7 +320,10 @@ struct mlir_program
MlirType make_tensor(const shape& s) const MlirType make_tensor(const shape& s) const
{ {
assert(s.standard()); if(not s.standard())
MIGRAPHX_THROW("MLIR expects all tensors to be in standard shape");
if(s.dynamic())
MIGRAPHX_THROW("MLIR does not support dynamic shapes");
std::vector<int64_t> lens(s.lens().begin(), s.lens().end()); std::vector<int64_t> lens(s.lens().begin(), s.lens().end());
return mlirRankedTensorTypeGet( return mlirRankedTensorTypeGet(
lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull()); lens.size(), lens.data(), make_type(s.type()), mlirAttributeGetNull());
......
...@@ -45,8 +45,7 @@ struct parse_reshape : op_parser<parse_reshape> ...@@ -45,8 +45,7 @@ struct parse_reshape : op_parser<parse_reshape>
auto s = args[1]->eval(); auto s = args[1]->eval();
std::vector<int64_t> dims; std::vector<int64_t> dims;
s.visit([&](auto v) { copy(v, std::back_inserter(dims)); }); s.visit([&](auto v) { copy(v, std::back_inserter(dims)); });
return info.add_instruction(make_op("reshape", {{"dims", dims}}), return info.add_instruction(make_op("reshape", {{"dims", dims}}), args[0]);
info.make_contiguous(args[0]));
} }
}; };
......
...@@ -155,7 +155,7 @@ int main() {} ...@@ -155,7 +155,7 @@ int main() {}
migraphx::src_file make_src_file(const std::string& name, const std::string& content) migraphx::src_file make_src_file(const std::string& name, const std::string& content)
{ {
return {name, std::make_pair(content.data(), content.data() + content.size())}; return {name, content};
} }
TEST_CASE(simple_compile_hip) TEST_CASE(simple_compile_hip)
......
...@@ -64,7 +64,7 @@ int main() {} ...@@ -64,7 +64,7 @@ int main() {}
migraphx::src_file make_src_file(const std::string& name, const std::string& content) migraphx::src_file make_src_file(const std::string& name, const std::string& content)
{ {
return {name, std::make_pair(content.data(), content.data() + content.size())}; return {name, content};
} }
hip_stream_ptr get_stream() hip_stream_ptr get_stream()
......
...@@ -339,6 +339,8 @@ inline std::ostream& operator<<(std::ostream& os, const color& c) ...@@ -339,6 +339,8 @@ inline std::ostream& operator<<(std::ostream& os, const color& c)
static const bool use_color = isatty(STDOUT_FILENO) != 0; static const bool use_color = isatty(STDOUT_FILENO) != 0;
if(use_color) if(use_color)
return os << "\033[" << static_cast<std::size_t>(c) << "m"; return os << "\033[" << static_cast<std::size_t>(c) << "m";
#else
(void)c;
#endif #endif
return os; return os;
} }
......
...@@ -48,9 +48,7 @@ compile_function(const std::string& src, const std::string& flags, const std::st ...@@ -48,9 +48,7 @@ compile_function(const std::string& src, const std::string& flags, const std::st
migraphx::src_compiler compiler; migraphx::src_compiler compiler;
compiler.flags = flags + "-std=c++14 -fPIC -shared"; compiler.flags = flags + "-std=c++14 -fPIC -shared";
compiler.output = "libsimple.so"; compiler.output = "libsimple.so";
migraphx::src_file f; migraphx::src_file f{"main.cpp", src};
f.path = "main.cpp";
f.content = std::make_pair(src.data(), src.data() + src.size());
auto image = compiler.compile({f}); auto image = compiler.compile({f});
return migraphx::dynamic_loader{image}.get_function<F>(fname); return migraphx::dynamic_loader{image}.get_function<F>(fname);
} }
......
...@@ -97,9 +97,12 @@ TEST_CASE(test_msgpack_bool) ...@@ -97,9 +97,12 @@ TEST_CASE(test_msgpack_bool)
TEST_CASE(test_msgpack_float) TEST_CASE(test_msgpack_float)
{ {
migraphx::value v = 3.0; // changed all double values in this code to not end with .0 because on msgpack for Windows if
// input type is double and ends with .0 it could be converted to uint64_t or int64_t and the
// goal of these functions is to test double without conversions
migraphx::value v = 3.01;
auto buffer = migraphx::to_msgpack(v); auto buffer = migraphx::to_msgpack(v);
EXPECT(buffer == msgpack_buffer(3.0)); EXPECT(buffer == msgpack_buffer(3.01));
EXPECT(migraphx::from_msgpack(buffer) == v); EXPECT(migraphx::from_msgpack(buffer) == v);
} }
...@@ -129,10 +132,10 @@ TEST_CASE(test_msgpack_empty_array) ...@@ -129,10 +132,10 @@ TEST_CASE(test_msgpack_empty_array)
TEST_CASE(test_msgpack_object) TEST_CASE(test_msgpack_object)
{ {
migraphx::value v = {{"one", 1.0}, {"three", 3.0}, {"two", 2.0}}; migraphx::value v = {{"one", 1.01}, {"three", 3.01}, {"two", 2.01}};
auto buffer = migraphx::to_msgpack(v); auto buffer = migraphx::to_msgpack(v);
EXPECT(buffer == msgpack_buffer(std::map<std::string, double>{ EXPECT(buffer == msgpack_buffer(std::map<std::string, double>{
{"one", 1.0}, {"three", 3.0}, {"two", 2.0}})); {"one", 1.01}, {"three", 3.01}, {"two", 2.01}}));
EXPECT(migraphx::from_msgpack(buffer) == v); EXPECT(migraphx::from_msgpack(buffer) == v);
} }
...@@ -157,17 +160,17 @@ struct foo ...@@ -157,17 +160,17 @@ struct foo
TEST_CASE(test_msgpack_object_class) TEST_CASE(test_msgpack_object_class)
{ {
migraphx::value v = {{"a", 1.0}, {"b", "abc"}}; migraphx::value v = {{"a", 1.01}, {"b", "abc"}};
auto buffer = migraphx::to_msgpack(v); auto buffer = migraphx::to_msgpack(v);
EXPECT(buffer == msgpack_buffer(foo{1.0, "abc"})); EXPECT(buffer == msgpack_buffer(foo{1.01, "abc"}));
EXPECT(migraphx::from_msgpack(buffer) == v); EXPECT(migraphx::from_msgpack(buffer) == v);
} }
TEST_CASE(test_msgpack_array_class) TEST_CASE(test_msgpack_array_class)
{ {
migraphx::value v = {{{"a", 1.0}, {"b", "abc"}}, {{"a", 3.0}, {"b", "xyz"}}}; migraphx::value v = {{{"a", 1.01}, {"b", "abc"}}, {{"a", 3.01}, {"b", "xyz"}}};
auto buffer = migraphx::to_msgpack(v); auto buffer = migraphx::to_msgpack(v);
EXPECT(buffer == msgpack_buffer(std::vector<foo>{foo{1.0, "abc"}, foo{3.0, "xyz"}})); EXPECT(buffer == msgpack_buffer(std::vector<foo>{foo{1.01, "abc"}, foo{3.01, "xyz"}}));
EXPECT(migraphx::from_msgpack(buffer) == v); EXPECT(migraphx::from_msgpack(buffer) == v);
} }
......
...@@ -37,7 +37,6 @@ ...@@ -37,7 +37,6 @@
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp> #include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp> #include <migraphx/functional.hpp>
#include <basic_ops.hpp>
#include <migraphx/compile_options.hpp> #include <migraphx/compile_options.hpp>
#include <migraphx/register_target.hpp> #include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
......
...@@ -5149,6 +5149,61 @@ def prelu_brcst_test(): ...@@ -5149,6 +5149,61 @@ def prelu_brcst_test():
return ([node], [arg0, arg1], [arg_out]) return ([node], [arg0, arg1], [arg_out])
@onnx_test()
def qlinearadd_test():
a = helper.make_tensor_value_info('A', TensorProto.UINT8, [64])
sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.UINT8, [], [0])
b = helper.make_tensor_value_info('B', TensorProto.UINT8, [64])
sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.UINT8, [],
[128])
sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.UINT8, [], [64])
c = helper.make_tensor_value_info('C', TensorProto.UINT8, [64])
node = onnx.helper.make_node(
'QLinearAdd',
inputs=[
'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point',
'C_scale', 'C_zero_point'
],
outputs=['C'],
)
return ([node], [a, b], [c],
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])
@onnx_test()
def qlinearadd_bcast_test():
a = helper.make_tensor_value_info('A', TensorProto.INT8, [64])
sc_a = helper.make_tensor('A_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_a = helper.make_tensor('A_zero_point', TensorProto.INT8, [], [0])
b = helper.make_tensor_value_info('B', TensorProto.INT8, [1, 1, 64])
sc_b = helper.make_tensor('B_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_b = helper.make_tensor('B_zero_point', TensorProto.INT8, [], [32])
sc_c = helper.make_tensor('C_scale', TensorProto.FLOAT, [], [0.05])
zero_pt_c = helper.make_tensor('C_zero_point', TensorProto.INT8, [], [-64])
c = helper.make_tensor_value_info('C', TensorProto.INT8, [1, 1, 64])
node = onnx.helper.make_node(
'QLinearAdd',
inputs=[
'A', 'A_scale', 'A_zero_point', 'B', 'B_scale', 'B_zero_point',
'C_scale', 'C_zero_point'
],
outputs=['C'],
)
return ([node], [a, b], [c],
[sc_a, zero_pt_a, sc_b, zero_pt_b, sc_c, zero_pt_c])
@onnx_test() @onnx_test()
def quantizelinear_test(): def quantizelinear_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5]) arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT, [5])
......
...@@ -1772,8 +1772,7 @@ TEST_CASE(depthtospace_test) ...@@ -1772,8 +1772,7 @@ TEST_CASE(depthtospace_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0);
auto tmp2 = mm->add_instruction( auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1); migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp3);
auto prog = optimize_onnx("depthtospace_test.onnx"); auto prog = optimize_onnx("depthtospace_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1787,8 +1786,7 @@ TEST_CASE(depthtospace_crd_test) ...@@ -1787,8 +1786,7 @@ TEST_CASE(depthtospace_crd_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0);
auto tmp2 = mm->add_instruction( auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 2, 5, 3}}}), tmp1); migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 2, 5, 3}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp3);
auto prog = optimize_onnx("depthtospace_crd_test.onnx"); auto prog = optimize_onnx("depthtospace_crd_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1802,8 +1800,7 @@ TEST_CASE(depthtospace_simple_test) ...@@ -1802,8 +1800,7 @@ TEST_CASE(depthtospace_simple_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 2, 3}}}), l0); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 2, 3}}}), l0);
auto tmp2 = mm->add_instruction( auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1); migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 4, 6}}}), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 4, 6}}}), tmp3);
auto prog = optimize_onnx("depthtospace_simple_test.onnx"); auto prog = optimize_onnx("depthtospace_simple_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1817,8 +1814,7 @@ TEST_CASE(spacetodepth_test) ...@@ -1817,8 +1814,7 @@ TEST_CASE(spacetodepth_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 5, 2, 5, 2}}}), l0); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 5, 2, 5, 2}}}), l0);
auto tmp2 = mm->add_instruction( auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1); migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 5, 5}}}), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 5, 5}}}), tmp3);
auto prog = optimize_onnx("spacetodepth_test.onnx"); auto prog = optimize_onnx("spacetodepth_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -1832,8 +1828,7 @@ TEST_CASE(spacetodepth_simple_test) ...@@ -1832,8 +1828,7 @@ TEST_CASE(spacetodepth_simple_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 3, 2}}}), l0); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 3, 2}}}), l0);
auto tmp2 = mm->add_instruction( auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1); migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 8, 2, 3}}}), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 8, 2, 3}}}), tmp3);
auto prog = optimize_onnx("spacetodepth_simple_test.onnx"); auto prog = optimize_onnx("spacetodepth_simple_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -4856,6 +4851,59 @@ TEST_CASE(prelu_brcst_test) ...@@ -4856,6 +4851,59 @@ TEST_CASE(prelu_brcst_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(qlinearadd_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("A", {migraphx::shape::uint8_type, {64}});
auto b = mm->add_parameter("B", {migraphx::shape::uint8_type, {64}});
auto sc_a = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_a = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {0}});
auto sc_b = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_b = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {128}});
auto sc_c = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_c = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {64}});
auto scale_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_a);
auto z_pt_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_a);
auto fp_a =
mm->add_instruction(migraphx::make_op("dequantizelinear"), a, scale_a_bcast, z_pt_a_bcast);
auto scale_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_b);
auto z_pt_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_b);
auto fp_b =
mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scale_b_bcast, z_pt_b_bcast);
auto fp_c = mm->add_instruction(migraphx::make_op("add"), fp_a, fp_b);
auto scale_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_c);
auto z_pt_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_c);
auto c =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_c, scale_c_bcast, z_pt_c_bcast);
mm->add_return({c});
auto prog = migraphx::parse_onnx("qlinearadd_test.onnx");
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(quantizelinear_test) TEST_CASE(quantizelinear_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5438,12 +5486,9 @@ TEST_CASE(reshape_test) ...@@ -5438,12 +5486,9 @@ TEST_CASE(reshape_test)
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims}); migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}}); auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims; op.dims = reshape_dims;
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0); mm->add_instruction(op, l0);
mm->add_instruction(op, c0); mm->add_instruction(op, l0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, c1);
auto prog = optimize_onnx("reshape_test.onnx"); auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
} }
...@@ -5456,8 +5501,7 @@ TEST_CASE(reshape_non_standard_test) ...@@ -5456,8 +5501,7 @@ TEST_CASE(reshape_non_standard_test)
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
auto tran_x = auto tran_x =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), x); mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), x);
auto cont_x = mm->add_instruction(migraphx::make_op("contiguous"), tran_x); mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), tran_x);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), cont_x);
auto prog = optimize_onnx("reshape_non_standard_test.onnx"); auto prog = optimize_onnx("reshape_non_standard_test.onnx");
EXPECT(p == prog); EXPECT(p == prog);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment