Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bee4e344
Commit
bee4e344
authored
May 19, 2023
by
aska-0096
Browse files
(5/5) attention pass, todo: debug lds perf bug
parent
fd4ff3a7
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
276 additions
and
248 deletions
+276
-248
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+48
-49
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+36
-25
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+18
-19
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+15
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+130
-117
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+27
-21
No files found.
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
bee4e344
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
bee4e344
...
@@ -365,9 +365,8 @@ struct BlockwiseGemmWMMA
...
@@ -365,9 +365,8 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
// k=0,kpack*1, ... read B
// read B
b_thread_copy_
.
Run
(
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
bee4e344
...
@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
WmmaK
=
16
;
...
@@ -175,8 +176,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -175,8 +176,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
return
Transform
::
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
WmmaK
>
{},
Number
<
MRepeat
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MWaves
>
{},
...
@@ -197,7 +200,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -197,7 +200,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
b0_gs_ls_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
WmmaK
>
{},
...
@@ -220,7 +224,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -220,7 +224,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
b1_gs_ns_ls_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
WmmaK
>
{},
...
@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
else
else
{
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I
5
);
arg
.
a_grid_desc
.
GetLength
(
I
4
)
*
arg
.
a_grid_desc
.
GetLength
(
I6
);
}
}
}();
}();
...
@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
<<
">"
<<
" NumPrefetch: "
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"B0EnableLds: "
<<
B0EnableLds
<<
", "
<<
"B1EnableLds: "
<<
B1EnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
bee4e344
...
@@ -468,8 +468,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -468,8 +468,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
has_main_k_block_loop
>
;
// Last Option is W/O
has_main_k_block_loop
>
;
// Last Option is W/O
return
return
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
bee4e344
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
bee4e344
...
@@ -243,10 +243,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -243,10 +243,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
else
{
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
AK1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
AK1
),
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
MRepeat
>
{}
*
AK1
,
AK1
,
AK1
,
AK1
,
AK1
,
I1
));
Number
<
MRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
AK1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
AK1
,
Number
<
K0PerWmma
>
{}
*
AK1
,
Number
<
K0PerWmma
>
{}
*
AK1
,
AK1
,
AK1
,
AK1
,
I1
));
}
}
}();
}();
...
@@ -277,10 +290,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -277,10 +290,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
else
{
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->NRepeat->NWave->NRow->NPerWmma->BK1 Per Thread
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
BK1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
LRepeat
>
{},
I1
,
I1
,
I1
,
BK1
),
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
LRepeat
>
{}
*
BK1
,
BK1
,
BK1
,
BK1
,
BK1
,
I1
));
Number
<
LRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
BK1
),
make_tuple
(
Number
<
LRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
BK1
,
Number
<
K0PerWmma
>
{}
*
BK1
,
Number
<
K0PerWmma
>
{}
*
BK1
,
BK1
,
BK1
,
BK1
,
I1
));
}
}
}();
}();
...
@@ -310,10 +336,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -310,10 +336,23 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
else
{
{
constexpr
auto
LWmmaPerblock
=
LPerBlock
/
WmmaL
;
constexpr
auto
LWmmaPerblock
=
LPerBlock
/
WmmaL
;
// LWmma->NRepeat->NWave->NRow->LPerWmma->BL1 Per Thread
constexpr
auto
L0PerWmma
=
WmmaL
/
2
/
BL1
;
// LWmma->NRepeat->MWave->L0PerWmma->LRow->MPerWmma->L1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
LWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
BL1
),
make_tuple
(
Number
<
LWmmaPerblock
>
{},
make_tuple
(
Number
<
NRepeat
>
{}
*
BL1
,
BL1
,
BL1
,
BL1
,
BL1
,
I1
));
Number
<
NRepeat
>
{},
I1
,
Number
<
L0PerWmma
>
{},
I1
,
I1
,
BL1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
Number
<
L0PerWmma
>
{}
*
BL1
,
Number
<
L0PerWmma
>
{}
*
BL1
,
Number
<
L0PerWmma
>
{}
*
BL1
,
BL1
,
BL1
,
BL1
,
I1
));
}
}
}();
}();
...
@@ -333,7 +372,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -333,7 +372,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -353,7 +392,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -353,7 +392,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -371,7 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -371,7 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
{
constexpr
auto
LWmmaPerBlock
=
LTilePerBlock
/
WmmaL
;
constexpr
auto
LWmmaPerBlock
=
LTilePerBlock
/
WmmaL
;
return
make_multi_index
(
LWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
LWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -389,42 +428,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -389,42 +428,30 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_KRow
=
I1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
A_K0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{}
,
A_KRow
)
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
else
else
{
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
// KWmma_MRepeat_MWave_
K0PerWmma_
KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
constexpr
auto
K0PerWmma
=
ABlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
A_KRow
=
ABlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
return
transform_tensor_descriptor
(
Number
<
MRepeat
>
{},
ABlockDesc_
{},
I1
,
make_tuple
(
make_freeze_transform
(
I0
),
Number
<
A_KRow
>
{},
make_pass_through_transform
(
Number
<
KWmma
>
{}),
I1
,
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
Number
<
A_K1
>
{}));
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
...
@@ -441,42 +468,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -441,42 +468,31 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
B0BlockDesc_
{},
B0BlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
B_K0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{}
,
B_KRow
)
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
else
else
{
{
// KWmma_
L
Repeat_
L
Wave_KRow_
L
PerWmma_K1 -> K0_
L
Repeat_
L
waves_
L
PerWmma_K1
// KWmma_
M
Repeat_
M
Wave_
K0PerWmma_
KRow_
M
PerWmma_K1 -> K0_
M
Repeat_
M
waves_
M
PerWmma_K1
constexpr
auto
KWmma
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I5
);
constexpr
auto
K0PerWmma
=
B0BlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
B0BlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
B0BlockDesc_
{},
Number
<
LRepeat
>
{},
make_tuple
(
make_freeze_transform
(
I0
),
I1
,
make_pass_through_transform
(
Number
<
KWmma
>
{}),
Number
<
B_KRow
>
{},
make_pass_through_transform
(
Number
<
LRepeat
>
{}),
I1
,
make_pass_through_transform
(
I1
),
Number
<
B_K1
>
{}));
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
...
@@ -489,14 +505,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -489,14 +505,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
{
constexpr
index_t
A_L0
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I0
);
constexpr
index_t
A_L0
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I0
);
constexpr
index_t
A_L1
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I2
);
constexpr
index_t
A_L1
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I2
);
constexpr
auto
A_LRow
=
I1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
A1BlockDesc_AL0_M_AL1
{},
A1BlockDesc_AL0_M_AL1
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
A_L0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
A_L0
>
{}
,
A_LRow
)
),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
)),
make_pass_through_transform
(
Number
<
A_L1
>
{})),
make_pass_through_transform
(
Number
<
A_L1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
template
<
typename
B1BlockDesc_
>
template
<
typename
B1BlockDesc_
>
...
@@ -509,42 +525,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -509,42 +525,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_LRow
=
I1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
B1BlockDesc_
{},
B1BlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
B_L0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
B_L0
>
{}
,
B_LRow
)
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
else
else
{
{
// LWmma_NRepeat_NWave_LRow_NPerWmma_L1 -> L0_NRepeat_Nwaves_NPerWmma_L1
constexpr
auto
LWmma
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
LWmma
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I5
);
constexpr
auto
L0PerWmma
=
B1BlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_LRow
=
B1BlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
LWmma
*
L0PerWmma
>
{},
return
transform_tensor_descriptor
(
Number
<
NRepeat
>
{},
B1BlockDesc_
{},
I1
,
make_tuple
(
make_freeze_transform
(
I0
),
Number
<
B_LRow
>
{},
make_pass_through_transform
(
Number
<
LWmma
>
{}),
I1
,
make_pass_through_transform
(
Number
<
NRepeat
>
{}),
Number
<
B_L1
>
{}));
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
...
@@ -610,9 +613,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -610,9 +613,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
else
{
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
a_grid_desc
.
GetLength
(
I
4
),
a_grid_desc
.
GetLength
(
I
5
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I
5
));
a_grid_desc
.
GetLength
(
I
4
)
*
a_grid_desc
.
GetLength
(
I6
));
}
}
};
};
...
@@ -625,9 +628,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -625,9 +628,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
else
{
{
return
make_tuple
(
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
return
make_tuple
(
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
b0_grid_desc
.
GetLength
(
I
4
),
b0_grid_desc
.
GetLength
(
I
5
),
b0_grid_desc
.
GetLength
(
I0
)
*
b0_grid_desc
.
GetLength
(
I3
)
*
b0_grid_desc
.
GetLength
(
I0
)
*
b0_grid_desc
.
GetLength
(
I3
)
*
b0_grid_desc
.
GetLength
(
I
5
));
b0_grid_desc
.
GetLength
(
I
4
)
*
b0_grid_desc
.
GetLength
(
I6
));
}
}
};
};
...
@@ -640,9 +643,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -640,9 +643,9 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
else
else
{
{
return
make_tuple
(
b1_grid_desc
.
GetLength
(
I1
)
*
b1_grid_desc
.
GetLength
(
I2
)
*
return
make_tuple
(
b1_grid_desc
.
GetLength
(
I1
)
*
b1_grid_desc
.
GetLength
(
I2
)
*
b1_grid_desc
.
GetLength
(
I
4
),
b1_grid_desc
.
GetLength
(
I
5
),
b1_grid_desc
.
GetLength
(
I0
)
*
b1_grid_desc
.
GetLength
(
I3
)
*
b1_grid_desc
.
GetLength
(
I0
)
*
b1_grid_desc
.
GetLength
(
I3
)
*
b1_grid_desc
.
GetLength
(
I
5
));
b1_grid_desc
.
GetLength
(
I
4
)
*
b1_grid_desc
.
GetLength
(
I6
));
}
}
};
};
...
@@ -884,6 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -884,6 +887,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_block_desc
.
GetElementSpaceSize
());
a_block_desc
.
GetElementSpaceSize
());
...
@@ -896,11 +900,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -896,11 +900,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
MRepeat
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
5
,
6
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
...
@@ -908,6 +913,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -908,6 +913,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index
(
0
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
get_thread_local_1d_id
()
/
32
,
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
get_thread_local_1d_id
()
%
16
,
0
));
0
));
...
@@ -960,6 +966,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -960,6 +966,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
// KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
b0_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B0DataType
>
(
auto
b0_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B0DataType
>
(
b0_block_desc
.
GetElementSpaceSize
());
b0_block_desc
.
GetElementSpaceSize
());
...
@@ -972,11 +979,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -972,11 +979,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
LRepeat
>
{},
Number
<
LRepeat
>
{},
I1
,
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
5
,
6
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferSrcScalarPerVector
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
B0ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
...
@@ -984,6 +992,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -984,6 +992,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index
(
0
,
make_multi_index
(
0
,
0
/
(
LWaves
*
LPerWmma
),
0
/
(
LWaves
*
LPerWmma
),
get_thread_local_1d_id
()
/
32
,
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
get_thread_local_1d_id
()
%
16
,
0
));
0
));
...
@@ -1054,7 +1063,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1054,7 +1063,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
);
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
);
}
}
else
{
else
{
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
-
a_grid_desc
.
GetLength
(
I0
),
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -1063,7 +1072,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1063,7 +1072,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LPerBlock
,
0
);
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LPerBlock
,
0
);
}
}
else
{
else
{
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LRepeat
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
-
b0_grid_desc
.
GetLength
(
I0
),
LRepeat
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -1072,7 +1081,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1072,7 +1081,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
}
}
else
{
else
{
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
);
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I4
)
*
a_grid_desc
.
GetLength
(
I6
);
}
}
}();
}();
...
@@ -1208,6 +1218,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1208,6 +1218,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
LWmmaPerBlock
=
LTilePerBlock
/
WmmaL
;
constexpr
auto
LWmmaPerBlock
=
LTilePerBlock
/
WmmaL
;
constexpr
auto
L0PerWmma
=
WmmaL
/
2
/
L1Value
;
auto
b1_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B1DataType
>
(
auto
b1_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
B1DataType
>
(
b1_block_desc
.
GetElementSpaceSize
());
b1_block_desc
.
GetElementSpaceSize
());
...
@@ -1220,11 +1231,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1220,11 +1231,12 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
Sequence
<
Number
<
LWmmaPerBlock
>
{},
Sequence
<
Number
<
LWmmaPerBlock
>
{},
Number
<
NRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
L0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
L1Value
>
{}
>
,
Number
<
L1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
5
,
6
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferSrcScalarPerVector
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
...
@@ -1232,6 +1244,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1232,6 +1244,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
make_multi_index
(
0
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
get_thread_local_1d_id
()
%
16
,
0
));
0
));
...
@@ -1262,7 +1275,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1262,7 +1275,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
KPack
,
KPack
,
false
,
false
,
B1EnableLds
,
B1EnableLds
,
true
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
)};
true
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
)};
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
auto
acc1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
...
@@ -1271,7 +1284,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
...
@@ -1271,7 +1284,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
return
b0_grid_desc
.
GetLength
(
I1
);
return
b0_grid_desc
.
GetLength
(
I1
);
}
}
else
{
else
{
return
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
b0_grid_desc
.
GetLength
(
I
4
);
return
b0_grid_desc
.
GetLength
(
I1
)
*
b0_grid_desc
.
GetLength
(
I2
)
*
b0_grid_desc
.
GetLength
(
I
5
);
}
}
}();
}();
...
...
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
bee4e344
...
@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
MPerWmma
,
typename
MPerWmma
,
typename
AK1
>
typename
AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_
AK0PerWmma_
AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
WmmaK
&
,
const
WmmaK
&
,
const
MRepeat
&
,
const
MRepeat
&
,
...
@@ -197,14 +197,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -197,14 +197,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
{};
const
auto
AKWmma
=
K
/
WmmaK
{};
constexpr
auto
AKRow
=
WmmaK
{}
/
AK1
{};
constexpr
auto
AKRow
=
2
;
constexpr
auto
AK0PerWmma
=
WmmaK
{}
/
AKRow
/
AK1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
AKRow
,
AK1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AK0PerWmma
>
{},
Number
<
AKRow
>
{},
AK1
{})),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
LPerWmma
,
typename
LPerWmma
,
typename
BK1
>
typename
BK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_
BK0PerWmma_
BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
WmmaK
&
,
const
LRepeat
&
,
const
LRepeat
&
,
...
@@ -265,14 +267,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -265,14 +267,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
WmmaK
{}
/
BK1
{};
constexpr
auto
BKRow
=
2
;
constexpr
auto
BK0PerWmma
=
WmmaK
{}
/
BKRow
/
BK1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
BKRow
,
BK1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
Number
<
BK0PerWmma
>
{},
Number
<
BKRow
>
{},
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
NPerWmma
,
typename
NPerWmma
,
typename
BL1
>
typename
BL1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_
_BL0PerWmma_
BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
WmmaL
&
,
const
NRepeat
&
,
const
NRepeat
&
,
...
@@ -334,14 +338,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -334,14 +338,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
WmmaL
{}
/
BL1
{};
constexpr
auto
BLRow
=
2
;
constexpr
auto
BL0PerWmma
=
WmmaL
{}
/
BLRow
/
BL1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
BLRow
,
BL1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
Number
<
BL0PerWmma
>
{},
Number
<
BLRow
>
{},
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment