Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
15acaee9
Unverified
Commit
15acaee9
authored
Sep 15, 2023
by
Umang Yadav
Committed by
GitHub
Sep 15, 2023
Browse files
Preserve layout of fused kernel for `layernorm+pointwise` (#2185)
parent
74ba9649
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
20 deletions
+43
-20
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+21
-18
test/verify/test_layernorm.cpp
test/verify/test_layernorm.cpp
+22
-2
No files found.
src/targets/gpu/prefuse_ops.cpp
View file @
15acaee9
...
@@ -21,6 +21,7 @@
...
@@ -21,6 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/check_shapes.hpp>
...
@@ -45,40 +46,42 @@ struct layernorm_base
...
@@ -45,40 +46,42 @@ struct layernorm_base
}
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
{
std
::
size_t
nargs
=
1
;
std
::
size_t
nargs
=
N
;
if
(
not
mods
.
empty
())
if
(
not
mods
.
empty
())
{
{
auto
*
pm
=
mods
.
front
();
auto
*
pm
=
mods
.
front
();
nargs
=
pm
->
get_parameter_names
().
size
();
nargs
+
=
pm
->
get_parameter_names
().
size
()
-
1
;
}
}
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
+
N
);
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
);
auto
s
=
inputs
.
a
t
(
0
);
auto
s
=
inputs
.
fron
t
();
auto
t
=
s
.
type
();
auto
t
=
s
.
type
();
if
(
not
mods
.
empty
())
if
(
not
mods
.
empty
())
t
=
mods
.
front
()
->
get_output_shapes
().
front
().
type
();
t
=
mods
.
front
()
->
get_output_shapes
().
front
().
type
();
if
(
s
.
scalar
())
{
// Scalar output if all inputs are scalar
return
s
;
if
(
inputs
.
front
().
elements
()
==
1
and
}
all_of
(
inputs
,
[](
const
auto
&
ss
)
{
return
ss
.
scalar
();
}))
else
if
(
s
.
broadcasted
())
return
inputs
.
front
();
{
auto
l_s
=
shape
::
from_permutation
(
return
{
t
,
s
.
lens
()};
t
,
s
.
lens
(),
find_permutation
(
std
::
vector
<
shape
>
(
inputs
.
begin
(),
inputs
.
begin
()
+
N
)));
}
// just prelayernorm or preadd_layernorm
else
if
(
nargs
<=
N
)
{
return
l_s
;
return
s
.
with_lens
(
t
,
s
.
lens
());
// else, layernorm + pointwise fusion, preserve layout of fused op
}
std
::
vector
<
shape
>
lp_s
(
inputs
.
begin
()
+
N
,
inputs
.
end
());
lp_s
.
insert
(
lp_s
.
begin
(),
l_s
);
return
shape
::
from_permutation
(
t
,
s
.
lens
(),
find_permutation
(
lp_s
));
}
}
};
};
struct
layernorm
:
layernorm_base
<
layernorm
,
0
>
struct
layernorm
:
layernorm_base
<
layernorm
,
1
>
{
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
};
};
MIGRAPHX_REGISTER_OP
(
layernorm
);
MIGRAPHX_REGISTER_OP
(
layernorm
);
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
1
>
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
2
>
{
{
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
};
};
...
...
test/verify/test_layernorm.cpp
View file @
15acaee9
...
@@ -49,7 +49,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
...
@@ -49,7 +49,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto
pow
=
m
.
add_instruction
(
migraphx
::
make_op
(
"pow"
),
sub
,
exponent_mbcast
);
auto
pow
=
m
.
add_instruction
(
migraphx
::
make_op
(
"pow"
),
sub
,
exponent_mbcast
);
auto
var
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
}}}),
pow
);
auto
var
=
m
.
add_instruction
(
migraphx
::
make_op
(
"reduce_mean"
,
{{
"axes"
,
{
2
}}}),
pow
);
auto
epsilon_mbcast
=
m
.
add_instruction
(
auto
epsilon_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
1
,
dims
.
at
(
1
),
1
}}}),
epsilon
);
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
dims
.
at
(
0
),
dims
.
at
(
1
),
1
}}}),
epsilon
);
auto
add_epsilon
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
var
,
epsilon_mbcast
);
auto
add_epsilon
=
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
var
,
epsilon_mbcast
);
auto
sqrt
=
m
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
add_epsilon
);
auto
sqrt
=
m
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
add_epsilon
);
auto
sqrt_mbcast
=
auto
sqrt_mbcast
=
...
@@ -57,7 +58,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
...
@@ -57,7 +58,8 @@ migraphx::instruction_ref add_layernorm(migraphx::module& m,
auto
div
=
m
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
sub
,
sqrt_mbcast
);
auto
div
=
m
.
add_instruction
(
migraphx
::
make_op
(
"div"
),
sub
,
sqrt_mbcast
);
auto
scale_mbcast
=
auto
scale_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
scale
);
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
scale
);
auto
mul
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
scale_mbcast
,
div
);
auto
mul
=
m
.
add_instruction
(
migraphx
::
make_op
(
"mul"
),
div
,
scale_mbcast
);
auto
bias_mbcast
=
auto
bias_mbcast
=
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
bias
);
m
.
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
dims
}}),
bias
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
mul
,
bias_mbcast
);
return
m
.
add_instruction
(
migraphx
::
make_op
(
"add"
),
mul
,
bias_mbcast
);
...
@@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large>
...
@@ -161,3 +163,21 @@ struct test_layernorm_triadd_large : verify_program<test_layernorm_triadd_large>
return
p
;
return
p
;
}
}
};
};
struct
test_add_layernorm_add_gemm_nonstd
:
verify_program
<
test_add_layernorm_add_gemm_nonstd
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
s
=
migraphx
::
shape
::
from_permutation
(
migraphx
::
shape
::
float_type
,
{
8
,
1
,
16
},
{
1
,
2
,
0
});
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
auto
y
=
mm
->
add_parameter
(
"y"
,
s
);
auto
z
=
mm
->
add_parameter
(
"z"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
8
,
16
,
64
}});
auto
add
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
x
,
y
);
auto
layernorm_ins
=
add_layernorm
(
*
mm
,
add
,
s
.
lens
());
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
layernorm_ins
,
z
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment