"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "c7c549cc9a338897b9e2172f13dec563ba628c17"
Unverified Commit b8b4630b authored by nives-vukovic's avatar nives-vukovic Committed by GitHub
Browse files

Fix trilu operator computation logic (#2212)

parent 6072b2c4
......@@ -56,9 +56,6 @@ struct parse_trilu : op_parser<parse_trilu>
k = arg_k.at<int>();
}
if(k < 0)
MIGRAPHX_THROW("PARSE_TRILU: negative k values not supported");
if(contains(info.attributes, "upper"))
{
upper = static_cast<bool>(info.attributes.at("upper").i());
......@@ -69,9 +66,12 @@ struct parse_trilu : op_parser<parse_trilu>
// when creating the mask, if upper == 1,
// the inner triangle will have values set to 0
std::vector<bool> mask_mat(num_rows * num_cols, upper);
// if upper == 0, kth diagonal must also be masked
if(not upper)
k++;
for(size_t i = 0; i < num_rows; i++)
{
for(size_t j = 0; j < std::min(k, static_cast<int>(num_cols)); j++)
for(int j = 0; j < std::min(k, static_cast<int>(num_cols)); j++)
{
mask_mat[i * num_cols + j] = not upper;
}
......
......@@ -8573,7 +8573,7 @@ def transpose_gather_test():
@onnx_test()
def trilu_test():
def triu_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
......@@ -8586,7 +8586,7 @@ def trilu_test():
@onnx_test()
def trilu_batch_diff_k_test():
def triu_batch_diff_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
k = np.array([2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
......@@ -8604,7 +8604,24 @@ def trilu_batch_diff_k_test():
@onnx_test()
def trilu_lower_test():
def tril_batch_diff_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 2, 3])
k = np.array([2])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2, 2, 3])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def tril_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
......@@ -8613,7 +8630,7 @@ def trilu_lower_test():
@onnx_test()
def trilu_neg_k_test():
def triu_neg_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([-1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
......@@ -8627,7 +8644,23 @@ def trilu_neg_k_test():
@onnx_test()
def trilu_out_k_test():
def tril_neg_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([-1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def triu_out_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
......@@ -8641,7 +8674,23 @@ def trilu_out_k_test():
@onnx_test()
def trilu_row_one_test():
def tril_out_k_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4])
k = np.array([5])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def triu_row_one_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4])
k = np.array([1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4])
......@@ -8658,6 +8707,23 @@ def trilu_row_one_test():
return ([node], [x], [y], [k_tensor])
@onnx_test()
def tril_row_one_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 4])
k = np.array([1])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 4])
k_tensor = helper.make_tensor(name='k',
data_type=TensorProto.INT64,
dims=k.shape,
vals=k.astype(np.int64))
node = onnx.helper.make_node('Trilu',
inputs=['x', 'k'],
outputs=['y'],
upper=0)
return ([node], [x], [y], [k_tensor])
@onnx_test()
def undefined_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [2, 3, 4, 5])
......
......@@ -8091,11 +8091,6 @@ TEST_CASE(transpose_gather_test)
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(trilu_neg_k_test)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("trilu_neg_k_test.onnx"); }));
}
TEST_CASE(undefined_test)
{
migraphx::program p;
......
trilu_batch_diff_k_test:i
triu_batch_diff_k_test:h

x
ky"Trilutrilu_batch_diff_k_test*
ky"Trilutriu_batch_diff_k_test*
:BkZ
x

......@@ -12,4 +12,4 @@



B
\ No newline at end of file
B
\ No newline at end of file
trilu_neg_k_test:c
triu_neg_k_test:b

x
ky"Trilutrilu_neg_k_test*:
ky"Trilutriu_neg_k_test*:
BkZ
x

......@@ -10,4 +10,4 @@
y


B
\ No newline at end of file
B
\ No newline at end of file
trilu_out_k_test:Z
triu_out_k_test:Y

x
ky"Trilutrilu_out_k_test*
ky"Trilutriu_out_k_test*
:BkZ
x

......@@ -10,4 +10,4 @@
y


B
\ No newline at end of file
B
\ No newline at end of file
trilu_row_one_test:\
triu_row_one_test:[

x
ky"Trilutrilu_row_one_test*
ky"Trilutriu_row_one_test*
:BkZ
x

......@@ -10,4 +10,4 @@
y


B
\ No newline at end of file
B
\ No newline at end of file

trilu_test:E
 triu_test:D
xy"Trilu
trilu_testZ
xy"Trilu triu_testZ
x


......@@ -10,4 +8,4 @@ trilu_testZ
y


B
\ No newline at end of file
B
\ No newline at end of file
......@@ -2233,9 +2233,10 @@ std::vector<float> gen_trilu_test(const migraphx::shape& s, const migraphx::prog
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
return result_vector;
}
TEST_CASE(trilu_test)
TEST_CASE(triu_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_test.onnx");
migraphx::program p = migraphx::parse_onnx("triu_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
......@@ -2244,9 +2245,9 @@ TEST_CASE(trilu_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(trilu_batch_diff_k_test)
TEST_CASE(triu_batch_diff_k_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_batch_diff_k_test.onnx");
migraphx::program p = migraphx::parse_onnx("triu_batch_diff_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p);
......@@ -2255,9 +2256,42 @@ TEST_CASE(trilu_batch_diff_k_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(trilu_lower_test)
TEST_CASE(tril_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_lower_test.onnx");
migraphx::program p = migraphx::parse_onnx("tril_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold = {1, 0, 0, 0, 5, 6, 0, 0, 9, 10, 11, 0};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(tril_batch_diff_k_test)
{
migraphx::program p = migraphx::parse_onnx("tril_batch_diff_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {2, 2, 3}}, p);
std::vector<float> gold = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(triu_neg_k_test)
{
migraphx::program p = migraphx::parse_onnx("triu_neg_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold = {1, 2, 3, 4, 5, 6, 7, 8, 0, 10, 11, 12};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(tril_neg_k_test)
{
migraphx::program p = migraphx::parse_onnx("tril_neg_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
......@@ -2266,9 +2300,9 @@ TEST_CASE(trilu_lower_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(trilu_out_k_test)
TEST_CASE(triu_out_k_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_out_k_test.onnx");
migraphx::program p = migraphx::parse_onnx("triu_out_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
......@@ -2277,9 +2311,20 @@ TEST_CASE(trilu_out_k_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(trilu_row_one_test)
TEST_CASE(tril_out_k_test)
{
migraphx::program p = migraphx::parse_onnx("tril_out_k_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {3, 4}}, p);
std::vector<float> gold = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(triu_row_one_test)
{
migraphx::program p = migraphx::parse_onnx("trilu_row_one_test.onnx");
migraphx::program p = migraphx::parse_onnx("triu_row_one_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p);
......@@ -2288,4 +2333,15 @@ TEST_CASE(trilu_row_one_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(tril_row_one_test)
{
migraphx::program p = migraphx::parse_onnx("tril_row_one_test.onnx");
std::vector<float> result_vector = gen_trilu_test({migraphx::shape::float_type, {1, 4}}, p);
std::vector<float> gold = {1, 2, 0, 0};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -590,9 +590,6 @@ def disabled_tests_onnx_1_9_0(backend_test):
backend_test.exclude(r'test_gru_batchwise_cpu')
backend_test.exclude(r'test_lstm_batchwise_cpu')
backend_test.exclude(r'test_simple_rnn_batchwise_cpu')
backend_test.exclude(r'test_tril_cpu')
backend_test.exclude(r'test_tril_one_row_neg_cpu')
backend_test.exclude(r'test_tril_square_cpu')
# from OnnxBackendPyTorchConvertedModelTest
backend_test.exclude(r'test_MaxPool1d_stride_padding_dilation_cpu')
backend_test.exclude(r'test_MaxPool2d_stride_padding_dilation_cpu')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment