Commit a797f890 authored by Paul Fultz II's avatar Paul Fultz II Committed by mvermeulen
Browse files

Fix bug in bert accuraccy (#385)

* Fix bug in bert accuraccy

* Formatting

* add another test

* Fix add and overflow

* Formatting

* Fix bug in shape_for_each

* Use front instead of iterator

* Use result.front()

* Split add_unary files

* Formatting

* Fix incorrect last index

* Remove comment

* Inline function

* Fix carry check

* Fix metadata errors

* Formatting

* Reflow

* Reflow
parent a625f7b4
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_RELU_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_MUL_ADD_RELU_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void mul_add_relu(hipStream_t stream,
const argument& result,
const argument& arg1,
const argument& arg2,
const argument& arg3);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
......@@ -1635,8 +1635,7 @@ TEST_CASE(contiguous_test)
result.visit([&](auto output) { results_vector.assign(output.begin(), output.end()); });
std::vector<size_t> new_lens = {1, 3, 2, 2};
std::vector<size_t> new_strides = {12, 1, 6, 3};
std::vector<float> gold = {0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11};
EXPECT(migraphx::verify_range(results_vector, gold));
EXPECT(migraphx::verify_range(results_vector, data));
}
TEST_CASE(identity_test)
......
......@@ -1643,6 +1643,32 @@ struct test_contiguous : verify_program<test_contiguous>
}
};
struct test_contiguous_broadcast : verify_program<test_contiguous_broadcast>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::contiguous{}, x);
EXPECT(p.get_shape().standard());
return p;
}
};
struct test_contiguous_broadcast_transpose : verify_program<test_contiguous_broadcast_transpose>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 3072, 768}, {0, 1, 3072}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::contiguous{}, x);
EXPECT(p.get_shape().standard());
return p;
}
};
struct test_transpose : verify_program<test_transpose>
{
migraphx::program create_program() const
......
......@@ -29,6 +29,15 @@ TEST_CASE(test_shape_packed_default)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_standard)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}};
EXPECT(s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_packed)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {2, 1}};
......@@ -56,6 +65,33 @@ TEST_CASE(test_shape_transposed2)
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_overlap)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 2}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_overlap2)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 2, 1}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_overlap3)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {4, 2, 1}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(not s.broadcasted());
}
TEST_CASE(test_shape_broadcasted)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 0}};
......@@ -65,6 +101,42 @@ TEST_CASE(test_shape_broadcasted)
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_broadcasted2)
{
migraphx::shape s{migraphx::shape::float_type, {1, 2}, {0, 1}};
EXPECT(not s.standard());
EXPECT(s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_broadcasted3)
{
migraphx::shape s{migraphx::shape::float_type, {3, 2}, {0, 1}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_broadcasted4)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 0, 1}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(not s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_broadcasted5)
{
migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {1, 0, 6}};
EXPECT(not s.standard());
EXPECT(not s.packed());
EXPECT(s.transposed());
EXPECT(s.broadcasted());
}
TEST_CASE(test_shape_default_copy)
{
migraphx::shape s1{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment