Unverified Commit dc8edad8 authored by Chris Austen's avatar Chris Austen Committed by GitHub
Browse files

Merge branch 'develop' into check-mlir-perf

parents b9cbfd8e 32f0b028
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_sinh : verify_program<test_sinh> template <migraphx::shape::type_t DType>
struct test_sinh : verify_program<test_sinh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("sinh"), x); mm->add_instruction(migraphx::make_op("sinh"), x);
return p; return p;
} }
}; };
template struct test_sinh<migraphx::shape::float_type>;
template struct test_sinh<migraphx::shape::half_type>;
template struct test_sinh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -48,3 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>; ...@@ -48,3 +48,7 @@ template struct test_softmax<0, migraphx::shape::half_type>;
template struct test_softmax<1, migraphx::shape::half_type>; template struct test_softmax<1, migraphx::shape::half_type>;
template struct test_softmax<2, migraphx::shape::half_type>; template struct test_softmax<2, migraphx::shape::half_type>;
template struct test_softmax<3, migraphx::shape::half_type>; template struct test_softmax<3, migraphx::shape::half_type>;
template struct test_softmax<0, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_softmax<1, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_softmax<2, migraphx::shape::fp8e4m3fnuz_type>;
template struct test_softmax<3, migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,16 +27,21 @@ ...@@ -27,16 +27,21 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_sqrt : verify_program<test_sqrt> template <migraphx::shape::type_t DType>
struct test_sqrt : verify_program<test_sqrt<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {2, 3, 4, 6}}; migraphx::shape s{DType, {2, 3, 4, 6}};
auto param = mm->add_parameter("x", s); auto param = mm->add_parameter("x", s);
auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param); auto param_abs = mm->add_instruction(migraphx::make_op("abs"), param);
mm->add_instruction(migraphx::make_op("sqrt"), param_abs); mm->add_instruction(migraphx::make_op("sqrt"), param_abs);
return p; return p;
} }
}; };
template struct test_sqrt<migraphx::shape::float_type>;
template struct test_sqrt<migraphx::shape::half_type>;
template struct test_sqrt<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,15 +27,20 @@ ...@@ -27,15 +27,20 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_tan : verify_program<test_tan> template <migraphx::shape::type_t DType>
struct test_tan : verify_program<test_tan<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
migraphx::shape s{migraphx::shape::float_type, {16}}; migraphx::shape s{DType, {16}};
auto x = mm->add_parameter("x", s); auto x = mm->add_parameter("x", s);
mm->add_instruction(migraphx::make_op("tan"), x); mm->add_instruction(migraphx::make_op("tan"), x);
return p; return p;
} }
}; };
template struct test_tan<migraphx::shape::float_type>;
template struct test_tan<migraphx::shape::half_type>;
template struct test_tan<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,14 +27,19 @@ ...@@ -27,14 +27,19 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_tanh : verify_program<test_tanh> template <migraphx::shape::type_t DType>
struct test_tanh : verify_program<test_tanh<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
migraphx::program p; migraphx::program p;
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {4, 3, 3, 3}}); auto x = mm->add_parameter("x", migraphx::shape{DType, {4, 3, 3, 3}});
mm->add_instruction(migraphx::make_op("tanh"), x); mm->add_instruction(migraphx::make_op("tanh"), x);
return p; return p;
} }
}; };
template struct test_tanh<migraphx::shape::float_type>;
template struct test_tanh<migraphx::shape::half_type>;
template struct test_tanh<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_where : verify_program<test_where> template <migraphx::shape::type_t DType>
struct test_where : verify_program<test_where<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where> ...@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where>
return p; return p;
}; };
}; };
template struct test_where<migraphx::shape::float_type>;
template struct test_where<migraphx::shape::half_type>;
template struct test_where<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -134,9 +134,10 @@ def check_correctness(gold_outputs, ...@@ -134,9 +134,10 @@ def check_correctness(gold_outputs,
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False ret = False
if verbose: if verbose:
with np.printoptions(threshold=np.inf):
print('\nOutput {} is incorrect ...'.format(i)) print('\nOutput {} is incorrect ...'.format(i))
print('Expected value: \n{}'.format(gold_outputs[i])) print('Expected value: \n{}\n'.format(gold_outputs[i]))
print('......') print('\n......\n')
print('Actual value: \n{}\n'.format(outputs[i])) print('Actual value: \n{}\n'.format(outputs[i]))
else: else:
print('Outputs do not match') print('Outputs do not match')
......
...@@ -63,7 +63,8 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH): ...@@ -63,7 +63,8 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH):
print(f"{git_clang_format} not installed. Skipping format.") print(f"{git_clang_format} not installed. Skipping format.")
return return
diff_flag = "" if apply else "--diff" diff_flag = "" if apply else "--diff"
run(f"{git_clang_format} --binary {clang_format} {diff_flag} {base}") run(f"{git_clang_format} --extensions c,cpp,hpp,h,cl,hip,in --binary {clang_format} {diff_flag} {base}"
)
def get_files_changed(against, ext=('py')): def get_files_changed(against, ext=('py')):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment