Unverified Commit 08dbaa12 authored by Charlie Lin's avatar Charlie Lin Committed by GitHub
Browse files

Merge branch 'develop' into dynamic_reduce

parents 4c11aeb5 b41c1f01
......@@ -198,7 +198,7 @@ struct check_shapes
*/
const check_shapes& same_ndims() const
{
if(not this->same([](const shape& s) { return s.max_lens().size(); }))
if(not this->same([](const shape& s) { return s.ndim(); }))
MIGRAPHX_THROW(prefix() + "Number of dimensions do not match");
return *this;
}
......
......@@ -28,6 +28,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <migraphx/dyn_output.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
......@@ -38,41 +39,69 @@ struct dot
std::string name() const { return "dot"; }
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.same_type().has(2);
check_shapes{inputs, *this, true}.same_type().same_ndims().has(2);
const shape& a = inputs.at(0);
const shape& b = inputs.at(1);
auto t = a.type();
if(not std::all_of(
inputs.begin(), inputs.end(), [](auto s) { return s.lens().size() >= 2; }))
if(not std::all_of(inputs.begin(), inputs.end(), [](auto s) { return s.ndim() >= 2; }))
{
MIGRAPHX_THROW("DOT: dot only accept 2 or more dims operands");
MIGRAPHX_THROW("DOT: dot only accepts operands with 2 or more dimensions ");
}
// only handle the case that the batch size of a and b are the same
if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
if(a.dynamic() or b.dynamic())
{
MIGRAPHX_THROW("DOT: batch size of A and B mismatch: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
auto s0 = a.to_dynamic();
auto s1 = b.to_dynamic();
if(not std::equal(s0.dyn_dims().rbegin() + 2,
s0.dyn_dims().rend(),
s1.dyn_dims().rbegin() + 2,
s1.dyn_dims().rend()))
{
MIGRAPHX_THROW("DOT: dynamic outer dimensions of A and B mismatch: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
std::size_t dim_0 = s0.ndim() - 2;
std::size_t dim_1 = s0.ndim() - 1;
if(s0.dyn_dims()[dim_1] != s1.dyn_dims()[dim_0])
{
MIGRAPHX_THROW("DOT: dynamic inner dimensions do not match: {" +
to_string_range(s0.dyn_dims()) + "} x {" +
to_string_range(s1.dyn_dims()) + "}");
}
auto out_dyn_dims = s0.dyn_dims();
out_dyn_dims[dim_1] = s1.dyn_dims()[dim_1];
return {t, out_dyn_dims};
}
std::size_t dim_0 = a.lens().size() - 2;
std::size_t dim_1 = a.lens().size() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
else
{
MIGRAPHX_THROW("DOT: inner dimensions do not match: {" + to_string_range(a.lens()) +
"} x {" + to_string_range(b.lens()) + "}");
}
// only handle the case that all the dimensions except the last two are the same
if(not std::equal(
a.lens().rbegin() + 2, a.lens().rend(), b.lens().rbegin() + 2, b.lens().rend()))
{
MIGRAPHX_THROW("DOT: static outer dimensions of A and B mismatch: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
std::size_t dim_0 = a.ndim() - 2;
std::size_t dim_1 = a.ndim() - 1;
if(a.lens()[dim_1] != b.lens()[dim_0])
{
MIGRAPHX_THROW("DOT: static inner dimensions do not match: {" +
to_string_range(a.lens()) + "} x {" + to_string_range(b.lens()) +
"}");
}
auto out_lens = a.lens();
out_lens[dim_1] = b.lens()[dim_1];
return {t, out_lens};
}
}
argument compute(shape output_shape, std::vector<argument> args) const
argument compute(const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result = argument{output_shape};
argument result = argument{dyn_out.computed_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, 1.0f, 0.0f); });
return result;
......
......@@ -53,15 +53,15 @@ struct softmax
std::string name() const { return "softmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1);
if(inputs.at(0).packed())
check_shapes{inputs, *this, true}.has(1);
auto s0 = inputs[0];
if(s0.dynamic() or s0.packed())
{
return inputs.at(0);
return s0;
}
else
{
auto lens = inputs.at(0).lens();
return {inputs.at(0).type(), lens};
return {s0.type(), s0.lens()};
}
}
......
......@@ -185,7 +185,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
options.push_back("-fno-gpu-rdc");
options.push_back(" -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3"));
options.push_back("-Wno-cuda-compat");
options.push_back("--cuda-gpu-arch=" + arch);
options.push_back("--offload-arch=" + arch);
prog.compile(options);
return {prog.get_code_obj()};
}
......@@ -237,7 +237,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
}
else if(is_hip_clang_compiler())
{
params += " --cuda-gpu-arch=" + arch;
params += " --offload-arch=" + arch;
params += " --cuda-device-only";
params += " -O" + string_value_of(MIGRAPHX_GPU_OPTIMIZE{}, "3") + " ";
}
......
......@@ -383,9 +383,9 @@ struct ref_gemm
std::string name() const { return "ref::dot"; }
shape compute_shape(const std::vector<shape>& inputs) const { return op.compute_shape(inputs); }
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
argument result{dyn_out.computed_shape};
migemm(result, args[0], args[1], 1.0f, 0.0f);
return result;
......@@ -449,10 +449,10 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
{
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
argument compute(context&, const dyn_output& dyn_out, std::vector<argument> args) const
{
argument result{output_shape};
auto batch_lens = output_shape.lens();
argument result{dyn_out.computed_shape};
auto batch_lens = dyn_out.computed_shape.lens();
int64_t tuned_axis = tune_axis(args[0].get_shape().lens().size(), op.axis, op.name());
std::size_t n_dims = batch_lens[tuned_axis];
batch_lens[tuned_axis] = 1;
......@@ -475,7 +475,7 @@ struct ref_softmax : auto_register_op<ref_softmax<Op>>
for(std::size_t j = 0; j < n_dims; ++j)
{
idx[tuned_axis] = j;
std::size_t index = output_shape.index(idx);
std::size_t index = dyn_out.computed_shape.index(idx);
output[index] = std::exp(input[index] - batch_max[i]);
}
......
......@@ -5843,6 +5843,16 @@ def softmax_nonstd_input_test():
return ([node0, node1], [x], [y])
@onnx_test
def softmax_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [None, 3, 4, 4])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [None, 3, 4, 4])
node = onnx.helper.make_node('Softmax', inputs=['0'], outputs=['1'])
return ([node], [x], [y])
@onnx_test
def softsign_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [5])
......
......@@ -5701,6 +5701,23 @@ TEST_CASE(softmax_nonstd_input_test)
EXPECT(p == prog);
}
TEST_CASE(softmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter(
"0",
migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {4, 4, 0}}});
auto ret = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), l0);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {1, 4, 0};
auto prog = migraphx::parse_onnx("softmax_dyn_test.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(softplus_test)
{
migraphx::program p;
......
......@@ -467,6 +467,199 @@ TEST_CASE(deconvolution_shape)
weights_3d);
}
TEST_CASE(dot_ndim_error0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_ndim_error4)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {10, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_inner_error3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_mismatch_outer_error)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 4}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {4, 8}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_2D_test3)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(
migraphx::shape{migraphx::shape::float_type, {1, 1}}, migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_3D_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_3D_test_1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_3D_test2)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_4D_test)
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_static_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {8, 8, 0}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_static_mismatch_error)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {3, 3, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_dyn_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{5, 5, 0}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_dyn_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {4, 5, 5}}};
migraphx::shape s_m2{migraphx::shape::float_type, {{4, 5, 5}, {6, 8, 8}}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {{1, 4, 0}, {6, 8, 8}}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
TEST_CASE(dot_dyn_mismatch_test0)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}, {5, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(dot_dyn_mismatch_test1)
{
migraphx::shape s_m1{migraphx::shape::float_type, {{4, 4, 0}, {5, 5, 0}, {2, 5, 0}}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
TEST_CASE(flatten_shape)
{
migraphx::shape input{migraphx::shape::float_type, {2, 4, 6, 8}};
......@@ -638,46 +831,6 @@ TEST_CASE(gather)
}
}
// 3 input arguments
TEST_CASE(gemm)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {10, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 6}};
migraphx::shape s_m2{migraphx::shape::float_type, {2, 5, 8}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
}
TEST_CASE(get_tuple_elem_test)
{
migraphx::shape s0{migraphx::shape::bool_type, {1, 1}};
......@@ -1131,106 +1284,6 @@ TEST_CASE(lstm)
}
}
// 2 inputs arguments
TEST_CASE(matmul)
{
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 2}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {4, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 6, 1, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 6, 5, 4}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 6, 1, 4}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {5, 8}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {4, 8}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 1}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 5, 7}};
expect_shape(migraphx::shape{migraphx::shape::float_type, {1, 4, 7}},
migraphx::make_op("dot"),
s_m1,
s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 1, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
{
migraphx::shape s_m1{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape s_m2{migraphx::shape::float_type, {1, 2, 5, 7}};
throws_shape(migraphx::make_op("dot"), s_m1, s_m2);
}
}
TEST_CASE(multibroadcast)
{
{
......@@ -2173,6 +2226,56 @@ TEST_CASE(slice_shape)
TEST_CASE(softmax) { test_softmax_variations<migraphx::op::softmax>(); }
TEST_CASE(softmax_dyn0)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 4, 0}, {3, 3, 0}, {4, 4, 0}, {5, 5, 0}}};
expect_shape(input, migraphx::make_op("softmax", {{"axis", 0}}), input);
}
TEST_CASE(softmax_dyn1)
{
migraphx::shape input{migraphx::shape::float_type,
{{1, 1, 0}, {3, 3, 0}, {4, 6, 0}, {5, 8, 6}}};
expect_shape(input, migraphx::make_op("softmax", {{"axis", 0}}), input);
}
TEST_CASE(test_argmax)
{
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {1, 3, 4, 5}},
migraphx::make_op("argmax", {{"axis", 0}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 1, 4, 5}},
migraphx::make_op("argmax", {{"axis", 1}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 1, 5}},
migraphx::make_op("argmax", {{"axis", 2}}),
input);
}
{
migraphx::shape input{migraphx::shape::half_type, {2, 3, 4, 5}};
expect_shape(migraphx::shape{migraphx::shape::int64_type, {2, 3, 4, 1}},
migraphx::make_op("argmax", {{"axis", 3}}),
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
throws_shape(migraphx::make_op("argmax", {{"axis", 4}}), input);
}
}
TEST_CASE(test_argmin)
{
{
......
......@@ -35,7 +35,7 @@
#include <migraphx/half.hpp>
template <class T>
void matmul_test()
void dot_2d_test()
{
migraphx::program p;
......@@ -82,11 +82,11 @@ void matmul_test()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test<float>)
TEST_CASE_REGISTER(matmul_test<double>)
TEST_CASE_REGISTER(dot_2d_test<float>)
TEST_CASE_REGISTER(dot_2d_test<double>)
template <class T>
void matmul_test_ex()
void dot_4d_test()
{
migraphx::program p;
......@@ -133,10 +133,10 @@ void matmul_test_ex()
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE_REGISTER(matmul_test_ex<float>)
TEST_CASE_REGISTER(matmul_test_ex<double>)
TEST_CASE_REGISTER(dot_4d_test<float>)
TEST_CASE_REGISTER(dot_4d_test<double>)
TEST_CASE(matmul_mutli_dim_2)
TEST_CASE(dot_3D_test)
{
migraphx::program p;
......@@ -189,7 +189,7 @@ TEST_CASE(matmul_mutli_dim_2)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim_2_beta0)
TEST_CASE(dot_3D_C_test0)
{
migraphx::program p;
......@@ -265,7 +265,7 @@ TEST_CASE(gemm_mutli_dim_2_beta0)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_beta_0)
TEST_CASE(dot_3D_C_test1)
{
migraphx::program p;
......@@ -324,7 +324,7 @@ TEST_CASE(gemm_beta_0)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(matmul_mutli_dim_2_3)
TEST_CASE(dot_4D_test1)
{
migraphx::program p;
......@@ -363,7 +363,7 @@ TEST_CASE(matmul_mutli_dim_2_3)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_dim1_2_3)
TEST_CASE(dot_4D_alpha_beta_test)
{
migraphx::program p;
......@@ -417,7 +417,7 @@ TEST_CASE(gemm_mutli_dim1_2_3)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_mutli_3args)
TEST_CASE(dot_4D_alpha_beta_C_test)
{
migraphx::program p;
......@@ -469,7 +469,7 @@ TEST_CASE(gemm_mutli_3args)
EXPECT(migraphx::verify_range(m, m_res));
}
TEST_CASE(gemm_3args)
TEST_CASE(dot_2D_C_test0)
{
{
migraphx::program p;
......@@ -533,7 +533,7 @@ TEST_CASE(gemm_3args)
}
}
TEST_CASE(matmul_vv_inner_product)
TEST_CASE(dot_vv_inner_product)
{
{
migraphx::program p;
......@@ -608,7 +608,7 @@ TEST_CASE(matmul_vv_inner_product)
}
}
TEST_CASE(matmul_vm)
TEST_CASE(dot_vm)
{
{
migraphx::program p;
......@@ -778,7 +778,7 @@ TEST_CASE(matmul_vm)
}
}
TEST_CASE(matmul_mv)
TEST_CASE(dot_mv)
{
{
migraphx::program p;
......@@ -899,7 +899,7 @@ TEST_CASE(matmul_mv)
}
}
TEST_CASE(matmul_mm1)
TEST_CASE(dot_mm1)
{
{
migraphx::program p;
......@@ -1006,7 +1006,7 @@ TEST_CASE(matmul_mm1)
}
}
TEST_CASE(matmul_mm2)
TEST_CASE(dot_mm2)
{
{
migraphx::program p;
......@@ -1193,6 +1193,113 @@ TEST_CASE(matmul_mm2)
}
}
TEST_CASE(dot_dyn_2D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type, {{1, 4, 0}, {5, 5, 0}}};
auto ap = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {5, 3}};
auto bp = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), ap, bp);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {4, 5}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape, a.data());
params["b"] = migraphx::argument(b_shape, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(dot_dyn_4D_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 1, 0}, {1, 1, 0}, {4, 6, 4}, {5, 5, 0}}};
auto al = mm->add_parameter("a", a_shape);
migraphx::shape b_shape{migraphx::shape::float_type, {1, 1, 5, 3}};
auto bl = mm->add_parameter("b", b_shape);
mm->add_instruction(migraphx::make_op("dot"), al, bl);
p.compile(migraphx::ref::target{});
std::vector<float> a = {-0.00925222, 0.56250403, 0.70107397, 0.75402161, -0.505885,
1.33628943, -0.11413, -0.31270559, 1.59336732, -0.19361027,
-0.91620867, 0.40108416, -0.06969921, 0.68483471, -0.39906632,
-1.66423624, 0.69040076, -1.31490171, -0.11282616, -0.79391814};
std::vector<float> b = {6.09568541e-01,
-6.10527007e-01,
3.66646462e-01,
1.18951101e-01,
5.58777432e-01,
-3.21296298e-01,
-5.95997198e-01,
-5.01425721e-01,
-2.84606807e-01,
-5.73673557e-01,
-8.99430260e-01,
-4.25103093e-01,
1.53027987e+00,
-3.81407415e-04,
-3.29650255e-01};
migraphx::shape input_fixed_shape0{migraphx::shape::float_type, {1, 1, 4, 5}};
migraphx::shape input_fixed_shape1{migraphx::shape::float_type, {1, 1, 5, 3}};
migraphx::parameter_map params;
params["a"] = migraphx::argument(input_fixed_shape0, a.data());
params["b"] = migraphx::argument(input_fixed_shape1, b.data());
auto result = p.eval(params).back();
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> c = {-1.56327541e+00,
-7.09570140e-01,
-5.37424982e-01,
-2.22994831e-01,
-2.15586437e+00,
2.09177941e-03,
-1.47279677e+00,
2.02627040e-01,
-6.04527691e-01,
-1.29885596e+00,
2.16294914e+00,
-1.48101497e-01};
EXPECT(migraphx::verify_range(c, results_vector));
}
TEST_CASE(quant_dot_2args_multi4)
{
{
......
......@@ -7215,6 +7215,69 @@ TEST_CASE(softmax_test)
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(softmax_dyn_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape a_shape{migraphx::shape::float_type,
{{1, 10, 0}, {1, 3, 3}, {4, 4, 0}, {2, 2, 2}}};
auto al = mm->add_parameter("a", a_shape);
mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), al);
p.compile(migraphx::ref::target{});
std::vector<float> a = {
-5.61869681e-01, 9.07827199e-01, 1.29255986e+00, 3.18533443e-02, -1.22183852e-03,
-2.83830553e-01, -1.03245842e+00, -9.28322077e-01, -8.82696748e-01, 1.11327164e-01,
-9.20038462e-01, 8.47388089e-01, 2.51734018e-01, 1.50563884e+00, 2.23056650e+00,
-6.17576987e-02, -1.00264274e-01, -6.10369384e-01, 1.17537189e+00, -2.51560897e-01,
-8.50333512e-01, -8.03578615e-01, -6.51194930e-01, -2.58137047e-01, 4.65528190e-01,
3.23284641e-02, -1.54700470e+00, 1.38096774e+00, 5.39869189e-01, -7.56884992e-01,
1.81503093e+00, -2.11269641e+00, 1.92466557e+00, 1.77230799e+00, 2.21660900e+00,
1.56777036e+00, -2.08995026e-03, 3.50566894e-01, -1.15042710e+00, -1.18577778e+00,
8.90633047e-01, -6.63949102e-02, 1.44661188e+00, 1.59215283e+00, -2.56262213e-01,
9.39079225e-01, 4.07298543e-02, 3.86590779e-01, 6.09607756e-01, 8.22331488e-01,
-2.82126725e-01, -9.49052632e-01, -4.24012303e-01, -5.32990396e-01, -3.18386006e+00,
3.27092171e-01, -1.33315325e+00, 3.62459183e-01, 3.74710828e-01, -1.30302286e+00,
1.79680198e-01, -4.51832324e-01, 4.34282750e-01, -7.09520102e-01, 6.20333970e-01,
-1.28712380e+00, 2.04130828e-01, -7.70607769e-01, 1.61889160e+00, -1.50951004e+00,
-4.10505563e-01, -3.56566496e-02, -1.29747534e+00, -1.49967879e-01, 7.77626812e-01,
-8.28408226e-02, 2.73412596e-02, 5.79780899e-03, 9.87900198e-02, -7.95276761e-01,
-1.38536084e+00, -6.63573861e-01, 3.89783204e-01, -1.30670881e+00, -7.62425125e-01,
-4.04883057e-01, 6.24344349e-01, 3.68128955e-01, -1.01577950e+00, -3.06715906e-01,
5.67961395e-01, 2.98198581e-01, -1.63613629e+00, -3.75131965e-01, -6.75393403e-01,
2.59172034e+00, 6.75538957e-01, 9.07939598e-02, 1.92257717e-01, -1.21592450e+00,
-2.73682117e-01, 1.25232983e+00, -1.39969170e+00, -1.91483587e-01, 2.57732719e-01,
3.10056299e-01, 1.41833842e+00, -1.81386679e-01, 3.92868072e-01, -8.14771175e-01,
2.02392387e+00, -9.42091495e-02, -3.77683818e-01, 2.05638766e+00, 2.93796062e-01,
-6.02131486e-01, 2.70461679e-01, -8.92358482e-01, 1.04388881e+00, 2.66154885e-01};
migraphx::parameter_map params;
migraphx::shape input_fixed_shape{migraphx::shape::float_type, {5, 3, 4, 2}};
params["a"] = migraphx::argument(input_fixed_shape, a.data());
auto result = p.eval(params).back();
std::vector<float> results_vector(120);
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<float> s = {
0.30191708, 0.59879845, 0.50029165, 0.24915339, 0.36823985, 0.13190967, 0.0349741,
0.18750034, 0.21905553, 0.27000085, 0.0547399, 0.56318235, 0.47422904, 0.78964758,
0.91381913, 0.44601166, 0.47902739, 0.13120073, 0.4449684, 0.18766427, 0.15753111,
0.07844277, 0.05120674, 0.36648798, 0.14637007, 0.13152322, 0.01560997, 0.29065287,
0.49196178, 0.10550152, 0.81890774, 0.06369215, 0.62972021, 0.74931765, 0.67285055,
0.35034987, 0.28612873, 0.31931475, 0.04220394, 0.16093165, 0.22390974, 0.11915915,
0.3115395, 0.35899726, 0.22190949, 0.57518375, 0.13888834, 0.7753762, 0.4642328,
0.57055861, 0.21954368, 0.34515455, 0.09486015, 0.40631217, 0.01842281, 0.48770609,
0.06652815, 0.36023033, 0.42343026, 0.24226256, 0.17348589, 0.44066274, 0.6865865,
0.17296699, 0.46923906, 0.06921105, 0.3570261, 0.4125829, 0.73165393, 0.15302512,
0.29499072, 0.33932695, 0.30852377, 0.40762195, 0.40170741, 0.36259529, 0.60848355,
0.42618036, 0.31721094, 0.02960522, 0.28256637, 0.24389413, 0.2725659, 0.10663581,
0.27622163, 0.28264219, 0.53652936, 0.09476089, 0.40890986, 0.34848392, 0.32572666,
0.53076893, 0.11529481, 0.29117745, 0.14625968, 0.8756339, 0.49818122, 0.10656087,
0.1813329, 0.17664003, 0.21410346, 0.80408043, 0.02315119, 0.27155462, 0.32804728,
0.13268511, 0.61795473, 0.49703068, 0.41696799, 0.10175809, 0.71028161, 0.29929739,
0.17377149, 0.76075399, 0.20071237, 0.32632929, 0.36892858, 0.09416146, 0.26656723,
0.42914796};
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(sqdiff_test)
{
migraphx::program p;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment