Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c83bad81
Commit
c83bad81
authored
Dec 22, 2022
by
rocking
Browse files
Rewrite the 2st kernel, use multiple block along N dimension in layernorm kernel
parent
90546fbe
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
77 deletions
+73
-77
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
...ce/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+9
-8
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+64
-69
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
c83bad81
...
...
@@ -142,7 +142,7 @@ __global__ void
const
GammaBetaGridDesc_N
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
beta_grid_desc_n
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
numNormBlockTileIteration_N
,
index_t
NBlockClusterLength
,
ComputeDataType
epsilon
,
HElementwiseOperation
h_element_op
)
{
...
...
@@ -160,7 +160,7 @@ __global__ void
gamma_grid_desc_n
,
beta_grid_desc_n
,
numMeanVarCountBlockTileIteration_N
,
numNormBlockTileIteration_N
,
NBlockClusterLength
,
epsilon
,
h_element_op
);
}
...
...
@@ -557,7 +557,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
DeviceOp
::
MakeEHGridDescriptor_M_N
<
DLayout
>
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
});
// populate desc for Ds/E/
F/G
// populate desc for Ds/E/
mean/var/count
if
(
GridwiseGemmWelford
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
...
...
@@ -736,14 +736,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg
.
block_2_etile_map_
,
arg
.
NRaw_
);
grid_size
=
math
::
integer_divide_ceil
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
));
index_t
MBlockClusterLength
=
math
::
integer_divide_ceil
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
));
index_t
NBlockClusterLength
=
math
::
integer_divide_ceil
(
N
,
LayernormBlockTileSize_M_N
::
At
(
1
));
grid_size
=
MBlockClusterLength
*
NBlockClusterLength
;
index_t
numMeanVarCountBlockTileIteration_N
=
math
::
integer_divide_ceil
(
arg
.
gemm_nblock_
,
LayernormThreadClusterSize_M_N
::
At
(
I1
));
index_t
numNormBlockTileIteration_N
=
math
::
integer_divide_ceil
(
N
,
LayernormBlockTileSize_M_N
::
At
(
I1
));
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_welford_layernorm
,
...
...
@@ -764,7 +765,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
arg
.
gamma_grid_desc_n_
,
arg
.
beta_grid_desc_n_
,
numMeanVarCountBlockTileIteration_N
,
numNormBlockTileIteration_N
,
NBlockClusterLength
,
arg
.
epsilon_
,
arg
.
h_element_op_
);
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
c83bad81
...
...
@@ -101,13 +101,16 @@ struct GridwiseWelfordSecondHalfLayernorm2d
const
GammaBetaGridDesc_N
&
gamma_grid_desc_n
,
const
GammaBetaGridDesc_N
&
beta_grid_desc_n
,
index_t
numMeanVarCountBlockTileIteration_N
,
index_t
numNormBlockTileIteration_N
,
index_t
NBlockClusterLength
,
ComputeDataType
epsilon
,
HElementwiseOperation
h_element_op
)
{
// Thread/Block id
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
auto
block_work_idx
=
make_tuple
(
block_global_id
/
NBlockClusterLength
,
block_global_id
%
NBlockClusterLength
);
const
auto
thread_cluster_idx
=
thread_cluster_desc_m_n
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
...
...
@@ -152,22 +155,22 @@ struct GridwiseWelfordSecondHalfLayernorm2d
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
ESrcVector
Size
,
MThreadSliceSize
*
NThreadSlice
Size
,
true
>
e_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
GammaSrcVector
Size
,
MThreadSliceSize
*
NThreadSlice
Size
,
true
>
gamma_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
BetaSrcVector
Size
,
MThreadSliceSize
*
NThreadSlice
Size
,
true
>
beta_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MThreadSliceSize
*
HDstVector
Size
,
MThreadSliceSize
*
NThreadSlice
Size
,
true
>
h_thread_buf
;
...
...
@@ -184,7 +187,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
true
>
(
mean_var_grid_desc_m_n
,
make_multi_index
(
block_
global_id
*
M_BlockTileSize
+
make_multi_index
(
block_
work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
...
...
@@ -200,7 +203,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
true
>
(
mean_var_grid_desc_m_n
,
make_multi_index
(
block_
global_id
*
M_BlockTileSize
+
make_multi_index
(
block_
work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
...
...
@@ -216,7 +219,7 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
true
>
(
count_grid_desc_m_n
,
make_multi_index
(
block_
global_id
*
M_BlockTileSize
+
make_multi_index
(
block_
work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
));
...
...
@@ -232,9 +235,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
true
>
(
e_grid_desc_m_n
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
*
NThreadSliceSize
));
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_work_idx
[
I1
]
*
N_BlockTileSize
+
thread_n_cluster_id
*
NThreadSliceSize
));
auto
threadwise_gamma_load_m_n
=
ThreadwiseTensorSliceTransfer_v2
<
GammaDataType
,
...
...
@@ -247,7 +250,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
GammaSrcVectorSize
,
1
,
true
>
(
gamma_grid_desc_n
,
make_multi_index
(
thread_n_cluster_id
*
NThreadSliceSize
));
gamma_grid_desc_n
,
make_multi_index
(
block_work_idx
[
I1
]
*
N_BlockTileSize
+
thread_n_cluster_id
*
NThreadSliceSize
));
auto
threadwise_beta_load_m_n
=
ThreadwiseTensorSliceTransfer_v2
<
BetaDataType
,
...
...
@@ -260,7 +265,9 @@ struct GridwiseWelfordSecondHalfLayernorm2d
BetaSrcVectorSize
,
1
,
true
>
(
beta_grid_desc_n
,
make_multi_index
(
thread_n_cluster_id
*
NThreadSliceSize
));
beta_grid_desc_n
,
make_multi_index
(
block_work_idx
[
I1
]
*
N_BlockTileSize
+
thread_n_cluster_id
*
NThreadSliceSize
));
auto
threadwise_h_store_m_n
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
...
...
@@ -276,13 +283,13 @@ struct GridwiseWelfordSecondHalfLayernorm2d
1
,
true
>
(
h_grid_desc_m_n
,
make_multi_index
(
block_global_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
thread_n_cluster_id
*
NThreadSliceSize
),
make_multi_index
(
block_work_idx
[
I0
]
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_work_idx
[
I1
]
*
N_BlockTileSize
+
thread_n_cluster_id
*
NThreadSliceSize
),
h_element_op
);
// step1: Merge mean and variance
constexpr
auto
mean_var_count_thread_copy_step_
m
_n
=
constexpr
auto
mean_var_count_thread_copy_step_
0
_n
=
make_multi_index
(
0
,
NThreadClusterSize
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
@@ -320,11 +327,11 @@ struct GridwiseWelfordSecondHalfLayernorm2d
welford_count_thread_buf
);
threadwise_mean_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_n
,
mean_var_count_thread_copy_step_
m
_n
);
mean_var_count_thread_copy_step_
0
_n
);
threadwise_var_load_m_nblock
.
MoveSrcSliceWindow
(
mean_var_grid_desc_m_n
,
mean_var_count_thread_copy_step_
m
_n
);
mean_var_count_thread_copy_step_
0
_n
);
threadwise_count_load_m_nblock
.
MoveSrcSliceWindow
(
count_grid_desc_m_n
,
mean_var_count_thread_copy_step_
m
_n
);
mean_var_count_thread_copy_step_
0
_n
);
}
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
I
)
{
...
...
@@ -336,66 +343,54 @@ struct GridwiseWelfordSecondHalfLayernorm2d
});
// step2: normalization
for
(
index_t
reducedTiles
=
0
;
reducedTiles
<
numNormBlockTileIteration_N
;
++
reducedTiles
)
{
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n
.
Run
(
e_grid_desc_m_n
,
e_global_val_buf
,
thread_buffer_desc_m_n
,
make_tuple
(
I0
,
I0
),
e_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
auto
divisor
=
1
/
__builtin_amdgcn_sqrtf
(
welford_var_thread_buf
(
m
)
+
epsilon
);
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
(
e_thread_buf
(
Number
<
m_n
>
{})
-
welford_mean_thread_buf
(
m
))
*
divisor
;
});
// h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n]
threadwise_e_load_m_n
.
Run
(
e_grid_desc_m_n
,
e_global_val_buf
,
thread_buffer_desc_m_n
,
make_tuple
(
I0
,
I0
),
e_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
auto
divisor
=
1
/
__builtin_amdgcn_sqrtf
(
welford_var_thread_buf
(
m
)
+
epsilon
);
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
(
e_thread_buf
(
Number
<
m_n
>
{})
-
welford_mean_thread_buf
(
m
))
*
divisor
;
});
});
threadwise_gamma_load_m_n
.
Run
(
gamma_grid_desc_n
,
gamma_global_val_buf
,
thread_buffer_desc_n
,
make_tuple
(
I0
),
gamma_thread_buf
);
threadwise_gamma_load_m_n
.
Run
(
gamma_grid_desc_n
,
gamma_global_val_buf
,
thread_buffer_desc_n
,
make_tuple
(
I0
),
gamma_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
h_thread_buf
(
Number
<
m_n
>
{})
*
gamma_thread_buf
(
n
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
h_thread_buf
(
Number
<
m_n
>
{})
*
gamma_thread_buf
(
n
);
});
});
threadwise_beta_load_m_n
.
Run
(
beta_grid_desc_n
,
beta_global_val_buf
,
thread_buffer_desc_n
,
make_tuple
(
I0
),
beta_thread_buf
);
threadwise_beta_load_m_n
.
Run
(
beta_grid_desc_n
,
beta_global_val_buf
,
thread_buffer_desc_n
,
make_tuple
(
I0
),
beta_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
h_thread_buf
(
Number
<
m_n
>
{})
+
beta_thread_buf
(
n
);
});
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
NThreadSliceSize
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
m_n
=
thread_buffer_desc_m_n
.
CalculateOffset
(
make_tuple
(
m
,
n
));
h_thread_buf
(
Number
<
m_n
>
{})
=
h_thread_buf
(
Number
<
m_n
>
{})
+
beta_thread_buf
(
n
);
});
});
threadwise_h_store_m_n
.
Run
(
thread_buffer_desc_m_n
,
threadwise_h_store_m_n
.
Run
(
thread_buffer_desc_m_n
,
make_tuple
(
I0
,
I0
),
h_thread_buf
,
h_grid_desc_m_n
,
h_global_val_buf
);
threadwise_e_load_m_n
.
MoveSrcSliceWindow
(
e_grid_desc_m_n
,
make_multi_index
(
0
,
N_BlockTileSize
));
threadwise_gamma_load_m_n
.
MoveSrcSliceWindow
(
gamma_grid_desc_n
,
make_multi_index
(
N_BlockTileSize
));
threadwise_beta_load_m_n
.
MoveSrcSliceWindow
(
beta_grid_desc_n
,
make_multi_index
(
N_BlockTileSize
));
threadwise_h_store_m_n
.
MoveDstSliceWindow
(
h_grid_desc_m_n
,
make_multi_index
(
0
,
N_BlockTileSize
));
}
}
// run
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment