Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
18e65656
Commit
18e65656
authored
Dec 23, 2022
by
rocking
Browse files
Refine naming
parent
44b66c41
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
19 deletions
+19
-19
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+19
-19
No files found.
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
18e65656
...
@@ -96,8 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -96,8 +96,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
HDataType
*
__restrict__
p_h_grid
,
HDataType
*
__restrict__
p_h_grid
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
e_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
EHGridDesc_M_N
&
h_grid_desc_m_n
,
const
MeanVarGridDesc_M_NBlock
&
mean_var_grid_desc_m_n
,
const
MeanVarGridDesc_M_NBlock
&
mean_var_grid_desc_m_n
block
,
const
CountGridDesc_M_NBlock
&
count_grid_desc_m_n
,
const
CountGridDesc_M_NBlock
&
count_grid_desc_m_n
block
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
numMeanVarCountBlockTileIteration_N
,
...
@@ -121,13 +121,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -121,13 +121,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
p_e_grid
,
e_grid_desc_m_n
.
GetElementSpaceSize
());
p_e_grid
,
e_grid_desc_m_n
.
GetElementSpaceSize
());
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_mean_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_mean_grid
,
mean_var_grid_desc_m_n
.
GetElementSpaceSize
());
p_in_welford_mean_grid
,
mean_var_grid_desc_m_n
block
.
GetElementSpaceSize
());
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_var_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_var_grid
,
mean_var_grid_desc_m_n
.
GetElementSpaceSize
());
p_in_welford_var_grid
,
mean_var_grid_desc_m_n
block
.
GetElementSpaceSize
());
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
welford_count_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_welford_count_grid
,
count_grid_desc_m_n
.
GetElementSpaceSize
());
p_in_welford_count_grid
,
count_grid_desc_m_n
block
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_grid
,
gamma_grid_desc_n
.
GetElementSpaceSize
());
p_gamma_grid
,
gamma_grid_desc_n
.
GetElementSpaceSize
());
...
@@ -186,7 +186,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -186,7 +186,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
1
,
1
,
1
,
true
>
(
true
>
(
mean_var_grid_desc_m_n
,
mean_var_grid_desc_m_n
block
,
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
...
@@ -202,7 +202,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -202,7 +202,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
1
,
1
,
1
,
true
>
(
true
>
(
mean_var_grid_desc_m_n
,
mean_var_grid_desc_m_n
block
,
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
...
@@ -218,7 +218,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -218,7 +218,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
1
,
1
,
1
,
true
>
(
true
>
(
count_grid_desc_m_n
,
count_grid_desc_m_n
block
,
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
thread_n_cluster_id
));
...
@@ -289,8 +289,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -289,8 +289,8 @@ struct GridwiseWelfordSecondHalfLayernorm2d
h_element_op
);
h_element_op
);
// step1: Merge mean and variance
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_0_n
=
constexpr
auto
mean_var_count_thread_copy_step_
I
0_n
=
make_multi_index
(
0
,
NThreadClusterSize
);
make_multi_index
(
I
0
,
NThreadClusterSize
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
welford_mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
welford_mean_thread_buf
(
I
)
=
type_convert
<
ComputeDataType
>
(
0.0
f
);
...
@@ -300,19 +300,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -300,19 +300,19 @@ struct GridwiseWelfordSecondHalfLayernorm2d
for
(
index_t
n
=
0
;
n
<
numMeanVarCountBlockTileIteration_N
;
++
n
)
for
(
index_t
n
=
0
;
n
<
numMeanVarCountBlockTileIteration_N
;
++
n
)
{
{
threadwise_mean_load_m_nblock
.
Run
(
mean_var_grid_desc_m_n
,
threadwise_mean_load_m_nblock
.
Run
(
mean_var_grid_desc_m_n
block
,
welford_mean_global_val_buf
,
welford_mean_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_welford_mean_thread_buf
);
in_welford_mean_thread_buf
);
threadwise_var_load_m_nblock
.
Run
(
mean_var_grid_desc_m_n
,
threadwise_var_load_m_nblock
.
Run
(
mean_var_grid_desc_m_n
block
,
welford_var_global_val_buf
,
welford_var_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
in_welford_var_thread_buf
);
in_welford_var_thread_buf
);
threadwise_count_load_m_nblock
.
Run
(
count_grid_desc_m_n
,
threadwise_count_load_m_nblock
.
Run
(
count_grid_desc_m_n
block
,
welford_count_global_val_buf
,
welford_count_global_val_buf
,
thread_buffer_desc_m_1
,
thread_buffer_desc_m_1
,
make_tuple
(
I0
,
I0
),
make_tuple
(
I0
,
I0
),
...
@@ -325,12 +325,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
...
@@ -325,12 +325,12 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_var_thread_buf
,
welford_var_thread_buf
,
welford_count_thread_buf
);
welford_count_thread_buf
);
threadwise_mean_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_n
,
threadwise_mean_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_n
block
,
mean_var_count_thread_copy_step_0_n
);
mean_var_count_thread_copy_step_
I
0_n
);
threadwise_var_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_n
,
threadwise_var_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_n
block
,
mean_var_count_thread_copy_step_0_n
);
mean_var_count_thread_copy_step_
I
0_n
);
threadwise_count_load_m_nblock
.
MoveSrcSliceWindow
(
count_grid_desc_m_n
,
threadwise_count_load_m_nblock
.
MoveSrcSliceWindow
(
count_grid_desc_m_n
block
,
mean_var_count_thread_copy_step_0_n
);
mean_var_count_thread_copy_step_
I
0_n
);
}
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment