Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c5fd087e
Commit
c5fd087e
authored
Mar 06, 2023
by
aska-0096
Browse files
Attn, skip b lds
parent
6e28a8ac
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
609 additions
and
237 deletions
+609
-237
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+61
-34
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+491
-202
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+56
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
c5fd087e
...
@@ -179,23 +179,53 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -179,23 +179,53 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
auto
MakeB0GridDescriptor
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
static
auto
MakeB0GridDescriptor
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides_vec
)
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides_vec
)
{
if
constexpr
(
B0EnableLds
)
{
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
K1
>
{});
Number
<
K1
>
{});
}
}
else
{
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{},
Number
<
K1
>
{});
}
}
static
auto
MakeB1GridDescriptor
_BL0_N_BL1
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths_vec
,
static
auto
MakeB1GridDescriptor
(
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_lengths_vec
,
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides_vec
)
const
std
::
vector
<
index_t
>&
b1_gs_ns_ls_strides_vec
)
{
if
constexpr
(
B1EnableLds
)
{
{
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
return
Transform
::
MakeB1GridDescriptor_BK0_N_BK1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
L1
>
{});
Number
<
L1
>
{});
}
}
else
{
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{},
Number
<
L1
>
{});
}
}
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
B0GridDesc
_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor
({},
{}));
using
B0GridDesc
=
decltype
(
MakeB0GridDescriptor
({},
{}));
using
B1GridDesc
_BL0_N_BL1
=
decltype
(
MakeB1GridDescriptor
_BL0_N_BL1
({},
{}));
using
B1GridDesc
=
decltype
(
MakeB1GridDescriptor
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
using
B0GridDesc_G_L_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
using
B0GridDesc_G_L_K
=
decltype
(
Transform
::
MakeB0GridDescriptor_G_N_K
({},
{}));
...
@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -274,8 +304,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
// InMemory Data Descriptor
// InMemory Data Descriptor
AGridDesc
,
AGridDesc
,
B0GridDesc
_BK0_L_BK1
,
B0GridDesc
,
B1GridDesc
_BL0_N_BL1
,
B1GridDesc
,
CGridDesc_M_N
,
CGridDesc_M_N
,
// Tiling Family
// Tiling Family
MPerBlock
,
MPerBlock
,
...
@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -364,10 +394,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b1_grid_
{
p_b1_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc
{
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
a_grid_desc
{
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc
_bk0_l_bk1_
{
b0_grid_desc
{
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc
_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_grid_desc
{
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
DeviceOp
::
MakeB1GridDescriptor
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
c_grid_desc_m_n_
{
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
Transform
::
MakeCGridDescriptor_M_N
(
c_gs_ms_ns_lengths
,
c_gs_ms_ns_strides
)},
a_grid_desc_g_m_k_
{
a_grid_desc_g_m_k_
{
...
@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -410,11 +440,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
if
(
GridwiseOp
::
CheckValidity
(
b0_grid_desc_bk0_l_bk1_
,
a_grid_desc
,
b0_grid_desc
,
b1_grid_desc
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
b1_grid_desc_bl0_n_bl1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseOp
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -430,8 +457,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Tensor Descriptors
// Tensor Descriptors
AGridDesc
a_grid_desc
;
AGridDesc
a_grid_desc
;
B0GridDesc
_BK0_L_BK1
b0_grid_desc
_bk0_l_bk1_
;
B0GridDesc
b0_grid_desc
;
B1GridDesc
_BL0_N_BL1
b1_grid_desc
_bl0_n_bl1_
;
B1GridDesc
b1_grid_desc
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
AGridDesc_G_M_K
a_grid_desc_g_m_k_
;
...
@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -498,8 +525,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
DeviceOp
::
AGridDesc
,
DeviceOp
::
AGridDesc
,
DeviceOp
::
B0GridDesc
_BK0_L_BK1
,
DeviceOp
::
B0GridDesc
,
DeviceOp
::
B1GridDesc
_BL0_N_BL1
,
DeviceOp
::
B1GridDesc
,
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
AElementwiseOperation
,
AElementwiseOperation
,
B0ElementwiseOperation
,
B0ElementwiseOperation
,
...
@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -521,8 +548,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg
.
p_b1_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc
,
arg
.
a_grid_desc
,
arg
.
b0_grid_desc
_bk0_l_bk1_
,
arg
.
b0_grid_desc
,
arg
.
b1_grid_desc
_bl0_n_bl1_
,
arg
.
b1_grid_desc
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b0_element_op_
,
arg
.
b0_element_op_
,
...
@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -582,8 +609,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b0_grid_desc
_bk0_l_bk1_
,
arg
.
b0_grid_desc
,
arg
.
b1_grid_desc
_bl0_n_bl1_
,
arg
.
b1_grid_desc
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
arg
.
block_2_ctile_map_
))
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
c5fd087e
...
@@ -18,14 +18,14 @@
...
@@ -18,14 +18,14 @@
namespace
ck
{
namespace
ck
{
template
<
typename
Gridwise
Gemm
,
template
<
typename
Gridwise
Op
,
typename
FloatA
,
typename
ADataType
,
typename
FloatB0
,
typename
B0DataType
,
typename
FloatB1
,
typename
B1DataType
,
typename
FloatC
,
typename
CDataType
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
B0GridDesc
_BK0_L_BK1
,
typename
B0GridDesc
,
typename
B1GridDesc
_BL0_N_BL1
,
typename
B1GridDesc
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
B0ElementwiseOperation
,
...
@@ -41,13 +41,13 @@ __global__ void
...
@@ -41,13 +41,13 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
(
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
(
const
FloatA
*
__restrict__
p_a_grid
,
const
ADataType
*
__restrict__
p_a_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
B0DataType
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
const
B1DataType
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
CDataType
*
__restrict__
p_c_grid
,
const
AGridDesc
a_grid_desc
,
const
AGridDesc
a_grid_desc
,
const
B0GridDesc
_BK0_L_BK1
b0_grid_desc
_bk0_l_bk1
,
const
B0GridDesc
b0_grid_desc
,
const
B1GridDesc
_BL0_N_BL1
b1_grid_desc
_l0_n_l1
,
const
B1GridDesc
b1_grid_desc
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
...
@@ -61,7 +61,7 @@ __global__ void
...
@@ -61,7 +61,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
__shared__
char
p_shared
[
Gridwise
Gemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
Gridwise
Op
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -76,14 +76,14 @@ __global__ void
...
@@ -76,14 +76,14 @@ __global__ void
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
Gridwise
Gemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
Gridwise
Op
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b0_grid
+
b0_batch_offset
,
p_b0_grid
+
b0_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
p_shared
,
a_grid_desc
,
a_grid_desc
,
b0_grid_desc
_bk0_l_bk1
,
b0_grid_desc
,
b1_grid_desc
_l0_n_l1
,
b1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
a_element_op
,
b0_element_op
,
b0_element_op
,
...
@@ -98,8 +98,8 @@ __global__ void
...
@@ -98,8 +98,8 @@ __global__ void
ignore
=
p_b1_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc
;
ignore
=
a_grid_desc
;
ignore
=
b0_grid_desc
_bk0_l_bk1
;
ignore
=
b0_grid_desc
;
ignore
=
b1_grid_desc
_l0_n_l1
;
ignore
=
b1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b0_element_op
;
ignore
=
b0_element_op
;
...
@@ -115,13 +115,13 @@ __global__ void
...
@@ -115,13 +115,13 @@ __global__ void
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N]
template
<
typename
FloatA
,
template
<
typename
ADataType
,
typename
FloatB0
,
typename
B0DataType
,
typename
FloatAcc0
,
typename
Acc0DataType
,
typename
FloatB1
,
typename
B1DataType
,
typename
FloatAcc1
,
typename
Acc1DataType
,
typename
Float
CShuffle
,
typename
CShuffle
DataType
,
typename
FloatC
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
AccElementwiseOperation
,
...
@@ -129,8 +129,8 @@ template <typename FloatA,
...
@@ -129,8 +129,8 @@ template <typename FloatA,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc
,
typename
AGridDesc
,
typename
B0GridDesc
_BK0_L_BK1
,
typename
B0GridDesc
,
typename
B1GridDesc
_BL0_N_BL1
,
typename
B1GridDesc
,
typename
CGridDesc_M_N
,
typename
CGridDesc_M_N
,
index_t
MPerBlock
,
index_t
MPerBlock
,
index_t
LPerBlock
,
index_t
LPerBlock
,
...
@@ -163,7 +163,7 @@ template <typename FloatA,
...
@@ -163,7 +163,7 @@ template <typename FloatA,
index_t
B0BlockTransferDstScalarPerVector_K1
,
index_t
B0BlockTransferDstScalarPerVector_K1
,
bool
B0ThreadTransferSrcResetCoordinateAfterRun
,
bool
B0ThreadTransferSrcResetCoordinateAfterRun
,
bool
B0EnableLds
,
bool
B0EnableLds
,
bool
B0BlockLdsExtra
N
,
bool
B0BlockLdsExtra
L
,
typename
B1BlockTransferThreadClusterLengths_L0_N_L1
,
typename
B1BlockTransferThreadClusterLengths_L0_N_L1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
typename
B1BlockTransferSrcAccessOrder
,
...
@@ -204,8 +204,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -204,8 +204,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static
constexpr
auto
BL1
=
Number
<
L1Value
>
{};
static
constexpr
auto
BL1
=
Number
<
L1Value
>
{};
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
WmmaL
=
16
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -250,6 +252,73 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -250,6 +252,73 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return
a_block_desc
;
return
a_block_desc
;
}
}
__host__
__device__
static
constexpr
auto
MakeB0BlockDescriptor
()
{
constexpr
auto
b0_block_desc
=
[
&
]()
{
if
constexpr
(
B0EnableLds
)
{
// K0->L->BK1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
BK1
;
constexpr
auto
max_lds_align
=
BK1
;
if
constexpr
(
B0BlockLdsExtraL
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
LPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
LPerBlock
+
1
>
{}
*
BK1
,
BK1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
LPerBlock
>
{},
BK1
),
max_lds_align
);
}
}
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->NRepeat->NWave->NRow->NPerWmma->BK1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
LRepeat
>
{},
I1
,
I1
,
I1
,
BK1
),
make_tuple
(
Number
<
LRepeat
>
{}
*
BK1
,
BK1
,
BK1
,
BK1
,
BK1
,
I1
));
}
}();
return
b0_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeB1BlockDescriptor
()
{
constexpr
auto
b1_block_desc
=
[
&
]()
{
if
constexpr
(
B1EnableLds
)
{
// L0->N->BL1 Per Block
constexpr
auto
max_lds_align
=
BL1
;
if
constexpr
(
B1BlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
L0PerBlock
>
{},
Number
<
NPerBlock
>
{},
BL1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
BL1
,
BL1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
L0PerBlock
>
{},
Number
<
NPerBlock
>
{},
BL1
),
max_lds_align
);
}
}
else
{
constexpr
auto
LWmmaPerblock
=
LPerBlock
/
WmmaL
;
// LWmma->NRepeat->NWave->NRow->LPerWmma->BL1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
LWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
BL1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
BL1
,
BL1
,
BL1
,
BL1
,
BL1
,
I1
));
}
}();
return
b1_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
...
@@ -270,6 +339,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -270,6 +339,44 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return
a_block_copy_step
;
return
a_block_copy_step
;
}
}
__host__
__device__
static
constexpr
auto
MakeB0BlockSliceCopyStep
()
{
constexpr
auto
b0_block_copy_step
=
[
&
]()
{
if
constexpr
(
B0EnableLds
)
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
BK1
;
return
make_multi_index
(
K0PerBlock
,
0
,
0
);
}
else
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
}
}();
return
b0_block_copy_step
;
}
__host__
__device__
static
constexpr
auto
MakeB1BlockSliceCopyStep
()
{
constexpr
auto
b1_block_copy_step
=
[
&
]()
{
if
constexpr
(
B1EnableLds
)
{
return
make_multi_index
(
L0PerBlock
,
0
,
0
);
}
else
{
constexpr
auto
LWmmaPerBlock
=
LTilePerBlock
/
WmmaL
;
return
make_multi_index
(
LWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
}
}();
return
b1_block_copy_step
;
}
// Describe how data read from (LDS/VGPR) buffer
// Describe how data read from (LDS/VGPR) buffer
template
<
typename
ABlockDesc_
>
template
<
typename
ABlockDesc_
>
__host__
__device__
static
constexpr
auto
MakeAWaveDescriptor
(
const
ABlockDesc_
&
)
__host__
__device__
static
constexpr
auto
MakeAWaveDescriptor
(
const
ABlockDesc_
&
)
...
@@ -323,26 +430,61 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -323,26 +430,61 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
return
a_wave_desc
;
return
a_wave_desc
;
}
}
template
<
typename
B0BlockDesc_BK0_L_BK1
>
template
<
typename
B0BlockDesc_
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB0WaveDescriptor
(
const
B0BlockDesc_
&
)
MakeB0BlockDescriptor_K0_L0_L1_L2_K1
(
const
B0BlockDesc_BK0_L_BK1
&
)
{
constexpr
auto
b0_wave_desc
=
[
&
]()
{
if
constexpr
(
B0EnableLds
)
{
{
constexpr
index_t
B_K0
=
B0BlockDesc_BK0_L_BK1
{}.
GetLength
(
I0
);
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr
index_t
B_K
1
=
B0BlockDesc_
BK0_L_BK1
{}.
GetLength
(
I
2
);
constexpr
auto
B_K
0
=
B0BlockDesc_
{}.
GetLength
(
I
0
);
constexpr
index_t
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
B0BlockDesc_
BK0_L_BK1
{},
B0BlockDesc_
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
make_tuple
(
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
}
else
{
// KWmma_LRepeat_LWave_KRow_LPerWmma_K1 -> K0_LRepeat_Lwaves_LPerWmma_K1
constexpr
auto
KWmma
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I5
);
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
B0BlockDesc_
{},
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
KWmma
>
{}),
make_pass_through_transform
(
Number
<
LRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}();
return
b0_wave_desc
;
}
template
<
typename
A1BlockDesc_AL0_M_AL1
>
template
<
typename
A1BlockDesc_AL0_M_AL1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeA1
Block
Descriptor_L0_M0_M1_M2_L1
(
const
A1BlockDesc_AL0_M_AL1
&
)
MakeA1
Wave
Descriptor_L0_M0_M1_M2_L1
(
const
A1BlockDesc_AL0_M_AL1
&
)
{
{
constexpr
index_t
A_L0
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I0
);
constexpr
index_t
A_L0
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I0
);
constexpr
index_t
A_L1
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I2
);
constexpr
index_t
A_L1
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I2
);
...
@@ -356,37 +498,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -356,37 +498,56 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
}
template
<
typename
B1BlockDesc_BL0_N_BL1
>
template
<
typename
B1BlockDesc_
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB1WaveDescriptor
(
const
B1BlockDesc_
&
)
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
const
B1BlockDesc_BL0_N_BL1
&
)
{
{
constexpr
index_t
B_K0
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I2
);
constexpr
auto
b1_wave_desc
=
[
&
]()
{
if
constexpr
(
B1EnableLds
)
{
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
B1BlockDesc_
BL0_N_BL1
{},
B1BlockDesc_
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_
K
0
>
{}),
make_tuple
(
make_pass_through_transform
(
Number
<
B_
L
0
>
{}),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_
K
1
>
{})),
make_pass_through_transform
(
Number
<
B_
L
1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
}
else
__host__
__device__
static
constexpr
auto
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
()
{
{
// B matrix in LDS memory, dst of blockwise copy
// LWmma_NRepeat_NWave_LRow_NPerWmma_L1 -> L0_NRepeat_Nwaves_NPerWmma_L1
return
make_naive_tensor_descriptor
(
constexpr
auto
LWmma
=
B1BlockDesc_
{}.
GetLength
(
I0
);
make_tuple
(
BK0
,
Number
<
LPerBlock
>
{},
BK1
),
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I5
);
make_tuple
(
Number
<
LPerBlock
+
B0BlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
B1BlockDesc_
{},
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
LWmma
>
{}),
make_pass_through_transform
(
Number
<
NRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
()
return
b1_wave_desc
;
{
// B1 matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BL0
,
Number
<
NPerBlock
>
{},
BL1
),
make_tuple
(
Number
<
NPerBlock
+
B1BlockLdsExtraN
>
{}
*
BL1
,
BL1
,
I1
));
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -410,29 +571,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -410,29 +571,28 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
const
index_t
gemm0_bytes_end
=
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
FloatB0
));
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
B0DataType
));
const
index_t
gemm1_bytes_end
=
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
*
sizeof
(
FloatB1
));
SharedMemTrait
::
b1_block_space_size_aligned
*
sizeof
(
B1DataType
));
const
index_t
softmax_bytes_end
=
const
index_t
softmax_bytes_end
=
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
*
sizeof
(
FloatAcc0
);
SharedMemTrait
::
reduction_space_size_aligned
*
sizeof
(
Acc0DataType
);
const
index_t
c_block_bytes_end
=
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
Float
CShuffle
);
SharedMemTrait
::
c_block_space_size
*
sizeof
(
CShuffle
DataType
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
B0GridDesc
&
b0_grid_desc
,
const
B0GridDesc_BK0_L_BK1
&
b0_grid_desc_bk0_l_bk1
,
const
B1GridDesc
&
b1_grid_desc
,
const
B1GridDesc_BL0_N_BL1
&
b1_grid_desc_l0_n_l1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
...
@@ -455,10 +615,40 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -455,10 +615,40 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
}
}
};
};
const
auto
GetB0ProblemsizeLK
=
[
&
]()
{
if
constexpr
(
B0EnableLds
)
{
return
make_tuple
(
b0_grid_desc
.
GetLength
(
I1
),
b0_grid_desc
.
GetLength
(
I0
)
*
b0_grid_desc
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
b0_grid_desc
.
GetLength
(
I4
),
b0_grid_desc
.
GetLength
(
I0
)
*
b0_grid_desc
.
GetLength
(
I3
)
*
b0_grid_desc
.
GetLength
(
I5
));
}
};
const
auto
GetB1ProblemsizeNL
=
[
&
]()
{
if
constexpr
(
B1EnableLds
)
{
return
make_tuple
(
b1_grid_desc
.
GetLength
(
I1
),
b1_grid_desc
.
GetLength
(
I0
)
*
b1_grid_desc
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
b1_grid_desc
.
GetLength
(
I1
)
*
b1_grid_desc
.
GetLength
(
I2
)
*
b1_grid_desc
.
GetLength
(
I4
),
b1_grid_desc
.
GetLength
(
I0
)
*
b1_grid_desc
.
GetLength
(
I3
)
*
b1_grid_desc
.
GetLength
(
I5
));
}
};
const
auto
M
=
GetAProblemsizeMK
()[
I0
];
const
auto
M
=
GetAProblemsizeMK
()[
I0
];
const
auto
L
=
b0_grid_desc_bk0_l_bk1
.
GetLength
(
I
1
);
const
auto
L
=
GetB0ProblemsizeLK
()
(
I
0
);
const
auto
K
=
GetAProblemsizeMK
()[
I1
];
const
auto
K
=
GetAProblemsizeMK
()[
I1
];
const
auto
N
=
b1_grid_desc_l0_n_l1
.
GetLength
(
I
1
);
const
auto
N
=
GetB1ProblemsizeNL
()
(
I
0
);
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
{
...
@@ -567,16 +757,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -567,16 +757,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
max_lds_align
)
max_lds_align
)
:
0
;
:
0
;
static
constexpr
auto
b0_block_space_size_aligned
=
static
constexpr
auto
b0_block_space_size_aligned
=
B0EnableLds
B0EnableLds
?
math
::
integer_least_multiple
(
?
math
::
integer_least_multiple
(
MakeB0BlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
:
0
;
static
constexpr
auto
b1_block_space_size_aligned
=
static
constexpr
auto
b1_block_space_size_aligned
=
B1EnableLds
B1EnableLds
?
math
::
integer_least_multiple
(
?
math
::
integer_least_multiple
(
MakeB1BlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
...
@@ -599,14 +785,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -599,14 +785,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
typename
C0MatrixMask
,
typename
C0MatrixMask
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
FloatB0
*
__restrict__
p_b0_grid
,
const
B0DataType
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
const
B1DataType
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
CDataType
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AGridDesc
&
a_grid_desc
,
const
AGridDesc
&
a_grid_desc
,
const
B0GridDesc
_BK0_L_BK1
&
b0_grid_desc
_k0_l_k1
,
const
B0GridDesc
&
b0_grid_desc
,
const
B1GridDesc
_BL0_N_BL1
&
b1_grid_desc
_l0_n_l1
,
const
B1GridDesc
&
b1_grid_desc
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
...
@@ -623,9 +809,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -623,9 +809,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
const
auto
b0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b0_grid
,
b0_grid_desc
_k0_l_k1
.
GetElementSpaceSize
());
p_b0_grid
,
b0_grid_desc
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc
_l0_n_l1
.
GetElementSpaceSize
());
p_b1_grid
,
b1_grid_desc
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
...
@@ -648,17 +834,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -648,17 +834,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
/*******************************************************************************/
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy
const
auto
K
=
[
&
](){
if
constexpr
(
AEnableLds
){
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
);
}
}();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b0_block_desc
_k0perblock_lperblock_k1
=
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
();
constexpr
auto
b0_block_desc
=
MakeB0BlockDescriptor
();
auto
a_block_trait
=
[
&
](){
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
// A matrix blockwise copy
...
@@ -666,7 +843,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -666,7 +843,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
{
constexpr
auto
AK0PerBlock
=
KPerBlock
/
AK1
;
constexpr
auto
AK0PerBlock
=
KPerBlock
/
AK1
;
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
static_cast
<
ADataType
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
SharedMemTrait
::
a_block_space_size_aligned
);
SharedMemTrait
::
a_block_space_size_aligned
);
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
...
@@ -677,8 +854,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -677,8 +854,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename BlockSliceLengths, */
Sequence
<
AK0PerBlock
,
MPerBlock
,
AK1
>
,
/* typename BlockSliceLengths, */
Sequence
<
AK0PerBlock
,
MPerBlock
,
AK1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename SrcData, */
ADataType
,
/* typename DstData, */
FloatA
,
/* typename DstData, */
ADataType
,
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
...
@@ -705,13 +882,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -705,13 +882,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_block_desc
.
GetElementSpaceSize
());
a_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatA
,
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
FloatA
,
ADataType
,
decltype
(
a_grid_desc
),
decltype
(
a_grid_desc
),
decltype
(
a_block_desc
),
decltype
(
a_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Sequence
<
Number
<
KWmmaPerBlock
>
{},
...
@@ -737,7 +914,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -737,7 +914,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
}
}
};
};
// B matrix blockwise copy
auto
b0_block_trait
=
[
&
](){
if
constexpr
(
B0EnableLds
)
{
auto
b0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
B0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b0_block_space_offset
,
SharedMemTrait
::
b0_block_space_size_aligned
);
auto
b0_blockwise_copy
=
auto
b0_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
B0ElementwiseOperation
,
B0ElementwiseOperation
,
...
@@ -746,10 +929,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -746,10 +929,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
Sequence
<
BK0
,
LPerBlock
,
BK1
>
,
Sequence
<
BK0
,
LPerBlock
,
BK1
>
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterLengths_K0_L_K1
,
B0BlockTransferThreadClusterArrangeOrder
,
B0BlockTransferThreadClusterArrangeOrder
,
FloatB0
,
B0DataType
,
FloatB0
,
B0DataType
,
decltype
(
b0_grid_desc
_k0_l_k1
),
decltype
(
b0_grid_desc
),
decltype
(
b0_block_desc
_k0perblock_lperblock_k1
),
decltype
(
b0_block_desc
),
B0BlockTransferSrcAccessOrder
,
B0BlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
Sequence
<
0
,
1
,
2
>
,
B0BlockTransferSrcVectorDim
,
B0BlockTransferSrcVectorDim
,
...
@@ -760,27 +943,69 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -760,27 +943,69 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
1
,
1
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
b0_grid_desc
_k0_l_k1
,
b0_grid_desc
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
b0_element_op
,
b0_element_op
,
b0_block_desc
_k0perblock_lperblock_k1
,
b0_block_desc
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
return
make_tuple
(
b0_block_buf
,
b0_blockwise_copy
);
}
else
{
// Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
b0_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B0DataType
>
(
b0_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
b0_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
B0DataType
,
B0DataType
,
decltype
(
b0_grid_desc
),
decltype
(
b0_block_desc
),
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
LRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B0BlockTransferSrcScalarPerVector
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b0_grid_desc
,
make_multi_index
(
0
,
0
/
(
LWaves
*
LPerWmma
),
get_thread_local_1d_id
()
/
32
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
b0_block_buf
,
b0_blockwise_copy
);
}
};
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_block_buf
=
a_block_trait
()[
I0
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
a_blockwise_copy
=
a_block_trait
()[
I1
];
auto
b0_block_buf
=
b0_block_trait
()[
I0
];
auto
b0_blockwise_copy
=
b0_block_trait
()[
I1
];
/*******************************************************************************/
/*******************************************************************************/
// Gemm0
// Gemm0
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1Value
,
WmmaK
);
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1Value
,
WmmaK
);
auto
blockwise_gemm0
=
BlockwiseGemmWMMA
<
auto
blockwise_gemm0
=
BlockwiseGemmWMMA
<
BlockSize
,
BlockSize
,
FloatA
,
ADataType
,
FloatB0
,
B0DataType
,
FloatAcc0
,
Acc0DataType
,
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeB0
Block
Descriptor
_K0_L0_L1_L2_K1
(
b0_block_desc_k0perblock_lperblock_k1
)),
decltype
(
MakeB0
Wave
Descriptor
(
b0_block_desc
)),
MPerBlock
,
MPerBlock
,
LPerBlock
,
LPerBlock
,
KPerBlock
,
KPerBlock
,
...
@@ -817,15 +1042,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -817,15 +1042,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
/*******************************************************************************/
/*******************************************************************************/
// LDS allocation for A and B: be careful of alignment
auto
b0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB0
*>
(
p_shared
)
+
SharedMemTrait
::
b0_block_space_offset
,
SharedMemTrait
::
b0_block_space_size_aligned
);
// Shift Per SUB_K
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
MakeABlockSliceCopyStep
();
constexpr
auto
a_block_slice_copy_step
=
MakeABlockSliceCopyStep
();
constexpr
auto
b0_block_slice_copy_step
=
m
ake
_multi_index
(
BK0
,
0
,
0
);
constexpr
auto
b0_block_slice_copy_step
=
M
ake
B0BlockSliceCopyStep
(
);
const
auto
a_block_reset_copy_step
=
[
&
](){
const
auto
a_block_reset_copy_step
=
[
&
](){
if
constexpr
(
AEnableLds
){
if
constexpr
(
AEnableLds
){
...
@@ -836,14 +1055,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -836,14 +1055,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
}
}
}();
}();
const
auto
b0_block_reset_copy_step
=
make_multi_index
(
-
b0_grid_desc_k0_l_k1
.
GetLength
(
I0
),
LPerBlock
,
0
);
const
auto
b0_block_reset_copy_step
=
[
&
](){
if
constexpr
(
B0EnableLds
){
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LPerBlock
,
0
);
}
else
{
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LRepeat
,
0
,
0
,
0
,
0
);
}
}();
const
auto
K
=
[
&
](){
if
constexpr
(
AEnableLds
){
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
);
}
}();
const
index_t
KBlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
const
index_t
KBlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K
/
KPerBlock
);
/*******************************************************************************/
/*******************************************************************************/
// softmax
// softmax
/*******************************************************************************/
/*******************************************************************************/
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
workspace_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAcc0
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
static_cast
<
Acc0DataType
*>
(
p_shared
)
+
SharedMemTrait
::
reduction_space_offset
,
SharedMemTrait
::
reduction_space_size_aligned
);
SharedMemTrait
::
reduction_space_size_aligned
);
// get acc0 7D thread cluster
// get acc0 7D thread cluster
constexpr
auto
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
constexpr
auto
thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
...
@@ -879,7 +1114,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -879,7 +1114,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple
(
mrepeat
*
mwave
*
mthreadpersubgroup
,
lrepeat
*
lwave
*
lsubgroup
*
laccvgprs
));
make_tuple
(
mrepeat
*
mwave
*
mthreadpersubgroup
,
lrepeat
*
lwave
*
lsubgroup
*
laccvgprs
));
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
auto
blockwise_softmax
=
BlockwiseSoftmax
<
BlockSize
,
FloatAcc0
,
Acc0DataType
,
decltype
(
threadid_to_l_n_thread_cluster_adaptor
),
decltype
(
threadid_to_l_n_thread_cluster_adaptor
),
decltype
(
thread_cluster_desc_m_l
),
decltype
(
thread_cluster_desc_m_l
),
decltype
(
thread_slice_desc_m_l
)
>
{};
decltype
(
thread_slice_desc_m_l
)
>
{};
...
@@ -889,15 +1124,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -889,15 +1124,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
SoftmaxBuf
running_sum
,
running_sum_new
,
running_max
,
running_max_new
;
SoftmaxBuf
running_sum
,
running_sum_new
,
running_max
,
running_max_new
;
running_sum
=
0
;
running_sum
=
0
;
running_sum_new
=
0
;
running_sum_new
=
0
;
running_max
=
NumericLimits
<
FloatAcc0
>::
Lowest
();
running_max
=
NumericLimits
<
Acc0DataType
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatAcc0
>::
Lowest
();
running_max_new
=
NumericLimits
<
Acc0DataType
>::
Lowest
();
/*******************************************************************************/
/*******************************************************************************/
// set up Gemm1
// set up Gemm1
/*******************************************************************************/
/*******************************************************************************/
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_l0perblock_nperblock_l1
=
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
();
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
BL0
,
0
,
0
);
// Acc0 thread buffer -> A1 thread buffer -> blockwise gemm
// Acc0 thread buffer -> A1 thread buffer -> blockwise gemm
// A1 matrix in VGPR
// A1 matrix in VGPR
constexpr
auto
A1ThreadSlice_L0PerBlock_MPerBlock_L1
=
make_tuple
(
constexpr
auto
A1ThreadSlice_L0PerBlock_MPerBlock_L1
=
make_tuple
(
...
@@ -915,8 +1146,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -915,8 +1146,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// A1 matrix blockwise copy
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatAcc0
,
Acc0DataType
,
FloatA
,
ADataType
,
decltype
(
acc0_thread_desc_l0perblock_mperblock_l1
),
decltype
(
acc0_thread_desc_l0perblock_mperblock_l1
),
decltype
(
a1_thread_desc_l0perblock_mperblock_l1
),
decltype
(
a1_thread_desc_l0perblock_mperblock_l1
),
tensor_operation
::
element_wise
::
PassThrough
,
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -925,7 +1156,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -925,7 +1156,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
2
,
2
,
laccvgprs
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
laccvgprs
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a1_thread_desc_l0perblock_mperblock_l1
.
GetElementSpaceSize
());
constexpr
auto
b1_block_desc
=
MakeB1BlockDescriptor
();
auto
b1_block_trait
=
[
&
](){
if
constexpr
(
B1EnableLds
)
{
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
SharedMemTrait
::
b1_block_space_size_aligned
);
auto
b1_blockwise_copy
=
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
B1ElementwiseOperation
,
/* typename SrcElementwiseOperation, */
B1ElementwiseOperation
,
...
@@ -934,10 +1176,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -934,10 +1176,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* typename BlockSliceLengths, */
Sequence
<
BL0
,
NPerBlock
,
BL1
>
,
/* typename BlockSliceLengths, */
Sequence
<
BL0
,
NPerBlock
,
BL1
>
,
/* typename ThreadClusterLengths, */
B1BlockTransferThreadClusterLengths_L0_N_L1
,
/* typename ThreadClusterLengths, */
B1BlockTransferThreadClusterLengths_L0_N_L1
,
/* typename ThreadClusterArrangeOrder, */
B1BlockTransferThreadClusterArrangeOrder
,
/* typename ThreadClusterArrangeOrder, */
B1BlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatB1
,
/* typename SrcData, */
B1DataType
,
/* typename DstData, */
FloatB1
,
/* typename DstData, */
B1DataType
,
/* typename SrcDesc, */
decltype
(
b1_grid_desc
_l0_n_l1
),
/* typename SrcDesc, */
decltype
(
b1_grid_desc
),
/* typename DstDesc, */
decltype
(
b1_block_desc
_l0perblock_nperblock_l1
),
/* typename DstDesc, */
decltype
(
b1_block_desc
),
/* typename SrcDimAccessOrder, */
B1BlockTransferSrcAccessOrder
,
/* typename SrcDimAccessOrder, */
B1BlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
1
,
0
,
2
>
,
/* typename DstDimAccessOrder, */
Sequence
<
1
,
0
,
2
>
,
/* index_t SrcVectorDim, */
B1BlockTransferSrcVectorDim
,
/* index_t SrcVectorDim, */
B1BlockTransferSrcVectorDim
,
...
@@ -949,26 +1191,64 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -949,26 +1191,64 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
B1ThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
B1ThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
// DstResetCoord
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b1_grid_desc
_l0_n_l1
,
b1_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b1_element_op
,
b1_element_op
,
b1_block_desc
_l0perblock_nperblock_l1
,
b1_block_desc
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
return
make_tuple
(
b1_block_buf
,
b1_blockwise_copy
);
a1_thread_desc_l0perblock_mperblock_l1
.
GetElementSpaceSize
());
}
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
else
static_cast
<
FloatB1
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
{
SharedMemTrait
::
b1_block_space_size_aligned
);
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
LWmmaPerBlock
=
LTilePerBlock
/
WmmaL
;
auto
b1_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B1DataType
>
(
b1_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
b1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
B1DataType
,
B1DataType
,
decltype
(
b1_grid_desc
),
decltype
(
b1_block_desc
),
Sequence
<
Number
<
LWmmaPerBlock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
L1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B1BlockTransferSrcScalarPerVector
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b1_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
b1_block_buf
,
b1_blockwise_copy
);
}
};
auto
b1_block_buf
=
b1_block_trait
()[
I0
];
auto
b1_blockwise_copy
=
b1_block_trait
()[
I1
];
constexpr
auto
b1_block_slice_copy_step
=
MakeB1BlockSliceCopyStep
();
auto
blockwise_gemm1
=
auto
blockwise_gemm1
=
BlockwiseGemmWMMA
<
BlockSize
,
BlockwiseGemmWMMA
<
BlockSize
,
FloatA
,
ADataType
,
FloatB1
,
B1DataType
,
FloatAcc1
,
Acc1DataType
,
decltype
(
MakeA1
Block
Descriptor_L0_M0_M1_M2_L1
(
a1_thread_desc_l0perblock_mperblock_l1
)),
decltype
(
MakeA1
Wave
Descriptor_L0_M0_M1_M2_L1
(
a1_thread_desc_l0perblock_mperblock_l1
)),
decltype
(
MakeB1
Block
Descriptor
_L0_N0_N1_N2_L1
(
b1_block_desc_l0perblock_nperblock_l1
)),
decltype
(
MakeB1
Wave
Descriptor
(
b1_block_desc
)),
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
LTilePerBlock
,
LTilePerBlock
,
...
@@ -983,11 +1263,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -983,11 +1263,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
const
index_t
num_gemm1_l_block_outer_loop
=
b0_grid_desc_k0_l_k1
.
GetLength
(
I1
)
/
LPerBlock
;
const
auto
L
=
[
&
](){
if
constexpr
(
B0EnableLds
){
return
b0_grid_desc
.
GetLength
(
I1
);
}
else
{
return
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
b0_grid_desc
.
GetLength
(
I4
);
}
}();
const
index_t
num_gemm1_l_block_outer_loop
=
L
/
LPerBlock
;
constexpr
index_t
num_gemm1_l_block_inner_loop
=
LPerBlock
/
LTilePerBlock
;
constexpr
index_t
num_gemm1_l_block_inner_loop
=
LPerBlock
/
LTilePerBlock
;
// Initialize C
// Initialize C
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc1
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
Acc1DataType
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
/*******************************************************************************/
/*******************************************************************************/
...
@@ -1014,8 +1303,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1014,8 +1303,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b0_grid_desc
_k0_l_k1
,
b0_grid_desc
,
b0_block_desc
_k0perblock_lperblock_k1
,
b0_block_desc
,
b0_blockwise_copy
,
b0_blockwise_copy
,
b0_grid_buf
,
b0_grid_buf
,
b0_block_buf
,
b0_block_buf
,
...
@@ -1106,20 +1395,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1106,20 +1395,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
acc1_thread_buf
.
Clear
();
acc1_thread_buf
.
Clear
();
// preload data into LDS
// preload data into LDS
b1_blockwise_copy
.
RunRead
(
b1_grid_desc
_l0_n_l1
,
b1_grid_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc
,
b1_grid_buf
);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc
_l0_n_l1
,
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc
,
b1_block_slice_copy_step
);
b1_block_slice_copy_step
);
block_sync_lds
();
// wait for reduction LDS read
block_sync_lds
();
// wait for reduction LDS read
b1_blockwise_copy
.
RunWrite
(
b1_block_desc
_l0perblock_nperblock_l1
,
b1_block_buf
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc
,
b1_block_buf
);
// main body
// main body
if
constexpr
(
num_gemm1_l_block_inner_loop
>
1
)
if
constexpr
(
num_gemm1_l_block_inner_loop
>
1
)
{
{
static_for
<
0
,
num_gemm1_l_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_gemm1_l_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
// Data cast from
FloatAcc0 to FloatA
happen here
// Data cast from
Acc0DataType to ADataType
happen here
a1_blockwise_copy
.
Run
(
acc0_thread_desc_l0perblock_mperblock_l1
,
a1_blockwise_copy
.
Run
(
acc0_thread_desc_l0perblock_mperblock_l1
,
make_tuple
(
Number
<
i
*
A1ThreadSliceL0PerBlock
>
{},
I0
,
I0
),
make_tuple
(
Number
<
i
*
A1ThreadSliceL0PerBlock
>
{},
I0
,
I0
),
acc0_thread_buf
,
acc0_thread_buf
,
...
@@ -1127,7 +1416,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1127,7 +1416,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
a1_thread_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc
_l0_n_l1
,
b1_grid_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc
,
b1_grid_buf
);
block_sync_lds
();
block_sync_lds
();
...
@@ -1135,10 +1424,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1135,10 +1424,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
block_sync_lds
();
block_sync_lds
();
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc
_l0_n_l1
,
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc
,
b1_block_slice_copy_step
);
b1_block_slice_copy_step
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc
_l0perblock_nperblock_l1
,
b1_block_buf
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc
,
b1_block_buf
);
});
});
}
}
// tail
// tail
...
@@ -1177,9 +1466,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1177,9 +1466,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_m
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
static_for
<
0
,
c_thread_buf_slice_n
,
1
>
{}([
&
](
auto
iN
)
{
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
auto
I
=
Number
<
c_thread_slice_desc_m_n
.
CalculateOffset
(
make_tuple
(
iM
,
iN
))
>
{};
FloatAcc1
acc1
=
acc1_thread_buf
[
I
];
// P*V
Acc1DataType
acc1
=
acc1_thread_buf
[
I
];
// P*V
FloatAcc1
c
=
c_thread_buf
[
I
];
// O
Acc1DataType
c
=
c_thread_buf
[
I
];
// O
FloatAcc1
c_new
=
Acc1DataType
c_new
=
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
(
running_sum
[
iM
]
*
math
::
exp
(
running_max
[
iM
]
-
running_max_new
[
iM
])
*
c
+
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
math
::
exp
(
max
[
iM
]
-
running_max_new
[
iM
])
*
acc1
)
/
running_sum_new
[
iM
];
running_sum_new
[
iM
];
...
@@ -1190,7 +1479,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1190,7 +1479,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_reset_copy_step
);
// rewind K
a_block_reset_copy_step
);
// rewind K
b0_blockwise_copy
.
MoveSrcSliceWindow
(
b0_grid_desc
_k0_l_k1
,
b0_blockwise_copy
.
MoveSrcSliceWindow
(
b0_grid_desc
,
b0_block_reset_copy_step
);
// rewind K and step N
b0_block_reset_copy_step
);
// rewind K and step N
// update before next j iteration
// update before next j iteration
...
@@ -1220,7 +1509,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1220,7 +1509,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
CShuffle
*>
(
p_shared
),
static_cast
<
CShuffle
DataType
*>
(
p_shared
),
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
());
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
=
transform_tensor_descriptor
(
...
@@ -1268,8 +1557,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1268,8 +1557,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// shuffle: threadwise copy C from VGPR to LDS
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc1
,
ThreadwiseTensorSliceTransfer_v1r3
<
Acc1DataType
,
Float
CShuffle
,
CShuffle
DataType
,
decltype
(
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
),
decltype
(
c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
),
decltype
(
c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
...
@@ -1307,8 +1596,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -1307,8 +1596,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
,
// BlockSliceLengths,
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Float
CShuffle
,
// typename SrcData,
CShuffle
DataType
,
// typename SrcData,
FloatC
,
// typename DstData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
c5fd087e
...
@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -719,7 +719,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
A
DataType
>
(
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B
DataType
>
(
b_block_desc
.
GetElementSpaceSize
());
b_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
// Limitation: NumDim of Src and Dst descriptor should be identical
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
c5fd087e
...
@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -247,6 +247,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
BGridDesc_L_K
,
typename
WmmaK
,
typename
LRepeat
,
typename
LWaves
,
typename
LPerWmma
,
typename
BK1
>
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
LRepeat
&
,
const
LWaves
&
,
const
LPerWmma
&
,
const
BK1
&
)
{
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
WmmaK
{}
/
BK1
{};
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
BKRow
,
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
//
//
// B1
// B1
//
//
...
@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -288,6 +316,34 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
template
<
typename
BGridDesc_N_L
,
typename
WmmaL
,
typename
NRepeat
,
typename
NWaves
,
typename
NPerWmma
,
typename
BL1
>
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
NRepeat
&
,
const
NWaves
&
,
const
NPerWmma
&
,
const
BL1
&
)
{
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
WmmaL
{}
/
BL1
{};
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
BLRow
,
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
//
//
// C
// C
//
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment