Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
02236580
"...composable_kernel-1.git" did not exist on "ccc4a1d365999a3e15623f490314e66c2d671389"
Commit
02236580
authored
Oct 14, 2024
by
rocking
Browse files
refine welford max count calculation
parent
96568141
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
47 deletions
+28
-47
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+28
-47
No files found.
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
02236580
...
@@ -31,9 +31,9 @@ struct Layernorm2dFwd
...
@@ -31,9 +31,9 @@ struct Layernorm2dFwd
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
bool
kPadM
=
false
;
// TODO - Problem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO -
BlockLayernorm2dFwd
Problem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
BlockShape
::
kNPerThread
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
BlockShape
::
kNPerThread
;
...
@@ -106,21 +106,6 @@ struct Layernorm2dFwd
...
@@ -106,21 +106,6 @@ struct Layernorm2dFwd
sequence
<
0
,
3
>>
{});
sequence
<
0
,
3
>>
{});
}
}
template
<
typename
Dstr
>
CK_TILE_DEVICE
static
constexpr
auto
GetNPerThread
(
Dstr
)
{
constexpr
auto
nDstrSpan
=
Dstr
::
get_distributed_spans
().
template
at
<
1
>();
using
Lengths
=
decltype
(
nDstrSpan
.
impl_
);
ck_tile
::
index_t
ret
=
1
;
ck_tile
::
static_for
<
0
,
Lengths
::
size
(),
1
>
{}(
[
&
](
auto
idx
)
{
ret
*=
Lengths
::
template
at
(
idx
);
});
return
ret
;
}
template
<
typename
DistributedTensor
>
template
<
typename
DistributedTensor
>
CK_TILE_DEVICE
static
auto
InvSqrt
(
const
DistributedTensor
&
in_dstr_tensor
,
CK_TILE_DEVICE
static
auto
InvSqrt
(
const
DistributedTensor
&
in_dstr_tensor
,
const
ComputeDataType
epsilon
)
const
ComputeDataType
epsilon
)
...
@@ -139,20 +124,25 @@ struct Layernorm2dFwd
...
@@ -139,20 +124,25 @@ struct Layernorm2dFwd
return
out_dstr_tensor
;
return
out_dstr_tensor
;
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
CK_TILE_DEVICE
static
int
GetWelfordMaxCount
(
int
N
)
GetLastloopLayerNormIntraLaneReduceCount
(
index_t
NLength
)
{
{
using
S
=
typename
Problem
::
BlockShape
;
constexpr
ck_tile
::
index_t
kNThreadPerBlock
=
kNPerBlock
/
kNPerThread
;
// S::kNWarpPerBlock, S::kNThreadPerWarp, S::kNPerThread
auto
LastloopN
=
NLength
%
kNPerBlock
==
0
?
kNPerBlock
:
NLength
%
kNPerBlock
;
int
thread_id_n
=
get_thread_id
()
%
kNThreadPerBlock
;
constexpr
auto
NThread
=
S
::
kNWarpPerBlock
*
S
::
kNThreadPerWarp
;
int
max_count
=
auto
iNLane
=
get_thread_local_1d_id
()
%
NThread
;
__builtin_amdgcn_readfirstlane
(
N
<
kNPerBlock
?
0
:
kNPerThread
*
(
N
/
kNPerBlock
));
auto
iN0
=
LastloopN
/
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
);
int
n_per_block_tail_loop
=
auto
iN1
=
(
LastloopN
%
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
))
/
S
::
kNPerThread
;
__builtin_amdgcn_readfirstlane
(
N
-
max_count
*
kNThreadPerBlock
);
auto
N2
=
(
LastloopN
%
(
S
::
kNPerThread
*
S
::
kNThreadPerWarp
))
%
S
::
kNPerThread
;
auto
iN3
=
iNLane
<
iN1
?
S
::
kNPerThread
:
iNLane
==
iN1
?
N2
:
0
;
if
(
n_per_block_tail_loop
>
0
)
{
return
iN0
*
S
::
kNPerThread
+
iN3
;
int
thread_max_n
=
(
thread_id_n
+
1
)
*
kNPerThread
;
int
delta
=
thread_max_n
-
n_per_block_tail_loop
;
delta
=
clamp
(
thread_max_n
-
n_per_block_tail_loop
,
0
,
kNPerThread
);
max_count
+=
kNPerThread
-
delta
;
}
return
max_count
;
}
}
template
<
typename
XBlockWindow
,
template
<
typename
XBlockWindow
,
...
@@ -172,8 +162,8 @@ struct Layernorm2dFwd
...
@@ -172,8 +162,8 @@ struct Layernorm2dFwd
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
ck_tile
::
index_t
N
)
const
{
{
auto
intra_thread
_count
_last
=
Get
LastloopLayerNormIntraLaneReduce
Count
(
N
);
int
welford_max
_count
=
Get
WelfordMax
Count
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
intra_thread_count_las
t
};
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_coun
t
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
auto
mean_compute_block_tensor
=
...
@@ -246,15 +236,11 @@ struct Layernorm2dFwd
...
@@ -246,15 +236,11 @@ struct Layernorm2dFwd
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
ck_tile
::
index_t
N
)
const
{
{
using
S
=
typename
Problem
::
BlockShape
;
index_t
num_n_tile_iteration
=
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
((
N
+
kNPerBlock
-
1
)
/
kNPerBlock
);
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
N
,
kNPerBlock
));
auto
intra_thread_count
=
S
::
kNRepeat
*
S
::
kNPerThread
*
(
num_n_tile_iteration
-
1
);
auto
intra_thread_count_last
=
GetLastloopLayerNormIntraLaneReduceCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
intra_thread_c
ount
}
;
int
welford_max_count
=
GetWelfordMaxC
ount
(
N
)
;
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
_last
{
intra_thread_count_las
t
};
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_coun
t
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
auto
mean_compute_block_tensor
=
...
@@ -265,19 +251,13 @@ struct Layernorm2dFwd
...
@@ -265,19 +251,13 @@ struct Layernorm2dFwd
clear_tile
(
mean_compute_block_tensor
);
clear_tile
(
mean_compute_block_tensor
);
clear_tile
(
var_compute_block_tensor
);
clear_tile
(
var_compute_block_tensor
);
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
-
1
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
move_tile_window
(
x_block_window
,
{
0
,
kNPerBlock
});
move_tile_window
(
x_block_window
,
{
0
,
kNPerBlock
});
}
}
const
auto
x_block_tensor_
=
load_tile
(
x_block_window
);
thread_welford_last
.
cur_count_
+=
intra_thread_count
;
thread_welford_last
.
max_count_
+=
intra_thread_count
;
thread_welford_last
(
x_block_tensor_
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
thread_welford
.
cur_count_
+=
intra_thread_count_last
;
// TODO: support cross warp Welford
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
...
@@ -295,6 +275,7 @@ struct Layernorm2dFwd
...
@@ -295,6 +275,7 @@ struct Layernorm2dFwd
ck_tile
::
index_t
stride_to_right_most_window
=
ck_tile
::
index_t
stride_to_right_most_window
=
N
%
kNPerBlock
==
0
?
N
-
kNPerBlock
:
N
-
N
%
kNPerBlock
;
N
%
kNPerBlock
==
0
?
N
-
kNPerBlock
:
N
-
N
%
kNPerBlock
;
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_block_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_block_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_block_window
,
{
0
,
stride_to_right_most_window
});
move_tile_window
(
y_block_window
,
{
0
,
stride_to_right_most_window
});
...
...
gaoqiong
@gaoqiong
mentioned in commit
4071440c
·
Feb 18, 2025
mentioned in commit
4071440c
mentioned in commit 4071440c125d4d7af31c9f21a18702ee13263bff
Toggle commit list
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment