Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b9eb4de3
Commit
b9eb4de3
authored
Jul 23, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
be3fbf7f
d22713a7
Changes
130
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1982 additions
and
1220 deletions
+1982
-1220
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+515
-996
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+1
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
...k/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
+226
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+225
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
.../library/tensor_operation_instance/gpu/gemm_universal.hpp
+81
-3
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
...y/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
+457
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
...wd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
+0
-96
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
...or_operation_instance/gpu/grouped_convolution_forward.hpp
+0
-13
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+0
-112
library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
...ensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
+14
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp
...ice_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp
+85
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp
...f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp
+37
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp
...8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp
+37
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp
...f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp
+37
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp
..._f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp
+37
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp
..._f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp
...f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
..._bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
+38
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
...ration_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
+17
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
+99
-0
No files found.
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
b9eb4de3
...
...
@@ -14,22 +14,15 @@
namespace
ck
{
namespace
tensor_operation
{
// function to be used on device, emulates std::accumulate
template
<
typename
T
,
typename
ForwardIterator
,
typename
Size
>
__host__
__device__
auto
mult_accumulate_n
(
ForwardIterator
first
,
Size
count
,
T
init
)
{
for
(
ForwardIterator
x
=
first
;
x
!=
first
+
count
;
x
++
)
{
init
*=
*
x
;
}
return
init
;
}
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
bool
SplitN
=
false
,
typename
ADataType
=
float
,
typename
CDataType
=
float
,
index_t
NumGroupsToMerge
=
1
>
struct
TransformConvFwdToGemm
{
private:
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -37,10 +30,10 @@ struct TransformConvFwdToGemm
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
long_index_t
calculate_element_space_size_impl
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
strides
,
index_t
i
)
template
<
typename
ConvDimsType
>
static
long_index_t
calculate_element_space_size_impl
(
const
ConvDimsType
&
lengths
,
const
ConvDimsType
&
strides
,
index_t
i
)
{
long_index_t
acc
=
1
;
for
(;
i
<
(
NDimSpatial
+
3
);
i
++
)
...
...
@@ -52,11 +45,11 @@ struct TransformConvFwdToGemm
return
acc
;
}
template
<
typename
ADataType
,
typename
CData
Type
>
static
index_t
GetSplitedNSize
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>
&
c_g_n_k_wos_strides
)
template
<
typename
ConvDims
Type
>
static
index_t
GetSplitedNSize
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
)
{
const
long_index_t
a_element_space_size
=
calculate_element_space_size_impl
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
I1
);
...
...
@@ -102,6 +95,216 @@ struct TransformConvFwdToGemm
}
}
public:
__host__
__device__
constexpr
TransformConvFwdToGemm
()
{}
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
const
ConvDimsType
&
b_g_k_c_xs_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
,
const
ConvSpatialDimsType
&
conv_filter_strides
,
const
ConvSpatialDimsType
&
conv_filter_dilations
,
const
ConvSpatialDimsType
&
input_left_pads
,
const
ConvSpatialDimsType
&
input_right_pads
)
:
Di_
{
I1
},
Hi_
{
I1
},
Wi_
{
a_g_n_c_wis_lengths
[
I3
]},
Do_
{
I1
},
Ho_
{
I1
},
Wo_
{
c_g_n_k_wos_lengths
[
I3
]},
Z_
{
I1
},
Y_
{
I1
},
X_
{
b_g_k_c_xs_lengths
[
I3
]},
K_
{
c_g_n_k_wos_lengths
[
I2
]},
C_
{
b_g_k_c_xs_lengths
[
I2
]},
DiStride_
{
I1
},
HiStride_
{
I1
},
WiStride_
{
a_g_n_c_wis_strides
[
I3
]},
WoStride_
{
c_g_n_k_wos_strides
[
I3
]},
XStride_
{
b_g_k_c_xs_strides
[
I3
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
CStrideTensorB_
{
b_g_k_c_xs_strides
[
I2
]},
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
ConvStrideD_
{
I1
},
ConvStrideH_
{
I1
},
ConvStrideW_
{
conv_filter_strides
[
I0
]},
ConvDilationD_
{
I1
},
ConvDilationH_
{
I1
},
ConvDilationW_
{
conv_filter_dilations
[
I0
]},
InLeftPadD_
{
I0
},
InLeftPadH_
{
I0
},
InLeftPadW_
{
input_left_pads
[
I0
]},
InRightPadD_
{
I0
},
InRightPadH_
{
I0
},
InRightPadW_
{
input_right_pads
[
I0
]},
ZYX_
{
X_
}
{
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
index_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
index_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
+
I3
>>
);
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
);
}
else
{
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Wo_
;
}
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
const
ConvDimsType
&
b_g_k_c_xs_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
,
const
ConvSpatialDimsType
&
conv_filter_strides
,
const
ConvSpatialDimsType
&
conv_filter_dilations
,
const
ConvSpatialDimsType
&
input_left_pads
,
const
ConvSpatialDimsType
&
input_right_pads
)
:
Di_
{
I1
},
Hi_
{
a_g_n_c_wis_lengths
[
I3
]},
Wi_
{
a_g_n_c_wis_lengths
[
I4
]},
Do_
{
I1
},
Ho_
{
c_g_n_k_wos_lengths
[
I3
]},
Wo_
{
c_g_n_k_wos_lengths
[
I4
]},
Z_
{
I1
},
Y_
{
b_g_k_c_xs_lengths
[
I3
]},
X_
{
b_g_k_c_xs_lengths
[
I4
]},
K_
{
c_g_n_k_wos_lengths
[
I2
]},
C_
{
b_g_k_c_xs_lengths
[
I2
]},
DiStride_
{
I1
},
HiStride_
{
a_g_n_c_wis_strides
[
I3
]},
WiStride_
{
a_g_n_c_wis_strides
[
I4
]},
WoStride_
{
c_g_n_k_wos_strides
[
I4
]},
XStride_
{
b_g_k_c_xs_strides
[
I4
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
CStrideTensorB_
{
b_g_k_c_xs_strides
[
I2
]},
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
ConvStrideD_
{
I1
},
ConvStrideH_
{
conv_filter_strides
[
I0
]},
ConvStrideW_
{
conv_filter_strides
[
I1
]},
ConvDilationD_
{
I1
},
ConvDilationH_
{
conv_filter_dilations
[
I0
]},
ConvDilationW_
{
conv_filter_dilations
[
I1
]},
InLeftPadD_
{
I0
},
InLeftPadH_
{
input_left_pads
[
I0
]},
InLeftPadW_
{
input_left_pads
[
I1
]},
InRightPadD_
{
I0
},
InRightPadH_
{
input_right_pads
[
I0
]},
InRightPadW_
{
input_right_pads
[
I1
]},
ZYX_
{
Y_
*
X_
}
{
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
index_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
index_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
+
I3
>>
);
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
);
}
else
{
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Ho_
*
Wo_
;
}
template
<
typename
ConvDimsType
,
typename
ConvSpatialDimsType
,
index_t
NDim
=
NDimSpatial
,
typename
std
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
__host__
__device__
TransformConvFwdToGemm
(
const
ConvDimsType
&
a_g_n_c_wis_lengths
,
const
ConvDimsType
&
a_g_n_c_wis_strides
,
const
ConvDimsType
&
b_g_k_c_xs_lengths
,
const
ConvDimsType
&
b_g_k_c_xs_strides
,
const
ConvDimsType
&
c_g_n_k_wos_lengths
,
const
ConvDimsType
&
c_g_n_k_wos_strides
,
const
ConvSpatialDimsType
&
conv_filter_strides
,
const
ConvSpatialDimsType
&
conv_filter_dilations
,
const
ConvSpatialDimsType
&
input_left_pads
,
const
ConvSpatialDimsType
&
input_right_pads
)
:
Di_
{
a_g_n_c_wis_lengths
[
I3
]},
Hi_
{
a_g_n_c_wis_lengths
[
I4
]},
Wi_
{
a_g_n_c_wis_lengths
[
I5
]},
Do_
{
c_g_n_k_wos_lengths
[
I3
]},
Ho_
{
c_g_n_k_wos_lengths
[
I4
]},
Wo_
{
c_g_n_k_wos_lengths
[
I5
]},
Z_
{
b_g_k_c_xs_lengths
[
I3
]},
Y_
{
b_g_k_c_xs_lengths
[
I4
]},
X_
{
b_g_k_c_xs_lengths
[
I5
]},
K_
{
c_g_n_k_wos_lengths
[
I2
]},
C_
{
b_g_k_c_xs_lengths
[
I2
]},
DiStride_
{
a_g_n_c_wis_strides
[
I3
]},
HiStride_
{
a_g_n_c_wis_strides
[
I4
]},
WiStride_
{
a_g_n_c_wis_strides
[
I5
]},
WoStride_
{
c_g_n_k_wos_strides
[
I5
]},
XStride_
{
b_g_k_c_xs_strides
[
I5
]},
CStrideTensorA_
{
a_g_n_c_wis_strides
[
I2
]},
CStrideTensorB_
{
b_g_k_c_xs_strides
[
I2
]},
KStrideTensorB_
{
b_g_k_c_xs_strides
[
I1
]},
KStrideTensorC_
{
c_g_n_k_wos_strides
[
I2
]},
NStrideTensorA_
{
a_g_n_c_wis_strides
[
I1
]},
GStrideTensorA_
{
a_g_n_c_wis_strides
[
I0
]},
GStrideTensorB_
{
b_g_k_c_xs_strides
[
I0
]},
GStrideTensorC_
{
c_g_n_k_wos_strides
[
I0
]},
ConvStrideD_
{
conv_filter_strides
[
I0
]},
ConvStrideH_
{
conv_filter_strides
[
I1
]},
ConvStrideW_
{
conv_filter_strides
[
I2
]},
ConvDilationD_
{
conv_filter_dilations
[
I0
]},
ConvDilationH_
{
conv_filter_dilations
[
I1
]},
ConvDilationW_
{
conv_filter_dilations
[
I2
]},
InLeftPadD_
{
input_left_pads
[
I0
]},
InLeftPadH_
{
input_left_pads
[
I1
]},
InLeftPadW_
{
input_left_pads
[
I2
]},
InRightPadD_
{
input_right_pads
[
I0
]},
InRightPadH_
{
input_right_pads
[
I1
]},
InRightPadW_
{
input_right_pads
[
I2
]},
ZYX_
{
Z_
*
Y_
*
X_
}
{
static_assert
(
is_same_v
<
ConvSpatialDimsType
,
std
::
array
<
index_t
,
NDimSpatial
>>
||
is_same_v
<
ConvSpatialDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
>>
);
static_assert
(
is_same_v
<
ConvDimsType
,
std
::
array
<
index_t
,
NDimSpatial
+
I3
>>
||
is_same_v
<
ConvDimsType
,
ck
::
Array
<
index_t
,
NDimSpatial
+
I3
>>
);
if
constexpr
(
SplitN
)
{
N_
=
GetSplitedNSize
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
);
}
else
{
N_
=
c_g_n_k_wos_lengths
[
I1
];
}
NDoHoWo_
=
N_
*
Do_
*
Ho_
*
Wo_
;
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template
<
typename
ALayout
,
...
...
@@ -110,53 +313,26 @@ struct TransformConvFwdToGemm
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
const
index_t
C
=
a_g_n_c_wis_lengths
[
I2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
I3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
I3
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
I0
];
const
index_t
GStride
=
a_g_n_c_wis_strides
[
I0
];
const
index_t
NStride
=
a_g_n_c_wis_strides
[
I1
];
const
auto
CStride
=
a_g_n_c_wis_strides
[
I2
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
I3
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
make_naive_tensor_descriptor
(
make_tuple
(
N
Do
HoWo
_
,
C
_
),
make_tuple
(
WiStride
_
,
CStride
TensorA_
));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
NumGroupsToMerge
,
C
),
make_tuple
(
WiStride
,
GStride
,
CStride
));
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NHoWo
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
Do
HoWo
_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -164,35 +340,30 @@ struct TransformConvFwdToGemm
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
),
make_tuple
(
NStride
,
WiStride
));
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
_
,
Wi
_
),
make_tuple
(
NStride
TensorA_
,
WiStride
_
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}));
return
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Wo
_
)),
make_pass_through_transform
(
Number
<
3
>
{})),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -200,28 +371,29 @@ struct TransformConvFwdToGemm
else
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
NumGroupsToMerge
),
make_tuple
(
NStride
,
WiStride
,
GStride
));
make_tuple
(
N_
,
Wi_
,
NumGroupsToMerge
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
GStrideTensorA_
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
,
NumGroupsToMerge
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Wo
_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
Number
<
3
>
{})),
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -233,110 +405,108 @@ struct TransformConvFwdToGemm
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
make_tuple
(
N_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Wo
_
),
make_tuple
(
ConvStrideW
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Wo
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
NumGroupsToMerge
,
C
),
make_tuple
(
NStride
,
WiStride
,
GStride
,
CStride
));
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
_
,
Wi
_
,
NumGroupsToMerge
,
C
_
),
make_tuple
(
NStride
TensorA_
,
WiStride
_
,
GStride
TensorA_
,
CStride
TensorA_
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Wo
_
),
make_tuple
(
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Wo
_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
{
const
index_t
X
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
make_tuple
(
N_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
X
_
,
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
return
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Wo
_
)),
make_merge_transform
(
make_tuple
(
X
_
,
C
_
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
NumGroupsToMerge
,
C
),
make_tuple
(
NStride
,
WiStride
,
GStride
,
CStride
));
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
_
,
Wi
_
,
NumGroupsToMerge
,
C
_
),
make_tuple
(
NStride
TensorA_
,
WiStride
_
,
GStride
TensorA_
,
CStride
TensorA_
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
X
_
,
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
return
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Wo
_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
X
_
,
C
_
))),
make_tuple
(
Sequence
<
0
,
2
,
3
>
{},
Sequence
<
1
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -349,57 +519,27 @@ struct TransformConvFwdToGemm
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
const
index_t
GStride
=
a_g_n_c_wis_strides
[
I0
];
const
index_t
NStride
=
a_g_n_c_wis_strides
[
I1
];
const
index_t
CStride
=
a_g_n_c_wis_strides
[
I2
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
I3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
I4
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
make_naive_tensor_descriptor
(
make_tuple
(
N
Do
HoWo
_
,
C
_
),
make_tuple
(
WiStride
_
,
CStride
TensorA_
));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
NumGroupsToMerge
,
C
),
make_tuple
(
WiStride
,
GStride
,
CStride
));
make_tuple
(
NDoHoWo_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NHoWo
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
Do
HoWo
_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -407,73 +547,65 @@ struct TransformConvFwdToGemm
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
),
make_tuple
(
NStride
,
HiStride
,
WiStride
));
make_tuple
(
N
_
,
Hi
_
,
Wi
_
),
make_tuple
(
NStride
TensorA_
,
HiStride
_
,
WiStride
_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho
_
),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Ho
_
,
Wo
_
)),
make_merge_transform
(
make_tuple
(
Number
<
3
>
{},
Number
<
3
>
{}))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_hi_wi_groups_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
NumGroupsToMerge
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
GStride
));
const
auto
in_n_hi_wi_groups_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
_
,
Hi
_
,
Wi
_
,
NumGroupsToMerge
),
make_tuple
(
NStride
TensorA_
,
HiStride
_
,
WiStride
_
,
GStride
TensorA_
));
const
auto
in_n_hip_wip_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho
_
),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_groups_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
NumGroupsToMerge
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Ho
_
,
Wo
_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Number
<
3
>
{},
Number
<
3
>
{}))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -485,37 +617,39 @@ struct TransformConvFwdToGemm
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
make_tuple
(
N_
,
Hi_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Ho
_
),
make_tuple
(
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
Wo
_
),
make_tuple
(
ConvStrideW
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
return
transform_tensor_descriptor
(
in_n_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Ho
_
,
Wo
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_hi_wi_groups_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
NumGroupsToMerge
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
GStride
,
CStride
));
make_tuple
(
N_
,
Hi_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_ho_wo_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Ho
_
),
make_tuple
(
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
Wo
_
),
make_tuple
(
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
...
...
@@ -523,55 +657,44 @@ struct TransformConvFwdToGemm
return
transform_tensor_descriptor
(
in_n_ho_wo_groups_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Ho
_
,
Wo
_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
{
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
make_tuple
(
N_
,
Hi_
,
Wi_
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
CStrideTensorA_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Y
_
,
Ho
_
),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
X
_
,
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Ho
_
,
Wo
_
)),
make_merge_transform
(
make_tuple
(
Y
_
,
X
_
,
C
_
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -579,16 +702,17 @@ struct TransformConvFwdToGemm
{
const
auto
in_n_hi_wi_groups_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
NumGroupsToMerge
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
GStride
,
CStride
));
make_tuple
(
N_
,
Hi_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_hip_wip_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
...
...
@@ -596,13 +720,13 @@ struct TransformConvFwdToGemm
const
auto
in_n_y_ho_x_wo_groups_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_groups_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Y
_
,
Ho
_
),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
X
_
,
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
...
...
@@ -613,8 +737,8 @@ struct TransformConvFwdToGemm
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_groups_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Ho
_
,
Wo
_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Y
_
,
X
_
,
C
_
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
,
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -627,63 +751,27 @@ struct TransformConvFwdToGemm
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides*/
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
__host__
__device__
auto
MakeADescriptor_M_K
()
const
{
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
5
];
const
index_t
Do
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
5
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
const
index_t
GStride
=
a_g_n_c_wis_strides
[
I0
];
const
index_t
NStride
=
a_g_n_c_wis_strides
[
I1
];
const
index_t
CStride
=
a_g_n_c_wis_strides
[
I2
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
I3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
I4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
I5
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
_
,
C
_
),
make_tuple
(
WiStride
_
,
CStride
TensorA_
));
}
else
{
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
NumGroupsToMerge
,
C
),
make_tuple
(
WiStride
,
GStride
,
CStride
));
const
auto
in_gemmm_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
_
,
NumGroupsToMerge
,
C
_
),
make_tuple
(
WiStride
_
,
GStride
TensorA_
,
CStride
TensorA_
));
return
transform_tensor_descriptor
(
in_gemmm_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
NDoHoWo
_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -691,41 +779,30 @@ struct TransformConvFwdToGemm
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
));
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Di
_
,
InLeftPadD
_
,
InRightPadD
_
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
))),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Do
_
),
make_tuple
(
ConvDilationD
_
,
ConvStrideD
_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho
_
),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{}));
...
...
@@ -733,7 +810,7 @@ struct TransformConvFwdToGemm
return
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
N
_
,
Do
_
,
Ho
_
,
Wo
_
)),
make_merge_transform
(
make_tuple
(
Number
<
3
>
{},
Number
<
3
>
{},
Number
<
3
>
{}))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -741,15 +818,15 @@ struct TransformConvFwdToGemm
else
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
NumGroupsToMerge
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
GStride
));
make_tuple
(
N
_
,
Di
_
,
Hi
_
,
Wi
_
,
NumGroupsToMerge
),
make_tuple
(
NStride
TensorA_
,
DiStride
_
,
HiStride
_
,
WiStride
_
,
GStride
TensorA_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Di
_
,
InLeftPadD
_
,
InRightPadD
_
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
...
...
@@ -758,13 +835,13 @@ struct TransformConvFwdToGemm
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Do
_
),
make_tuple
(
ConvDilationD
_
,
ConvStrideD
_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Ho
_
),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
Number
<
3
>
{},
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
...
...
@@ -777,7 +854,7 @@ struct TransformConvFwdToGemm
return
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
N
_
,
Do
_
,
Ho
_
,
Wo
_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Number
<
3
>
{},
Number
<
3
>
{},
Number
<
3
>
{}))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -789,16 +866,16 @@ struct TransformConvFwdToGemm
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
make_tuple
(
N
_
,
Di
_
,
Hi
_
,
Wi
_
,
C
_
),
make_tuple
(
NStride
TensorA_
,
DiStride
_
,
HiStride
_
,
WiStride
_
,
CStride
TensorA_
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Do
_
),
make_tuple
(
ConvStrideD
_
)),
make_embed_transform
(
make_tuple
(
Ho
_
),
make_tuple
(
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
Wo
_
),
make_tuple
(
ConvStrideW
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
...
...
@@ -806,25 +883,30 @@ struct TransformConvFwdToGemm
return
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Do
_
,
Ho
_
,
Wo
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
NumGroupsToMerge
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
GStride
,
CStride
));
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Do
_
),
make_tuple
(
ConvStrideD
_
)),
make_embed_transform
(
make_tuple
(
Ho
_
),
make_tuple
(
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
Wo
_
),
make_tuple
(
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -840,43 +922,28 @@ struct TransformConvFwdToGemm
return
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
C_
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
else
{
const
index_t
Z
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
5
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
if
constexpr
(
NumGroupsToMerge
==
1
)
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
make_tuple
(
N
_
,
Di
_
,
Hi
_
,
Wi
_
,
C
_
),
make_tuple
(
NStride
TensorA_
,
DiStride
_
,
HiStride
_
,
WiStride
_
,
CStride
TensorA_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Di
_
,
InLeftPadD
_
,
InRightPadD
_
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
...
...
@@ -884,14 +951,14 @@ struct TransformConvFwdToGemm
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Z
_
,
Do
_
),
make_tuple
(
ConvDilationD
_
,
ConvStrideD
_
)),
make_embed_transform
(
make_tuple
(
Y
_
,
Ho
_
),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
X
_
,
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
...
...
@@ -902,25 +969,30 @@ struct TransformConvFwdToGemm
return
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
_
,
Do
_
,
Ho
_
,
Wo
_
)),
make_merge_transform
(
make_tuple
(
Z
_
,
Y
_
,
X
_
,
C
_
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
NumGroupsToMerge
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
GStride
,
CStride
));
make_tuple
(
N_
,
Di_
,
Hi_
,
Wi_
,
NumGroupsToMerge
,
C_
),
make_tuple
(
NStrideTensorA_
,
DiStride_
,
HiStride_
,
WiStride_
,
GStrideTensorA_
,
CStrideTensorA_
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_pad_transform
(
Di
_
,
InLeftPadD
_
,
InRightPadD
_
),
make_pad_transform
(
Hi
_
,
InLeftPadH
_
,
InRightPadH
_
),
make_pad_transform
(
Wi
_
,
InLeftPadW
_
,
InRightPadW
_
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -936,15 +1008,15 @@ struct TransformConvFwdToGemm
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_tuple
(
make_pass_through_transform
(
N
_
),
make_embed_transform
(
make_tuple
(
Z
_
,
Do
_
),
make_tuple
(
ConvDilationD
_
,
ConvStrideD
_
)),
make_embed_transform
(
make_tuple
(
Y
_
,
Ho
_
),
make_tuple
(
ConvDilationH
_
,
ConvStrideH
_
)),
make_embed_transform
(
make_tuple
(
X
_
,
Wo
_
),
make_tuple
(
ConvDilationW
_
,
ConvStrideW
_
)),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
C
)),
make_pass_through_transform
(
C
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
...
...
@@ -960,8 +1032,9 @@ struct TransformConvFwdToGemm
return
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N_
,
Do_
,
Ho_
,
Wo_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
Z_
,
Y_
,
X_
,
C_
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
,
7
>
{},
Sequence
<
1
,
3
,
5
,
8
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -973,19 +1046,8 @@ struct TransformConvFwdToGemm
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
static
auto
MakeBDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
__host__
__device__
auto
MakeBDescriptor_N_K
()
const
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
ck
::
accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
GStride
=
b_g_k_c_xs_strides
[
I0
];
const
index_t
KStride
=
b_g_k_c_xs_strides
[
I1
];
const
index_t
CStride
=
b_g_k_c_xs_strides
[
I2
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter3x3
)
{
...
...
@@ -996,17 +1058,17 @@ struct TransformConvFwdToGemm
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
FilterSizeNumType
{}));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
_
,
FilterSizeNumType
{}));
}
else
{
const
auto
wei_gemmn_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
NumGroupsToMerge
,
FilterSizeNumType
{}),
make_tuple
(
KStride
,
GStride
,
CStride
));
make_tuple
(
K
_
,
NumGroupsToMerge
,
FilterSizeNumType
{}),
make_tuple
(
KStride
TensorB_
,
GStride
TensorB_
,
CStride
TensorB_
));
return
transform_tensor_descriptor
(
wei_gemmn_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K
,
NumGroupsToMerge
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K
_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
FilterSizeNumType
{})),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -1016,16 +1078,17 @@ struct TransformConvFwdToGemm
{
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
_
,
Z
YX
_
*
C
_
));
}
else
{
const
auto
wei_gemmn_groups_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
NumGroupsToMerge
,
YX
*
C
),
make_tuple
(
KStride
,
GStride
,
CStride
));
make_tuple
(
K_
,
NumGroupsToMerge
,
ZYX_
*
C_
),
make_tuple
(
KStrideTensorB_
,
GStrideTensorB_
,
CStrideTensorB_
));
return
transform_tensor_descriptor
(
wei_gemmn_groups_gemmk_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
K
,
NumGroupsToMerge
)),
make_pass_through_transform
(
YX
*
C
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
K
_
,
NumGroupsToMerge
)),
make_pass_through_transform
(
Z
YX
_
*
C
_
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -1041,25 +1104,14 @@ struct TransformConvFwdToGemm
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
static
auto
MakeBDescriptor_N_K
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
__host__
__device__
auto
MakeBDescriptor_N_K
()
const
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
ck
::
accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
wei_k_yx_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
YX
,
C
),
make_tuple
(
KStride
,
XStride
,
CStride
));
make_tuple
(
K
_
,
Z
YX
_
,
C
_
),
make_tuple
(
KStride
TensorB_
,
XStride
_
,
CStride
TensorB_
));
const
auto
wei_gemmn_gemmk_desc
=
transform_tensor_descriptor
(
wei_k_yx_c_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
YX
,
C
))),
make_tuple
(
make_pass_through_transform
(
K
_
),
make_merge_transform
(
make_tuple
(
Z
YX
_
,
C
_
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
...
@@ -1071,24 +1123,14 @@ struct TransformConvFwdToGemm
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
index_t
N
)
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
return
out_gemmm_gemmn_desc
;
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
NDoHoWo_
,
K_
));
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
...
...
@@ -1096,39 +1138,28 @@ struct TransformConvFwdToGemm
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
index_t
N
)
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
I1
;
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
GStride
=
c_g_n_k_wos_strides
[
0
];
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
if
constexpr
(
NumGroupsToMerge
==
1
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
return
make_naive_tensor_descriptor
(
make_tuple
(
N
Do
HoWo
_
,
K
_
),
make_tuple
(
WoStride
_
,
KStride
TensorC_
));
}
else
{
const
auto
nhwo_groups_k_1_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
NumGroupsToMerge
,
K
,
1
),
make_tuple
(
WoStride
,
GStride
,
KStride
,
GStride
));
const
auto
nhwo_groups_k_1_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
Do
HoWo
_
,
NumGroupsToMerge
,
K
_
,
1
),
make_tuple
(
WoStride
_
,
GStride
TensorC_
,
KStride
TensorC_
,
GStride
TensorC_
));
// Padd 1 to NumGroupsToMerge
const
auto
padded_desc
=
transform_tensor_descriptor
(
nhwo_groups_k_1_desc
,
make_tuple
(
make_pass_through_transform
(
NHoWo
),
make_tuple
(
make_pass_through_transform
(
N
Do
HoWo
_
),
make_pass_through_transform
(
NumGroupsToMerge
),
make_pass_through_transform
(
K
),
make_pass_through_transform
(
K
_
),
make_pad_transform
(
1
,
0
,
NumGroupsToMerge
-
1
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// We need only matrices from diagonal. X
_
or returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert
(
NumGroupsToMerge
==
1
||
NumGroupsToMerge
==
2
||
NumGroupsToMerge
==
4
||
...
...
@@ -1136,16 +1167,16 @@ struct TransformConvFwdToGemm
NumGroupsToMerge
==
32
||
NumGroupsToMerge
==
64
);
const
auto
unmerged_padded_desc
=
transform_tensor_descriptor
(
padded_desc
,
make_tuple
(
make_pass_through_transform
(
NHoWo
),
make_tuple
(
make_pass_through_transform
(
N
Do
HoWo
_
),
make_xor_transform
(
make_tuple
(
NumGroupsToMerge
,
NumGroupsToMerge
)),
make_pass_through_transform
(
K
)),
make_pass_through_transform
(
K
_
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
>
{},
Sequence
<
2
>
{}));
// Merge To M, N
return
transform_tensor_descriptor
(
unmerged_padded_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
NHoWo
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
K
,
NumGroupsToMerge
))),
make_tuple
(
make_merge_transform
(
make_tuple
(
N
Do
HoWo
_
,
NumGroupsToMerge
)),
make_merge_transform
(
make_tuple
(
K
_
,
NumGroupsToMerge
))),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
...
...
@@ -1155,542 +1186,34 @@ struct TransformConvFwdToGemm
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
index_t
N
)
__host__
__device__
auto
MakeCDescriptor_M_N
()
const
{
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
KStride
));
make_naive_tensor_descriptor
(
make_tuple
(
N
Do
HoWo
_
,
K
_
),
make_tuple
(
I0
,
KStride
TensorC_
));
return
out_gemmm_gemmn_desc
;
}
// Overloaded functions for hipRTC purposes
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
X
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
3
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
5
];
const
index_t
Do
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
5
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
Z
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
5
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeBDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
mult_accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
wei_gemmn_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeBDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
mult_accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
wei_k_yx_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
YX
,
C
),
make_tuple
(
KStride
,
XStride
,
CStride
));
const
auto
wei_gemmn_gemmk_desc
=
transform_tensor_descriptor
(
wei_k_yx_c_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
YX
,
C
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
return
out_gemmm_gemmn_desc
;
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
auto
KStride
=
I1
;
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
public:
index_t
N_
;
private:
const
index_t
Di_
,
Hi_
,
Wi_
;
const
index_t
Do_
,
Ho_
,
Wo_
;
const
index_t
Z_
,
Y_
,
X_
;
const
index_t
K_
,
C_
;
const
index_t
DiStride_
,
HiStride_
,
WiStride_
;
const
index_t
WoStride_
;
const
index_t
XStride_
;
const
index_t
CStrideTensorA_
,
CStrideTensorB_
,
KStrideTensorB_
,
KStrideTensorC_
;
const
index_t
NStrideTensorA_
;
const
index_t
GStrideTensorA_
,
GStrideTensorB_
,
GStrideTensorC_
;
const
index_t
ConvStrideD_
,
ConvStrideH_
,
ConvStrideW_
;
const
index_t
ConvDilationD_
,
ConvDilationH_
,
ConvDilationW_
;
const
index_t
InLeftPadD_
,
InLeftPadH_
,
InLeftPadW_
;
const
index_t
InRightPadD_
,
InRightPadH_
,
InRightPadW_
;
const
index_t
ZYX_
;
index_t
NDoHoWo_
;
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
...
...
@@ -1702,26 +1225,22 @@ struct TransformConv
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
auto
transform_func
(
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
out_lengths
,
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
out_strides
,
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
conv_fwd_to_gemm
)
transform_func
(
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
conv_fwd_to_gemm
)
{
if
(
NDimSpatial
==
2
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
ck
::
tensor_layout
::
convolution
::
NHWGK
>(
out_lengths
,
out_strides
);
.
template
MakeCDescriptor_M_N
<
ck
::
tensor_layout
::
convolution
::
NHWGK
>();
}
else
if
(
NDimSpatial
==
3
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NDHWGK
>(
out_lengths
,
out_strides
);
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NDHWGK
>();
}
else
if
(
NDimSpatial
==
1
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NWGK
>(
out_lengths
,
out_strides
);
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NWGK
>(
);
}
}
};
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
b9eb4de3
...
...
@@ -108,6 +108,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using
MultiplyFastGelu
=
ck
::
tensor_operation
::
element_wise
::
MultiplyFastGelu
;
using
AddMultiply
=
ck
::
tensor_operation
::
element_wise
::
AddMultiply
;
using
MultiplyAdd
=
ck
::
tensor_operation
::
element_wise
::
MultiplyAdd
;
using
MultiplyMultiply
=
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
A0DataType
,
typename
A1DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD_ABScale
<
ALayout
,
BLayout
,
Tuple
<>
,
CLayout
,
A0DataType
,
A1DataType
,
B0DataType
,
B1DataType
,
Tuple
<>
,
CDataType
,
128
,
128
,
128
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmMultipleD_ABScale
<
ALayout
,
BLayout
,
Tuple
<>
,
CLayout
,
A0DataType
,
A1DataType
,
B0DataType
,
B1DataType
,
Tuple
<>
,
CDataType
,
128
,
128
,
128
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
A0DataType
,
f8_t
>
&&
is_same_v
<
B0DataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
CLayout
,
ADataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Tuple
<
Row
,
Col
>
,
CLayout
,
ADataType
,
BDataType
,
Tuple
<
F32
,
F32
>
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
MultiplyMultiply
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal.hpp
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -315,7 +315,7 @@ void add_device_gemm_xdl_universal_f8_f16_f16_mk_nk_mn_mem_v2_mnkpadding_instanc
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_F
P
16
#ifdef CK_ENABLE_
B
F16
void
add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Row
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
...
...
@@ -416,6 +416,57 @@ void add_device_gemm_xdl_universal_bf16_bf16_bf16_mk_nk_mn_mem_v2_mnkpadding_ins
DeviceGemmV2
<
Row
,
Col
,
Row
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
...
...
@@ -596,7 +647,7 @@ struct DeviceOperationInstanceFactory<
}
}
#endif
#ifdef CK_ENABLE_F
P
16
#ifdef CK_ENABLE_
B
F16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
...
...
@@ -653,6 +704,33 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if
constexpr
(
is_same_v
<
ADataType
,
f8_t
>
&&
is_same_v
<
BDataType
,
f8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_universal_reduce.hpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
DsLayout
=
ck
::
Tuple
<>
;
using
DsDataType
=
ck
::
Tuple
<>
;
#ifdef CK_ENABLE_FP16
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
F16
,
F16
,
DsDataType
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
I8
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmV2R1
<
Row
,
Row
,
DsLayout
,
Row
,
BF16
,
BF16
,
DsDataType
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
template
<
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
CLayout
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmV2R1
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmV2R1
<
ALayout
,
BLayout
,
DsLayout
,
CLayout
,
ADataType
,
BDataType
,
DsDataType
,
CDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
#ifdef CK_ENABLE_FP16
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v1_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_f16_f16_f16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_INT8))
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
int8_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_i8_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
#ifdef CK_ENABLE_BF16
if
constexpr
(
is_same_v
<
ADataType
,
bhalf_t
>
&&
is_same_v
<
BDataType
,
bhalf_t
>
&&
is_same_v
<
CDataType
,
bhalf_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnkpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_default_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_kpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_comp_mnpadding_instances
(
op_ptrs
);
add_device_gemm_xdl_universal_reduce_bf16_bf16_bf16_mk_kn_mn_mem_v2_mnkpadding_instances
(
op_ptrs
);
}
}
#endif
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_merged_groups_instance.hpp
deleted
100644 → 0
View file @
be3fbf7f
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
Empty_Tuple
=
ck
::
Tuple
<>
;
using
namespace
ck
::
tensor_layout
::
convolution
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
static
constexpr
auto
ConvFwd3x3
=
ConvolutionForwardSpecialization
::
Filter3x3
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_bf16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| ACompute| BCompute| BlockGemm| NumGroups|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| Type| Type| Pipeline| ToMerge|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| | | Scheduler| |
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
BF16
,
BF16
,
F32
,
BF16
,
DsLayout
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
BF16
,
BF16
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f16_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F16
,
F16
,
F32
,
F16
,
DsLayout
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F16
,
F16
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
ConvolutionForwardSpecialization
ConvSpec
>
using
device_grouped_conv_fwd_xdl_merged_groups_f32_instances
=
std
::
tuple
<
// clang-format off
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Instances with NumGroupsPerBatch > 1
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
8
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
16
>
,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
F32
,
F32
,
F32
,
F32
,
DsLayout
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
ConvSpec
,
GemmMNKPadding
,
1
,
64
,
64
,
16
,
16
,
4
,
4
,
16
,
16
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
4
,
1
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
1
,
F32
,
F32
,
LoopScheduler
::
Default
,
32
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp
View file @
b9eb4de3
...
...
@@ -17,7 +17,6 @@
#endif
#ifdef CK_USE_XDL
#include "grouped_convolution_forward_xdl.inc"
#include "grouped_convolution_forward_xdl_merged_groups.inc"
#include "grouped_convolution_forward_comp_xdl.inc"
#include "grouped_convolution_forward_mem_inter_xdl.inc"
#include "grouped_convolution_forward_mem_intra_xdl.inc"
...
...
@@ -200,8 +199,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
float
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
(
op_ptrs
);
...
...
@@ -215,8 +212,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
half_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instances
(
op_ptrs
);
...
...
@@ -232,8 +227,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instances
(
op_ptrs
);
...
...
@@ -291,8 +284,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
float
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances
(
op_ptrs
);
...
...
@@ -347,8 +338,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
half_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances
(
op_ptrs
);
...
...
@@ -364,8 +353,6 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v
<
BComputeType
,
ck
::
bhalf_t
>
)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances
(
op_ptrs
);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
deleted
100644 → 0
View file @
be3fbf7f
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// grouped conv2d forward, NHWGC/GKYXC/NHWGK
#ifdef CK_ENABLE_BF16
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
BF16
,
BF16
,
Empty_Tuple
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_FP32
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
Empty_Tuple
,
NDHWGK
,
F32
,
F32
,
Empty_Tuple
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt
0 → 100644
View file @
b9eb4de3
# ONLY XDL_KERNELS
set
(
GEMM_AB_SCALE_INSTANCES
)
list
(
APPEND GEMM_AB_SCALE_INSTANCES
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp
device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
)
add_instance_library
(
device_gemm_ab_scale_instance
${
GEMM_AB_SCALE_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
f8_t
;
using
BF16
=
bhalf_t
;
using
F32
=
float
;
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
index_t
...
Is
>
using
S
=
Sequence
<
Is
...
>
;
using
PassThrough
=
element_wise
::
PassThrough
;
using
PassThrough
=
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNPadding
=
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Intrawave
=
BlockGemmPipelineScheduler
::
Intrawave
;
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
// Spill in current compiler
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 224, 256, 128, 16, 16, 16, 16, 7, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
// DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F8, Tuple<F32, F32>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 256, 224, 128, 16, 16, 16, 16, 8, 7, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 1, S<1, 64, 1, 4>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
// clang-format on
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
128
,
128
,
128
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
// Memory friendly
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
128
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
128
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
64
,
32
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
64
,
16
,
128
,
16
,
16
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
128
,
128
,
128
,
16
,
16
,
64
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
128
,
128
,
128
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
16
,
64
,
128
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
32
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
16
,
128
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
128
,
128
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instance.cpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances
<
GemmMNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instance.cpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_mnpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_instances
<
GemmMNPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances
<
Intrawave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instance.cpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_mnkpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD_ABScale
<
Row
,
Col
,
Tuple
<>
,
Row
,
F8
,
F32
,
F8
,
F32
,
Tuple
<>
,
BF16
,
128
,
128
,
128
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_instances
<
Intrawave
,
GemmMNKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/CMakeLists.txt
0 → 100644
View file @
b9eb4de3
# ONLY XDL_KERNELS
set
(
GEMM_MULTIPLY_MULTIPLY_INSTANCES
)
list
(
APPEND GEMM_MULTIPLY_MULTIPLY_INSTANCES
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_mnkpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v2_mnkpadding_instance.cpp
)
add_instance_library
(
device_gemm_multiply_multiply_instance
${
GEMM_MULTIPLY_MULTIPLY_INSTANCES
}
)
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
0 → 100644
View file @
b9eb4de3
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F8
=
f8_t
;
using
BF16
=
bhalf_t
;
using
F32
=
float
;
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
index_t
...
Is
>
using
S
=
Sequence
<
Is
...
>
;
using
PassThrough
=
element_wise
::
PassThrough
;
using
MultiplyMultiply
=
element_wise
::
MultiplyMultiply
;
static
constexpr
auto
GemmDefault
=
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNPadding
=
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Intrawave
=
BlockGemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
BlockGemmPipelineScheduler
::
Interwave
;
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
224
,
256
,
128
,
16
,
16
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
256
,
64
,
16
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
// clang-format on
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
// Memory friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
128
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
128
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
64
,
32
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
64
,
16
,
128
,
16
,
16
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
64
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
// clang-format on
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment