"git@developer.sourcefind.cn:gaoqiong/composable_kernel.git" did not exist on "4349abe1791184c86b0e9dcb2255225b38c408c4"
Commit eda59383 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

add parsing grouped conv fwd instances

parent 7d576f17
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import logging
import os
import subprocess
from dataclasses import replace
from functools import lru_cache
from typing import List
from ..util import library_path
from .op import CKGroupedConvFwdOp
log = logging.getLogger(__name__)
def _ck_conv_instances_path():
conv_instances_path = os.path.join( # noqa: F821
library_path(),
"include",
"ck",
"library",
"tensor_operation_instance",
"gpu",
"grouped_conv_fwd",
)
if not os.path.exists(conv_instances_path):
log.error(
"CK library conv instances path %s does not exist", conv_instances_path
)
return None
return conv_instances_path
def parse_instances(str_instances: List[str]) -> List[CKGroupedConvFwdOp]:
"""
Parse the lines containing Grouped Convolution Forward template instances
into `CKGroupedConvFwdOp` instances
"""
def maybe_int(s):
try:
return int(s)
except ValueError:
return s
op_instances = []
# TODO: maybe use libclang for parsing C++ code in the future
# to avoid this hacky parsing logic below ? :) - copilot
for line in str_instances:
s_template_args = line.split("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")[
-1
].strip("<>, ")
template_args = []
i_current = 0
while i_current < len(s_template_args):
if s_template_args[i_current] == " ":
# skip whitespace
i_current += 1
continue
elif s_template_args[i_current : i_current + 2] == "S<":
# parse template S<Index...>
i_next = s_template_args.find(">", i_current)
template_args.append(
tuple(map(int, s_template_args[i_current + 2 : i_next].split(",")))
)
i_current = i_next + 2
else:
# all string attributes must be either type aliases or global constants in C++
i_next = s_template_args.find(",", i_current)
template_args.append(
maybe_int(
s_template_args[i_current : i_next if i_next != -1 else None]
)
)
if i_next != -1:
i_current = i_next + 1
if i_next == -1:
break
template_args[0] = -1 # n_dim_spatial
template_args[3] = tuple() # ds_layout
template_args[9] = tuple() # ds_element_dtype
new_instance = CKGroupedConvFwdOp(
*template_args, # type: ignore[arg-type]
)
op_instances.append(new_instance)
return op_instances
@lru_cache(None)
def gen_conv_ops_library() -> List[CKGroupedConvFwdOp]:
"""
Parse the Grouped Convolution Forward instances
defined in the Composable Kernel library folder.
"""
ck_library_dir = _ck_conv_instances_path()
if not ck_library_dir:
return []
grep_result = subprocess.run(
[
"grep",
"-inR",
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
ck_library_dir,
],
capture_output=True,
text=True,
)
op_instances = parse_instances(grep_result.stdout.strip().split("\n"))
log.debug("ck instances from library: %d", len(op_instances))
schedulers = [
"BlockGemmPipelineScheduler::Intrawave",
"BlockGemmPipelineScheduler::Interwave",
]
conv_specs = [
"ConvolutionForwardSpecialization::Default",
"ConvolutionForwardSpecialization::Filter1x1Pad0",
"ConvolutionForwardSpecialization::Filter1x1Stride1Pad0",
"ConvolutionForwardSpecialization::OddC",
]
# substitute templated args by looping through their domains
substitute_instances = []
for instance in op_instances:
sub_scheduler = (
instance.block_gemm_pipeline_scheduler == "BlkGemmPipeSched"
)
sub_spec = instance.conv_forward_specialization == "ConvSpec"
schedulers_range = (
schedulers if sub_scheduler else [instance.block_gemm_pipeline_scheduler]
)
spec_range = conv_specs if sub_spec else [instance.conv_forward_specialization]
for scheduler in schedulers_range:
for spec in spec_range:
for channels_last in [True, False]:
if channels_last:
a_layout = "NHWGC"
e_layout = "NHWGK"
else:
a_layout = "NGCHW"
e_layout = "NGKHW"
substitute_instances.append(
replace(
instance,
block_gemm_pipeline_scheduler=scheduler,
conv_forward_specialization=spec,
gemm_specialization="GemmSpecialization::MNKPadding",
n_dim_spatial=2,
a_layout=a_layout,
b_layout="GKYXC",
e_layout=e_layout,
)
)
return substitute_instances
if __name__ == "__main__":
print(gen_conv_ops_library())
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import asdict, dataclass
from typing import Optional, Tuple
@dataclass
class CKGroupedConvFwdOp:
n_dim_spatial: int
a_layout: str
b_layout: str
ds_layout: Tuple[str]
e_layout: str
a_element_dtype: str
b_element_dtype: str
acc_dtype: str
c_shuffle_dtype: str
ds_element_dtype: Tuple[str]
e_element_dtype: str
a_elementwise_op: str
b_elementwise_op: str
cde_elementwise_op: str
conv_forward_specialization: str
gemm_specialization: str
block_size: int
m_per_block: int
n_per_block: int
k_per_block: int
ak1: int
bk1: int
m_per_xdl: int
n_per_xdl: int
m_xdl_per_wave: int
n_xdl_per_wave: int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1: Tuple[int, int, int]
a_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
a_block_transfer_src_access_order: Tuple[int, int, int]
a_block_transfer_src_vector_dim: int
a_block_transfer_src_scalar_per_vector: int
a_block_transfer_dst_scalar_per_vector_ak1: int
a_block_lds_extra_m: bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1: Tuple[int, int, int]
b_block_transfer_thread_cluster_arrange_order: Tuple[int, int, int]
b_block_transfer_src_access_order: Tuple[int, int, int]
b_block_transfer_src_vector_dim: int
b_block_transfer_src_scalar_per_vector: int
b_block_transfer_dst_scalar_per_vector_bk1: int
b_block_lds_extra_n: bool
c_shuffle_m_xdl_per_wave_per_shuffle: int
c_shuffle_n_xdl_per_wave_per_shuffle: int
cde_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block: Tuple[ # noqa
int,
int,
int,
int,
]
cde_block_transfer_scalar_per_vector_n_per_block: int
block_gemm_pipeline_scheduler: str
block_gemm_pipeline_version: str
a_compute_dtype: Optional[str] = None
b_compute_dtype: Optional[str] = None
def name(self):
# cpp alias for template instance
return (
f"ck_device_grouped_convolution_fwd_multiple_abd_xdl_c_shuffle_v3_"
f"{self.key_name()}"
)
def key_name(self):
# TBD; must be unique per instance. Intended to use as dict key
return "_".join(
[
"K"
+ field_name.replace("_", "").lower()
+ "V"
+ (
"x".join(map(str, iter(field_value)))
if isinstance(field_value, tuple)
else str(field_value).replace(":", "")
)
for field_name, field_value in self.dict_items()
]
)
def dict_items(self):
return asdict(self).items()
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import logging import logging
import os import os
import subprocess import subprocess
from dataclasses import fields, replace from dataclasses import replace
from functools import lru_cache, partial from functools import lru_cache, partial
from typing import List from typing import List
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from typing import Optional, Tuple from typing import Optional, Tuple
......
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import functools import functools
import os import os
@functools.lru_cache(None) @functools.lru_cache(None)
def library_path(): def library_path():
return os.path.join(os.path.dirname(__file__), 'library') return os.path.join(os.path.dirname(__file__), "library")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment