Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bee4e344
Commit
bee4e344
authored
May 19, 2023
by
aska-0096
Browse files
(5/5) attention pass, todo: debug lds perf bug
parent
fd4ff3a7
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
276 additions
and
248 deletions
+276
-248
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
...ple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+48
-49
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+36
-25
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+18
-19
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+15
-15
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+130
-117
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
...tion/operator_transform/transform_contraction_to_gemm.hpp
+27
-21
No files found.
example/30_grouped_conv_fwd_multiple_d/run_grouped_conv_fwd_bias_relu_add_wmma_example.inc
View file @
bee4e344
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
bee4e344
...
@@ -365,9 +365,8 @@ struct BlockwiseGemmWMMA
...
@@ -365,9 +365,8 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
// k=0,kpack*1, ... read B
// read B
b_thread_copy_
.
Run
(
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
bee4e344
...
@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -136,6 +136,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
WmmaK
=
16
;
static
constexpr
auto
WmmaK
=
16
;
...
@@ -175,8 +176,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -175,8 +176,10 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
return
Transform
::
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1
(
Transform
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ks_lengths_vec
,
a_gs_ms_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
WmmaK
>
{},
Number
<
MRepeat
>
{},
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MWaves
>
{},
...
@@ -197,7 +200,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -197,7 +200,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
return
Transform
::
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1
(
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
Transform
::
MakeB0GridDescriptor_N_K
(
b0_gs_ls_ks_lengths_vec
,
b0_gs_ls_ks_strides_vec
),
b0_gs_ls_ks_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
WmmaK
>
{},
...
@@ -220,7 +224,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -220,7 +224,8 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
}
}
else
else
{
{
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
return
Transform
::
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1
(
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
Transform
::
MakeB1GridDescriptor_N_K
(
b1_gs_ns_ls_lengths_vec
,
b1_gs_ns_ls_strides_vec
),
b1_gs_ns_ls_strides_vec
),
Number
<
WmmaK
>
{},
Number
<
WmmaK
>
{},
...
@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -521,7 +526,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
else
else
{
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I
5
);
arg
.
a_grid_desc
.
GetLength
(
I
4
)
*
arg
.
a_grid_desc
.
GetLength
(
I6
);
}
}
}();
}();
...
@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -826,7 +831,13 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
"CSpec"
<<
getTensorSpecializationString
(
CSpec
)
<<
", "
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
getMaskingSpecializationString
(
MaskingSpec
)
<<
">"
<<
">"
<<
" NumPrefetch: "
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"B0EnableLds: "
<<
B0EnableLds
<<
", "
<<
"B1EnableLds: "
<<
B1EnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
bee4e344
...
@@ -468,8 +468,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -468,8 +468,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
has_main_k_block_loop
>
;
// Last Option is W/O
has_main_k_block_loop
>
;
// Last Option is W/O
return
return
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
bee4e344
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
bee4e344
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp
View file @
bee4e344
...
@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -186,7 +186,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
MPerWmma
,
typename
MPerWmma
,
typename
AK1
>
typename
AK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AKRow_MPerWmma_AK1
(
MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_
AK0PerWmma_
AKRow_MPerWmma_AK1
(
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
AGridDesc_M_K
&
a_grid_desc_m_k
,
const
WmmaK
&
,
const
WmmaK
&
,
const
MRepeat
&
,
const
MRepeat
&
,
...
@@ -197,14 +197,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -197,14 +197,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
M0
=
a_grid_desc_m_k
.
GetLength
(
I0
)
/
MPerBlock
;
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
const
auto
AKWmma
=
K
/
WmmaK
{};
const
auto
AKWmma
=
K
/
WmmaK
{};
constexpr
auto
AKRow
=
WmmaK
{}
/
AK1
{};
constexpr
auto
AKRow
=
2
;
constexpr
auto
AK0PerWmma
=
WmmaK
{}
/
AKRow
/
AK1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
AKRow
,
AK1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AKWmma
,
Number
<
AK0PerWmma
>
{},
Number
<
AKRow
>
{},
AK1
{})),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
{},
MWaves
{},
MPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -254,7 +256,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
LPerWmma
,
typename
LPerWmma
,
typename
BK1
>
typename
BK1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BKRow_LPerWmma_BK1
(
MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_
BK0PerWmma_
BKRow_LPerWmma_BK1
(
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
BGridDesc_L_K
&
b_grid_desc_l_k
,
const
WmmaK
&
,
const
WmmaK
&
,
const
LRepeat
&
,
const
LRepeat
&
,
...
@@ -265,14 +267,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -265,14 +267,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
L0
=
b_grid_desc_l_k
.
GetLength
(
I0
)
/
NPerBlock
;
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
K
=
b_grid_desc_l_k
.
GetLength
(
I1
);
const
auto
BKWmma
=
K
/
WmmaK
{};
const
auto
BKWmma
=
K
/
WmmaK
{};
constexpr
auto
BKRow
=
WmmaK
{}
/
BK1
{};
constexpr
auto
BKRow
=
2
;
constexpr
auto
BK0PerWmma
=
WmmaK
{}
/
BKRow
/
BK1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_l_k
,
b_grid_desc_l_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
BKRow
,
BK1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BKWmma
,
Number
<
BK0PerWmma
>
{},
Number
<
BKRow
>
{},
BK1
{})),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
L0
*
LRepeat
{},
LWaves
{},
LPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -323,7 +327,7 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
typename
NPerWmma
,
typename
NPerWmma
,
typename
BL1
>
typename
BL1
>
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_BLRow_NPerWmma_BL1
(
MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves_
_BL0PerWmma_
BLRow_NPerWmma_BL1
(
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
BGridDesc_N_L
&
b_grid_desc_n_l
,
const
WmmaL
&
,
const
WmmaL
&
,
const
NRepeat
&
,
const
NRepeat
&
,
...
@@ -334,14 +338,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
...
@@ -334,14 +338,16 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
N0
=
b_grid_desc_n_l
.
GetLength
(
I0
)
/
OPerBlock
;
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
L
=
b_grid_desc_n_l
.
GetLength
(
I1
);
const
auto
BLWmma
=
L
/
WmmaL
{};
const
auto
BLWmma
=
L
/
WmmaL
{};
constexpr
auto
BLRow
=
WmmaL
{}
/
BL1
{};
constexpr
auto
BLRow
=
2
;
constexpr
auto
BL0PerWmma
=
WmmaL
{}
/
BLRow
/
BL1
{};
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_n_l
,
b_grid_desc_n_l
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
BLRow
,
BL1
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BLWmma
,
Number
<
BL0PerWmma
>
{},
Number
<
BLRow
>
{},
BL1
{})),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
{},
NWaves
{},
NPerWmma
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
//
//
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment