Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
fbd9d357
Unverified
Commit
fbd9d357
authored
May 28, 2024
by
Illia Silin
Committed by
GitHub
May 28, 2024
Browse files
Merge pull request #68 from ROCm/merge_from_public
Merge from public
parents
395b155a
22593e25
Changes
55
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
907 additions
and
153 deletions
+907
-153
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
..._multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+1
-106
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
...i_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
+58
-0
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
+3
-1
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
...t_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
+4
-11
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
...t_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
+41
-0
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
...ion_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
+10
-8
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
...wo_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
+4
-11
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
...wo_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
+41
-0
pyproject.toml
pyproject.toml
+36
-0
python/ck4inductor/__init__.py
python/ck4inductor/__init__.py
+0
-0
python/ck4inductor/universal_gemm/gen_instances.py
python/ck4inductor/universal_gemm/gen_instances.py
+570
-0
python/ck4inductor/universal_gemm/op.py
python/ck4inductor/universal_gemm/op.py
+95
-0
python/ck4inductor/util.py
python/ck4inductor/util.py
+7
-0
test/CMakeLists.txt
test/CMakeLists.txt
+31
-3
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
...uped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
+6
-13
No files found.
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
View file @
fbd9d357
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
...
...
@@ -52,111 +52,6 @@ void add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_gelu_v1_i
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_bias_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyAdd
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
D0Layout
,
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
D0DataType
,
B1DataType
>
,
PassThrough
,
MultiplyAdd
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
Multiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
Multiply
,
GemmMNKPadding
,
Interwave
>
{});
}
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyFastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
...
...
library/src/tensor_operation_instance/gpu/gemm_multi_abd/device_gemm_xdl_multi_abd_multiply_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
0 → 100644
View file @
fbd9d357
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_common.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_xdl_multi_abd_multiply_bf16_i8_bf16_mk_kn_mn_gelu_v1_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleABD
<
AsLayout
,
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ELayout
,
AsDataType
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
EDataType
,
AElementOp
,
PassThrough
,
MultiplyFastGelu
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_comp_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
add_device_operation_instances
(
instances
,
device_gemm_xdl_multi_abd_bf16_i8_bf16_mk_kn_mn_mem_instances
<
ck
::
Tuple
<
B0Layout
>
,
ck
::
Tuple
<
B1Layout
>
,
ck
::
Tuple
<
B0DataType
>
,
ck
::
Tuple
<
B1DataType
>
,
PassThrough
,
MultiplyFastGelu
,
GemmMNKPadding
,
Interwave
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/CMakeLists.txt
View file @
fbd9d357
...
...
@@ -6,7 +6,9 @@ set(GROUPED_CONV2D_BWD_WEIGHT
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
)
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
list
(
APPEND GROUPED_CONV2D_BWD_WEIGHT
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_
pipev2_
instance.cpp
View file @
fbd9d357
...
...
@@ -10,7 +10,7 @@ namespace device {
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_instances
(
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_
pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
...
...
@@ -30,16 +30,9 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_in
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp
0 → 100644
View file @
fbd9d357
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
2
,
NHWGC
,
GKYXC
,
NHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt
View file @
fbd9d357
# XDL_DL_WMMA_KERNELS
# XDL_DL_WMMA_KERNELS
set
(
GROUPED_CONV3D_BWD_WEIGHT
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
)
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp
xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
)
if
(
DL_KERNELS
)
list
(
APPEND GROUPED_CONV3D_BWD_WEIGHT
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
→
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_
pipev2_
instance.cpp
View file @
fbd9d357
...
...
@@ -10,7 +10,7 @@ namespace device {
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
(
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_
pipev2_
instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
...
...
@@ -30,16 +30,9 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
>
{});
// 2. Filter1x1Stride1Pad0
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightFilter1x1Stride1Pad0
>
{});
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v2
>
{});
}
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp
0 → 100644
View file @
fbd9d357
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k]
void
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdWeight
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
// 1. Default
add_device_operation_instances
(
instances
,
device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances
<
3
,
NDHWGC
,
GKZYXC
,
NDHWGK
,
ConvBwdWeightDefault
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v5
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
pyproject.toml
0 → 100644
View file @
fbd9d357
[build-system]
requires
=
[
"setuptools"
,
"setuptools-scm"
]
build-backend
=
"setuptools.build_meta"
[project]
name
=
"rocm-composable-kernel"
dynamic
=
["version"]
description
=
"Composable Kernel, performance-critical kernels for machine learning workloads"
readme
=
"README.md"
requires-python
=
">=3.8"
license
=
{
file
=
"LICENSE"
}
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: MIT License"
,
"Operating System :: OS Independent"
,
]
dependencies
=
[]
[project.urls]
"Homepage"
=
"https://github.com/rocm/composable_kernel"
"Bug
Tracker"
=
"https://github.com/rocm/composable_kernel/issues"
[tool.setuptools]
packages
=
[
"ck4inductor"
,
"ck4inductor.include"
,
"ck4inductor.library"
]
[tool.setuptools.package-dir]
ck4inductor
=
"python/ck4inductor"
"ck4inductor.include"
=
"include"
"ck4inductor.library"
=
"library"
[tool.setuptools.package-data]
"ck4inductor.include"
=
["ck/**/*.hpp"]
"ck4inductor.library"
=
["src/tensor_operation_instance/gpu/gemm_universal/**/*.hpp"]
[tool.setuptools.dynamic]
version
=
{
attr
=
"setuptools_scm.get_version"
}
python/ck4inductor/__init__.py
0 → 100644
View file @
fbd9d357
python/ck4inductor/universal_gemm/gen_instances.py
0 → 100644
View file @
fbd9d357
import
logging
import
os
import
subprocess
from
dataclasses
import
fields
,
replace
from
functools
import
lru_cache
,
partial
from
typing
import
List
from
..util
import
library_path
from
.op
import
CKGemmOperation
log
=
logging
.
getLogger
(
__name__
)
def
_ck_library_dir
():
gemm_instances_path
=
os
.
path
.
join
(
library_path
(),
"src"
,
"tensor_operation_instance"
,
"gpu"
,
"gemm_universal"
)
if
not
os
.
path
.
exists
(
gemm_instances_path
):
log
.
error
(
"CK library path %s does not exist"
,
gemm_instances_path
)
return
None
return
gemm_instances_path
def
parse_instances
(
str_instances
:
List
[
str
])
->
List
[
CKGemmOperation
]:
"""
Parse the lines containing Universal Gemm template instances into `CKGemmOperation` instances
"""
def
maybe_int
(
s
):
try
:
return
int
(
s
)
except
ValueError
:
return
s
op_instances
=
[]
for
line
in
str_instances
:
s_template_args
=
line
.
split
(
"DeviceGemm_Xdl_CShuffleV3"
)[
-
1
].
strip
(
"<>, "
)
template_args
=
[]
i_current
=
0
while
i_current
<
len
(
s_template_args
):
if
s_template_args
[
i_current
]
==
" "
:
# skip whitespace
i_current
+=
1
continue
elif
s_template_args
[
i_current
:
i_current
+
2
]
==
"S<"
:
# parse template S<Index...>
i_next
=
s_template_args
.
find
(
">"
,
i_current
)
template_args
.
append
(
tuple
(
map
(
int
,
s_template_args
[
i_current
+
2
:
i_next
].
split
(
","
)))
)
i_current
=
i_next
+
2
else
:
# all string attributes must be either type aliases or global constants in C++
i_next
=
s_template_args
.
find
(
","
,
i_current
)
template_args
.
append
(
maybe_int
(
s_template_args
[
i_current
:
i_next
if
i_next
!=
-
1
else
None
]
)
)
if
i_next
!=
-
1
:
i_current
=
i_next
+
1
if
i_next
==
-
1
:
break
# pad with `None`s for the fields which are not defined in the instance
new_instance
=
CKGemmOperation
(
*
template_args
,
# type: ignore[arg-type]
*
((
None
,)
*
(
len
(
fields
(
CKGemmOperation
))
-
len
(
template_args
))),
)
# the last 2 template parameters are optional
# if they are absent, substitute them with default values from Universal Gemm C++ template declaration
if
new_instance
.
a_compute_dtype
is
None
:
new_instance
.
a_compute_dtype
=
new_instance
.
c_element_dtype
if
new_instance
.
b_compute_dtype
is
None
:
new_instance
.
b_compute_dtype
=
new_instance
.
c_element_dtype
op_instances
.
append
(
new_instance
)
return
op_instances
def
default_instances
()
->
List
[
CKGemmOperation
]:
# fallback: known working op instance for problem size M=2240 K=256 N=2048
# all string attributes must be either type aliases or global constants in C++
return
[
CKGemmOperation
(
a_layout
=
"Row"
,
b_layout
=
"Row"
,
c_layout
=
"Row"
,
a_element_dtype
=
"F16"
,
b_element_dtype
=
"F16"
,
c_element_dtype
=
"F16"
,
a_compute_dtype
=
"F16"
,
b_compute_dtype
=
"F16"
,
acc_dtype
=
"F32"
,
c_shuffle_dtype
=
"F16"
,
a_elementwise_op
=
"PassThrough"
,
b_elementwise_op
=
"PassThrough"
,
c_elementwise_op
=
"PassThrough"
,
gemm_specialization
=
"GemmSpecialization::Default"
,
block_size
=
256
,
m_per_block
=
224
,
n_per_block
=
256
,
k_per_block
=
64
,
a_k1
=
8
,
b_k1
=
2
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
7
,
n_xdl_per_wave
=
8
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
32
,
1
),
a_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
a_block_transfer_src_access_order
=
(
1
,
0
,
2
),
a_block_transfer_src_vector_dim
=
2
,
a_block_transfer_src_scalar_per_vector
=
8
,
a_block_transfer_dst_scalar_per_vector_ak1
=
8
,
a_block_lds_extra_m
=
0
,
# type: ignore[arg-type]
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
32
,
1
),
b_block_transfer_thread_cluster_arrange_order
=
(
0
,
2
,
1
),
b_block_transfer_src_access_order
=
(
0
,
2
,
1
),
b_block_transfer_src_vector_dim
=
1
,
b_block_transfer_src_scalar_per_vector
=
8
,
b_block_transfer_dst_scalar_per_vector_bk1
=
2
,
b_block_lds_extra_n
=
0
,
# type: ignore[arg-type]
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
)
]
@
lru_cache
(
None
)
def
gen_ops_library
()
->
List
[
CKGemmOperation
]:
"""
Parse the Universal Gemm instances defined in the composable kernel library folder.
"""
ck_library_dir
=
_ck_library_dir
()
if
not
ck_library_dir
:
return
[]
grep_result
=
subprocess
.
run
(
[
"grep"
,
"-inR"
,
"DeviceGemm_Xdl_CShuffleV3"
,
_ck_library_dir
(),
],
capture_output
=
True
,
text
=
True
,
)
op_instances
=
parse_instances
(
grep_result
.
stdout
.
strip
().
split
(
"
\n
"
))
log
.
debug
(
"ck instances from library: %d"
,
len
(
op_instances
))
schedulers
=
[
"BlockGemmPipelineScheduler::Intrawave"
,
"BlockGemmPipelineScheduler::Interwave"
,
]
gemm_specs
=
[
"GemmSpecialization::Default"
,
"GemmSpecialization::MPadding"
,
"GemmSpecialization::NPadding"
,
"GemmSpecialization::KPadding"
,
"GemmSpecialization::MNPadding"
,
"GemmSpecialization::MKPadding"
,
"GemmSpecialization::NKPadding"
,
"GemmSpecialization::MNKPadding"
,
]
# substitute templated args by looping through their domains
substitute_instances
=
[]
for
instance
in
op_instances
:
sub_scheduler
=
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
sub_spec
=
instance
.
gemm_specialization
==
"GemmSpec"
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
)
spec_range
=
gemm_specs
if
sub_spec
else
[
instance
.
gemm_specialization
]
for
scheduler
in
schedulers_range
:
for
spec
in
spec_range
:
substitute_instances
.
append
(
replace
(
instance
,
block_gemm_pipeline_scheduler
=
scheduler
,
gemm_specialization
=
spec
,
)
)
return
substitute_instances
@
lru_cache
(
None
)
def
gen_ops_preselected
()
->
List
[
CKGemmOperation
]:
"""
Manually selected (through benchmarking) F16/F16/F16 Row/Col/Row instances
"""
ck_gemm_f16_rcr
=
partial
(
CKGemmOperation
,
a_layout
=
"Row"
,
b_layout
=
"Col"
,
c_layout
=
"Row"
,
a_element_dtype
=
"F16"
,
b_element_dtype
=
"F16"
,
c_element_dtype
=
"F16"
,
acc_dtype
=
"F32"
,
c_shuffle_dtype
=
"F16"
,
a_elementwise_op
=
"PassThrough"
,
b_elementwise_op
=
"PassThrough"
,
c_elementwise_op
=
"PassThrough"
,
k_per_block
=
64
,
a_k1
=
8
,
b_k1
=
8
,
a_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
a_block_transfer_src_access_order
=
(
1
,
0
,
2
),
a_block_transfer_src_vector_dim
=
2
,
a_block_transfer_src_scalar_per_vector
=
8
,
a_block_transfer_dst_scalar_per_vector_ak1
=
8
,
a_block_lds_extra_m
=
0
,
b_block_transfer_thread_cluster_arrange_order
=
(
1
,
0
,
2
),
b_block_transfer_src_access_order
=
(
1
,
0
,
2
),
b_block_transfer_src_vector_dim
=
2
,
b_block_transfer_src_scalar_per_vector
=
8
,
b_block_transfer_dst_scalar_per_vector_bk1
=
8
,
b_block_lds_extra_n
=
0
,
a_compute_dtype
=
"F16"
,
b_compute_dtype
=
"F16"
,
)
ck_gemm_f16_rcr_compute_friendly
=
partial
(
ck_gemm_f16_rcr
,
block_size
=
256
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
32
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
32
,
1
),
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
)
ck_gemm_f16_rcr_memory_friendly
=
partial
(
ck_gemm_f16_rcr
,
block_size
=
128
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
16
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
16
,
1
),
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Interwave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v2"
,
)
ck_gemm_f16_rcr_latency_friendly
=
partial
(
ck_gemm_f16_rcr
,
gemm_specialization
=
"GemmSpecialization::Default"
,
block_size
=
128
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
=
(
8
,
16
,
1
),
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
=
(
8
,
16
,
1
),
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v1"
,
)
return
[
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
224
,
n_per_block
=
256
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
7
,
n_xdl_per_wave
=
8
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v4"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v5"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v3"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v4"
,
),
ck_gemm_f16_rcr_compute_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
128
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
block_gemm_pipeline_scheduler
=
"BlockGemmPipelineScheduler::Intrawave"
,
block_gemm_pipeline_version
=
"BlockGemmPipelineVersion::v5"
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
16
,
n_per_block
=
32
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
16
,
n_per_block
=
32
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
16
,
n_per_block
=
64
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
64
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
128
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
2
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::Default"
,
m_per_block
=
32
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
32
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
4
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
64
,
n_per_block
=
16
,
m_per_xdl
=
16
,
n_per_xdl
=
16
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
64
,
1
,
2
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
64
,
n_per_block
=
32
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
1
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_memory_friendly
(
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
m_per_block
=
128
,
n_per_block
=
32
,
m_per_xdl
=
32
,
n_per_xdl
=
32
,
m_xdl_per_wave
=
2
,
n_xdl_per_wave
=
1
,
c_shuffle_m_xdl_per_wave_per_shuffle
=
2
,
c_shuffle_n_xdl_per_wave_per_shuffle
=
1
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
c_shuffle_block_transfer_scalar_per_vector_n_per_block
=
8
,
),
ck_gemm_f16_rcr_latency_friendly
(
m_per_block
=
16
,
n_per_block
=
32
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
16
,
1
,
8
,
),
),
ck_gemm_f16_rcr_latency_friendly
(
m_per_block
=
32
,
n_per_block
=
16
,
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
=
(
1
,
32
,
1
,
4
,
),
),
]
if
__name__
==
"__main__"
:
print
(
gen_ops_library
())
python/ck4inductor/universal_gemm/op.py
0 → 100644
View file @
fbd9d357
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
@
dataclass
class
CKGemmOperation
:
"""
A python dataclass storing the template parameters of a CK Universal Gemm template instance
"""
a_layout
:
str
b_layout
:
str
c_layout
:
str
a_element_dtype
:
str
b_element_dtype
:
str
c_element_dtype
:
str
acc_dtype
:
str
c_shuffle_dtype
:
str
a_elementwise_op
:
str
b_elementwise_op
:
str
c_elementwise_op
:
str
gemm_specialization
:
str
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
a_k1
:
int
b_k1
:
int
m_per_xdl
:
int
n_per_xdl
:
int
m_xdl_per_wave
:
int
n_xdl_per_wave
:
int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_vector_dim
:
int
a_block_transfer_src_scalar_per_vector
:
int
a_block_transfer_dst_scalar_per_vector_ak1
:
int
a_block_lds_extra_m
:
bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_vector_dim
:
int
b_block_transfer_src_scalar_per_vector
:
int
b_block_transfer_dst_scalar_per_vector_bk1
:
int
b_block_lds_extra_n
:
bool
c_shuffle_m_xdl_per_wave_per_shuffle
:
int
c_shuffle_n_xdl_per_wave_per_shuffle
:
int
c_shuffle_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
:
(
Tuple
[
int
,
int
,
int
,
int
]
)
c_shuffle_block_transfer_scalar_per_vector_n_per_block
:
int
block_gemm_pipeline_scheduler
:
str
block_gemm_pipeline_version
:
Optional
[
str
]
a_compute_dtype
:
Optional
[
str
]
b_compute_dtype
:
Optional
[
str
]
def
name
(
self
):
# cpp alias for template instance
return
f
"ck_devicegemm_xdl_shuffle_v3_
{
self
.
key_name
()
}
"
def
key_name
(
self
):
# TBD; must be unique per instance. Intended to use as dict key
return
"_"
.
join
(
[
"K"
+
field_name
.
replace
(
"_"
,
""
).
lower
()
+
"V"
+
(
"x"
.
join
(
map
(
str
,
iter
(
field_value
)))
if
isinstance
(
field_value
,
tuple
)
else
str
(
field_value
).
replace
(
":"
,
""
)
)
for
field_name
,
field_value
in
self
.
dict_items
()
]
)
def
dict_items
(
self
):
return
asdict
(
self
).
items
()
python/ck4inductor/util.py
0 → 100644
View file @
fbd9d357
import
functools
import
os
@
functools
.
lru_cache
(
None
)
def
library_path
():
return
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'library'
)
test/CMakeLists.txt
View file @
fbd9d357
...
...
@@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME)
endif
()
endforeach
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
...
...
@@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND NOT GPU_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx11"
AND NOT GPU_TARGETS MATCHES
"gfx12"
AND source MATCHES
"wmma"
)
message
(
"removing wmma test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
set_property
(
TARGET
${
TEST_NAME
}
PROPERTY HIP_ARCHITECTURES
${
TEST_TARGETS
}
)
target_link_libraries
(
${
TEST_NAME
}
PRIVATE getopt::getopt
)
add_test
(
NAME
${
TEST_NAME
}
COMMAND $<TARGET_FILE:
${
TEST_NAME
}
>
)
add_dependencies
(
tests
${
TEST_NAME
}
)
...
...
@@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME)
endif
()
endforeach
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
message
(
"removing dl test
${
source
}
"
)
...
...
@@ -112,7 +133,7 @@ function(add_gtest_executable TEST_NAME)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
if
(
NOT
TEST
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"xdl"
)
message
(
"removing xdl test
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -125,7 +146,14 @@ function(add_gtest_executable TEST_NAME)
endforeach
()
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN MATCHES
"_xdl"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
set_property
(
TARGET
${
TEST_NAME
}
PROPERTY HIP_ARCHITECTURES
${
TEST_TARGETS
}
)
add_dependencies
(
tests
${
TEST_NAME
}
)
add_dependencies
(
check
${
TEST_NAME
}
)
...
...
test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight.cpp
View file @
fbd9d357
...
...
@@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
std
::
vector
<
ck
::
utils
::
conv
::
ConvParam
>
conv_params
;
std
::
vector
<
ck
::
index_t
>
split_ks
{
1
,
2
};
bool
skip_case
(
const
ck
::
utils
::
conv
::
ConvParam
&
params
,
const
ck
::
index_t
split_k
)
bool
skip_case
(
const
ck
::
index_t
split_k
)
{
// Odd K or C values are supported only by DL and WMMA
// kernels (only applies to fp16)
// DL and WMMA kernels currently support only `split_k=1`
if
constexpr
(
std
::
is_same_v
<
InDataType
,
ck
::
half_t
>
)
{
if
(
split_k
!=
1
&&
(
params
.
K_
%
2
!=
0
||
params
.
C_
%
2
!=
0
))
{
return
true
;
}
}
// 1d NWGC is only supported by DL kernel
// DL kernel is only supported for split_k=1
if
constexpr
(
std
::
is_same_v
<
InLayout
,
NWGC
>
&&
std
::
is_same_v
<
OutLayout
,
NWGK
>
)
...
...
@@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
{
for
(
auto
&
param
:
conv_params
)
{
if
(
!
skip_case
(
param
,
split_k
))
if
(
!
skip_case
(
split_k
))
{
pass
=
pass
&&
ck
::
profiler
::
profile_grouped_conv_bwd_weight_impl
<
NDimSpatial
{},
InLayout
,
...
...
@@ -189,6 +178,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
32
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
64
,
3
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
({
2
,
1
,
1
,
1
,
1
,
{
3
,
3
},
{
32
,
32
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
16
,
16
,
1
,
1
,
{
3
,
3
},
{
28
,
28
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
();
}
...
...
@@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{
3
,
1
,
1
,
64
,
3
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
1
,
1
,
1
,
1
,
{
3
,
3
,
3
},
{
32
,
32
,
32
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
16
,
16
,
1
,
1
,
{
3
,
3
,
3
},
{
28
,
28
,
28
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
Run
();
}
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment