Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c5fd087e
Commit
c5fd087e
authored
Mar 06, 2023
by
aska-0096
Browse files
Attn, skip b lds
parent
6e28a8ac
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
609 additions
and
237 deletions
+609
-237
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+61
-34
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+491
-202
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+56
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
c5fd087e
...
...
@@ -180,27 +180,57 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
auto
MakeB0GridDescriptor
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides_vec
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
K1
>
{});
if
constexpr
(
B0EnableLds
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
K1
>
{});
}
else
{
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{},
Number
<
K1
>
{});
}
}
static
auto
MakeB1GridDescriptor
_BL0_N_BL1
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides_vec
)
static
auto
MakeB1GridDescriptor
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides_vec
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
L1
>
{});
if
constexpr
(
B1EnableLds
)
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
L1
>
{});
}
else
{
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{},
Number
<
L1
>
{});
}
}
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
B0GridDesc
_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor
({},
{}));
using
B1GridDesc
_BL0_N_BL1
=
decltype
(
MakeB1GridDescriptor
_BL0_N_BL1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
B0GridDesc_G_L_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_L
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
B0GridDesc
=
decltype
(
MakeB0GridDescriptor
({},
{}));
using
B1GridDesc
=
decltype
(
MakeB1GridDescriptor
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
B0GridDesc_G_L_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B1GridDesc_G_N_L
=
decltype
(
Transform
::
MakeB1GridDescriptor_G_N_K
({},
{}));
using
CGridDesc_G_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_G_M_N
({},
{}));
constexpr
static
auto
make_MaskOutPredicate
()
{
...
...
@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
InMemoryDataOperationEnum
::
Set
,
// InMemory Data Descriptor
AGridDesc
,
B0GridDesc
_BK0_L_BK1
,
B1GridDesc
_BL0_N_BL1
,
B0GridDesc
,
B1GridDesc
,
CGridDesc_M_N
,
// Tiling Family
MPerBlock
,
...
...
@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc
{
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc
_bk0_l_bk1_
{
b0_grid_desc
{
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc
_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
b1_grid_desc
{
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
a_grid_desc_g_m_k_
{
...
...
@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc_bk0_l_bk1_
,
b1_grid_desc_bl0_n_bl1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
...
@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Tensor Descriptors
AGridDesc
a_grid_desc
;
B0GridDesc
_BK0_L_BK1
b0_grid_desc
_bk0_l_bk1_
;
B1GridDesc
_BL0_N_BL1
b1_grid_desc
_bl0_n_bl1_
;
B0GridDesc
b0_grid_desc
;
B1GridDesc
b1_grid_desc
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
...
...
@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1DataType
,
CDataType
,
DeviceOp
::
AGridDesc
,
DeviceOp
::
B0GridDesc
_BK0_L_BK1
,
DeviceOp
::
B1GridDesc
_BL0_N_BL1
,
DeviceOp
::
B0GridDesc
,
DeviceOp
::
B1GridDesc
,
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
B0ElementwiseOperation
,
...
...
@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc
,
arg
.
b0_grid_desc
_bk0_l_bk1_
,
arg
.
b1_grid_desc
_bl0_n_bl1_
,
arg
.
b0_grid_desc
,
arg
.
b1_grid_desc
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b0_element_op_
,
...
...
@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b0_grid_desc
_bk0_l_bk1_
,
arg
.
b1_grid_desc
_bl0_n_bl1_
,
arg
.
b0_grid_desc
,
arg
.
b1_grid_desc
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
c5fd087e
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
c5fd087e
...
...
@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
A
DataType
>
(
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B
DataType
>
(
b_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
c5fd087e
...
...
@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
BGridDesc_L_K
,
typename
WmmaK
,
typename
LRepeat
,
typename
LWaves
,
typename
LPerWmma
,
typename
BK1
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
LRepeat
&
,
const
LWaves
&
,
const
LPerWmma
&
,
const
BK1
&
)
{
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
WmmaK
{}
/
BK1
{};
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
BKRow
,
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
//
// B1
//
...
...
@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
BGridDesc_N_L
,
typename
WmmaL
,
typename
NRepeat
,
typename
NWaves
,
typename
NPerWmma
,
typename
BL1
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
NRepeat
&
,
const
NWaves
&
,
const
NPerWmma
&
,
const
BL1
&
)
{
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
WmmaL
{}
/
BL1
{};
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
BLRow
,
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
//
// C
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment