Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
095e49a3
Commit
095e49a3
authored
May 20, 2022
by
turneram
Browse files
Add transposectx and transposeqkv ref tests
parent
7757cfd0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
0 deletions
+61
-0
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+61
-0
No files found.
test/ref_ops_test.cpp
View file @
095e49a3
...
...
@@ -666,6 +666,67 @@ TEST_CASE(batch_norm_inference_test)
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
TEST_CASE
(
bert_transpose_ops_test
)
{
{
// transposeQKV
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
int
>
bsknh
{
2
,
384
,
3
,
12
,
64
};
const
int
elements
=
std
::
accumulate
(
bsknh
.
begin
(),
bsknh
.
end
(),
1
,
std
::
multiplies
<
int
>
());
migraphx
::
shape
sh
{
migraphx
::
shape
::
float_type
,
bsknh
};
std
::
vector
<
float
>
data
(
elements
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
l1
);
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
result_vector
(
elements
);
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
(
elements
);
migraphx
::
program
p2
;
auto
*
mm2
=
p2
.
get_main_module
();
auto
l2
=
mm2
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
// BSKNH->KBNSH : perm=2,0,3,1,4
mm2
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
2
,
0
,
3
,
1
,
4
}}}),
l2
);
p2
.
compile
(
migraphx
::
ref
::
target
{});
auto
result2
=
p2
.
eval
({}).
back
();
result2
.
visit
([
&
](
auto
output
)
{
gold
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
{
// transposeCtx
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
int
>
bnsh
{
2
,
12
,
384
,
64
};
const
int
elements
=
std
::
accumulate
(
bnsh
.
begin
(),
bnsh
.
end
(),
1
,
std
::
multiplies
<
int
>
());
migraphx
::
shape
sh
{
migraphx
::
shape
::
float_type
,
bnsh
};
std
::
vector
<
float
>
data
(
elements
);
std
::
iota
(
data
.
begin
(),
data
.
end
(),
0
);
auto
l1
=
mm
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"transposectx"
),
l1
);
p
.
compile
(
migraphx
::
ref
::
target
{});
auto
result
=
p
.
eval
({}).
back
();
std
::
vector
<
float
>
result_vector
(
elements
);
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
std
::
vector
<
float
>
gold
(
elements
);
migraphx
::
program
p2
;
auto
*
mm2
=
p2
.
get_main_module
();
auto
l2
=
mm2
->
add_literal
(
migraphx
::
literal
{
sh
,
data
});
// BNSH->BSNH : perm=0,2,1,3
mm2
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
l2
);
p2
.
compile
(
migraphx
::
ref
::
target
{});
auto
result2
=
p2
.
eval
({}).
back
();
result2
.
visit
([
&
](
auto
output
)
{
gold
.
assign
(
output
.
begin
(),
output
.
end
());
});
EXPECT
(
migraphx
::
verify_range
(
result_vector
,
gold
));
}
}
TEST_CASE
(
broadcast_test
)
{
migraphx
::
program
p
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment