Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
129e58ae
Commit
129e58ae
authored
Jun 05, 2024
by
Adam Osewski
Browse files
Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2
parents
9bebfd42
cb0645be
Changes
188
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1050 additions
and
141 deletions
+1050
-141
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
...i_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+1
-106
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
...i_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
+3
-1
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
...t_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
+4
-11
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
...t_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
+10
-8
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
...wo_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
+4
-11
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
...wo_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
..._instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
+5
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
...3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
+62
-0
profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp
...r/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp
+1
-1
profiler/include/profiler/profile_grouped_gemm_impl.hpp
profiler/include/profiler/profile_grouped_gemm_impl.hpp
+1
-1
profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp
.../include/profiler/profile_grouped_gemm_tile_loop_impl.hpp
+1
-1
profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp
.../include/profiler/profile_grouped_gemm_two_stage_impl.hpp
+1
-1
pyproject.toml
pyproject.toml
+36
-0
python/ck4inductor/__init__.py
python/ck4inductor/__init__.py
+0
-0
python/ck4inductor/universal_gemm/gen_instances.py
python/ck4inductor/universal_gemm/gen_instances.py
+570
-0
python/ck4inductor/universal_gemm/op.py
python/ck4inductor/universal_gemm/op.py
+95
-0
No files found.
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
Multiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
...
...
@@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
Multiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyFastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyFastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
View file @
129e58ae
...
...
@@ -6,7 +6,9 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
)
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
list
(
APPEND GROUPED_CONV2D_BWD_WEIGHT
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_
pipev2_
instance.cpp
View file @
129e58ae
...
...
@@ -10,7 +10,7 @@ namespace device {
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_
pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
...
...
@@ -30,16 +30,9 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_in
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
View file @
129e58ae
# XDL_DL_WMMA_KERNELS
# XDL_DL_WMMA_KERNELS
set
(
GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
)
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
list
(
APPEND GROUPED_CONV3D_BWD_WEIGHT
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_
pipev2_
instance.cpp
View file @
129e58ae
...
...
@@ -10,7 +10,7 @@ namespace device {
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_
pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
...
...
@@ -30,16 +30,9 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/CMakeLists.txt
0 → 100644
View file @
129e58ae
# ONLY XDL_KERNELS
set
(
GROUPED_CONV3D_FWD_CONVSCALE
xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
)
add_instance_library
(
device_grouped_conv3d_fwd_convscale_instance
${
GROUPED_CONV3D_FWD_CONVSCALE
}
)
library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_convscale/xdl/device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instance.cpp
0 → 100644
View file @
129e58ae
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_outelementop_instance.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
ConvScale
=
ck
::
tensor_operation
::
element_wise
::
ConvScale
;
void
add_device_grouped_conv3d_fwd_xdl_convscale_ndhwgc_gkzyxc_ndhwgk_f8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvFwdMultipleABD
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
F8
,
F8
,
ck
::
Tuple
<>
,
F8
,
PassThrough
,
PassThrough
,
ConvScale
,
F8
,
F8
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwdDefault
,
ConvScale
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1P0
,
ConvScale
>
{});
add_device_operation_instances
(
instances
,
device_grouped_conv_fwd_xdl_outelementop_f8_instances
<
3
,
NDHWGC
,
GKZYXC
,
ck
::
Tuple
<>
,
NDHWGK
,
ConvFwd1x1S1P0
,
ConvScale
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp
View file @
129e58ae
...
...
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification,
c_m_n_host_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
profiler/include/profiler/profile_grouped_gemm_impl.hpp
View file @
129e58ae
...
...
@@ -87,7 +87,7 @@ bool profile_grouped_gemm_impl(int do_verification,
c_m_n_host_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
profiler/include/profiler/profile_grouped_gemm_tile_loop_impl.hpp
View file @
129e58ae
...
...
@@ -82,7 +82,7 @@ bool profile_grouped_gemm_tile_loop_impl(int do_verification,
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
c_m_n_host_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
profiler/include/profiler/profile_grouped_gemm_two_stage_impl.hpp
View file @
129e58ae
...
...
@@ -88,7 +88,7 @@ bool profile_grouped_gemm_two_stage_impl(int do_verification,
c_m_n_host_results
.
push_back
(
Tensor
<
CDataType
>
(
f_host_tensor_descriptor
(
Ms
[
i
],
Ns
[
i
],
StrideCs
[
i
],
CLayout
{})));
if
(
ck
::
EnvIsEnabled
(
ENV
(
CK_LOGGING
)))
if
(
ck
::
EnvIsEnabled
(
CK_
ENV
(
CK_LOGGING
)))
{
std
::
cout
<<
"group: "
<<
i
<<
" a_m_k["
<<
i
<<
"]:"
<<
a_m_k
[
i
].
mDesc
<<
", b_k_n["
<<
i
<<
"]:"
<<
b_k_n
[
i
].
mDesc
<<
", c_m_n_device_results["
<<
i
...
...
pyproject.toml
0 → 100644
View file @
129e58ae
[build-system]
requires
=
[
"setuptools"
,
"setuptools-scm"
]
build-backend
=
"setuptools.build_meta"
[project]
name
=
"rocm-composable-kernel"
dynamic
=
["version"]
description
=
"Composable Kernel, performance-critical kernels for machine learning workloads"
readme
=
"README.md"
requires-python
=
">=3.8"
license
=
{
file
=
"LICENSE"
}
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: MIT License"
,
"Operating System :: OS Independent"
,
]
dependencies
=
[]
[project.urls]
"Homepage"
=
"https://github.com/rocm/composable_kernel"
"Bug
Tracker"
=
"https://github.com/rocm/composable_kernel/issues"
[tool.setuptools]
packages
=
[
"ck4inductor"
,
"ck4inductor.include"
,
"ck4inductor.library"
]
[tool.setuptools.package-dir]
ck4inductor
=
"python/ck4inductor"
"ck4inductor.include"
=
"include"
"ck4inductor.library"
=
"library"
[tool.setuptools.package-data]
"ck4inductor.include"
=
["ck/**/*.hpp"]
"ck4inductor.library"
=
["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
[tool.setuptools.dynamic]
version
=
{
attr
=
"setuptools_scm.get_version"
}
python/ck4inductor/__init__.py
0 → 100644
View file @
129e58ae
python/ck4inductor/universal_gemm/gen_instances.py
0 → 100644
View file @
129e58ae
import
logging
import
os
import
subprocess
from
dataclasses
import
fields
,
replace
from
functools
import
lru_cache
,
partial
from
typing
import
List
from
..util
import
library_path
from
.op
import
CKGemmOperation
log
=
logging
.
getLogger
(
__name__
)
def
_ck_library_dir
():
gemm_instances_path
=
os
.
path
.
join
(
library_path
(),
"src"
,
"tensor_operation_instance"
,
"gpu"
,
"gemm_universal"
)
if
not
os
.
path
.
exists
(
gemm_instances_path
):
log
.
error
(
"CK library path %s does not exist"
,
gemm_instances_path
)
return
None
return
gemm_instances_path
def
parse_instances
(
str_instances
:
List
[
str
])
->
List
[
CKGemmOperation
]:
"""
Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances
"""
def
maybe_int
(
s
):
try
:
return
int
(
s
)
except
ValueError
:
return
s
op_instances
=
[]
for
line
in
str_instances
:
s_template_args
=
line
.
split
(
"DeviceGemm_Xdl_CShuffleV3"
)[
-
1
].
strip
(
"<>, "
)
template_args
=
[]
i_current
=
0
while
i_current
<
len
(
s_template_args
):
if
s_template_args
[
i_current
]
==
" "
:
# skip whitespace
i_current
+=
1
continue
elif
s_template_args
[
i_current
:
i_current
+
2
]
==
"S<"
:
# parse template S<Index...>
i_next
=
s_template_args
.
find
(
">"
,
i_current
)
template_args
.
append
(
tuple
(
map
(
int
,
s_template_args
[
i_current
+
2
:
i_next
].
split
(
","
)))
)
i_current
=
i_next
+
2
else
:
# all string attributes must be either type aliases or global constants in C++
i_next
=
s_template_args
.
find
(
","
,
i_current
)
template_args
.
append
(
maybe_int
(
s_template_args
[
i_current
:
i_next
if
i_next
!=
-
1
else
None
]
)
)
if
i_next
!=
-
1
:
i_current
=
i_next
+
1
if
i_next
==
-
1
:
break
# pad with `None`s for the fields which are not defined in the instance
new_instance
=
CKGemmOperation
(
*
template_args
,
# type: ignore[arg-type]
*
((
None
,)
*
(
len
(
fields
(
CKGemmOperation
))
-
len
(
template_args
))),
)
# the last 2 template parameters are optional
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
if
new_instance
.
a_compute_dtype
is
None
:
new_instance
.
a_compute_dtype
=
new_instance
.
c_element_dtype
if
new_instance
.
b_compute_dtype
is
None
:
new_instance
.
b_compute_dtype
=
new_instance
.
c_element_dtype
op_instances
.
append
(
new_instance
)
return
op_instances
def
default_instances
()
->
List
[
CKGemmOperation
]:
# fallback: known working op instance for problem size M=2240 K=256 N=2048
# all string attributes must be either type aliases or global constants in C++
return
[
CKGemmOperation
(
a_layout
=
"Row"
,
b_layout
=
"Row"
,
c_layout
=
"Row"
,
a_element_dtype
=
"F16"
,
b_element_dtype
=
"F16"
,
c_element_dtype
=
"F16"
,
a_compute_dtype
=
"F16"
,
b_compute_dtype
=
"F16"
,
acc_dtype
=
"F32"
,
c_shuffle_dtype
=
"F16"
,
a_elementwise_op
=
"PassThrough"
,
b_elementwise_op
=
"PassThrough"
,
c_elementwise_op
=
"PassThrough"
,
gemm_specialization
=
"GemmSpecialization::Default"
,
block_size
=
256
,
m_per_block
=
224
,
n_per_block
=
256
,
k_per_block
=
64
,
a_k1
=
8
,
b_k1
=
2
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
7
,
n_xdl_per_wave
=
8
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
32
,
1
),
a_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
a_block_transfer_src_access_order
=
(
1
,
0
,
2
),
a_block_transfer_src_vector_dim
=
2
,
a_block_transfer_src_scalar_per_vector
=
8
,
a_block_transfer_dst_scalar_per_vector_ak1
=
8
,
a_block_lds_extra_m
=
0
,
# type: ignore[arg-type]
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
32
,
1
),
b_block_transfer_thread_cluster_arrange_order
=
(
0
,
2
,
1
),
b_block_transfer_src_access_order
=
(
0
,
2
,
1
),
b_block_transfer_src_vector_dim
=
1
,
b_block_transfer_src_scalar_per_vector
=
8
,
b_block_transfer_dst_scalar_per_vector_bk1
=
2
,
b_block_lds_extra_n
=
0
,
# type: ignore[arg-type]
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
)
]
@
lru_cache
(
None
)
def
gen_ops_library
()
->
List
[
CKGemmOperation
]:
"""
Parse the Universal Gemm instances defined in the composable kernel library folder.
"""
ck_library_dir
=
_ck_library_dir
()
if
not
ck_library_dir
:
return
[]
grep_result
=
subprocess
.
run
(
[
"grep"
,
"-inR"
,
"DeviceGemm_Xdl_CShuffleV3"
,
_ck_library_dir
(),
],
capture_output
=
True
,
text
=
True
,
)
op_instances
=
parse_instances
(
grep_result
.
stdout
.
strip
().
split
(
"
\n
"
))
log
.
debug
(
"ck instances from library: %d"
,
len
(
op_instances
))
schedulers
=
[
"BlockGemmPipelineScheduler::Intrawave"
,
"BlockGemmPipelineScheduler::Interwave"
,
]
gemm_specs
=
[
"GemmSpecialization::Default"
,
"GemmSpecialization::MPadding"
,
"GemmSpecialization::NPadding"
,
"GemmSpecialization::KPadding"
,
"GemmSpecialization::MNPadding"
,
"GemmSpecialization::MKPadding"
,
"GemmSpecialization::NKPadding"
,
"GemmSpecialization::MNKPadding"
,
]
# substitute templated args by looping through their domains
substitute_instances
=
[]
for
instance
in
op_instances
:
sub_scheduler
=
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
sub_spec
=
instance
.
gemm_specialization
==
"GemmSpec"
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
)
spec_range
=
gemm_specs
if
sub_spec
else
[
instance
.
gemm_specialization
]
for
scheduler
in
schedulers_range
:
for
spec
in
spec_range
:
substitute_instances
.
append
(
replace
(
instance
,
block_gemm_pipeline_scheduler
=
scheduler
,
gemm_specialization
=
spec
,
)
)
return
substitute_instances
@
lru_cache
(
None
)
def
gen_ops_preselected
()
->
List
[
CKGemmOperation
]:
"""
Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances
"""
ck_gemm_f16_rcr
=
partial
(
CKGemmOperation
,
a_layout
=
"Row"
,
b_layout
=
"Col"
,
c_layout
=
"Row"
,
a_element_dtype
=
"F16"
,
b_element_dtype
=
"F16"
,
c_element_dtype
=
"F16"
,
acc_dtype
=
"F32"
,
c_shuffle_dtype
=
"F16"
,
a_elementwise_op
=
"PassThrough"
,
b_elementwise_op
=
"PassThrough"
,
c_elementwise_op
=
"PassThrough"
,
k_per_block
=
64
,
a_k1
=
8
,
b_k1
=
8
,
a_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
a_block_transfer_src_access_order
=
(
1
,
0
,
2
),
a_block_transfer_src_vector_dim
=
2
,
a_block_transfer_src_scalar_per_vector
=
8
,
a_block_transfer_dst_scalar_per_vector_ak1
=
8
,
a_block_lds_extra_m
=
0
,
b_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
b_block_transfer_src_access_order
=
(
1
,
0
,
2
),
b_block_transfer_src_vector_dim
=
2
,
b_block_transfer_src_scalar_per_vector
=
8
,
b_block_transfer_dst_scalar_per_vector_bk1
=
8
,
b_block_lds_extra_n
=
0
,
a_compute_dtype
=
"F16"
,
b_compute_dtype
=
"F16"
,
)
ck_gemm_f16_rcr_compute_friendly
=
partial
(
ck_gemm_f16_rcr
,
block_size
=
256
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
32
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
32
,
1
),
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
)
ck_gemm_f16_rcr_memory_friendly
=
partial
(
ck_gemm_f16_rcr
,
block_size
=
128
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
16
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
16
,
1
),
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Interwave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v2"
,
)
ck_gemm_f16_rcr_latency_friendly
=
partial
(
ck_gemm_f16_rcr
,
gemm_specialization
=
"GemmSpecialization::Default"
,
block_size
=
128
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
16
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
16
,
1
),
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v1"
,
)
return
[
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
224
,
n_per_block
=
256
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
7
,
n_xdl_per_wave
=
8
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v4"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v5"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v4"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v5"
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
16
,
n_per_block
=
32
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
16
,
n_per_block
=
32
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
16
,
n_per_block
=
64
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
64
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
32
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
64
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
64
,
1
,
2
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
64
,
n_per_block
=
32
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
32
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_latency_friendly
(
m_per_block
=
16
,
n_per_block
=
32
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
),
ck_gemm_f16_rcr_latency_friendly
(
m_per_block
=
32
,
n_per_block
=
16
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
),
]
if
__name__
==
"__main__"
:
print
(
gen_ops_library
())
python/ck4inductor/universal_gemm/op.py
0 → 100644
View file @
129e58ae
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
@
dataclass
class
CKGemmOperation
:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout
:
str
b_layout
:
str
c_layout
:
str
a_element_dtype
:
str
b_element_dtype
:
str
c_element_dtype
:
str
acc_dtype
:
str
c_shuffle_dtype
:
str
a_elementwise_op
:
str
b_elementwise_op
:
str
c_elementwise_op
:
str
gemm_specialization
:
str
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
a_k1
:
int
b_k1
:
int
m_per_xdl
:
int
n_per_xdl
:
int
m_xdl_per_wave
:
int
n_xdl_per_wave
:
int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_vector_dim
:
int
a_block_transfer_src_scalar_per_vector
:
int
a_block_transfer_dst_scalar_per_vector_ak1
:
int
a_block_lds_extra_m
:
bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_vector_dim
:
int
b_block_transfer_src_scalar_per_vector
:
int
b_block_transfer_dst_scalar_per_vector_bk1
:
int
b_block_lds_extra_n
:
bool
c_shuffle_m_xdl_per_wave_per_shuffle
:
int
c_shuffle_n_xdl_per_wave_per_shuffle
:
int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
:
(
Tuple
[
int
,
int
,
int
,
int
]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block
:
int
block_gemm_pipeline_scheduler
:
str
block_gemm_pipeline_version
:
Optional
[
str
]
a_compute_dtype
:
Optional
[
str
]
b_compute_dtype
:
Optional
[
str
]
def
name
(
self
):
# cpp alias for template instance
return
f
"ck_devicegemm_xdl_shuffle_v3_
{
self
.
key_name
()
}
"
def
key_name
(
self
):
# TBD; must be unique per instance. Intended to use as dict key
return
"_"
.
join
(
[
"K"
+
field_name
.
replace
(
"_"
,
""
).
lower
()
+
"V"
+
(
"x"
.
join
(
map
(
str
,
iter
(
field_value
)))
if
isinstance
(
field_value
,
tuple
)
else
str
(
field_value
).
replace
(
":"
,
""
)
)
for
field_name
,
field_value
in
self
.
dict_items
()
]
)
def
dict_items
(
self
):
return
asdict
(
self
).
items
()
Prev
1
…
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment