Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
8c3d43cf
Commit
8c3d43cf
authored
Oct 16, 2024
by
rocking
Browse files
Fix bug of padding
parent
629257f9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
24 deletions
+13
-24
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+13
-24
No files found.
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
8c3d43cf
...
...
@@ -125,29 +125,17 @@ struct Layernorm2dFwd
return
out_dstr_tensor
;
}
CK_TILE_DEVICE
static
int
Get
WelfordMax
Count
(
int
N
)
CK_TILE_DEVICE
static
in
dex_
t
Get
LastloopIntraLaneReduce
Count
(
in
dex_
t
N
)
{
constexpr
ck_tile
::
index_t
kNThreadPerBlock
=
kNPerBlock
/
kNPerThread
;
constexpr
ck_tile
::
index_t
kNThreadSliceSize
=
kNPerThread
*
kNRepeat
;
constexpr
ck_tile
::
index_t
kNThreadStepSize
=
kNThreadPerBlock
*
kNPerThread
;
int
thread_id_n
=
get_thread_id
()
%
kNThreadPerBlock
;
int
max_count
=
__builtin_amdgcn_readfirstlane
(
N
<
kNPerBlock
?
0
:
kNThreadSliceSize
*
(
N
/
kNPerBlock
));
int
n_per_block_tail_loop
=
__builtin_amdgcn_readfirstlane
(
N
-
max_count
*
kNThreadPerBlock
);
if
(
n_per_block_tail_loop
>
0
)
{
static_for
<
0
,
kNRepeat
,
1
>
{}([
&
](
auto
i
)
{
int
thread_max_n
=
(
thread_id_n
+
1
)
*
kNPerThread
+
kNThreadStepSize
*
i
;
int
delta
=
thread_max_n
-
n_per_block_tail_loop
;
delta
=
clamp
(
thread_max_n
-
n_per_block_tail_loop
,
0
,
kNPerThread
);
max_count
+=
kNPerThread
-
delta
;
});
}
return
max_count
;
using
S
=
typename
Problem
::
BlockShape
;
index_t
LastloopN
=
N
%
kNPerBlock
==
0
?
kNPerBlock
:
N
%
kNPerBlock
;
constexpr
index_t
NThread
=
S
::
kNWarpPerBlock
*
S
::
kNThreadPerWarp
;
index_t
iNLane
=
get_thread_id
()
%
NThread
;
index_t
iN0
=
LastloopN
/
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
);
index_t
iN1
=
(
LastloopN
%
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
))
/
S
::
kNPerThread
;
index_t
N2
=
(
LastloopN
%
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
))
%
S
::
kNPerThread
;
index_t
iN3
=
iNLane
<
iN1
?
S
::
kNPerThread
:
iNLane
==
iN1
?
N2
:
0
;
return
iN0
*
S
::
kNPerThread
+
iN3
;
}
template
<
typename
XBlockWindow
,
...
...
@@ -167,7 +155,7 @@ struct Layernorm2dFwd
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
int
welford_max_count
=
Get
WelfordMax
Count
(
N
);
in
dex_
t
welford_max_count
=
Get
LastloopIntraLaneReduce
Count
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
...
...
@@ -244,7 +232,8 @@ struct Layernorm2dFwd
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
N
,
kNPerBlock
));
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
index_t
intra_thread_count
=
kNRepeat
*
kNPerThread
*
(
num_n_tile_iteration
-
1
);
index_t
welford_max_count
=
intra_thread_count
+
GetLastloopIntraLaneReduceCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment