Commit 8d7a8a6c authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents 25b33431 a09dc502
...@@ -27,7 +27,8 @@ ...@@ -27,7 +27,8 @@
#include <migraphx/generate.hpp> #include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp> #include <migraphx/make_op.hpp>
struct test_where : verify_program<test_where> template <migraphx::shape::type_t DType>
struct test_where : verify_program<test_where<DType>>
{ {
migraphx::program create_program() const migraphx::program create_program() const
{ {
...@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where> ...@@ -44,3 +45,7 @@ struct test_where : verify_program<test_where>
return p; return p;
}; };
}; };
template struct test_where<migraphx::shape::float_type>;
template struct test_where<migraphx::shape::half_type>;
template struct test_where<migraphx::shape::fp8e4m3fnuz_type>;
...@@ -134,10 +134,11 @@ def check_correctness(gold_outputs, ...@@ -134,10 +134,11 @@ def check_correctness(gold_outputs,
if not np.allclose(gold_outputs[i], outputs[i], rtol, atol): if not np.allclose(gold_outputs[i], outputs[i], rtol, atol):
ret = False ret = False
if verbose: if verbose:
print('\nOutput {} is incorrect ...'.format(i)) with np.printoptions(threshold=np.inf):
print('Expected value: \n{}'.format(gold_outputs[i])) print('\nOutput {} is incorrect ...'.format(i))
print('......') print('Expected value: \n{}\n'.format(gold_outputs[i]))
print('Actual value: \n{}\n'.format(outputs[i])) print('\n......\n')
print('Actual value: \n{}\n'.format(outputs[i]))
else: else:
print('Outputs do not match') print('Outputs do not match')
break break
......
...@@ -63,7 +63,8 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH): ...@@ -63,7 +63,8 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH):
print(f"{git_clang_format} not installed. Skipping format.") print(f"{git_clang_format} not installed. Skipping format.")
return return
diff_flag = "" if apply else "--diff" diff_flag = "" if apply else "--diff"
run(f"{git_clang_format} --binary {clang_format} {diff_flag} {base}") run(f"{git_clang_format} --extensions c,cpp,hpp,h,cl,hip,in --binary {clang_format} {diff_flag} {base}"
)
def get_files_changed(against, ext=('py')): def get_files_changed(against, ext=('py')):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment