Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e51806cc
Commit
e51806cc
authored
Apr 06, 2023
by
illsilin
Browse files
merge down from public repo
parents
e0c2a70f
fde6d274
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
292 additions
and
599 deletions
+292
-599
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+13
-1
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
...gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
...pu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
+8
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
.../device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp
...ration/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp
+13
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
...r_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...e/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+52
-87
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+8
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+53
-87
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
...on/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+54
-403
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
...e_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
+11
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+9
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+6
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
e51806cc
...
...
@@ -631,7 +631,19 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
View file @
e51806cc
...
...
@@ -1567,7 +1567,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
">"
;
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
){
...
...
include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp
View file @
e51806cc
...
...
@@ -1552,7 +1552,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
">"
;
if
constexpr
(
ConvBackwardDataSpecialization
==
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
){
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp
View file @
e51806cc
...
...
@@ -862,7 +862,15 @@ struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceO
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_e_permute_xdl.hpp
View file @
e51806cc
...
...
@@ -561,7 +561,19 @@ struct DeviceGemmBiasEPermute_Xdl : public DeviceGemmBiasCPermute<AElementwiseOp
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
BK1
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
View file @
e51806cc
...
...
@@ -51,7 +51,7 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx1030__))
defined(__gfx90a__) ||
defined(__gfx1030__))
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
ABDataType
);
...
...
@@ -552,7 +552,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx1030"
)
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx1030"
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
e_grid_desc_m_n_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
e51806cc
...
...
@@ -671,7 +671,15 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
e51806cc
...
...
@@ -86,120 +86,84 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
};
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
#endif
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
::
value
)
else
if
constexpr
(
is_same
_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
template
<
typename
ELayout_
>
static
auto
MakeEGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideE
)
static
auto
MakeEGridDescriptor_M_N
(
index_t
M
Raw
,
index_t
N
Raw
,
index_t
StrideE
)
{
const
auto
e_grid_desc_m
_n
=
[
&
]()
{
const
auto
e_grid_desc_m
raw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ELayout_
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideE
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ELayout_
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideE
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideE
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
index_t
,
NumDTensor
>&
Ms
,
...
...
@@ -512,7 +476,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
)
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
e51806cc
...
...
@@ -682,6 +682,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
<<
" LoopScheduler: "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp
View file @
e51806cc
...
...
@@ -822,7 +822,15 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperatio
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
e51806cc
...
...
@@ -12,6 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -78,119 +79,83 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
K0PerBlock
*
K1
};
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
#endif
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
::
value
)
else
if
constexpr
(
is_same
_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
Raw
,
index_t
N
Raw
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m
_n
=
[
&
]()
{
const
auto
c_grid_desc_m
raw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
}
// Gridwise descriptor, mapping to whole given provblem.
...
...
@@ -439,7 +404,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx1100"
)
if
(
ck
::
get_device_name
()
==
"gfx1100"
||
ck
::
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
e51806cc
...
...
@@ -77,8 +77,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
...
...
@@ -116,8 +114,6 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
...
...
@@ -551,7 +547,11 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
">"
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
e51806cc
...
...
@@ -683,7 +683,15 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
View file @
e51806cc
...
...
@@ -761,7 +761,15 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
e51806cc
...
...
@@ -73,157 +73,18 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
int
KBatch
,
int
KPad
)
{
assert
(
KPad
%
(
K1
*
KBatch
)
==
0
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
int
KBatch
,
int
KPad
)
{
assert
(
KPad
%
(
K1
*
KBatch
)
==
0
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
return
KPad
;
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_KBatch_K0_M_K1
(
1
,
1
,
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
ALayout
,
BLayout
,
CLayout
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
...
...
@@ -253,236 +114,64 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
// GridwiseGemm
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
typename
GridwiseGemm
::
CBlockClusterAdaptor
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
k_batch
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
k_batch_
{
k_batch
}
{
int
KPad
=
DeviceGemmXdlSplitKCShuffle
::
GetKPad
(
K
,
k_batch_
);
a_grid_desc_kbatch_k0_m_k1_
=
DeviceGemmXdlSplitKCShuffle
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
M
,
K
,
StrideA
,
k_batch_
,
KPad
);
b_grid_desc_kbatch_k0_n_k1_
=
DeviceGemmXdlSplitKCShuffle
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
K
,
N
,
StrideB
,
k_batch_
,
KPad
);
c_grid_desc_m_n_
=
DeviceGemmXdlSplitKCShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
k_batch_
;
};
using
Argument
=
typename
GridwiseGemm
::
Argument
;
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGemmXdlSplitKCShuffle
::
Argument
;
void
Print
(
const
Argument
&
arg
)
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
void
Print
(
const
Argument
&
karg
)
{
karg
.
Print
();
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
k
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
stream_config
.
log_level_
>
0
)
{
Print
(
arg
);
Print
(
k
arg
);
}
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
;
const
auto
kbatch
=
k
arg
.
k_batch
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"
);
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
"setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
karg
);
const
auto
K0
=
karg
.
K0
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
hipGetErrorString
(
hipMemset
(
arg
.
p_c_grid_
,
0
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
if
(
kbatch
>
1
)
hipGetErrorString
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
};
if
(
has_main_k0_block_loop
)
{
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
Block2CTileMap
>
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
Block2CTileMap
>
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
...
...
@@ -491,37 +180,19 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
{
if
(
kbatch
==
1
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
Block2CTileMap
>
,
false
>
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
GridwiseGemmAtomicAdd
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
Block2CTileMap
>
,
false
>
;
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
...
...
@@ -544,12 +215,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
k
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
return
GridwiseGemm
::
CheckValidity
(
karg
);
}
// polymorphic
...
...
@@ -567,9 +235,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
index_t
KBatch
)
{
return
Argument
{
p_a
,
...
...
@@ -581,11 +249,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
};
}
...
...
@@ -601,9 +268,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ck
::
index_t
KBatch
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
...
@@ -615,11 +282,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
KBatch
);
}
...
...
@@ -630,22 +296,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmXdlSplitKCShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
">"
;
// clang-format on
return
str
.
str
();
}
std
::
string
GetTypeString
()
const
override
{
return
GridwiseGemm
::
GetTypeString
();
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
e51806cc
...
...
@@ -1004,7 +1004,15 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
getConvBackwardDataSpecializationString
(
ConvBackwardDataSpecialization
)
<<
getConvBackwardDataSpecializationString
(
ConvBackwardDataSpecialization
)
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
e51806cc
...
...
@@ -1203,7 +1203,8 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
View file @
e51806cc
...
...
@@ -1232,7 +1232,17 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
e51806cc
...
...
@@ -1093,7 +1093,15 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
e51806cc
...
...
@@ -579,7 +579,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace
ctc
=
tensor_layout
::
convolution
;
// check device
if
(
get_device_name
()
==
"gfx1100"
)
if
(
get_device_name
()
==
"gfx1100"
||
get_device_name
()
==
"gfx1101"
||
ck
::
get_device_name
()
==
"gfx1102"
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
@@ -837,7 +838,10 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
K1
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
">"
;
// clang-format on
...
...
Prev
1
2
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment