Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f20e48f1
Commit
f20e48f1
authored
Nov 05, 2024
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCm/composable_kernel
into update_cka8w8
parents
b97c6876
0c9012fb
Changes
361
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
671 additions
and
13 deletions
+671
-13
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
...nstance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
...peration_instance/gpu/grouped_convolution_forward_xdl.inc
+32
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
...ance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
+16
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
...nce/gpu/grouped_convolution_forward_xdl_merged_groups.inc
+30
-0
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+127
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+27
-3
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
+6
-4
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
+99
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
...ultiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
...ltiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
...tiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
...iply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
...tiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
...iply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
+33
-0
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
...f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
...f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
+6
-4
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
...conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
+46
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
..._operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
+11
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
...d_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
+39
-0
No files found.
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_mem_intra_xdl.inc
View file @
f20e48f1
...
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instances
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instances
(
...
...
@@ -90,6 +106,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instances
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc
View file @
f20e48f1
...
...
@@ -171,6 +171,22 @@ void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
// grouped conv2d forward, NGCHW/GKYXC/NGKHW
#ifdef CK_ENABLE_FP16
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instances
(
...
...
@@ -204,6 +220,22 @@ void add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instances(
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, GNDHWC/GKZYXC/GNDHWK
void
add_device_grouped_conv3d_fwd_xdl_gndhwc_gkzyxc_gndhwk_bf16_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_large_tensor.inc
View file @
f20e48f1
...
...
@@ -57,6 +57,22 @@ void add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instan
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl_merged_groups.inc
View file @
f20e48f1
...
...
@@ -85,6 +85,36 @@ void add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_insta
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_INT8
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NHWGC
,
GKYXC
,
Empty_Tuple
,
NHWGK
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
#endif
#ifdef CK_ENABLE_BF16
// grouped conv3d forward, NDHWGC/GKZYXC/NDHWGK
void
add_device_grouped_conv3d_fwd_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances
(
...
...
library/include/ck/library/utility/check_err.hpp
View file @
f20e48f1
...
...
@@ -23,6 +23,130 @@
namespace
ck
{
namespace
utils
{
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
numberOfAccumulations
=
1
)
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F8
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_same_v
<
OutDataType
,
F8
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
-
NumericUtils
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F8
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
numberOfAccumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
numberOfAccumulations
=
1
)
{
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F8
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
)
{
return
0
;
}
else
{
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_same_v
<
OutDataType
,
F8
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
)
{
return
0
;
}
else
{
output_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
OutDataType
>::
mant
)
*
0.5
;
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F8
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
)
{
return
0
;
}
else
{
acc_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
AccDataType
>::
mant
)
*
0.5
*
numberOfAccumulations
;
}
return
std
::
max
(
acc_error
,
midway_error
);
}
template
<
typename
Range
,
typename
RefRange
>
typename
std
::
enable_if
<
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
...
...
@@ -253,11 +377,13 @@ check_err(const Range& out,
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
...
...
@@ -270,6 +396,7 @@ check_err(const Range& out,
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
f20e48f1
...
...
@@ -67,6 +67,21 @@ function(add_instance_library INSTANCE_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
message
(
"removing gemm_multiply_multiply_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
endif
()
#only continue if there are some source files left on the list
if
(
ARGN
)
set
(
INST_OBJ
)
...
...
@@ -74,11 +89,20 @@ function(add_instance_library INSTANCE_NAME)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
elseif
(
ARGN
MATCHES
"_wmma"
)
elseif
(
source
MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN
MATCHES
"mha"
)
elseif
(
source
MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
#only build the fp8 gemm instances for gfx908/90a if the build argument is set
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
if
(
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"f8"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
if
(
source MATCHES
"gemm_multiply_multiply_f8"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
endif
()
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
string
(
APPEND offload_targets
"--offload-arch=
${
target
}
"
)
...
...
@@ -108,7 +132,7 @@ function(add_instance_library INSTANCE_NAME)
# flags to compress the library
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600241132
)
message
(
"Adding --offload-compress flag for
${
INSTANCE_NAME
}
"
)
#
message("Adding --offload-compress flag for ${INSTANCE_NAME}")
target_compile_options
(
${
INSTANCE_NAME
}
PRIVATE --offload-compress
)
endif
()
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
View file @
f20e48f1
...
...
@@ -36,12 +36,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
...
@@ -58,17 +58,18 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
#endif
// clang-format on
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
...
@@ -90,6 +91,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std:
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
I8
=
int8_t
;
using
I32
=
int
;
using
BF16
=
bhalf_t
;
using
F32
=
float
;
using
Row
=
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
index_t
...
Is
>
using
S
=
Sequence
<
Is
...
>
;
using
PassThrough
=
element_wise
::
PassThrough
;
using
MultiplyMultiply
=
element_wise
::
MultiplyMultiply
;
static
constexpr
auto
GemmDefault
=
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmKPadding
=
GemmSpecialization
::
KPadding
;
static
constexpr
auto
GemmMNPadding
=
GemmSpecialization
::
MNPadding
;
static
constexpr
auto
GemmMNKPadding
=
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
Intrawave
=
BlockGemmPipelineScheduler
::
Intrawave
;
static
constexpr
auto
Interwave
=
BlockGemmPipelineScheduler
::
Interwave
;
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
64
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
128
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
16
,
16
,
8
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
224
,
256
,
128
,
16
,
16
,
16
,
16
,
7
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
224
,
128
,
16
,
16
,
16
,
16
,
8
,
7
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
2
,
1
,
S
<
1
,
64
,
1
,
4
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
256
,
64
,
16
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
128
,
64
,
16
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Interwave
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
I8
>
// clang-format oI
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
I8
>
,
// Memory friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
128
,
32
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
128
,
16
,
128
,
16
,
16
,
16
,
16
,
4
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
64
,
32
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
64
,
16
,
128
,
16
,
16
,
16
,
16
,
2
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
64
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
32
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
16
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
I32
,
I32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
I8
>
// clang-format oI
>
;
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instance.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instance.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_comp_instances
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instance.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instance.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v1_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Intrawave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instance.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_default_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Interwave
,
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_i8_i8_bf16/device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instance.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_v2_kpadding_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
I8
,
I8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_i8_i8_bf16_mk_nk_mn_mem_instances
<
Interwave
,
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
View file @
f20e48f1
...
...
@@ -62,12 +62,12 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple<
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
...
@@ -90,6 +90,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
4
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
View file @
f20e48f1
...
...
@@ -35,12 +35,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Compute friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
...
@@ -57,17 +57,18 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
// DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
#endif
// clang-format on
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
...
@@ -97,6 +98,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
>
;
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
View file @
f20e48f1
...
...
@@ -6,6 +6,7 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
BF16
,
BF16
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_xdl_c_shuffle_bf16_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/CMakeLists.txt
View file @
f20e48f1
...
...
@@ -9,45 +9,56 @@ add_instance_library(device_grouped_conv2d_fwd_instance
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_instance.cpp
# large tensor
# NHWGC, GKYXC, NHWGK
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/large_tensor/device_grouped_conv2d_fwd_xdl_large_tensor_nhwgc_gkyxc_nhwgk_int8_instance.cpp
# merged groups
# NHWGC, GKYXC, NHWGK
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_nhwgc_gkyxc_nhwgk_int8_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f16_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_f32_instance.cpp
xdl/merged_groups/device_grouped_conv2d_fwd_xdl_merged_groups_ngchw_gkyxc_ngkhw_int8_instance.cpp
#mem
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp
# NHWGC, GKYXC, NHWGK
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_intra_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_intra_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_mem_inter_instance.cpp
xdl/mem/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_mem_inter_instance.cpp
#comp
# NHWGC, GKYXC, NHWGK
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_bf16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_comp_instance.cpp
# NGCHW, GKYXC, NGKHW
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f16_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_f32_comp_instance.cpp
xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
#dl
# GNHWC, GKYXC, GNHWK
dl/device_grouped_conv2d_fwd_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instance.cpp
0 → 100644
View file @
f20e48f1
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_fwd_xdl_ngchw_gkyxc_ngkhw_int8_comp_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
int8_t
,
int8_t
,
Empty_Tuple
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_int8_comp_instances
<
2
,
NGCHW
,
GKYXC
,
Empty_Tuple
,
NGKHW
,
ConvFwdDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
12
13
14
15
16
17
18
19
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment