Unverified Commit b8b4630b authored by nives-vukovic's avatar nives-vukovic Committed by GitHub
Browse files

Fix trilu operator computation logic (#2212)

parent 6072b2c4
...@@ -56,9 +56,6 @@ struct parse_trilu : op_parser<parse_trilu> ...@@ -56,9 +56,6 @@ struct parse_trilu : op_parser<parse_trilu>
k = arg_k.at<int>(); k = arg_k.at<int>();
} }
if(k < 0)
MIGRAPHX_THROW("PARSE_TRILU: negative k values not supported");
if(contains(info.attributes, "upper")) if(contains(info.attributes, "upper"))
{ {
upper = static_cast<bool>(info.attributes.at("upper").i()); upper = static_cast<bool>(info.attributes.at("upper").i());
...@@ -69,9 +66,12 @@ struct parse_trilu : op_parser<parse_trilu> ...@@ -69,9 +66,12 @@ struct parse_trilu : op_parser<parse_trilu>
// when creating the mask, if upper == 1, // when creating the mask, if upper == 1,
// the inner triangle will have values set to 0 // the inner triangle will have values set to 0
std::vector<bool> mask_mat(num_rows * num_cols, upper); std::vector<bool> mask_mat(num_rows * num_cols, upper);
// if upper == 0, kth diagonal must also be masked
if(not upper)
k++;
for(size_t i = 0; i < num_rows; i++) for(size_t i = 0; i < num_rows; i++)
{ {
for(size_t j = 0; j < std::min(k, static_cast<int>(num_cols)); j++) for(int j = 0; j < std::min(k, static_cast<int>(num_cols)); j++)
{ {
mask_mat[i * num_cols + j] = not upper; mask_mat[i * num_cols + j] = not upper;
} }
......
...@@ -8573,7 +8573,7 @@ def transpose_gather_test(): ...@@ -8573,7 +8573,7 @@ def transpose_gather_test():
@onnx_test() @onnx_test()
def trilu_test(): def triu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
...@@ -8586,7 +8586,7 @@ def trilu_test(): ...@@ -8586,7 +8586,7 @@ def trilu_test():
@onnx_test() @onnx_test()
def trilu_batch_diff_k_test(): def triu_batch_diff_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
k = np.array([2]) k = np.array([2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
...@@ -8604,7 +8604,24 @@ def trilu_batch_diff_k_test(): ...@@ -8604,7 +8604,24 @@ def trilu_batch_diff_k_test():
@onnx_test() @onnx_test()
def trilu_lower_test(): def tril_batch_diff_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
k = np.array([2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def tril_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
...@@ -8613,7 +8630,7 @@ def trilu_lower_test(): ...@@ -8613,7 +8630,7 @@ def trilu_lower_test():
@onnx_test() @onnx_test()
def trilu_neg_k_test(): def triu_neg_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([-1]) k = np.array([-1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
...@@ -8627,7 +8644,23 @@ def trilu_neg_k_test(): ...@@ -8627,7 +8644,23 @@ def trilu_neg_k_test():
@onnx_test() @onnx_test()
def trilu_out_k_test(): def tril_neg_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([-1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def triu_out_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([5]) k = np.array([5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
...@@ -8641,7 +8674,23 @@ def trilu_out_k_test(): ...@@ -8641,7 +8674,23 @@ def trilu_out_k_test():
@onnx_test() @onnx_test()
def trilu_row_one_test(): def tril_out_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def triu_row_one_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4]) x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4])
k = np.array([1]) k = np.array([1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4]) y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4])
...@@ -8658,6 +8707,23 @@ def trilu_row_one_test(): ...@@ -8658,6 +8707,23 @@ def trilu_row_one_test():
return ([node], [x], [y], [k_tensor]) return ([node], [x], [y], [k_tensor])
@onnx_test()
def tril_row_one_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4])
k = np.array([1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0)
return ([node], [x], [y], [k_tensor])
@onnx_test() @onnx_test()
def undefined_test(): def undefined_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5]) x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
...@@ -8091,11 +8091,6 @@ TEST_CASE(transpose_gather_test) ...@@ -8091,11 +8091,6 @@ TEST_CASE(transpose_gather_test)
EXPECT(p.sort() == prog.sort()); EXPECT(p.sort() == prog.sort());
} }
TEST_CASE(trilu_neg_k_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("trilu_neg_k_test.onnx"); }));
}
TEST_CASE(undefined_test) TEST_CASE(undefined_test)
{ {
migraphx::program p; migraphx::program p;
......
trilu_batch_diff_k_test:i triu_batch_diff_k_test:h
 
x x
ky"Trilutrilu_batch_diff_k_test* ky"Trilutriu_batch_diff_k_test*
:BkZ :BkZ
x x
 
...@@ -12,4 +12,4 @@ ...@@ -12,4 +12,4 @@
 
 
 
B B
\ No newline at end of file \ No newline at end of file
trilu_neg_k_test:c triu_neg_k_test:b
 
x x
ky"Trilutrilu_neg_k_test*: ky"Trilutriu_neg_k_test*:
BkZ BkZ
x x
 
...@@ -10,4 +10,4 @@ ...@@ -10,4 +10,4 @@
y y
 
 
B B
\ No newline at end of file \ No newline at end of file
trilu_out_k_test:Z triu_out_k_test:Y
 
x x
ky"Trilutrilu_out_k_test* ky"Trilutriu_out_k_test*
:BkZ :BkZ
x x
 
...@@ -10,4 +10,4 @@ ...@@ -10,4 +10,4 @@
y y
 
 
B B
\ No newline at end of file \ No newline at end of file
trilu_row_one_test:\ triu_row_one_test:[
 
x x
ky"Trilutrilu_row_one_test* ky"Trilutriu_row_one_test*
:BkZ :BkZ
x x
 
...@@ -10,4 +10,4 @@ ...@@ -10,4 +10,4 @@
y y
 
 
B B
\ No newline at end of file \ No newline at end of file
  triu_test:D
trilu_test:E
xy"Trilu xy"Trilu triu_testZ
trilu_testZ
x x
 
 
...@@ -10,4 +8,4 @@ trilu_testZ ...@@ -10,4 +8,4 @@ trilu_testZ
y y
 
 
B B
\ No newline at end of file \ No newline at end of file
...@@ -2233,9 +2233,10 @@ std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::prog ...@@ -2233,9 +2233,10 @@ std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::prog
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
return result_vector; return result_vector;
} }
TEST_CASE(trilu_test)
TEST_CASE(triu_test)
{ {
migraphx::program p = migraphx::parse_onnx("trilu_test.onnx"); migraphx::program p = migraphx::parse_onnx("triu_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
...@@ -2244,9 +2245,9 @@ TEST_CASE(trilu_test) ...@@ -2244,9 +2245,9 @@ TEST_CASE(trilu_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_batch_diff_k_test) TEST_CASE(triu_batch_diff_k_test)
{ {
migraphx::program p = migraphx::parse_onnx("trilu_batch_diff_k_test.onnx"); migraphx::program p = migraphx::parse_onnx("triu_batch_diff_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p); std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p);
...@@ -2255,9 +2256,42 @@ TEST_CASE(trilu_batch_diff_k_test) ...@@ -2255,9 +2256,42 @@ TEST_CASE(trilu_batch_diff_k_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_lower_test) TEST_CASE(tril_test)
{ {
migraphx::program p = migraphx::parse_onnx("trilu_lower_test.onnx"); migraphx::program p = migraphx::parse_onnx("tril_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold = {1, 0, 0, 0, 5, 6, 0, 0, 9, 10, 11, 0};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(tril_batch_diff_k_test)
{
migraphx::program p = migraphx::parse_onnx("tril_batch_diff_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p);
std::vector<float> gold = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(triu_neg_k_test)
{
migraphx::program p = migraphx::parse_onnx("triu_neg_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold = {1, 2, 3, 4, 5, 6, 7, 8, 0, 10, 11, 12};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(tril_neg_k_test)
{
migraphx::program p = migraphx::parse_onnx("tril_neg_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
...@@ -2266,9 +2300,9 @@ TEST_CASE(trilu_lower_test) ...@@ -2266,9 +2300,9 @@ TEST_CASE(trilu_lower_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_out_k_test) TEST_CASE(triu_out_k_test)
{ {
migraphx::program p = migraphx::parse_onnx("trilu_out_k_test.onnx"); migraphx::program p = migraphx::parse_onnx("triu_out_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p); std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
...@@ -2277,9 +2311,20 @@ TEST_CASE(trilu_out_k_test) ...@@ -2277,9 +2311,20 @@ TEST_CASE(trilu_out_k_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(trilu_row_one_test) TEST_CASE(tril_out_k_test)
{
migraphx::program p = migraphx::parse_onnx("tril_out_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(triu_row_one_test)
{ {
migraphx::program p = migraphx::parse_onnx("trilu_row_one_test.onnx"); migraphx::program p = migraphx::parse_onnx("triu_row_one_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p); std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p);
...@@ -2288,4 +2333,15 @@ TEST_CASE(trilu_row_one_test) ...@@ -2288,4 +2333,15 @@ TEST_CASE(trilu_row_one_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
} }
TEST_CASE(tril_row_one_test)
{
migraphx::program p = migraphx::parse_onnx("tril_row_one_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p);
std::vector<float> gold = {1, 2, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -590,9 +590,6 @@ def disabled_tests_onnx_1_9_0(backend_test): ...@@ -590,9 +590,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
backend_test.exclude(r'test_gru_batchwise_cpu') backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_lstm_batchwise_cpu') backend_test.exclude(r'test_lstm_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu') backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
backend_test.exclude(r'test_tril_cpu')
backend_test.exclude(r'test_tril_one_row_neg_cpu')
backend_test.exclude(r'test_tril_square_cpu')
# from OnnxBackendPyTorchConvertedModelTest # from OnnxBackendPyTorchConvertedModelTest
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu') backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu') backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment