Commit 0b12d3c7 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

merge logsoftmax changes here.

parents 9b53cf55 6590c790
......@@ -957,7 +957,7 @@ struct logsoftmax
shape compute_shape(std::vector<shape> inputs) const
{
check_shapes{inputs}.has(1);
if(axis < 0 || axis >= inputs[0].lens().size())
if(axis < 0 || axis > inputs[0].lens().size())
{
MIGRAPHX_THROW("LogSoftMax: input axis value " + std::to_string(axis) +
" is out of range");
......
......@@ -652,18 +652,18 @@ struct cpu_logsoftmax
std::vector<value_type> batch_max(batch_shape.elements(),
std::numeric_limits<value_type>::lowest());
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_max[index] = std::max(batch_max[index], input(idx.begin(), idx.end()));
});
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) = input(idx.begin(), idx.end()) - batch_max[index];
});
std::vector<value_type> batch_sum(batch_shape.elements(), value_type(0));
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
batch_sum[index] += std::exp(output(idx.begin(), idx.end()));
});
......@@ -673,7 +673,7 @@ struct cpu_logsoftmax
}
shape_for_each(output_shape, [&](auto idx) {
auto index = compute_batch_index(idx, batch_shape, op.axis);
auto index = this->compute_batch_index(idx, batch_shape, op.axis);
output(idx.begin(), idx.end()) -= batch_sum[index];
});
});
......
......@@ -1085,6 +1085,176 @@ TEST_CASE(softmax_test)
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_0)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-2.71138556, -5.85030702, -3.74063578, -4.22915517, -6.15821977, -5.96072346, -3.57208097,
-5.78313166, -5.51435497, -3.67224195, -3.88393048, -2.57061599, -5.54431083, -6.27880025,
-5.1878749, -6.1318955, -5.29178545, -4.22537886, -3.75693516, -7.07047099, -4.45763333,
-4.66281846, -6.18290503, -4.11886536, -6.17408292, -4.18030052, -4.64570814, -4.64354473,
-3.06629525, -3.80807681, -4.69162374, -5.53605222, -3.20969275, -4.82645674, -6.63942356,
-4.73634471, -3.86003866, -5.32738981, -4.22249802, -4.51258693, -2.41455206, -3.48343199,
-5.86215889, -4.93435935, -4.83713408, -2.97471885, -2.16666459, -3.69133151, -4.71640968,
-5.64652924, -3.60709827, -5.87967748, -3.8809403, -4.33917815};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 0;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_1)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-1.77931988, -4.91824134, -2.80857010, -3.29708949, -5.22615409, -5.02865778, -2.64001529,
-4.85106598, -4.58228929, -2.74017627, -2.95186480, -1.63855031, -4.61224515, -5.34673457,
-4.25580922, -5.19982982, -4.35971977, -3.29331318, -2.82486948, -6.13840531, -3.52556765,
-3.73075278, -5.25083935, -3.18679968, -5.24201724, -3.24823484, -3.71364246, -4.14309917,
-2.56584969, -3.30763125, -4.19117818, -5.03560666, -2.70924719, -4.32601118, -6.13897800,
-4.23589915, -3.35959310, -4.82694425, -3.72205246, -4.01214137, -1.91410650, -2.98298643,
-5.36171333, -4.43391379, -4.33668852, -2.47427329, -1.66621903, -3.19088595, -4.21596412,
-5.14608368, -3.10665271, -5.37923192, -3.38049474, -3.83873259};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_2)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.79763715, -3.93655861, -1.82688737, -2.31540676, -4.24447136, -4.04697505, -1.65833256,
-3.86938325, -3.60060656, -1.81223672, -2.02392525, -0.71061076, -3.68430560, -4.41879502,
-3.32786967, -4.27189027, -3.43178022, -2.36537363, -1.35498658, -4.66852241, -2.05568475,
-2.26086988, -3.78095645, -1.71691678, -3.77213434, -1.77835194, -2.24375956, -2.74631770,
-1.16906822, -1.91084978, -2.79439671, -3.63882519, -1.31246572, -2.92922971, -4.74219653,
-2.83911768, -2.19738500, -3.66473615, -2.55984436, -2.84993327, -0.75189840, -1.82077833,
-4.19950523, -3.27170569, -3.17448042, -1.65286841, -0.84481415, -2.36948107, -3.39455924,
-4.32467880, -2.28524783, -4.55782704, -2.55908986, -3.01732771};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 2;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_3)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {
-0.33690375, -3.47582521, -1.36615397, -0.27936556, -2.20843016, -2.01093385, -0.22551114,
-2.43656183, -2.16778514, -1.57241522, -1.78410375, -0.47078926, -1.06745881, -1.80194823,
-0.71102288, -2.30719726, -1.46708721, -0.40068062, -0.42698261, -3.74051844, -1.12768078,
-1.07891856, -2.59900513, -0.53496546, -2.56139951, -0.56761711, -1.03302473, -2.09771276,
-0.52046328, -1.26224484, -1.76322959, -2.60765807, -0.28129860, -0.81424303, -2.62720985,
-0.72413100, -0.65570381, -2.12305496, -1.01816317, -2.48063402, -0.38259915, -1.45147908,
-1.84310238, -0.91530284, -0.81807757, -1.31692881, -0.50887455, -2.03354147, -1.48767160,
-2.41779116, -0.37836019, -2.56853147, -0.56979429, -1.02803214};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 3;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(logsoftmax_test_axis_4)
{
migraphx::program p;
std::vector<float> a = {
1.93885877, -1.20006269, 0.90960855, 0.42108916, -1.50797544, -1.31047913, 1.07816336,
-1.13288733, -0.86411064, 0.97800238, 0.76631385, 2.07962834, -0.8940665, -1.62855592,
-0.53763057, -1.48165117, -0.64154112, 0.42486547, 0.89330917, -2.42022666, 0.192611,
-0.01257413, -1.5326607, 0.53137897, -1.52383859, 0.46994381, 0.00453619, 0.0066996,
1.58394908, 0.84216752, -0.04137941, -0.88580789, 1.44055158, -0.17621241, -1.98917923,
-0.08610038, 0.79020567, -0.67714548, 0.42774631, 0.1376574, 2.23569227, 1.16681234,
-1.21191456, -0.28411502, -0.18688975, 1.67552548, 2.48357974, 0.95891282, -0.06616535,
-0.99628491, 1.04314606, -1.22943315, 0.76930403, 0.31106618};
std::vector<float> s = {0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000,
0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000, 0.00000000};
migraphx::shape a_shape{migraphx::shape::float_type, {2, 3, 3, 3}};
auto al = p.add_literal(migraphx::literal{a_shape, a});
int axis = 4;
p.add_instruction(migraphx::op::logsoftmax{axis}, al);
p.compile(migraphx::cpu::target{});
auto result = p.eval({});
std::vector<float> results_vector;
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
EXPECT(migraphx::verify_range(results_vector, s));
}
TEST_CASE(conv2d_test)
{
migraphx::program p;
......
......@@ -2925,6 +2925,34 @@ struct test_lstm_bidirct_default_actv2
}
};
template <int Axis>
struct test_logsoftmax
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3, 4, 5, 6}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
template <int Axis>
struct test_logsoftmax_1
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {3}};
auto param = p.add_parameter("0", s);
p.add_instruction(migraphx::op::logsoftmax{Axis}, param);
return p;
}
};
int main()
{
verify_program<test_relu_lrn>();
......@@ -3040,4 +3068,11 @@ int main()
verify_program<test_lstm_bidirct_default_actv>();
verify_program<test_lstm_bidirct_default_actv1>();
verify_program<test_lstm_bidirct_default_actv2>();
verify_program<test_logsoftmax<0>>();
verify_program<test_logsoftmax<1>>();
verify_program<test_logsoftmax<2>>();
verify_program<test_logsoftmax<3>>();
verify_program<test_logsoftmax<4>>();
verify_program<test_logsoftmax_1<0>>();
verify_program<test_logsoftmax_1<1>>();
}
logsoftmax-example:l

xy"
LogSoftmax*
axistest_logsoftmaxZ
x




b
y




B
\ No newline at end of file
......@@ -672,4 +672,15 @@ TEST_CASE(add_fp16_test)
EXPECT(p == prog);
}
TEST_CASE(logsoftmax)
{
migraphx::program p;
auto l0 = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {3, 4, 5, 6}});
int axis = 1;
p.add_instruction(migraphx::op::logsoftmax{axis}, l0);
auto prog = migraphx::parse_onnx("logsoftmax_test.onnx");
EXPECT(p == prog);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -316,6 +316,7 @@ TEST_CASE(gather)
}
}
<<<<<<< HEAD
TEST_CASE(dot)
{
{
......@@ -390,6 +391,61 @@ TEST_CASE(dot)
}
}
TEST_CASE(logsoftmax)
{
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 0;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 1;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 2;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 3;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 4;
expect_shape(migraphx::shape{migraphx::shape::float_type, {2, 3, 4, 5}},
migraphx::op::logsoftmax{axis},
input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = 5;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
{
migraphx::shape input{migraphx::shape::float_type, {2, 3, 4, 5}};
int axis = -1;
throws_shape(migraphx::op::logsoftmax{axis}, input);
}
}
TEST_CASE(rnn)
{
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment