Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
c7096299
Commit
c7096299
authored
May 17, 2022
by
turneram
Browse files
Use parse_layernorm to un-fuse layernorm op
parent
ebfbae82
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
11 deletions
+28
-11
src/onnx/parse_layernorm.cpp
src/onnx/parse_layernorm.cpp
+27
-9
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
...ets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
+0
-1
No files found.
src/onnx/parse_layernorm.cpp
View file @
c7096299
...
@@ -2,6 +2,8 @@
...
@@ -2,6 +2,8 @@
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/op/layernorm.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -16,7 +18,14 @@ struct parse_layernorm : op_parser<parse_layernorm>
...
@@ -16,7 +18,14 @@ struct parse_layernorm : op_parser<parse_layernorm>
onnx_parser
::
node_info
info
,
onnx_parser
::
node_info
info
,
const
std
::
vector
<
instruction_ref
>&
args
)
const
const
std
::
vector
<
instruction_ref
>&
args
)
const
{
{
float
epsilon
=
1e-3
f
;
// un-fuse layernorm op so migraphx can handle fusion instead
auto
x
=
args
.
front
();
auto
x_type
=
x
->
get_shape
().
type
();
auto
weights
=
args
.
at
(
1
);
auto
bias
=
args
.
at
(
2
);
float
epsilon
=
1e-12
f
;
int64_t
axis
=
-
1
;
int64_t
axis
=
-
1
;
if
(
contains
(
info
.
attributes
,
"epsilon"
))
if
(
contains
(
info
.
attributes
,
"epsilon"
))
{
{
...
@@ -26,16 +35,25 @@ struct parse_layernorm : op_parser<parse_layernorm>
...
@@ -26,16 +35,25 @@ struct parse_layernorm : op_parser<parse_layernorm>
{
{
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
int64_t
>
();
epsilon
=
parser
.
parse_value
(
info
.
attributes
.
at
(
"axis"
)).
at
<
int64_t
>
();
}
}
auto
epsilon_lit
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
epsilon
}});
auto
exponent
=
info
.
add_literal
(
literal
{
shape
{
x_type
,
{
1
}},
{
2.0
}});
auto
dims
=
x
->
get_shape
().
lens
();
auto
layernorm
=
info
.
add_instruction
(
auto
mean
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
axis
}}}),
x
);
make_op
(
"layernorm"
,
{{
"epsilon"
,
epsilon
},
{
"axis"
,
axis
}}),
args
.
front
());
auto
mean_mbcast
=
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
mean
);
if
(
args
.
size
()
>=
2
)
auto
sub
=
info
.
add_instruction
(
migraphx
::
make_op
(
"sub"
),
x
,
mean_mbcast
);
layernorm
=
info
.
add_broadcastable_binary_op
(
"mul"
,
layernorm
,
args
.
at
(
1
));
auto
exponent_mbcast
=
if
(
args
.
size
()
==
3
)
info
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
exponent
);
layernorm
=
info
.
add_broadcastable_binary_op
(
"add"
,
layernorm
,
args
.
at
(
2
));
auto
pow
=
info
.
add_instruction
(
migraphx
::
make_op
(
"pow"
),
sub
,
exponent_mbcast
);
auto
var
=
info
.
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
axis
}}}),
pow
);
auto
add_epsilon
=
info
.
add_broadcastable_binary_op
(
"add"
,
var
,
epsilon_lit
);
auto
sqrt
=
info
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
add_epsilon
);
auto
div
=
info
.
add_broadcastable_binary_op
(
"div"
,
sub
,
sqrt
);
auto
mul
=
info
.
add_broadcastable_binary_op
(
"mul"
,
div
,
weights
);
return
layernorm
;
return
info
.
add_broadcastable_binary_op
(
"add"
,
mul
,
bias
)
;
}
}
};
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/transposectx.hpp
View file @
c7096299
...
@@ -28,7 +28,7 @@ __device__ void transposectx(const T& input_t, const U& output_t)
...
@@ -28,7 +28,7 @@ __device__ void transposectx(const T& input_t, const U& output_t)
const
int
NHS
=
NH
*
sequence_length
;
const
int
NHS
=
NH
*
sequence_length
;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
const
int
out_offset
=
n
*
head_size
+
s
*
NH
+
b
*
NHS
;
if
(
index
.
lo
c
al
<
1024
)
if
(
index
.
g
lo
b
al
<
input_shape
.
elements
()
)
output_t
[
out_offset
+
idx
[
3
]]
=
input_t
[
index
.
global
];
output_t
[
out_offset
+
idx
[
3
]]
=
input_t
[
index
.
global
];
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/transposeqkv.hpp
View file @
c7096299
...
@@ -23,7 +23,6 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
...
@@ -23,7 +23,6 @@ __device__ void transposeqkv(const T& input_t, const U& output_t)
const
int
s
=
idx
[
1
];
const
int
s
=
idx
[
1
];
const
int
m
=
idx
[
2
];
const
int
m
=
idx
[
2
];
const
int
n
=
idx
[
3
];
const
int
n
=
idx
[
3
];
// const int j = idx[4];
const
int
num_heads
=
lens
[
3
];
const
int
num_heads
=
lens
[
3
];
const
int
sequence_length
=
lens
[
1
];
const
int
sequence_length
=
lens
[
1
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment