Commit 385452d4 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

add test cases for the logsoftmax operator.

parent 05d2e2ca
...@@ -945,7 +945,7 @@ struct logsoftmax ...@@ -945,7 +945,7 @@ struct logsoftmax
shape compute_shape(std::vector<shape> inputs) const shape compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs}.has(1); check_shapes{inputs}.has(1);
if(axis < 0 || axis >= inputs[0].lens().size()) if(axis < 0 || axis > inputs[0].lens().size())
{ {
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) + MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range"); " is out of range");
......
...@@ -652,18 +652,18 @@ struct cpu_logsoftmax ...@@ -652,18 +652,18 @@ struct cpu_logsoftmax
std::vector<value_type> batch_max(batch_shape.elements(), std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest()); std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end())); batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
}); });
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index]; output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index];
}); });
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0)); std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_sum[index] += std::exp(output(idx.begin(), idx.end())); batch_sum[index] += std::exp(output(idx.begin(), idx.end()));
}); });
...@@ -673,7 +673,7 @@ struct cpu_logsoftmax ...@@ -673,7 +673,7 @@ struct cpu_logsoftmax
} }
shape_for_each(output_shape, [&](auto idx) { shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis); auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) -= batch_sum[index]; output(idx.begin(), idx.end()) -= batch_sum[index];
}); });
}); });
......
...@@ -1036,6 +1036,186 @@ TEST_CASE(softmax_test) ...@@ -1036,6 +1036,186 @@ TEST_CASE(softmax_test)
EXPECT(migraphx::verify_range(results_vector, s)); EXPECT(migraphx::verify_range(results_vector, s));
} }
TEST_CASE(logsoftmax_test_axis_0)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913,
1.07816336, -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834,
-0.8940665 , -1.62855592, -0.53763057, -1.48165117, -0.64154112, 0.42486547,
0.89330917, -2.42022666, 0.192611 , -0.01257413, -1.5326607 , 0.53137897,
-1.52383859, 0.46994381, 0.00453619, 0.0066996 , 1.58394908, 0.84216752,
-0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, -0.08610038,
0.79020567, -0.67714548, 0.42774631, 0.1376574 , 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282,
-0.06616535, -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-2.71138556, -5.85030702, -3.74063578, -4.22915517, -6.15821977, -5.96072346,
-3.57208097, -5.78313166, -5.51435497, -3.67224195, -3.88393048, -2.57061599,
-5.54431083, -6.27880025, -5.1878749 , -6.1318955 , -5.29178545, -4.22537886,
-3.75693516, -7.07047099, -4.45763333, -4.66281846, -6.18290503, -4.11886536,
-6.17408292, -4.18030052, -4.64570814, -4.64354473, -3.06629525, -3.80807681,
-4.69162374, -5.53605222, -3.20969275, -4.82645674, -6.63942356, -4.73634471,
-3.86003866, -5.32738981, -4.22249802, -4.51258693, -2.41455206, -3.48343199,
-5.86215889, -4.93435935, -4.83713408, -2.97471885, -2.16666459, -3.69133151,
-4.71640968, -5.64652924, -3.60709827, -5.87967748, -3.8809403 , -4.33917815};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 0;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_1)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913,
1.07816336, -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834,
-0.8940665 , -1.62855592, -0.53763057, -1.48165117, -0.64154112, 0.42486547,
0.89330917, -2.42022666, 0.192611 , -0.01257413, -1.5326607 , 0.53137897,
-1.52383859, 0.46994381, 0.00453619, 0.0066996 , 1.58394908, 0.84216752,
-0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, -0.08610038,
0.79020567, -0.67714548, 0.42774631, 0.1376574 , 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282,
-0.06616535, -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-1.77931988, -4.91824134, -2.80857010, -3.29708949, -5.22615409, -5.02865778,
-2.64001529, -4.85106598, -4.58228929, -2.74017627, -2.95186480, -1.63855031,
-4.61224515, -5.34673457, -4.25580922, -5.19982982, -4.35971977, -3.29331318,
-2.82486948, -6.13840531, -3.52556765, -3.73075278, -5.25083935, -3.18679968,
-5.24201724, -3.24823484, -3.71364246, -4.14309917, -2.56584969, -3.30763125,
-4.19117818, -5.03560666, -2.70924719, -4.32601118, -6.13897800, -4.23589915,
-3.35959310, -4.82694425, -3.72205246, -4.01214137, -1.91410650, -2.98298643,
-5.36171333, -4.43391379, -4.33668852, -2.47427329, -1.66621903, -3.19088595,
-4.21596412, -5.14608368, -3.10665271, -5.37923192, -3.38049474, -3.83873259};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_2)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913,
1.07816336, -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834,
-0.8940665 , -1.62855592, -0.53763057, -1.48165117, -0.64154112, 0.42486547,
0.89330917, -2.42022666, 0.192611 , -0.01257413, -1.5326607 , 0.53137897,
-1.52383859, 0.46994381, 0.00453619, 0.0066996 , 1.58394908, 0.84216752,
-0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, -0.08610038,
0.79020567, -0.67714548, 0.42774631, 0.1376574 , 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282,
-0.06616535, -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.79763715, -3.93655861, -1.82688737, -2.31540676, -4.24447136, -4.04697505,
-1.65833256, -3.86938325, -3.60060656, -1.81223672, -2.02392525, -0.71061076,
-3.68430560, -4.41879502, -3.32786967, -4.27189027, -3.43178022, -2.36537363,
-1.35498658, -4.66852241, -2.05568475, -2.26086988, -3.78095645, -1.71691678,
-3.77213434, -1.77835194, -2.24375956, -2.74631770, -1.16906822, -1.91084978,
-2.79439671, -3.63882519, -1.31246572, -2.92922971, -4.74219653, -2.83911768,
-2.19738500, -3.66473615, -2.55984436, -2.84993327, -0.75189840, -1.82077833,
-4.19950523, -3.27170569, -3.17448042, -1.65286841, -0.84481415, -2.36948107,
-3.39455924, -4.32467880, -2.28524783, -4.55782704, -2.55908986, -3.01732771};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 2;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_3)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913,
1.07816336, -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834,
-0.8940665 , -1.62855592, -0.53763057, -1.48165117, -0.64154112, 0.42486547,
0.89330917, -2.42022666, 0.192611 , -0.01257413, -1.5326607 , 0.53137897,
-1.52383859, 0.46994381, 0.00453619, 0.0066996 , 1.58394908, 0.84216752,
-0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, -0.08610038,
0.79020567, -0.67714548, 0.42774631, 0.1376574 , 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282,
-0.06616535, -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.33690375, -3.47582521, -1.36615397, -0.27936556, -2.20843016, -2.01093385,
-0.22551114, -2.43656183, -2.16778514, -1.57241522, -1.78410375, -0.47078926,
-1.06745881, -1.80194823, -0.71102288, -2.30719726, -1.46708721, -0.40068062,
-0.42698261, -3.74051844, -1.12768078, -1.07891856, -2.59900513, -0.53496546,
-2.56139951, -0.56761711, -1.03302473, -2.09771276, -0.52046328, -1.26224484,
-1.76322959, -2.60765807, -0.28129860, -0.81424303, -2.62720985, -0.72413100,
-0.65570381, -2.12305496, -1.01816317, -2.48063402, -0.38259915, -1.45147908,
-1.84310238, -0.91530284, -0.81807757, -1.31692881, -0.50887455, -2.03354147,
-1.48767160, -2.41779116, -0.37836019, -2.56853147, -0.56979429, -1.02803214};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 3;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_4)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913,
1.07816336, -1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834,
-0.8940665 , -1.62855592, -0.53763057, -1.48165117, -0.64154112, 0.42486547,
0.89330917, -2.42022666, 0.192611 , -0.01257413, -1.5326607 , 0.53137897,
-1.52383859, 0.46994381, 0.00453619, 0.0066996 , 1.58394908, 0.84216752,
-0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923, -0.08610038,
0.79020567, -0.67714548, 0.42774631, 0.1376574 , 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282,
-0.06616535, -0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 4;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(conv2d_test) TEST_CASE(conv2d_test)
{ {
migraphx::program p; migraphx::program p;
......
logsoftmax-example:l

xy"
LogSoftmax*
axistest_logsoftmaxZ
x




b
y




B
\ No newline at end of file
...@@ -651,4 +651,15 @@ TEST_CASE(add_fp16_test) ...@@ -651,4 +651,15 @@ TEST_CASE(add_fp16_test)
EXPECT(p == prog); EXPECT(p == prog);
} }
TEST_CASE(logsoftmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, l0);
auto prog = migraphx::parse_onnx("logsoftmax_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -316,6 +316,61 @@ TEST_CASE(gather) ...@@ -316,6 +316,61 @@ TEST_CASE(gather)
} }
} }
TEST_CASE(logsoftmax)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 5;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = -1;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
}
TEST_CASE(rnn) TEST_CASE(rnn)
{ {
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment