Unverified Commit 5d0ca2a6 authored by Shucai Xiao's avatar Shucai Xiao Committed by GitHub
Browse files

Log softmax nonstd input shape (#740)



* fix a bug that softmax/logsoftmax cannot handle nonstd input shape

* clang format

* fix review comments

* clang format

* refine test to have more code coverage

* clang format
Co-authored-by: default avatarmvermeulen <5479696+mvermeulen@users.noreply.github.com>
parent 4ef8ae15
......@@ -30,8 +30,15 @@ struct logsoftmax
std::string name() const { return "logsoftmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
return inputs.at(0);
if(inputs.at(0).packed())
{
return inputs.at(0);
}
else
{
auto lens = inputs.at(0).lens();
return {inputs.at(0).type(), lens};
}
}
auto output() const
......
......@@ -30,8 +30,16 @@ struct softmax
std::string name() const { return "softmax"; }
shape normalize_compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
return inputs.at(0);
check_shapes{inputs, *this}.has(1);
if(inputs.at(0).packed())
{
return inputs.at(0);
}
else
{
auto lens = inputs.at(0).lens();
return {inputs.at(0).type(), lens};
}
}
auto output() const
......
......@@ -403,6 +403,7 @@ struct cpu_softmax : auto_register_op<cpu_softmax<Op>>
std::string name() const { return "cpu::" + op.name(); }
shape compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(1).standard();
return op.normalize_compute_shape(inputs);
}
argument compute(context&, const shape& output_shape, std::vector<argument> args) const
......
......@@ -1821,6 +1821,24 @@ def logsoftmax_test():
return ([node], [x], [y])
@onnx_test
def logsoftmax_nonstd_input_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [6, 9])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 4])
node0 = onnx.helper.make_node('Slice',
inputs=['0'],
axes=[0, 1],
starts=[1, 0],
ends=[4, 4],
outputs=['1'])
node1 = onnx.helper.make_node('LogSoftmax', inputs=['1'], outputs=['2'])
return ([node0, node1], [x], [z])
@onnx_test
def lrn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 28, 24, 24])
......@@ -2951,6 +2969,24 @@ def softmax_test():
return ([node], [x], [y])
@onnx_test
def softmax_nonstd_input_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [6, 8])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT, [3, 4])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT, [3, 4])
node0 = onnx.helper.make_node('Slice',
inputs=['0'],
axes=[0, 1],
starts=[1, 0],
ends=[4, 4],
outputs=['1'])
node1 = onnx.helper.make_node('Softmax', inputs=['1'], outputs=['2'])
return ([node0, node1], [x], [z])
@onnx_test
def split_minus_axis_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10, 15])
......
......@@ -1657,6 +1657,21 @@ TEST_CASE(logsoftmax_test)
EXPECT(p == prog);
}
TEST_CASE(logsoftmax_nonstd_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {6, 9}});
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {4, 4}}}), l0);
auto l2 = mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", 1}}), l1);
mm->add_return({l2});
auto prog = migraphx::parse_onnx("logsoftmax_nonstd_input_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(lrn_test)
{
migraphx::program p;
......@@ -2715,6 +2730,21 @@ TEST_CASE(softmax_test)
EXPECT(p == prog);
}
TEST_CASE(softmax_nonstd_input_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {6, 8}});
auto l1 = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 1}}, {"starts", {1, 0}}, {"ends", {4, 4}}}), l0);
auto l2 = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 1}}), l1);
mm->add_return({l2});
auto prog = migraphx::parse_onnx("softmax_nonstd_input_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(split_minus_axis_test)
{
migraphx::program p;
......
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_logsoftmax1 : verify_program<test_logsoftmax1>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}});
auto tx = mm->add_instruction(migraphx::make_op("transpose", {{"dims", {2, 3, 0, 1}}}), x);
auto r = mm->add_instruction(migraphx::make_op("logsoftmax", {{"axis", 0}}), tx);
mm->add_return({r});
return p;
}
};
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct test_softmax3 : verify_program<test_softmax3>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 3, 4}});
auto sx = mm->add_instruction(
migraphx::make_op("slice", {{"axes", {0, 3}}, {"starts", {1, 1}}, {"ends", {5, 4}}}),
x);
auto r = mm->add_instruction(migraphx::make_op("softmax", {{"axis", 0}}), sx);
mm->add_return({r});
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment