Commit 4c1e707b authored by Khalique's avatar Khalique
Browse files

fix test cases, revert code

parent 366d4f83
...@@ -33,12 +33,22 @@ argument softmax(hipStream_t stream, ...@@ -33,12 +33,22 @@ argument softmax(hipStream_t stream,
std::size_t row_start = i * n_dims; std::size_t row_start = i * n_dims;
// get max // get max
auto batch_max = input_ptr[row_start]; auto batch_max = input_ptr[row_start];
for(std::size_t j = 1; j < n_dims; ++j)
{
auto ind = row_start + j;
batch_max = std::max(to_hip_type(batch_max), to_hip_type(input_ptr[ind]));
}
for(std::size_t j = 0; j < n_dims; ++j)
{
auto ind = row_start + j;
output_ptr[ind] = input_ptr[ind] - batch_max;
}
for(std::size_t j = 0; j < n_dims; ++j) for(std::size_t j = 0; j < n_dims; ++j)
{ {
auto ind = row_start + j; auto ind = row_start + j;
auto hip_type_input = to_hip_type(input_ptr[ind]); output_ptr[ind] = exp(to_hip_type(input_ptr[ind]));
batch_max = std::max(to_hip_type(batch_max), hip_type_input);
output_ptr[ind] = ::exp(hip_type_input);
} }
auto batch_sum = output_ptr[row_start]; auto batch_sum = output_ptr[row_start];
......
...@@ -574,7 +574,7 @@ struct test_softmax : verify_program<test_softmax> ...@@ -574,7 +574,7 @@ struct test_softmax : verify_program<test_softmax>
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 4, 2}}); auto x = p.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {5, 3, 1, 1}});
p.add_instruction(migraphx::op::softmax{}, x); p.add_instruction(migraphx::op::softmax{}, x);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment