Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
eda59383
Commit
eda59383
authored
Oct 18, 2024
by
Max Podkorytov
Browse files
add parsing grouped conv fwd instances
parent
7d576f17
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
271 additions
and
2 deletions
+271
-2
python/ck4inductor/grouped_conv_fwd/gen_instances.py
python/ck4inductor/grouped_conv_fwd/gen_instances.py
+167
-0
python/ck4inductor/grouped_conv_fwd/op.py
python/ck4inductor/grouped_conv_fwd/op.py
+93
-0
python/ck4inductor/universal_gemm/gen_instances.py
python/ck4inductor/universal_gemm/gen_instances.py
+4
-1
python/ck4inductor/universal_gemm/op.py
python/ck4inductor/universal_gemm/op.py
+3
-0
python/ck4inductor/util.py
python/ck4inductor/util.py
+4
-1
No files found.
python/ck4inductor/grouped_conv_fwd/gen_instances.py
0 → 100644
View file @
eda59383
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
os
import
subprocess
from
dataclasses
import
replace
from
functools
import
lru_cache
from
typing
import
List
from
..util
import
library_path
from
.op
import
CKGroupedConvFwdOp
log
=
logging
.
getLogger
(
__name__
)
def
_ck_conv_instances_path
():
conv_instances_path
=
os
.
path
.
join
(
# noqa: F821
library_path
(),
"include"
,
"ck"
,
"library"
,
"tensor_operation_instance"
,
"gpu"
,
"grouped_conv_fwd"
,
)
if
not
os
.
path
.
exists
(
conv_instances_path
):
log
.
error
(
"CK library conv instances path %s does not exist"
,
conv_instances_path
)
return
None
return
conv_instances_path
def
parse_instances
(
str_instances
:
List
[
str
])
->
List
[
CKGroupedConvFwdOp
]:
"""
Parse the lines containing Grouped Convolution Forward template instances
into `CKGroupedConvFwdOp` instances
"""
def
maybe_int
(
s
):
try
:
return
int
(
s
)
except
ValueError
:
return
s
op_instances
=
[]
# TODO: maybe use libclang for parsing C++ code in the future
# to avoid this hacky parsing logic below ? :) - copilot
for
line
in
str_instances
:
s_template_args
=
line
.
split
(
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
)[
-
1
].
strip
(
"<>, "
)
template_args
=
[]
i_current
=
0
while
i_current
<
len
(
s_template_args
):
if
s_template_args
[
i_current
]
==
" "
:
# skip whitespace
i_current
+=
1
continue
elif
s_template_args
[
i_current
:
i_current
+
2
]
==
"S<"
:
# parse template S<Index...>
i_next
=
s_template_args
.
find
(
">"
,
i_current
)
template_args
.
append
(
tuple
(
map
(
int
,
s_template_args
[
i_current
+
2
:
i_next
].
split
(
","
)))
)
i_current
=
i_next
+
2
else
:
# all string attributes must be either type aliases or global constants in C++
i_next
=
s_template_args
.
find
(
","
,
i_current
)
template_args
.
append
(
maybe_int
(
s_template_args
[
i_current
:
i_next
if
i_next
!=
-
1
else
None
]
)
)
if
i_next
!=
-
1
:
i_current
=
i_next
+
1
if
i_next
==
-
1
:
break
template_args
[
0
]
=
-
1
# n_dim_spatial
template_args
[
3
]
=
tuple
()
# ds_layout
template_args
[
9
]
=
tuple
()
# ds_element_dtype
new_instance
=
CKGroupedConvFwdOp
(
*
template_args
,
# type: ignore[arg-type]
)
op_instances
.
append
(
new_instance
)
return
op_instances
@
lru_cache
(
None
)
def
gen_conv_ops_library
()
->
List
[
CKGroupedConvFwdOp
]:
"""
Parse the Grouped Convolution Forward instances
defined in the Composable Kernel library folder.
"""
ck_library_dir
=
_ck_conv_instances_path
()
if
not
ck_library_dir
:
return
[]
grep_result
=
subprocess
.
run
(
[
"grep"
,
"-inR"
,
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"
,
ck_library_dir
,
],
capture_output
=
True
,
text
=
True
,
)
op_instances
=
parse_instances
(
grep_result
.
stdout
.
strip
().
split
(
"
\n
"
))
log
.
debug
(
"ck instances from library: %d"
,
len
(
op_instances
))
schedulers
=
[
"BlockGemmPipelineScheduler::Intrawave"
,
"BlockGemmPipelineScheduler::Interwave"
,
]
conv_specs
=
[
"ConvolutionForwardSpecialization::Default"
,
"ConvolutionForwardSpecialization::Filter1x1Pad0"
,
"ConvolutionForwardSpecialization::Filter1x1Stride1Pad0"
,
"ConvolutionForwardSpecialization::OddC"
,
]
# substitute templated args by looping through their domains
substitute_instances
=
[]
for
instance
in
op_instances
:
sub_scheduler
=
(
instance
.
block_gemm_pipeline_scheduler
==
"BlkGemmPipeSched"
)
sub_spec
=
instance
.
conv_forward_specialization
==
"ConvSpec"
schedulers_range
=
(
schedulers
if
sub_scheduler
else
[
instance
.
block_gemm_pipeline_scheduler
]
)
spec_range
=
conv_specs
if
sub_spec
else
[
instance
.
conv_forward_specialization
]
for
scheduler
in
schedulers_range
:
for
spec
in
spec_range
:
for
channels_last
in
[
True
,
False
]:
if
channels_last
:
a_layout
=
"NHWGC"
e_layout
=
"NHWGK"
else
:
a_layout
=
"NGCHW"
e_layout
=
"NGKHW"
substitute_instances
.
append
(
replace
(
instance
,
block_gemm_pipeline_scheduler
=
scheduler
,
conv_forward_specialization
=
spec
,
gemm_specialization
=
"GemmSpecialization::MNKPadding"
,
n_dim_spatial
=
2
,
a_layout
=
a_layout
,
b_layout
=
"GKYXC"
,
e_layout
=
e_layout
,
)
)
return
substitute_instances
if
__name__
==
"__main__"
:
print
(
gen_conv_ops_library
())
python/ck4inductor/grouped_conv_fwd/op.py
0 → 100644
View file @
eda59383
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
@
dataclass
class
CKGroupedConvFwdOp
:
n_dim_spatial
:
int
a_layout
:
str
b_layout
:
str
ds_layout
:
Tuple
[
str
]
e_layout
:
str
a_element_dtype
:
str
b_element_dtype
:
str
acc_dtype
:
str
c_shuffle_dtype
:
str
ds_element_dtype
:
Tuple
[
str
]
e_element_dtype
:
str
a_elementwise_op
:
str
b_elementwise_op
:
str
cde_elementwise_op
:
str
conv_forward_specialization
:
str
gemm_specialization
:
str
block_size
:
int
m_per_block
:
int
n_per_block
:
int
k_per_block
:
int
ak1
:
int
bk1
:
int
m_per_xdl
:
int
n_per_xdl
:
int
m_xdl_per_wave
:
int
n_xdl_per_wave
:
int
a_block_transfer_thread_cluster_lengths_ak0_m_ak1
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
a_block_transfer_src_vector_dim
:
int
a_block_transfer_src_scalar_per_vector
:
int
a_block_transfer_dst_scalar_per_vector_ak1
:
int
a_block_lds_extra_m
:
bool
b_block_transfer_thread_cluster_lengths_bk0_n_bk1
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_thread_cluster_arrange_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_access_order
:
Tuple
[
int
,
int
,
int
]
b_block_transfer_src_vector_dim
:
int
b_block_transfer_src_scalar_per_vector
:
int
b_block_transfer_dst_scalar_per_vector_bk1
:
int
b_block_lds_extra_n
:
bool
c_shuffle_m_xdl_per_wave_per_shuffle
:
int
c_shuffle_n_xdl_per_wave_per_shuffle
:
int
cde_block_transfer_cluster_lengths_m_block_m_per_block_n_block_n_per_block
:
Tuple
[
# noqa
int
,
int
,
int
,
int
,
]
cde_block_transfer_scalar_per_vector_n_per_block
:
int
block_gemm_pipeline_scheduler
:
str
block_gemm_pipeline_version
:
str
a_compute_dtype
:
Optional
[
str
]
=
None
b_compute_dtype
:
Optional
[
str
]
=
None
def
name
(
self
):
# cpp alias for template instance
return
(
f
"ck_device_grouped_convolution_fwd_multiple_abd_xdl_c_shuffle_v3_"
f
"
{
self
.
key_name
()
}
"
)
def
key_name
(
self
):
# TBD; must be unique per instance. Intended to use as dict key
return
"_"
.
join
(
[
"K"
+
field_name
.
replace
(
"_"
,
""
).
lower
()
+
"V"
+
(
"x"
.
join
(
map
(
str
,
iter
(
field_value
)))
if
isinstance
(
field_value
,
tuple
)
else
str
(
field_value
).
replace
(
":"
,
""
)
)
for
field_name
,
field_value
in
self
.
dict_items
()
]
)
def
dict_items
(
self
):
return
asdict
(
self
).
items
()
python/ck4inductor/universal_gemm/gen_instances.py
View file @
eda59383
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import
logging
import
os
import
subprocess
from
dataclasses
import
fields
,
replace
from
dataclasses
import
replace
from
functools
import
lru_cache
,
partial
from
typing
import
List
...
...
python/ck4inductor/universal_gemm/op.py
View file @
eda59383
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Optional
,
Tuple
...
...
python/ck4inductor/util.py
View file @
eda59383
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
import
functools
import
os
@
functools
.
lru_cache
(
None
)
def
library_path
():
return
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'
library
'
)
return
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"
library
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment