Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
1974671d
Commit
1974671d
authored
May 17, 2022
by
turneram
Browse files
Formatting
parent
fe9a42f1
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
48 additions
and
46 deletions
+48
-46
src/include/migraphx/op/transposectx.hpp
src/include/migraphx/op/transposectx.hpp
+5
-5
src/include/migraphx/op/transposeqkv.hpp
src/include/migraphx/op/transposeqkv.hpp
+7
-6
src/onnx/parse_attention.cpp
src/onnx/parse_attention.cpp
+4
-8
src/targets/gpu/jit/bert_transpose.cpp
src/targets/gpu/jit/bert_transpose.cpp
+4
-2
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
+8
-8
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
+8
-8
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+12
-9
No files found.
src/include/migraphx/op/transposectx.hpp
View file @
1974671d
...
...
@@ -52,7 +52,7 @@ struct transposectx
const
int
NH
=
num_heads
*
head_size
;
const
int
NHS
=
NH
*
sequence_length
;
//const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
//
const int in_offset = s * head_size + n * sequence_length * head_size + b * NHS;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
...
...
src/include/migraphx/op/transposeqkv.hpp
View file @
1974671d
...
...
@@ -58,7 +58,8 @@ struct transposeqkv
const
int
NH
=
num_heads
*
H
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
const
int
out_offset
=
s
*
H
+
n
*
sequence_length
*
H
+
b
*
NHS
+
m
*
NHS
*
batch_size
;
output
[
out_offset
+
j
]
=
input
[
i
];
});
...
...
src/onnx/parse_attention.cpp
View file @
1974671d
...
...
@@ -79,10 +79,7 @@ struct parse_attention : op_parser<parse_attention>
auto
ones
=
info
.
add_literal
(
migraphx
::
literal
{
migraphx
::
shape
{
bias_type
,
ones_lens
},
ones_vec
});
bias
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
n
,
1
}}}),
bias
);
auto
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bias
,
ones
);
auto
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"dot"
),
bias
,
ones
);
gemm_1
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
1
,
0
}}}),
gemm_1
);
...
...
@@ -99,8 +96,7 @@ struct parse_attention : op_parser<parse_attention>
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch_size
,
sequence_length
,
3
,
num_heads
,
head_size
}}}),
add_gemms
);
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
add_gemms
);
auto
transqkv
=
info
.
add_instruction
(
migraphx
::
make_op
(
"transposeqkv"
),
add_gemms
);
// transqkv has shape 3xBxNxSxH
// => Q, K, V: each has size BxNxSxH
...
...
src/targets/gpu/jit/bert_transpose.cpp
View file @
1974671d
...
...
@@ -40,7 +40,8 @@ struct transposectx_compiler : compiler<transposectx_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"transposectx_kernel"
;
...
...
@@ -78,7 +79,8 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()),
inputs
.
front
().
lens
().
back
());
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"transposeqkv_kernel"
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
View file @
1974671d
...
...
@@ -28,7 +28,7 @@ __device__ void transposectx(const T& input_t, const U& output_t)
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
if
(
index
.
local
<
1024
)
if
(
index
.
local
<
1024
)
output_t
[
out_offset
+
idx
[
3
]]
=
input_t
[
index
.
global
];
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
View file @
1974671d
...
...
@@ -23,7 +23,7 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
const
int
s
=
idx
[
1
];
const
int
m
=
idx
[
2
];
const
int
n
=
idx
[
3
];
//const int j = idx[4];
//
const int j = idx[4];
const
int
num_heads
=
lens
[
3
];
const
int
sequence_length
=
lens
[
1
];
...
...
test/ref_ops_test.cpp
View file @
1974671d
...
...
@@ -687,13 +687,14 @@ TEST_CASE(bert_transpose_ops_test)
migraphx::program p2;
auto* mm2 = p2.get_main_module();
auto l2 = mm2->add_literal(migraphx::literal{sh, data});
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}), l2);
mm2->add_instruction(migraphx::make_op("transpose", {{"permutation", {2, 0, 3, 1, 4}}}),
l2);
p2.compile(migraphx::ref::target{});
auto result2 = p2.eval({}).back();
std::vector<float> result_vector2(k * b * n * s * h);
result2.visit([&](auto output) { result_vector2.assign(output.begin(), output.end()); });
for
(auto& i : result_vector2)
for(auto& i : result_vector2)
std::cout << i << ", ";
std::cout << std::endl;
...
...
@@ -730,7 +731,9 @@ TEST_CASE(bert_transpose_ops_test)
auto result = p.eval({}).back();
std::vector<float> result_vector(b * n * s * h);
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<float> gold{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 36, 37, 38, 39, 28, 29, 30, 31, 40, 41, 42, 43, 32, 33, 34, 35, 44, 45, 46, 47};
std::vector<float> gold{0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19,
8, 9, 10, 11, 20, 21, 22, 23, 24, 25, 26, 27, 36, 37, 38, 39,
28, 29, 30, 31, 40, 41, 42, 43, 32, 33, 34, 35, 44, 45, 46, 47};
EXPECT(migraphx::verify_range(result_vector, gold));
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment