Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
439f96bc
Commit
439f96bc
authored
Feb 15, 2023
by
charlie
Browse files
Bracket change
parent
c9497134
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
814 additions
and
818 deletions
+814
-818
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+814
-818
No files found.
test/ref_ops_test.cpp
View file @
439f96bc
...
...
@@ -7413,9 +7413,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
vector
<
float
>
results_vector
;
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
-
5
,
12
,
7
,
4
,
-
5
,
12
,
7
,
4
};
}
TEST_CASE(scatternd_reduction_dyn_test)
{
TEST_CASE
(
scatternd_reduction_dyn_test
)
{
// reduction = add, with dynamic input shapes
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -7437,10 +7438,9 @@ TEST_CASE(select_module_reduce_test1)
migraphx
::
parameter_map
params
;
migraphx
::
shape
input_fixed_shape0
{
migraphx
::
shape
::
float_type
,
{
4
,
4
,
4
}};
// data
std::vector<float> input_data{1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
1, 2, 3, 4, 5, 6, 7, 8, 8, 7, 6, 5, 4, 3, 2, 1,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8,
8, 7, 6, 5, 4, 3, 2, 1, 1, 2, 3, 4, 5, 6, 7, 8};
std
::
vector
<
float
>
input_data
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
std
::
vector
<
uint64_t
>
input_index
{
0
,
2
};
migraphx
::
shape
input_fixed_shape1
{
migraphx
::
shape
::
float_type
,
{
2
,
4
,
4
}};
// updates
std
::
vector
<
float
>
input_updates
{
5
,
5
,
5
,
5
,
6
,
6
,
6
,
6
,
7
,
7
,
7
,
7
,
8
,
8
,
8
,
8
,
...
...
@@ -7458,10 +7458,10 @@ TEST_CASE(select_module_reduce_test1)
9
,
8
,
7
,
6
,
6
,
5
,
4
,
3
,
4
,
5
,
6
,
7
,
9
,
10
,
11
,
12
,
8
,
7
,
6
,
5
,
4
,
3
,
2
,
1
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sigmoid_test)
{
TEST_CASE
(
sigmoid_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
...
...
@@ -7473,10 +7473,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
sigmoid
(
-
1
),
sigmoid
(
2
),
sigmoid
(
-
3
),
sigmoid
(
4
)};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sigmoid_dyn_test)
{
TEST_CASE
(
sigmoid_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
2
,
4
,
0
},
{
2
,
2
,
0
}}};
...
...
@@ -7493,10 +7493,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
sigmoid
(
-
1
),
sigmoid
(
2
),
sigmoid
(
-
3
),
sigmoid
(
4
)};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sign_test)
{
TEST_CASE
(
sign_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
5
}};
...
...
@@ -7509,10 +7509,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.0
,
1.0
,
-
1.0
,
-
1.0
,
0.0
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sign_dyn_test)
{
TEST_CASE
(
sign_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
::
dynamic_dimension
dd
{
3
,
8
,
0
};
...
...
@@ -7530,10 +7530,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
1.0
,
1.0
,
-
1.0
,
-
1.0
,
0.0
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sin_test)
{
TEST_CASE
(
sin_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
...
...
@@ -7548,10 +7548,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
sinf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sin_dyn_test)
{
TEST_CASE
(
sin_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
::
dynamic_dimension
dd
{
3
,
8
,
0
};
...
...
@@ -7571,10 +7571,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
sinf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sinh_test)
{
TEST_CASE
(
sinh_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
...
...
@@ -7589,10 +7589,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
sinhf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sinh_dynamic_test)
{
TEST_CASE
(
sinh_dynamic_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
2
,
4
,
0
},
{
2
,
4
,
0
}}};
...
...
@@ -7611,10 +7611,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
sinhf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(slice_test)
{
TEST_CASE
(
slice_test
)
{
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -7643,8 +7643,8 @@ TEST_CASE(select_module_reduce_test1)
migraphx
::
shape
s
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
3
}};
auto
l0
=
mm
->
add_literal
(
migraphx
::
literal
{
s
,
data
});
mm
->
add_instruction
(
migraphx::make_op(
"slice",
{{"axes", {0, 1, 2}}, {"starts", {0, 0, 0}}, {"ends", {2, 2, 2}}}),
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
,
1
,
2
}},
{
"starts"
,
{
0
,
0
,
0
}},
{
"ends"
,
{
2
,
2
,
2
}}}),
l0
);
migraphx
::
shape
s2
{
migraphx
::
shape
::
int32_type
,
{
2
,
2
,
2
},
{
6
,
3
,
1
}};
EXPECT
(
p
.
get_output_shapes
().
back
()
==
s2
);
...
...
@@ -7657,10 +7657,10 @@ TEST_CASE(select_module_reduce_test1)
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
EXPECT
(
result
.
get_shape
()
==
sresult
);
}
}
}
TEST_CASE(softmax_simple_test)
{
TEST_CASE
(
softmax_simple_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
a
=
{
0.25
,
0.75
};
...
...
@@ -7673,10 +7673,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
vector
<
float
>
results_vector
(
2
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
}
TEST_CASE(softmax_test)
{
TEST_CASE
(
softmax_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
float
>
a
=
{
...
...
@@ -7733,10 +7733,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
vector
<
float
>
results_vector
(
120
);
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
}
TEST_CASE(softmax_dyn_test)
{
TEST_CASE
(
softmax_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
...
...
@@ -7796,10 +7796,10 @@ TEST_CASE(select_module_reduce_test1)
0.17377149
,
0.76075399
,
0.20071237
,
0.32632929
,
0.36892858
,
0.09416146
,
0.26656723
,
0.42914796
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
s
));
}
}
TEST_CASE(sqdiff_test)
{
TEST_CASE
(
sqdiff_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
...
...
@@ -7812,10 +7812,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
4
,
4
,
4
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sqdiff_dyn_test)
{
TEST_CASE
(
sqdiff_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dd
{{
2
,
6
,
0
}};
...
...
@@ -7836,10 +7836,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
4
,
4
,
4
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sqrt_test)
{
TEST_CASE
(
sqrt_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
5
}};
...
...
@@ -7854,10 +7854,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
sqrtf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sqrt_dynamic_test)
{
TEST_CASE
(
sqrt_dynamic_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
::
dynamic_dimension
dd
{
3
,
8
,
0
};
...
...
@@ -7877,10 +7877,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
sqrtf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(squeeze_test)
{
TEST_CASE
(
squeeze_test
)
{
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -7918,10 +7918,10 @@ TEST_CASE(select_module_reduce_test1)
auto
result
=
p
.
eval
({}).
back
();
EXPECT
(
result
.
get_shape
()
==
s2
);
}
}
}
TEST_CASE(squeeze_dyn_test)
{
TEST_CASE
(
squeeze_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s1
{
migraphx
::
shape
::
float_type
,
...
...
@@ -7937,10 +7937,10 @@ TEST_CASE(select_module_reduce_test1)
auto
result
=
p
.
eval
(
params0
).
back
();
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
3
,
1
,
3
}};
EXPECT
(
result
.
get_shape
()
==
s2
);
}
}
TEST_CASE(step_test)
{
TEST_CASE
(
step_test
)
{
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -7974,10 +7974,10 @@ TEST_CASE(select_module_reduce_test1)
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
2
,
1
}};
EXPECT
(
result
.
get_shape
()
==
s2
);
}
}
}
TEST_CASE(sub_test)
{
TEST_CASE
(
sub_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
...
...
@@ -7990,10 +7990,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
-
2
,
-
2
,
-
2
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(sub_dyn_test)
{
TEST_CASE
(
sub_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
migraphx
::
shape
::
dynamic_dimension
>
dd
{{
2
,
6
,
0
}};
...
...
@@ -8014,10 +8014,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
-
2
,
-
2
,
-
2
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(tan_test)
{
TEST_CASE
(
tan_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
3
}};
...
...
@@ -8032,10 +8032,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
tanf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(tan_dynamic_test)
{
TEST_CASE
(
tan_dynamic_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
::
dynamic_dimension
dd
{
3
,
8
,
0
};
...
...
@@ -8055,10 +8055,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
tanf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(tanh_test)
{
TEST_CASE
(
tanh_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
2
}};
...
...
@@ -8073,10 +8073,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
tanhf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(tanh_dynamic_test)
{
TEST_CASE
(
tanh_dynamic_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
::
dynamic_dimension
dd
{
3
,
8
,
0
};
...
...
@@ -8096,10 +8096,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
transform
(
gold
.
begin
(),
gold
.
end
(),
gold
.
begin
(),
[](
float
n
)
->
float
{
return
tanhf
(
n
);
});
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(topk_test)
{
TEST_CASE
(
topk_test
)
{
auto
create_program
=
[](
int64_t
k
,
int64_t
axis
,
int
largest
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -8148,10 +8148,10 @@ TEST_CASE(select_module_reduce_test1)
std
::
vector
<
int64_t
>
gold_ind
=
{
4
,
2
,
0
,
1
,
3
,
1
,
4
,
0
,
3
,
0
,
4
,
2
};
EXPECT
(
results
.
second
==
gold_ind
);
}
}
}
TEST_CASE(transpose_test)
{
TEST_CASE
(
transpose_test
)
{
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
1
,
2
,
2
,
3
}};
std
::
vector
<
float
>
data
(
12
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
...
...
@@ -8177,19 +8177,17 @@ TEST_CASE(select_module_reduce_test1)
auto
result2
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
results_vector
(
12
);
result2.visit(
[&](auto output) { results_vector.assign(output.begin(), output.end()); });
result2
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
3
,
6
,
9
,
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
}
TEST_CASE(transpose_dyn_test)
{
TEST_CASE
(
transpose_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx::shape s{migraphx::shape::float_type,
{{1, 4, 0}, {2, 2, 0}, {2, 2, 0}, {3, 3, 0}}};
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{{
1
,
4
,
0
},
{
2
,
2
,
0
},
{
2
,
2
,
0
},
{
3
,
3
,
0
}}};
auto
l
=
mm
->
add_parameter
(
"X"
,
s
);
std
::
vector
<
int64_t
>
perm
=
{
0
,
3
,
1
,
2
};
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
l
);
...
...
@@ -8209,10 +8207,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
=
{
0
,
3
,
6
,
9
,
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(unsqueeze_test)
{
TEST_CASE
(
unsqueeze_test
)
{
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -8237,10 +8235,10 @@ TEST_CASE(select_module_reduce_test1)
auto
result
=
p
.
eval
({}).
back
();
EXPECT
(
result
.
get_shape
()
==
s2
);
}
}
}
TEST_CASE(unsqueeze_dyn_test)
{
TEST_CASE
(
unsqueeze_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
...
...
@@ -8256,10 +8254,10 @@ TEST_CASE(select_module_reduce_test1)
auto
result
=
p
.
eval
(
params0
).
back
();
migraphx
::
shape
s2
{
migraphx
::
shape
::
float_type
,
{
4
,
1
,
3
,
3
}};
EXPECT
(
result
.
get_shape
()
==
s2
);
}
}
TEST_CASE(where_test)
{
TEST_CASE
(
where_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sb
{
migraphx
::
shape
::
bool_type
,
{
3
,
3
}};
...
...
@@ -8283,10 +8281,10 @@ TEST_CASE(select_module_reduce_test1)
gold
[
i
]
=
b
[
i
]
?
x
[
i
]
:
y
[
i
];
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
gold
));
}
}
TEST_CASE(where_dyn_test)
{
TEST_CASE
(
where_dyn_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sb
{
migraphx
::
shape
::
bool_type
,
{{
2
,
3
,
0
},
{
2
,
3
,
0
}}};
...
...
@@ -8314,10 +8312,10 @@ TEST_CASE(select_module_reduce_test1)
result
.
visit
([
&
](
auto
output
)
{
results_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
{
1
,
1
,
1
,
2
,
2
,
2
,
1
,
2
,
1
};
EXPECT
(
migraphx
::
verify_range
(
results_vector
,
gold
));
}
}
TEST_CASE(where_broadcasted_inputs_test)
{
TEST_CASE
(
where_broadcasted_inputs_test
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sb
{
migraphx
::
shape
::
bool_type
,
{
3
,
3
}};
...
...
@@ -8327,10 +8325,8 @@ TEST_CASE(select_module_reduce_test1)
auto
lb
=
mm
->
add_literal
(
migraphx
::
literal
{
sb
,
b
});
auto
lx
=
mm
->
add_literal
(
migraphx
::
literal
(
1.0
f
));
auto
ly
=
mm
->
add_literal
(
migraphx
::
literal
(
2.0
f
));
auto mbx =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), lx);
auto mby =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 3}}}), ly);
auto
mbx
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
3
}}}),
lx
);
auto
mby
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
3
,
3
}}}),
ly
);
auto
w
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"where"
),
lb
,
mbx
,
mby
);
mm
->
add_return
({
w
});
p
.
compile
(
migraphx
::
ref
::
target
{});
...
...
@@ -8344,6 +8340,6 @@ TEST_CASE(select_module_reduce_test1)
gold
[
i
]
=
b
[
i
]
?
x
[
i
]
:
y
[
i
];
EXPECT
(
migraphx
::
verify_range
(
result_vec
,
gold
));
}
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment