Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8bf23425
Commit
8bf23425
authored
Dec 08, 2022
by
rocking
Browse files
calculate max count for tail block
parent
cb17765e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
24 deletions
+54
-24
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
...ple/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
+11
-11
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+9
-7
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
...dwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
+33
-4
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+1
-2
No files found.
example/21_gemm_layernorm/gemm_add_add_layernorm_xdl_fp16.cpp
View file @
8bf23425
...
@@ -56,15 +56,15 @@ using BElementOp = PassThrough;
...
@@ -56,15 +56,15 @@ using BElementOp = PassThrough;
using
CDEElementOp
=
AddReluAdd
;
using
CDEElementOp
=
AddReluAdd
;
using
HElementOp
=
PassThrough
;
using
HElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
// clang-format off
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
using
DeviceOpInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
//######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm| Layernorm|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
ThreadClusterSize| ThreadSliceSize| ESrcHDst| ESrc| HDst| GammaSrc| BetaSrc|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
_M_N| _M_N| VectorDim| VectorSize| VectorSize| VectorSize| VectorSize|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
LayernormThreadClusterSize_M_N, LayernormThreadSliceSize_M_N
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
| | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
HLayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
HElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
64
,
4
>
,
4
,
S
<
8
,
32
>
,
S
<
1
,
8
>
,
1
,
8
,
8
,
8
,
8
,
1
>
;
<
ALayout
,
BLayout
,
DsLayout
,
HLayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
GammaDataType
,
BetaDataType
,
HDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
HElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
64
,
4
>
,
4
,
S
<
8
,
32
>
,
S
<
1
,
8
>
,
1
,
8
,
8
,
8
,
8
>
;
// clang-format on
// clang-format on
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
...
@@ -149,11 +149,11 @@ int main()
...
@@ -149,11 +149,11 @@ int main()
ck
::
index_t
N
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
K
=
1024
;
ck
::
index_t
StrideA
=
1024
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
1024
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideD0
=
0
;
ck
::
index_t
StrideD0
=
0
;
ck
::
index_t
StrideD1
=
1024
;
ck
::
index_t
StrideD1
=
N
;
ck
::
index_t
StrideH
=
1024
;
ck
::
index_t
StrideH
=
N
;
float
epsilon
=
1e-5
;
float
epsilon
=
1e-5
;
...
@@ -253,7 +253,7 @@ int main()
...
@@ -253,7 +253,7 @@ int main()
e_device_buf
.
FromDevice
(
e_m_n
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n
.
mData
.
data
());
h_device_buf
.
FromDevice
(
h_m_n
.
mData
.
data
());
h_device_buf
.
FromDevice
(
h_m_n
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
e_m_n
,
e_m_n_host
);
pass
&=
ck
::
utils
::
check_err
(
e_m_n
,
e_m_n_host
,
"Error: Incorrect results e_m_n"
);
pass
&=
pass
&=
ck
::
utils
::
check_err
(
h_m_n
,
h_m_n_host
,
"Error: Incorrect results h_m_n"
,
1e-2
,
1e-2
);
ck
::
utils
::
check_err
(
h_m_n
,
h_m_n_host
,
"Error: Incorrect results h_m_n"
,
1e-2
,
1e-2
);
}
}
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
8bf23425
...
@@ -59,7 +59,8 @@ __global__ void
...
@@ -59,7 +59,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
const
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
,
index_t
NRaw
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemmWelford
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemmWelford
::
GetSharedMemoryNumberOfByte
()];
...
@@ -81,7 +82,8 @@ __global__ void
...
@@ -81,7 +82,8 @@ __global__ void
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
block_2_etile_map
);
block_2_etile_map
,
NRaw
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -99,6 +101,7 @@ __global__ void
...
@@ -99,6 +101,7 @@ __global__ void
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
mean_var_count_grid_desc_mblock_mperblock_nblock
;
ignore
=
mean_var_count_grid_desc_mblock_mperblock_nblock
;
ignore
=
block_2_etile_map
;
ignore
=
block_2_etile_map
;
ignore
=
NRaw
;
#endif
#endif
}
}
...
@@ -225,7 +228,6 @@ template <typename ALayout,
...
@@ -225,7 +228,6 @@ template <typename ALayout,
index_t
LayernormHDstVectorSize
,
index_t
LayernormHDstVectorSize
,
index_t
LayernormGammaSrcVectorSize
,
index_t
LayernormGammaSrcVectorSize
,
index_t
LayernormBetaSrcVectorSize
,
index_t
LayernormBetaSrcVectorSize
,
index_t
LayernormMeanVarSrcDstVectorSize
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
:
public
BaseOperator
struct
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
:
public
BaseOperator
{
{
...
@@ -329,7 +331,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -329,7 +331,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
}();
}();
return
PadTensorDescriptor
(
return
PadTensorDescriptor
(
grid_desc_m_n
,
make_tuple
(
MPerBlock
,
N
Per
Block
),
Sequence
<
true
,
tru
e
>
{});
grid_desc_m_n
,
make_tuple
(
MPerBlock
,
NBlock
),
Sequence
<
true
,
fals
e
>
{});
}
}
template
<
typename
LayOut
>
template
<
typename
LayOut
>
...
@@ -487,8 +489,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -487,8 +489,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
LayernormESrcVectorSize
,
LayernormESrcVectorSize
,
LayernormHDstVectorSize
,
LayernormHDstVectorSize
,
LayernormGammaSrcVectorSize
,
LayernormGammaSrcVectorSize
,
LayernormBetaSrcVectorSize
,
LayernormBetaSrcVectorSize
>
;
LayernormMeanVarSrcDstVectorSize
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -732,7 +733,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -732,7 +733,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
mean_var_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
mean_var_count_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
,
arg
.
NRaw_
);
grid_size
=
math
::
integer_divide_ceil
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
));
grid_size
=
math
::
integer_divide_ceil
(
M
,
LayernormBlockTileSize_M_N
::
At
(
0
));
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp
View file @
8bf23425
...
@@ -240,7 +240,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -240,7 +240,6 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
}
}
// TODO - MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
template
<
typename
GridDescriptor_M_N
>
template
<
typename
GridDescriptor_M_N
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
(
const
GridDescriptor_M_N
&
grid_desc_m_n
)
...
@@ -381,7 +380,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -381,7 +380,8 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
&
const
MeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock
&
mean_var_count_grid_desc_mblock_mperblock_nblock
,
mean_var_count_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
block_2_etile_map
,
index_t
NRaw
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
@@ -879,9 +879,38 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
...
@@ -879,9 +879,38 @@ struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
var_thread_bufs
;
Array
<
mean_var_vgpr_type
,
num_shuffleM
>
var_thread_bufs
;
Array
<
welford_count_vgpr_type
,
num_shuffleM
>
welford_count_thread_bufs
;
Array
<
welford_count_vgpr_type
,
num_shuffleM
>
welford_count_thread_bufs
;
int
max_count
=
PostShuffleThreadSliceSize_N
*
num_shuffleN
;
const
auto
nblock
=
mean_var_count_grid_desc_mblock_mperblock_nblock
.
GetLength
(
I2
);
// tail block
if
(
block_work_idx
[
I1
]
%
nblock
==
nblock
-
1
)
{
constexpr
index_t
NPerShuffleBlock
=
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
;
int
NPerBlockTail
=
NRaw
-
NPerBlock
*
(
nblock
-
1
);
int
thread_max_len
=
PostShuffleThreadSliceSize_N
*
(
post_shuffle_thread_cluster_idx
[
I1
]
+
1
);
int
shuffle_step
=
0
;
while
(
thread_max_len
<=
NPerBlockTail
&&
shuffle_step
<
num_shuffleN
)
{
++
shuffle_step
;
thread_max_len
+=
NPerShuffleBlock
;
}
int
delta
=
0
;
if
(
thread_max_len
-
NPerBlockTail
>
PostShuffleThreadSliceSize_N
)
delta
=
0
;
else
if
(
NPerBlockTail
>
thread_max_len
)
delta
=
PostShuffleThreadSliceSize_N
;
else
delta
=
PostShuffleThreadSliceSize_N
-
thread_max_len
+
NPerBlockTail
;
max_count
=
shuffle_step
*
PostShuffleThreadSliceSize_N
+
delta
;
}
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_shuffleM
,
1
>
{}([
&
](
auto
i
)
{
// TODO - padding
threadwise_welfords
(
i
).
max_count_
=
max_count
;
threadwise_welfords
(
i
).
max_count_
=
PostShuffleThreadSliceSize_N
*
num_shuffleN
;
mean_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
mean_thread_bufs
(
i
)
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
>
(
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
thread_welford_dst_desc_m
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
8bf23425
...
@@ -39,8 +39,7 @@ template <typename EDataType,
...
@@ -39,8 +39,7 @@ template <typename EDataType,
index_t
ESrcVectorSize
,
index_t
ESrcVectorSize
,
index_t
HDstVectorSize
,
index_t
HDstVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
BetaSrcVectorSize
>
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseWelfordSecondHalfLayernorm2d
struct
GridwiseWelfordSecondHalfLayernorm2d
{
{
// TODO - Support ESrcHDstVectorDim == 0
// TODO - Support ESrcHDstVectorDim == 0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment