Commit 9a758ec4 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/optimize' into ck-gsg

parents ffe2c0cc 913ae362
...@@ -196,12 +196,21 @@ argument to_gpu(const argument& arg, bool host) ...@@ -196,12 +196,21 @@ argument to_gpu(const argument& arg, bool host)
argument from_gpu(const argument& arg) argument from_gpu(const argument& arg)
{ {
argument result; argument result;
arg.visit([&](auto x) { arg.visit(
[&](auto x) {
using type = typename decltype(x)::value_type; using type = typename decltype(x)::value_type;
auto v = read_from_gpu<type>(arg.data(), x.get_shape().bytes() / sizeof(type)); auto v = read_from_gpu<type>(arg.data(), x.get_shape().bytes() / sizeof(type));
// cppcheck-suppress returnDanglingLifetime // cppcheck-suppress returnDanglingLifetime
result = {x.get_shape(), [v]() mutable { return v.data(); }}; result = {x.get_shape(), [v]() mutable { return v.data(); }};
},
[&](const auto& xs) {
std::vector<argument> args;
std::transform(xs.begin(), xs.end(), std::back_inserter(args), [&](auto x) {
return from_gpu(x);
}); });
result = argument{args};
});
return result; return result;
} }
......
...@@ -24,7 +24,6 @@ ...@@ -24,7 +24,6 @@
#include <migraphx/gpu/compiler.hpp> #include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/mlir.hpp> #include <migraphx/gpu/mlir.hpp>
namespace migraphx { namespace migraphx {
......
...@@ -32,7 +32,13 @@ ...@@ -32,7 +32,13 @@
#include <mlir-c/Dialect/MIGraphX.h> #include <mlir-c/Dialect/MIGraphX.h>
#include <mlir-c/IntegerSet.h> #include <mlir-c/IntegerSet.h>
#include <mlir-c/Pass.h> #include <mlir-c/Pass.h>
#include <mlir-c/Registration.h> #include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION != 3
#warning "Incompatible version of rocMLIR library used, disabling"
#undef MIGRAPHX_MLIR
#else
#include <mlir-c/RegisterRocMLIR.h>
#endif
#endif #endif
#include <migraphx/env.hpp> #include <migraphx/env.hpp>
...@@ -50,10 +56,6 @@ ...@@ -50,10 +56,6 @@
#include <deque> #include <deque>
#include <variant> #include <variant>
#if defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) && MLIR_MIGRAPHX_DIALECT_API_VERSION >= 2
#define MIGRAPHX_MLIR_BARE_POINTER
#endif
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
...@@ -168,9 +170,11 @@ struct mlir_program ...@@ -168,9 +170,11 @@ struct mlir_program
location(mlirLocationUnknownGet(ctx.get())), location(mlirLocationUnknownGet(ctx.get())),
mmodule(mlirModuleCreateEmpty(location)) mmodule(mlirModuleCreateEmpty(location))
{ {
MlirDialectHandle mixr_handle = mlirGetDialectHandle__migraphx__(); MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirDialectHandleRegisterDialect(mixr_handle, ctx.get()); mlirRegisterRocMLIRDialects(registry);
mlirRegisterAllDialects(ctx.get()); mlirContextAppendDialectRegistry(ctx.get(), registry);
mlirContextLoadAllAvailableDialects(ctx.get());
mlirDialectRegistryDestroy(registry);
mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/); mlirContextSetAllowUnregisteredDialects(ctx.get(), true /*allow*/);
} }
...@@ -452,7 +456,8 @@ struct mlir_program ...@@ -452,7 +456,8 @@ struct mlir_program
auto ops = create_operation_state("func.func"); auto ops = create_operation_state("func.func");
ops.add_attributes({{"function_type", make_function_type(inputs, outputs)}, ops.add_attributes({{"function_type", make_function_type(inputs, outputs)},
{"sym_name", std::string("main")}, {"sym_name", std::string("main")},
{"kernel", std::string("mixr")}}); {"kernel", std::string("mixr")},
{"arch", target_arch}});
ops.add_region(std::move(region)); ops.add_region(std::move(region));
insert(body, std::move(ops)); insert(body, std::move(ops));
...@@ -512,7 +517,8 @@ struct mlir_program ...@@ -512,7 +517,8 @@ struct mlir_program
pp = pp =
problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()}; problem_params{ins->get_operator(), to_shapes(ins->inputs()), ins->get_shape()};
// check if HW supports xdlops // check if HW supports xdlops
bool xdlops = contains(get_xdlops_archs(), target_name); auto target_chip = trim(split_string(target_arch, ':').front());
bool xdlops = contains(get_xdlops_archs(), target_chip);
std::string tuned = get_tune_params(xdlops); std::string tuned = get_tune_params(xdlops);
if(not tuned.empty()) if(not tuned.empty())
ops.add_attributes({{"perf_config", tuned}}); ops.add_attributes({{"perf_config", tuned}});
...@@ -540,7 +546,7 @@ struct mlir_program ...@@ -540,7 +546,7 @@ struct mlir_program
// 1st pipeline to call // 1st pipeline to call
mlirMIGraphXAddHighLevelPipeline(pm.get()); mlirMIGraphXAddHighLevelPipeline(pm.get());
// 2nd pipeline to call // 2nd pipeline to call
mlirMIGraphXAddBackendPipeline(pm.get(), target_name.c_str(), "amdgcn-amd-amdhsa", ""); mlirMIGraphXAddBackendPipeline(pm.get(), target_arch.c_str());
mlirPassManagerRun(pm.get(), mmodule.get()); mlirPassManagerRun(pm.get(), mmodule.get());
code_object_op op{}; code_object_op op{};
...@@ -550,16 +556,7 @@ struct mlir_program ...@@ -550,16 +556,7 @@ struct mlir_program
return op; return op;
} }
void find_target() void find_target() { target_arch = get_device_name(); }
{
std::string tname = get_device_name();
// HACK: Since MLIR can't handle the full target name
target_name = trim(split_string(tname, ':').front());
if(tname.size() != target_name.size())
std::cout
<< "*************** WARNING: MLIR may not compile the correct target features for: "
<< tname << std::endl;
}
std::pair<std::size_t, std::size_t> get_launch_params() const std::pair<std::size_t, std::size_t> get_launch_params() const
{ {
...@@ -588,7 +585,7 @@ struct mlir_program ...@@ -588,7 +585,7 @@ struct mlir_program
mlir_module mmodule; mlir_module mmodule;
problem_params pp; problem_params pp;
std::deque<std::string> strings{}; std::deque<std::string> strings{};
std::string target_name; std::string target_arch;
}; };
std::string dump_mlir(const module& m) std::string dump_mlir(const module& m)
...@@ -650,6 +647,10 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct ...@@ -650,6 +647,10 @@ code_object_op compile_mlir(const context&, module m, const std::vector<instruct
const bool trace = enabled(MIGRAPHX_TRACE_MLIR{}); const bool trace = enabled(MIGRAPHX_TRACE_MLIR{});
if(trace) if(trace)
std::cout << m << std::endl; std::cout << m << std::endl;
// set mutex while llvm thread support is disabled.
static std::mutex g_mlirc_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_mlirc_mutex);
mlir_program mp; mlir_program mp;
mp.find_target(); mp.find_target();
mp.parse(m); mp.parse(m);
...@@ -669,46 +670,9 @@ instruction_ref insert_mlir(module& m, ...@@ -669,46 +670,9 @@ instruction_ref insert_mlir(module& m,
std::vector<instruction_ref> refs; std::vector<instruction_ref> refs;
std::size_t last = 0; std::size_t last = 0;
#ifdef MIGRAPHX_MLIR_BARE_POINTER
refs.reserve(inputs.size()); refs.reserve(inputs.size());
std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs)); std::copy(inputs.begin(), inputs.end(), std::back_inserter(refs));
last = refs.size() - 1; last = refs.size() - 1;
#else
refs.reserve(inputs.size() * 15);
std::unordered_map<uint64_t, instruction_ref> literal_map{};
auto get_literal = [&](uint64_t value) {
auto fi = literal_map.find(value);
if(fi != literal_map.end())
return fi->second;
auto lit = m.add_literal(value);
literal_map.emplace(value, lit);
return lit;
};
for(auto input : inputs)
{
const size_t offset = 0;
auto s = input->get_shape();
last = refs.size();
refs.push_back(input);
refs.push_back(input);
refs.push_back(get_literal(offset)); // offset
// dim sizes
std::transform(s.lens().begin(),
s.lens().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
// dim strides
std::transform(s.strides().begin(),
s.strides().end(),
std::back_inserter(refs),
[&](const auto& lval) { return get_literal(lval); });
// refs.push_back(get_literal(1)); // G
}
#endif
co.expected_inputs = to_shapes(refs); co.expected_inputs = to_shapes(refs);
co.output_arg = last; co.output_arg = last;
return m.insert_instruction(ins, co, refs); return m.insert_instruction(ins, co, refs);
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include <migraphx/stringutils.hpp> #include <migraphx/stringutils.hpp>
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <fstream> #include <fstream>
#include <mutex>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -88,6 +89,9 @@ std::string generate_miopen_config(const problem_params& pp) ...@@ -88,6 +89,9 @@ std::string generate_miopen_config(const problem_params& pp)
auto query_miopen_db(const std::string& query) auto query_miopen_db(const std::string& query)
{ {
static std::mutex g_db_mutex; // NOLINT
const std::lock_guard<std::mutex> lock(g_db_mutex);
// TODO: Store db as a static variable // TODO: Store db as a static variable
const auto dbpath = fs::path{"/opt"} / "rocm" / "share" / "miopen" / "db" / "miopen.db"; const auto dbpath = fs::path{"/opt"} / "rocm" / "share" / "miopen" / "db" / "miopen.db";
// Check if db file exists. // Check if db file exists.
......
...@@ -383,9 +383,9 @@ struct ref_gemm ...@@ -383,9 +383,9 @@ struct ref_gemm
std::string name() const { return "ref::dot"; } std::string name() const { return "ref::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); } shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
migemm(result, args[0], args[1], 1.0f, 0.0f); migemm(result, args[0], args[1], 1.0f, 0.0f);
return result; return result;
...@@ -449,10 +449,10 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>> ...@@ -449,10 +449,10 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
{ {
return op.normalize_compute_shape(inputs); return op.normalize_compute_shape(inputs);
} }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{ {
argument result{output_shape}; argument result{dyn_out.computed_shape};
auto batch_lens = output_shape.lens(); auto batch_lens = dyn_out.computed_shape.lens();
int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name()); int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name());
std::size_t n_dims = batch_lens[tuned_axis]; std::size_t n_dims = batch_lens[tuned_axis];
batch_lens[tuned_axis] = 1; batch_lens[tuned_axis] = 1;
...@@ -475,7 +475,7 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>> ...@@ -475,7 +475,7 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
idx[tuned_axis] = j; idx[tuned_axis] = j;
std::size_t index = output_shape.index(idx); std::size_t index = dyn_out.computed_shape.index(idx);
output[index] = std::exp(input[index] - batch_max[i]); output[index] = std::exp(input[index] - batch_max[i]);
} }
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <test.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/target.hpp>
TEST_CASE(tuple_to_from_gpu)
{
migraphx::shape s1{migraphx::shape::float_type, {2, 3}};
migraphx::shape s2{migraphx::shape::int32_type, {2, 4}};
std::vector<float> p1_data = {1.1, 2.2, 3.3, 4.4, 5.5, 6.6};
std::vector<int> p2_data = {1, 2, 3, 4, 5, 6, 7, 8};
auto p1 = migraphx::argument{s1, p1_data.data()};
auto p2 = migraphx::argument{s2, p2_data.data()};
auto p1_gpu = migraphx::gpu::to_gpu(p1);
auto p2_gpu = migraphx::gpu::to_gpu(p2);
auto p_tuple = migraphx::gpu::from_gpu(migraphx::argument({p1_gpu, p2_gpu}));
std::vector<migraphx::argument> results = p_tuple.get_sub_objects();
std::vector<float> result1;
results[0].visit([&](auto output) { result1.assign(output.begin(), output.end()); });
std::vector<int> result2;
results[1].visit([&](auto output) { result2.assign(output.begin(), output.end()); });
EXPECT(result1 == p1_data);
EXPECT(result2 == p2_data);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -140,7 +140,7 @@ TEST_CASE(conv) ...@@ -140,7 +140,7 @@ TEST_CASE(conv)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} { func.func @main(%arg0: tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg1, %arg0) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
return %0 : tensor<1x2x2x2xf32> return %0 : tensor<1x2x2x2xf32>
} }
...@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu) ...@@ -163,7 +163,7 @@ TEST_CASE(conv_add_relu)
{ {
const std::string mlir_output = R"__migraphx__( const std::string mlir_output = R"__migraphx__(
module { module {
func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {kernel = "mixr"} { func.func @main(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor<1x2x2x2xf32> attributes {arch = "", kernel = "mixr"} {
%0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32> %0 = migraphx.convolution(%arg2, %arg1) {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : (tensor<1x8x4x4xf32>, tensor<2x8x3x3xf32>) -> tensor<1x2x2x2xf32>
%1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %1 = migraphx.add(%0, %arg0) : (tensor<1x2x2x2xf32>, tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
%2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> %2 = migraphx.relu(%1) : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32>
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/instruction.hpp>
#include <migraphx/program.hpp>
#include <migraphx/make_op.hpp>
#include "test.hpp"
TEST_CASE(check_undefined)
{
migraphx::module m;
auto und = m.add_instruction(migraphx::make_op("undefined"));
auto cov = m.add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::half_type}}), und);
auto abs = m.add_instruction(migraphx::make_op("abs"), cov);
migraphx::shape xs{migraphx::shape::float_type, {2, 3}};
std::vector<float> datax = {1, 2, 3, 4, 5, 6};
auto lit = m.add_literal(migraphx::literal(xs, datax));
auto mul = m.add_instruction(migraphx::make_op("mul"), lit, lit);
EXPECT(und->is_undefined());
EXPECT(cov->is_undefined());
EXPECT(abs->is_undefined());
EXPECT(not lit->is_undefined());
EXPECT(not mul->is_undefined());
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
external_constant_test:¡
v0"Constant*g
value*[B const_tensorj)
locationexternal_constant_test.weightj
offset48j
length24p external_constant_testb
0

B
\ No newline at end of file
This diff is collapsed.
...@@ -181,6 +181,24 @@ TEST_CASE(argmax_test) ...@@ -181,6 +181,24 @@ TEST_CASE(argmax_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(argmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"x",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {4, 4, 0}, {5, 5, 0}, {6, 6, 0}}});
auto ins = mm->add_instruction(migraphx::make_op("argmax", {{"axis", 2}}), l0);
auto ret = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), ins);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("argmax_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(argmin_test) TEST_CASE(argmin_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -1800,6 +1818,16 @@ migraphx::program create_external_data_prog() ...@@ -1800,6 +1818,16 @@ migraphx::program create_external_data_prog()
return p; return p;
} }
TEST_CASE(external_constant_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
mm->add_literal(migraphx::literal{{migraphx::shape::int64_type, {3}}, {0, 1, 2}});
auto prog = optimize_onnx("external_constant_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(external_data_test) TEST_CASE(external_data_test)
{ {
migraphx::program p = create_external_data_prog(); migraphx::program p = create_external_data_prog();
...@@ -1959,6 +1987,23 @@ TEST_CASE(flatten_nonstd_test) ...@@ -1959,6 +1987,23 @@ TEST_CASE(flatten_nonstd_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(flatten_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto ret = mm->add_instruction(migraphx::make_op("flatten", {{"axis", 2}}), c0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("flatten_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(floor_test) TEST_CASE(floor_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5604,6 +5649,23 @@ TEST_CASE(softmax_nonstd_input_test) ...@@ -5604,6 +5649,23 @@ TEST_CASE(softmax_nonstd_input_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(softmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), l0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("softmax_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(softplus_test) TEST_CASE(softplus_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -5802,6 +5864,29 @@ TEST_CASE(squeeze_unsqueeze_test) ...@@ -5802,6 +5864,29 @@ TEST_CASE(squeeze_unsqueeze_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(squeeze_unsqueeze_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
std::vector<int64_t> squeeze_axes{0, 2, 3, 5};
std::vector<int64_t> unsqueeze_axes{0, 1, 3, 5};
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 4, 0}, {1, 1, 0}, {1, 1, 0}, {1, 4, 0}, {1, 1, 0}}});
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
auto l1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", squeeze_axes}}), c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l1);
auto ret = mm->add_instruction(migraphx::make_op("unsqueeze", {{"axes", unsqueeze_axes}}), c1);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = parse_onnx("squeeze_unsqueeze_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(squeeze_axes_input_test) TEST_CASE(squeeze_axes_input_test)
{ {
migraphx::program p; migraphx::program p;
......
This diff is collapsed.
...@@ -35,7 +35,7 @@ ...@@ -35,7 +35,7 @@
#include <migraphx/half.hpp> #include <migraphx/half.hpp>
template <class T> template <class T>
void matmul_test() void dot_2d_test()
{ {
migraphx::program p; migraphx::program p;
...@@ -82,11 +82,11 @@ void matmul_test() ...@@ -82,11 +82,11 @@ void matmul_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(matmul_test<float>) TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(matmul_test<double>) TEST_CASE_REGISTER(dot_2d_test<double>)
template <class T> template <class T>
void matmul_test_ex() void dot_4d_test()
{ {
migraphx::program p; migraphx::program p;
...@@ -133,10 +133,10 @@ void matmul_test_ex() ...@@ -133,10 +133,10 @@ void matmul_test_ex()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector)); EXPECT(migraphx::verify_range(c, results_vector));
} }
TEST_CASE_REGISTER(matmul_test_ex<float>) TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>) TEST_CASE_REGISTER(dot_4d_test<double>)
TEST_CASE(matmul_mutli_dim_2) TEST_CASE(dot_3D_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2) ...@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim_2_beta0) TEST_CASE(dot_3D_C_test0)
{ {
migraphx::program p; migraphx::program p;
...@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0) ...@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_beta_0) TEST_CASE(dot_3D_C_test1)
{ {
migraphx::program p; migraphx::program p;
...@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0) ...@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(matmul_mutli_dim_2_3) TEST_CASE(dot_4D_test1)
{ {
migraphx::program p; migraphx::program p;
...@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3) ...@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_dim1_2_3) TEST_CASE(dot_4D_alpha_beta_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3) ...@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_mutli_3args) TEST_CASE(dot_4D_alpha_beta_C_test)
{ {
migraphx::program p; migraphx::program p;
...@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args) ...@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res)); EXPECT(migraphx::verify_range(m, m_res));
} }
TEST_CASE(gemm_3args) TEST_CASE(dot_2D_C_test0)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args) ...@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args)
} }
} }
TEST_CASE(matmul_vv_inner_product) TEST_CASE(dot_vv_inner_product)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product) ...@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product)
} }
} }
TEST_CASE(matmul_vm) TEST_CASE(dot_vm)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm) ...@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm)
} }
} }
TEST_CASE(matmul_mv) TEST_CASE(dot_mv)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv) ...@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv)
} }
} }
TEST_CASE(matmul_mm1) TEST_CASE(dot_mm1)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1) ...@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1)
} }
} }
TEST_CASE(matmul_mm2) TEST_CASE(dot_mm2)
{ {
{ {
migraphx::program p; migraphx::program p;
...@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2) ...@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2)
} }
} }
TEST_CASE(dot_dyn_2D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), ap, bp);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {4, 5}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape, a.data());
params["b"] = migraphx::argument(b_shape, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(dot_dyn_4D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 1, 5, 3}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape0, a.data());
params["b"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(quant_dot_2args_multi4) TEST_CASE(quant_dot_2args_multi4)
{ {
{ {
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment