Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dd0255ba
"docs/git@developer.sourcefind.cn:SIYIXNI/vllm.git" did not exist on "7e1d5e5308fa3549dfed1821188d588260a03c8a"
Commit
dd0255ba
authored
Sep 06, 2022
by
rocking
Browse files
Add gridwise gemm + welford
parent
e9f656fa
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1343 additions
and
17 deletions
+1343
-17
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+297
-17
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_xdl_cshuffle.hpp
...pu/grid/gridwise_gemm_multiple_d_welford_xdl_cshuffle.hpp
+1046
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
dd0255ba
...
...
@@ -12,10 +12,96 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
ABDataType
,
typename
DsPointer
,
typename
EDataType
,
typename
FDataType
,
typename
GDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
FGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_multiple_d_welford_xdl_cshuffle
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
FDataType
*
__restrict__
p_f_grid
,
GDataType
*
__restrict__
p_g_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
FGridDescriptor_MBlock_MPerBlock_NBlock
f_grid_desc_mblock_mperblock_nblock
,
const
GGridDescriptor_MBlock_MPerBlock_NBlock
g_grid_desc_mblock_mperblock_nblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_f_grid
,
p_g_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
f_grid_desc_mblock_mperblock_nblock
,
g_grid_desc_mblock_mperblock_nblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_f_grid
;
ignore
=
p_g_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
f_grid_desc_mblock_mperblock_nblock
;
ignore
=
g_grid_desc_mblock_mperblock_nblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
@@ -43,7 +129,7 @@ template <typename ALayout,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
F
DataType
,
typename
H
DataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
...
@@ -81,7 +167,9 @@ template <typename ALayout,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
:
public
BaseOperator
{
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmMultipleDLayernorm_Xdl_CShuffle
;
using
FDataType
=
CShuffleDataType
;
using
GDataType
=
CShuffleDataType
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -162,8 +250,64 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
FGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
GGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
using
HGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleDWelford_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
FDataType
,
GDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_M_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
FGridDesc_M_N
,
GGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
ReduceThreadTransferClusterLengths_MPerBlock_NPerBlock
,
ReduceThreadTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
struct
Argument
:
public
BaseArgument
{
...
...
@@ -171,7 +315,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
void
*
p_
f
_grid
,
void
*
p_
h
_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -188,18 +332,70 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_f_grid_
{
static_cast
<
FDataType
*>
(
p_f_grid
)},
p_f_grid_
{
nullptr
},
p_g_grid_
{
nullptr
},
p_h_grid_
{
static_cast
<
HDataType
*>
(
p_h_grid
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
NRaw
,
StrideE
)},
f_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
),
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
))},
g_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
ELayout
>
(
MRaw
,
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
),
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
))},
h_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
h_element_op_
{
h_element_op
}
{
// TODO
int
welford_size
=
MRaw
*
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
);
hip_check_error
(
hipMalloc
(
&
p_f_grid_
,
sizeof
(
FDataType
)
*
welford_size
));
hip_check_error
(
hipMalloc
(
&
p_g_grid_
,
sizeof
(
GDataType
)
*
welford_size
));
// populate pointer, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeGridDescriptor_M_N
<
DLayout
>
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
});
// populate desc for Ds/E/F/G
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
f_grid_desc_m_n_
,
g_grid_desc_m_n_
,
block_2_etile_map_
))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
f_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemm
::
MakeFGGridDescriptor_MBlock_MPerBlock_NBlock
(
f_grid_desc_m_n_
);
g_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemm
::
MakeFGGridDescriptor_MBlock_MPerBlock_NBlock
(
g_grid_desc_m_n_
);
}
// TODO - H
}
void
Print
()
const
...
...
@@ -216,20 +412,35 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
// FIXME - typename GridwiseGemm::DsGridPointer
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
FDataType
*
p_f_grid_
;
FDataType
*
p_f_grid_
;
// mean
GDataType
*
p_g_grid_
;
// variance * count
HDataType
*
p_h_grid_
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
FGridDesc_M_N
f_grid_desc_m_n_
;
GGridDesc_M_N
g_grid_desc_m_n_
;
HGridDesc_M_N
h_grid_desc_m_n_
;
// TODO - tensor descriptors for block/thread-wise copy
// TODO - block-to-e-tile map
// tensor descriptors for block/thread-wise copy
typename
GridwiseGemm
::
DefaultAGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
typename
GridwiseGemm
::
DefaultBGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
FGridDescriptor_MBlock_MPerBlock_NBlock
f_grid_desc_mblock_mperblock_nblock_
;
typename
GridwiseGemm
::
GGridDescriptor_MBlock_MPerBlock_NBlock
g_grid_desc_mblock_mperblock_nblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -243,10 +454,79 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
,
const
StreamConfig
&
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
// TODO
return
0
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
f_grid_desc_m_n_
,
arg
.
g_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_d_welford_xdl_cshuffle
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
FDataType
,
GDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
typename
GridwiseGemm
::
DefaultAGridDesc_AK0_M_AK1
,
typename
GridwiseGemm
::
DefaultBGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
FGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
GGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_f_grid_
,
arg
.
p_g_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
f_grid_desc_mblock_mperblock_nblock_
,
arg
.
g_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
...
...
@@ -264,7 +544,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
return
false
;
}
return
fals
e
;
return
tru
e
;
}
// polymorphic
...
...
@@ -277,7 +557,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_
f
,
void
*
p_
h
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -295,7 +575,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b
,
p_ds
,
p_e
,
p_
f
,
p_
h
,
MRaw
,
NRaw
,
KRaw
,
...
...
@@ -317,7 +597,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
void
*
p_
f
,
void
*
p_
h
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
...
...
@@ -335,7 +615,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
p_b
,
p_ds
,
p_e
,
p_
f
,
p_
h
,
MRaw
,
NRaw
,
KRaw
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_welford_xdl_cshuffle.hpp
0 → 100644
View file @
dd0255ba
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment