Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
de99db23
Commit
de99db23
authored
Mar 29, 2022
by
Shucai Xiao
Browse files
simplify the layernorm kernel arguments
parent
780fffc8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
17 deletions
+9
-17
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+5
-1
src/targets/gpu/device/layernorm.cpp
src/targets/gpu/device/layernorm.cpp
+4
-16
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
de99db23
...
@@ -176,7 +176,11 @@ template <index_int N, class T, class... Ts>
...
@@ -176,7 +176,11 @@ template <index_int N, class T, class... Ts>
auto
hip_vec_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
auto
hip_vec_visit_all
(
T
&&
x
,
Ts
&&
...
xs
)
{
{
return
[
&
](
auto
f
)
{
return
[
&
](
auto
f
)
{
hip_visit_all_impl
(
get_shape
(
x
),
auto
sx
=
get_shape
(
x
);
auto
lens
=
sx
.
lens
();
lens
.
back
()
/=
N
;
shape
ssx
{
sx
.
type
(),
lens
};
hip_visit_all_impl
(
ssx
,
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));
}),
make_hip_convert
([](
auto
*
p
)
{
return
as_vec
<
N
>
(
device_cast
(
p
));
}),
f
,
f
,
x
,
x
,
...
...
src/targets/gpu/device/layernorm.cpp
View file @
de99db23
...
@@ -81,16 +81,14 @@ __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f)
...
@@ -81,16 +81,14 @@ __device__ auto auto_block_reduce(index idx, Op op, T init, index_int n, F f)
}
}
template
<
index_int
MaxBlockSize
,
class
Input
,
class
Output
>
template
<
index_int
MaxBlockSize
,
class
Input
,
class
Output
>
__device__
void
layernorm
(
index_int
i
,
__device__
void
layernorm
(
index
idx
,
index
idx
,
std
::
size_t
block_size_div
,
index_int
relements
,
index_int
relements
,
Input
input
,
Input
input
,
Output
output
)
Output
output
)
{
{
using
value_type
=
decltype
(
input
(
idx
.
local
));
using
value_type
=
decltype
(
input
(
idx
.
local
));
const
auto
relements_v
=
relements
/
vector_size
<
value_type
>
{};
const
auto
relements_v
=
relements
/
vector_size
<
value_type
>
{};
const
auto
out_idx
=
fast_div
(
i
,
block_size_div
)
;
const
auto
out_idx
=
blockIdx
.
x
;
const
auto
base_idx
=
out_idx
*
relements_v
;
const
auto
base_idx
=
out_idx
*
relements_v
;
const
auto
input_idx
=
base_idx
+
idx
.
local
;
const
auto
input_idx
=
base_idx
+
idx
.
local
;
const
bool
in_range
=
idx
.
local
<
relements_v
;
const
bool
in_range
=
idx
.
local
<
relements_v
;
...
@@ -133,14 +131,11 @@ void layernorm_vec_impl(hipStream_t stream,
...
@@ -133,14 +131,11 @@ void layernorm_vec_impl(hipStream_t stream,
const
auto
relements_v
=
relements
/
N
;
const
auto
relements_v
=
relements
/
N
;
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
relements_v
,
max_block_size
);
const
std
::
size_t
block_size
=
compute_block_size
(
relements_v
,
max_block_size
);
const
std
::
size_t
block_size_div
=
encode_divisor
(
block_size
);
assert
(
relements_v
<=
block_size
);
assert
(
relements_v
<=
block_size
);
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
,
auto
idx
)
__device__
{
layernorm
<
max_block_size
>
(
layernorm
<
max_block_size
>
(
i
,
idx
,
idx
,
block_size_div
,
relements
,
relements
,
[
&
](
auto
input_idx
)
{
return
in
(
inputs
.
data
()[
input_idx
]...);
},
[
&
](
auto
input_idx
)
{
return
in
(
inputs
.
data
()[
input_idx
]...);
},
[
&
](
auto
input_idx
,
auto
x
)
{
[
&
](
auto
input_idx
,
auto
x
)
{
...
@@ -162,14 +157,11 @@ void layernorm_impl(hipStream_t stream,
...
@@ -162,14 +157,11 @@ void layernorm_impl(hipStream_t stream,
hip_visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
hip_visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
relements
,
max_block_size
);
const
std
::
size_t
block_size
=
compute_block_size
(
relements
,
max_block_size
);
const
std
::
size_t
block_size_div
=
encode_divisor
(
block_size
);
assert
(
relements
<=
block_size
);
assert
(
relements
<=
block_size
);
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
,
auto
idx
)
__device__
{
layernorm
<
max_block_size
>
(
layernorm
<
max_block_size
>
(
i
,
idx
,
idx
,
block_size_div
,
relements
,
relements
,
[
&
](
auto
input_idx
)
{
return
in
(
inputs
.
data
()[
input_idx
]...);
},
[
&
](
auto
input_idx
)
{
return
in
(
inputs
.
data
()[
input_idx
]...);
},
[
&
](
auto
input_idx
,
auto
x
)
{
[
&
](
auto
input_idx
,
auto
x
)
{
...
@@ -188,10 +180,6 @@ auto layernorm_fusion(hipStream_t stream,
...
@@ -188,10 +180,6 @@ auto layernorm_fusion(hipStream_t stream,
return
[
=
](
auto
input
,
auto
output
)
{
return
[
=
](
auto
input
,
auto
output
)
{
auto
relements
=
arg1
.
get_shape
().
lens
().
back
();
auto
relements
=
arg1
.
get_shape
().
lens
().
back
();
auto
nelements
=
result
.
get_shape
().
elements
()
/
relements
;
auto
nelements
=
result
.
get_shape
().
elements
()
/
relements
;
// auto output_shape = result.get_shape();
// auto reduce_output_lens(output_shape.lens());
// reduce_output_lens.back() = 1;
if
((
relements
%
4
)
==
0
)
if
((
relements
%
4
)
==
0
)
layernorm_vec_impl
<
4
>
(
layernorm_vec_impl
<
4
>
(
stream
,
nelements
,
relements
,
input
,
output
,
result
,
arg1
,
args
...);
stream
,
nelements
,
relements
,
input
,
output
,
result
,
arg1
,
args
...);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment