Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7e003d31
Commit
7e003d31
authored
Feb 28, 2023
by
aska-0096
Browse files
Porting new blockwise gemm to flash attention
parent
84b4ada5
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
209 additions
and
136 deletions
+209
-136
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+7
-7
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
...tmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
+6
-6
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+51
-35
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+121
-74
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+20
-12
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
7e003d31
...
...
@@ -101,7 +101,7 @@ using DeviceGemmInstance =
8
,
// K1
// Gemm 1
64
,
// NPerBlock
32
,
// LPerBlock
32
,
// L
Tile
PerBlock
8
,
// L1
16
,
// MPerWMMA
16
,
// LPerWMMA
...
...
@@ -124,7 +124,7 @@ using DeviceGemmInstance =
8
,
8
,
true
,
S
<
4
,
8
,
8
>
,
// B1BlockTransfer
L
N -> L0 N L1
S
<
4
,
8
,
8
>
,
// B1BlockTransfer N
L
-> L0 N L1
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute.inc
View file @
7e003d31
...
...
@@ -122,20 +122,20 @@ int run(int argc, char* argv[])
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
5
:
// Rand: b1 ; unit: a
b0 fail
case
5
:
// Rand: b1
b0
; unit: a
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_
1
<
B0DataType
>
{});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_2
<
B1DataType
>
{
-
2
,
2
});
break
;
case
6
:
// Rand: b0 ; unit:
a
b1 pass
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_
1
<
ADataType
>
{});
case
6
:
// Rand:
a
b0 ; unit: b1 pass
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_
2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
2
,
2
});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_1
<
B1DataType
>
{});
break
;
case
7
:
// Rand: a ; unit: b0
b1
pass
case
7
:
// Rand: a
b1
; unit: b0 pass
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
2
,
2
});
b0_gs_ns_ks
.
GenerateTensorValue
(
GeneratorTensor_1
<
B0DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_
1
<
B1DataType
>
{});
b1_gs_os_ns
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B1DataType
>
{
-
2
,
2
});
break
;
default
:
a_gs_ms_ks
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
2
>
{});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
7e003d31
...
...
@@ -7,6 +7,7 @@
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#define CK_MNK_LOOP
...
...
@@ -340,6 +341,7 @@ struct BlockwiseGemmWMMA
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
...
...
@@ -413,7 +415,7 @@ struct BlockwiseGemmWMMA
A_K1
,
0x76543210
,
0xfedcba98
,
true
>
;
TransposeC
?
false
:
true
>
;
};
template
<
bool
EnableLds
>
...
...
@@ -448,7 +450,7 @@ struct BlockwiseGemmWMMA
B_K1
,
0x76543210
,
0xfedcba98
,
false
>
;
TransposeC
?
true
:
false
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
7e003d31
...
...
@@ -56,11 +56,11 @@ template <index_t NumDimG,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
LPerBlock
,
ck
::
index_t
L
Tile
PerBlock
,
ck
::
index_t
L1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
LPerW
MMA
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
LPerW
mma
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
LRepeat
,
ck
::
index_t
NRepeat
,
...
...
@@ -134,15 +134,18 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
LWaves
=
LPerBlock
/
(
LRepeat
*
LPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
AEnableLds
=
LWaves
==
1
?
false
:
true
;
//
static constexpr auto B0EnableLds = MWaves == 1 ? false : true;
//
static constexpr auto B1EnableLds = MWaves == 1 ? false : true;
static
constexpr
auto
B0EnableLds
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B1EnableLds
=
MWaves
==
1
?
false
:
true
;
using
Transform
=
TransformBatchedContractionContractionToBatchedGemmGemm
<
Sequence
<
NumDimG
,
NumDimM
,
NumDimL
,
NumDimK
,
NumDimN
>
,
...
...
@@ -166,12 +169,15 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
WmmaK
,
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{})
Number
<
WmmaK
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{},
Number
<
K1
>
{});
}
}
static
auto
MakeB0GridDescriptor
_BK0_L_BK1
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
static
auto
MakeB0GridDescriptor
(
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b0_gs_ls_ks_strides_vec
)
{
return
Transform
::
MakeB0GridDescriptor_BK0_N_BK1
(
...
...
@@ -188,7 +194,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
({},
{}));
using
B0GridDesc_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor
_BK0_L_BK1
({},
{}));
using
B0GridDesc_BK0_L_BK1
=
decltype
(
MakeB0GridDescriptor
({},
{}));
using
B1GridDesc_BL0_N_BL1
=
decltype
(
MakeB1GridDescriptor_BL0_N_BL1
({},
{}));
using
CGridDesc_M_N
=
decltype
(
Transform
::
MakeCGridDescriptor_M_N
({},
{}));
using
AGridDesc_G_M_K
=
decltype
(
Transform
::
MakeAGridDescriptor_G_M_K
({},
{}));
...
...
@@ -277,11 +283,11 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
KPerBlock
,
K1
,
NPerBlock
,
LPerBlock
,
L
Tile
PerBlock
,
L1
,
MPerW
MMA
,
LPerW
MMA
,
NPerW
MMA
,
MPerW
mma
,
LPerW
mma
,
NPerW
mma
,
MRepeat
,
LRepeat
,
NRepeat
,
...
...
@@ -357,10 +363,9 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
p_b0_grid_
{
p_b0_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor_BK0_L_BK1
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
a_grid_desc
{
DeviceOp
::
MakeAGridDescriptor
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
)},
b0_grid_desc_bk0_l_bk1_
{
DeviceOp
::
MakeB0GridDescriptor
(
b0_gs_ls_ks_lengths
,
b0_gs_ls_ks_strides
)},
b1_grid_desc_bl0_n_bl1_
{
DeviceOp
::
MakeB1GridDescriptor_BL0_N_BL1
(
b1_gs_ns_ls_lengths
,
b1_gs_ns_ls_strides
)},
c_grid_desc_m_n_
{
...
...
@@ -405,7 +410,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
ignore
=
acc1_biases_gs_ms_ns_lengths
;
ignore
=
acc1_biases_gs_ms_ns_strides
;
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
_ak0_m_ak1_
,
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b0_grid_desc_bk0_l_bk1_
,
b1_grid_desc_bl0_n_bl1_
,
c_grid_desc_m_n_
,
...
...
@@ -424,7 +429,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
CDataType
*
p_c_grid_
;
// Tensor Descriptors
AGridDesc
a_grid_desc
_ak0_m_ak1_
;
AGridDesc
a_grid_desc
;
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1_
;
B1GridDesc_BL0_N_BL1
b1_grid_desc_bl0_n_bl1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
...
...
@@ -473,8 +478,17 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
auto
K
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I5
);
}
}();
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_wmma_cshuffle
<
...
...
@@ -506,7 +520,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
arg
.
p_b0_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc
_ak0_m_ak1_
,
arg
.
a_grid_desc
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
...
...
@@ -551,20 +565,23 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Acc0 Type err"
);
return
false
;
}
if
constexpr
(
!
(
is_same_v
<
Acc1DataType
,
float
>
||
is_same_v
<
Acc1DataType
,
int32_t
>
))
{
printf
(
"DeviceOp: Acc1 Type err"
);
return
false
;
}
}
else
{
printf
(
"DeviceOp: Arch err"
);
return
false
;
}
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
_ak0_m_ak1_
,
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b0_grid_desc_bk0_l_bk1_
,
arg
.
b1_grid_desc_bl0_n_bl1_
,
arg
.
c_grid_desc_m_n_
,
...
...
@@ -575,13 +592,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
// Check if C permute dimension matches GEMM + GEMM shape
const
index_t
c_g
=
arg
.
c_grid_desc_g_m_n_
.
GetLength
(
I0
);
// unpadded
const
index_t
c_m
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
);
const
index_t
c_n
=
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
);
const
index_t
a_m
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
);
const
index_t
b1_n
=
arg
.
b1_grid_desc_bl0_n_bl1_
.
GetLength
(
I1
);
if
(
!
(
c_g
==
arg
.
batch_count_
&&
c_m
==
a_m
&&
c_n
==
b1_n
))
if
(
!
(
c_g
==
arg
.
batch_count_
))
{
printf
(
"DeviceOp: BatchCount err"
);
return
false
;
}
...
...
@@ -604,6 +618,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
b1_extent_lowest
%
B1BlockTransferSrcScalarPerVector
==
0
&&
c_extent_lowest
%
CShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
{
printf
(
"DeviceOp: Data Transfer Vector scalar err"
);
return
false
;
}
...
...
@@ -619,6 +634,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
if
(
!
(
a_stride_lowest
==
1
||
b0_stride_lowest
==
1
||
b1_stride_lowest
==
1
||
c_stride_lowest
==
1
))
{
printf
(
"DeviceOp: Data Vectorize transfer err"
);
return
false
;
}
...
...
@@ -765,7 +781,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
K1
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
LPerBlock
<<
", "
<<
L
Tile
PerBlock
<<
", "
<<
L1
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
"ASpec"
<<
getTensorSpecializationString
(
ASpec
)
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
7e003d31
...
...
@@ -23,7 +23,7 @@ template <typename GridwiseGemm,
typename
FloatB0
,
typename
FloatB1
,
typename
FloatC
,
typename
AGridDesc
_AK0_M_AK1
,
typename
AGridDesc
,
typename
B0GridDesc_BK0_L_BK1
,
typename
B1GridDesc_BL0_N_BL1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
...
...
@@ -45,7 +45,7 @@ __global__ void
const
FloatB0
*
__restrict__
p_b0_grid
,
const
FloatB1
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc
_AK0_M_AK1
a_grid_desc
_ak0_m_ak1
,
const
AGridDesc
a_grid_desc
,
const
B0GridDesc_BK0_L_BK1
b0_grid_desc_bk0_l_bk1
,
const
B1GridDesc_BL0_N_BL1
b1_grid_desc_l0_n_l1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -81,7 +81,7 @@ __global__ void
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_grid_desc
_ak0_m_ak1
,
a_grid_desc
,
b0_grid_desc_bk0_l_bk1
,
b1_grid_desc_l0_n_l1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -97,7 +97,7 @@ __global__ void
ignore
=
p_b0_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc
_ak0_m_ak1
;
ignore
=
a_grid_desc
;
ignore
=
b0_grid_desc_bk0_l_bk1
;
ignore
=
b1_grid_desc_l0_n_l1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
@@ -128,7 +128,7 @@ template <typename FloatA,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc
_AK0_M_AK1
,
typename
AGridDesc
,
typename
B0GridDesc_BK0_L_BK1
,
typename
B1GridDesc_BL0_N_BL1
,
typename
CGridDesc_M_N
,
...
...
@@ -137,7 +137,7 @@ template <typename FloatA,
index_t
KPerBlock
,
index_t
K1Value
,
index_t
NPerBlock
,
index_t
LPerBlock
,
index_t
L
Tile
PerBlock
,
index_t
L1Value
,
index_t
MPerWmma
,
index_t
LPerWmma
,
...
...
@@ -194,10 +194,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
AK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
K1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
K1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
K1Value
>
{};
static
constexpr
auto
L0PerBlock
=
LPerBlock
/
L1Value
;
static
constexpr
auto
L0PerBlock
=
L
Tile
PerBlock
/
L1Value
;
static
constexpr
auto
AL0
=
Number
<
L0PerBlock
/
2
>
{};
static
constexpr
auto
AL1
=
Number
<
L1Value
>
{};
static
constexpr
auto
BL0
=
Number
<
L0PerBlock
>
{};
...
...
@@ -209,8 +209,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
AEnableLds
,
B0EnableLds
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
AEnableLds
,
B0EnableLds
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor
()
{
...
...
@@ -238,7 +242,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
A
K1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
AK1
,
AK1
,
AK1
,
AK1
,
AK1
,
I1
));
}
}();
...
...
@@ -351,7 +355,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
constexpr
index_t
B_K0
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I0
);
constexpr
index_t
B_K1
=
B1BlockDesc_BL0_N_BL1
{}.
GetLength
(
I2
);
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
return
transform_tensor_descriptor
(
B1BlockDesc_BL0_N_BL1
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
...
...
@@ -399,16 +403,19 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
{
// LDS allocation for A and B: be careful of alignment
const
index_t
gemm0_bytes_end
=
(
SharedMemTrait
::
a_block_space_size_aligned
+
SharedMemTrait
::
b0_block_space_size_aligned
);
(
SharedMemTrait
::
a_block_space_size_aligned
*
sizeof
(
FloatA
)
+
SharedMemTrait
::
b0_block_space_size_aligned
*
sizeof
(
FloatB0
)
);
const
index_t
gemm1_bytes_end
=
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
);
(
SharedMemTrait
::
b1_block_space_offset
+
SharedMemTrait
::
b1_block_space_size_aligned
*
sizeof
(
FloatB1
));
const
index_t
softmax_bytes_end
=
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
const
index_t
softmax_bytes_end
=
SharedMemTrait
::
reduction_space_offset
+
SharedMemTrait
::
reduction_space_size_aligned
*
sizeof
(
FloatAcc0
);
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
;
const
index_t
c_block_bytes_end
=
SharedMemTrait
::
c_block_space_size
*
sizeof
(
FloatCShuffle
);
return
math
::
max
(
gemm0_bytes_end
,
gemm1_bytes_end
,
softmax_bytes_end
,
c_block_bytes_end
);
}
...
...
@@ -416,7 +423,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc
_AK0_M_AK1
&
a_grid_desc
_ak0_m_ak1
,
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
B0GridDesc_BK0_L_BK1
&
b0_grid_desc_bk0_l_bk1
,
const
B1GridDesc_BL0_N_BL1
&
b1_grid_desc_l0_n_l1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
...
...
@@ -426,19 +433,48 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
(
LPerBlock
%
(
LPerWmma
*
LRepeat
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
GetAProblemsizeMK
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
a_grid_desc
.
GetLength
(
I4
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
));
}
};
const
auto
M
=
GetAProblemsizeMK
()[
I0
];
const
auto
L
=
b0_grid_desc_bk0_l_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
)
;
const
auto
K
=
GetAProblemsizeMK
()[
I1
]
;
const
auto
N
=
b1_grid_desc_l0_n_l1
.
GetLength
(
I1
);
const
auto
KPerBlock
=
K0PerBlock
*
K1Value
;
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
printf
(
"GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d
\n
"
,
M
,
N
,
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
L
%
LPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
N
%
NPerBlock
==
0
))
{
printf
(
"GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = "
"%d, %d, %d, %d
\n
"
,
M
,
L
,
K
,
N
,
MPerBlock
,
LPerBlock
,
KPerBlock
,
NPerBlock
);
return
false
;
}
...
...
@@ -446,18 +482,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
printf
(
"GridwiseOp: outer loop unsupport
\n
"
);
return
false
;
}
// check gemm1 gridwise gemm pipeline
if
(
!
(
LPerBlock
%
(
L0
PerBlock
*
L1Value
)
==
0
))
if
(
!
(
LPerBlock
%
LTile
PerBlock
==
0
))
{
printf
(
"GridwiseOp: inner loop division, L/LTilePerblock: %d, %d
\n
"
,
LPerBlock
,
LTilePerBlock
);
return
false
;
}
const
auto
num_gemm1_k_inner_loop
=
LPerBlock
/
(
L0
PerBlock
*
L1Value
)
;
const
auto
num_gemm1_k_inner_loop
=
LPerBlock
/
LTile
PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inner_loop
))
{
printf
(
"GridwiseOp: inner loop unsupport
\n
"
);
return
false
;
}
...
...
@@ -472,7 +513,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
(
K0
PerBlock
*
K1Value
)
;
const
index_t
num_loop
=
K
/
K
PerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
...
...
@@ -514,28 +555,38 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// LDS allocation for A and B: be careful of alignment
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
math
::
lcm
(
AK1
,
BK1
),
BL1
);
static
constexpr
auto
a_block_space_size_aligned
=
AEnableLds
?
math
::
integer_least_multiple
(
MakeABlockDescriptor
().
GetElementSpaceSize
()
*
sizeof
(
FloatA
),
max_lds_align
)
:
0
;
static
constexpr
auto
b0_block_space_size_aligned
=
B0EnableLds
?
math
::
integer_least_multiple
(
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
().
GetElementSpaceSize
()
*
sizeof
(
FloatB0
),
max_lds_align
)
:
0
;
static
constexpr
auto
b1_block_space_size_aligned
=
B1EnableLds
?
math
::
integer_least_multiple
(
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
().
GetElementSpaceSize
()
*
sizeof
(
FloatB1
),
max_lds_align
)
:
0
;
static
constexpr
auto
a_block_space_size_aligned
=
AEnableLds
?
math
::
integer_least_multiple
(
MakeABlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
b0_block_space_size_aligned
=
B0EnableLds
?
math
::
integer_least_multiple
(
GetB0BlockDescriptor_BK0PerBlock_LPerBlock_BK1
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
b1_block_space_size_aligned
=
B1EnableLds
?
math
::
integer_least_multiple
(
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
static
constexpr
auto
a_block_space_offset
=
0
;
static
constexpr
auto
b0_block_space_offset
=
a_block_space_size_aligned
.
value
;
static
constexpr
auto
b0_block_space_offset
=
a_block_space_size_aligned
;
static
constexpr
auto
b1_block_space_offset
=
0
;
// LDS allocation for reduction
// Feature to add, IntraThread Reduction
static
constexpr
index_t
reduction_space_size_aligned
=
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
)
*
sizeof
(
FloatAcc0
)
;
math
::
integer_least_multiple
(
BlockSize
,
max_lds_align
);
static
constexpr
auto
reduction_space_offset
=
0
;
// LDS allocation for C shuffle in LDS
static
constexpr
auto
c_block_space_size
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
.
GetElementSpaceSize
()
*
sizeof
(
FloatCShuffle
)
;
.
GetElementSpaceSize
();
};
template
<
bool
HasMainKBlockLoop
,
...
...
@@ -546,7 +597,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const
FloatB1
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc
_AK0_M_AK1
&
a_grid_desc
_k0_m_k1
,
const
AGridDesc
&
a_grid_desc
,
const
B0GridDesc_BK0_L_BK1
&
b0_grid_desc_k0_l_k1
,
const
B1GridDesc_BL0_N_BL1
&
b1_grid_desc_l0_n_l1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -563,7 +614,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// Memory buffer zone.
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc
_k0_m_k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
const
auto
b0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b0_grid
,
b0_grid_desc_k0_l_k1
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -610,17 +661,18 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
*>
(
p_shared
)
+
SharedMemTrait
::
a_block_space_offset
,
SharedMemTrait
::
a_block_space_size_aligned
);
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
/* typename SrcElementwiseOperation, */
AElementwiseOperation
,
/* typename DstElementwiseOperation, */
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* InMemoryDataOperationEnum DstInMemOp, */
InMemoryDataOperationEnum
::
Set
,
/* typename BlockSliceLengths, */
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
/* typename BlockSliceLengths, */
Sequence
<
AK0
PerBlock
,
MPerBlock
,
AK1
>
,
/* typename ThreadClusterLengths, */
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterArrangeOrder, */
ABlockTransferThreadClusterArrangeOrder
,
/* typename SrcData, */
FloatA
,
/* typename DstData, */
FloatA
,
/* typename SrcDesc, */
decltype
(
a_grid_desc
_k0_m_k1
),
/* typename SrcDesc, */
decltype
(
a_grid_desc
),
/* typename DstDesc, */
decltype
(
a_block_desc
),
/* typename SrcDimAccessOrder, */
ABlockTransferSrcAccessOrder
,
/* typename DstDimAccessOrder, */
Sequence
<
0
,
1
,
2
>
,
...
...
@@ -632,7 +684,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/* index_t DstScalarStrideInVector, */
1
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
a_grid_desc
_k0_m_k1
,
a_grid_desc
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc
,
...
...
@@ -713,7 +765,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// Gemm0
constexpr
auto
WmmaK
=
16
;
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1Value
,
WmmaK
);
auto
blockwise_gemm0
=
BlockwiseGemmWMMA
<
...
...
@@ -725,7 +776,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
decltype
(
MakeB0BlockDescriptor_K0_L0_L1_L2_K1
(
b0_block_desc_k0perblock_lperblock_k1
)),
MPerBlock
,
LPerBlock
,
K
0
PerBlock
*
K1Value
,
KPerBlock
,
MPerWmma
,
LPerWmma
,
MRepeat
,
...
...
@@ -759,18 +810,20 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
/*******************************************************************************/
// LDS allocation for A and B: be careful of alignment
auto
b0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB0
*>
(
p_shared
)
+
SharedMemTrait
::
b0_block_space_offset
,
b0_block_desc_k0perblock_lperblock_k1
.
GetElementSpaceSize
());
auto
b0_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB0
*>
(
p_shared
)
+
SharedMemTrait
::
b0_block_space_offset
,
SharedMemTrait
::
b0_block_space_size_aligned
);
// Shift Per SUB_K
constexpr
auto
a_block_slice_copy_step
=
MakeABlockSliceCopyStep
();
constexpr
auto
b0_block_slice_copy_step
=
make_multi_index
(
K0
PerBlock
,
0
,
0
);
constexpr
auto
b0_block_slice_copy_step
=
make_multi_index
(
B
K0
,
0
,
0
);
const
auto
a_block_reset_copy_step
=
[
&
](){
if
constexpr
(
AEnableLds
){
return
make_multi_index
(
-
a_grid_desc_k0_m_k1
.
GetLength
(
I0
),
0
,
0
);
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
);
}
else
{
return
make_multi_index
(
-
a_grid_desc
_k0_m_k1
.
GetLength
(
I0
),
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
,
0
,
0
,
0
);
}
}();
...
...
@@ -836,24 +889,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
constexpr
auto
b1_block_desc_l0perblock_nperblock_l1
=
GetB1BlockDescriptor_BL0PerBlock_NPerBlock_BL1
();
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
BL0
,
0
,
0
);
// Acc0 thread buffer -> A1 thread buffer -> blockwise gemm
// A1 matrix in VGPR
constexpr
auto
A1ThreadSlice_L0PerBlock_MPerBlock_L1
=
make_tuple
(
Number
<
AL0
*
AL1
/
laccvgprs
>
{},
Number
<
mrepeat
*
mwave
*
mthreadpersubgroup
>
{},
Number
<
laccvgprs
>
{});
// Data duplicated dimension
Number
<
laccvgprs
>
{});
constexpr
auto
A1ThreadSliceL0PerBlock
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I0
];
constexpr
auto
A1ThreadSliceMPerBlock
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I1
];
constexpr
auto
A1ThreadSliceL1
=
A1ThreadSlice_L0PerBlock_MPerBlock_L1
[
I2
];
// A1 has duplicated data
constexpr
auto
A1ThreadDuplicatedDim
=
I2
*
A1ThreadSliceL1
;
constexpr
auto
a1_thread_desc_l0perblock_mperblock_l1
=
make_naive_tensor_descriptor
(
make_tuple
(
A1ThreadSliceL0PerBlock
,
A1ThreadSliceMPerBlock
,
A1Thread
DuplicatedDim
),
make_tuple
(
A1ThreadSliceMPerBlock
*
A1Thread
DuplicatedDim
,
A1ThreadDuplicatedDim
,
I1
));
make_tuple
(
A1ThreadSliceL0PerBlock
,
A1ThreadSliceMPerBlock
,
A1Thread
SliceL1
),
make_tuple
(
A1ThreadSliceMPerBlock
*
A1Thread
SliceL1
,
A1ThreadSliceL1
,
I1
));
// A1 matrix blockwise copy
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
_InterRow
<
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatAcc0
,
FloatA
,
decltype
(
acc0_thread_desc_l0perblock_mperblock_l1
),
...
...
@@ -862,12 +914,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
Sequence
<
A1ThreadSliceL0PerBlock
,
A1ThreadSliceMPerBlock
,
A1ThreadSliceL1
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
laccvgprs
,
// dst Rowlane
// 0x76543210 0xfedcba98
// src Rowlane
0x76543210
,
0xfedcba98
,
false
>
{};
laccvgprs
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
...
...
@@ -904,7 +951,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
a1_thread_desc_l0perblock_mperblock_l1
.
GetElementSpaceSize
());
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatB1
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_l0perblock_nperblock_l1
.
GetElementS
pace
S
ize
()
);
SharedMemTrait
::
b1_block_s
pace
_s
ize
_aligned
);
auto
blockwise_gemm1
=
BlockwiseGemmWMMA
<
BlockSize
,
...
...
@@ -915,7 +962,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
decltype
(
MakeB1BlockDescriptor_L0_N0_N1_N2_L1
(
b1_block_desc_l0perblock_nperblock_l1
)),
MPerBlock
,
NPerBlock
,
BL0
*
BL1
,
LTilePerBlock
,
MPerWmma
,
NPerWmma
,
MRepeat
,
...
...
@@ -926,13 +973,16 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
const
index_t
num_gemm1_l_block_outer_loop
=
b0_grid_desc_k0_l_k1
.
GetLength
(
I1
)
/
LPerBlock
;
constexpr
index_t
num_gemm1_l_block_inner_loop
=
LPerBlock
/
(
BL0
*
BL1
)
;
constexpr
index_t
num_gemm1_l_block_inner_loop
=
LPerBlock
/
LTilePerBlock
;
// Initialize C
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc1
,
acc1_thread_buf
.
Size
(),
true
>
c_thread_buf
;
c_thread_buf
.
Clear
();
/*******************************************************************************/
//
// Kernel Main Stage
//
// Flash Attention
// Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022).
index_t
gemm1_l_block_outer_index
=
0
;
...
...
@@ -947,7 +997,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
continue
;
}
// gemm0 start, A-B swaped
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
_k0_m_k1
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc
,
a_block_desc
,
a_blockwise_copy
,
a_grid_buf
,
...
...
@@ -1019,10 +1069,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
[
&
](
auto
i
)
{
acc_element_op
(
acc0_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
});
}
block_sync_lds
();
// gemm0 end
// gemm0 incorrect
// Tiled softmax start
// softmax
SoftmaxBuf
&
max
=
blockwise_softmax
.
max_value_buf
;
...
...
@@ -1130,7 +1177,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
});
});
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
_k0_m_k1
,
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_reset_copy_step
);
// rewind K
b0_blockwise_copy
.
MoveSrcSliceWindow
(
b0_grid_desc_k0_l_k1
,
b0_block_reset_copy_step
);
// rewind K and step N
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
7e003d31
...
...
@@ -179,22 +179,30 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
AGridDesc_M_K
,
typename
Number
>
template
<
typename
AGridDesc_M_K
,
typename
WmmaK
,
typename
MRepeat
,
typename
MWaves
,
typename
MPerWmma
,
typename
AK1
>
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
Number
&
WmmaK
,
const
Number
&
MRepeat
,
const
Number
&
MWaves
,
const
Number
&
MPerWmma
,
const
Number
&
AK1
)
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
WmmaK
&
,
const
MRepeat
&
,
const
MWaves
&
,
const
MPerWmma
&
,
const
AK1
&
)
{
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBl
c
ok
;
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlo
c
k
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
;
constexpr
auto
AKRow
=
WmmaK
/
K1
;
const
auto
AKWmma
=
K
/
WmmaK
{}
;
constexpr
auto
AKRow
=
WmmaK
{}
/
A
K1
{}
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AKRow
>
{},
AK1
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
MWaves
,
MPerWmma
))),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
AKRow
,
AK1
{})),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment