Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e8b54cb3
Commit
e8b54cb3
authored
Aug 29, 2023
by
Alan Turner
Browse files
Update parse_instance_strings
parent
4f7d9bbe
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
382 additions
and
112 deletions
+382
-112
.gitignore
.gitignore
+3
-0
library/src/jit_library/util/file_templates.py
library/src/jit_library/util/file_templates.py
+177
-0
library/src/jit_library/util/make_instance_strings.py
library/src/jit_library/util/make_instance_strings.py
+199
-110
test/CMakeLists.txt
test/CMakeLists.txt
+3
-2
No files found.
.gitignore
View file @
e8b54cb3
...
@@ -63,3 +63,6 @@ _templates/
...
@@ -63,3 +63,6 @@ _templates/
_toc.yml
_toc.yml
docBin/
docBin/
_doxygen/
_doxygen/
# pycache
__pycache__/
library/src/jit_library/util/file_templates.py
0 → 100644
View file @
e8b54cb3
out_file_with_quant
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {col_row_name} =
{{
{col_row_instances}
}};
static inline std::vector<std::string> {col_col_name} =
{{
{col_col_instances}
}};
static inline std::vector<std::string> {row_row_name} =
{{
{row_row_instances}
}};
static inline std::vector<std::string> {row_col_name} =
{{
{row_col_instances}
}};
static inline std::vector<std::string> {int8_col_row_name} =
{{
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances(const bool quantize)
{{
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances(const bool quantize)
{{
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances(const bool quantize)
{{
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
out_file_no_quant
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {instances_name} =
{{
{instances}
}};
static auto get_instances()
{{
return {instances_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
def
get_device_gemm_multiple_d_file
(
op_name
,
col_row_name
,
col_row_instances
,
col_col_name
,
col_col_instances
,
row_row_name
,
row_row_instances
,
row_col_name
,
row_col_instances
,
int8_col_row_name
,
int8_col_row_instances
,
int8_col_col_name
,
int8_col_col_instances
,
int8_row_row_name
,
int8_row_row_instances
,
int8_row_col_name
,
int8_row_col_instances
,
include_header
):
return
out_file_with_quant
.
format
(
op_name
=
op_name
,
col_row_name
=
col_row_name
,
col_row_instances
=
col_row_instances
,
col_col_name
=
col_col_name
,
col_col_instances
=
col_col_instances
,
row_row_name
=
row_row_name
,
row_row_instances
=
row_row_instances
,
row_col_name
=
row_col_name
,
row_col_instances
=
row_col_instances
,
int8_col_row_name
=
int8_col_row_name
,
int8_col_row_instances
=
int8_col_row_instances
,
int8_col_col_name
=
int8_col_col_name
,
int8_col_col_instances
=
int8_col_col_instances
,
int8_row_row_name
=
int8_row_row_name
,
int8_row_row_instances
=
int8_row_row_instances
,
int8_row_col_name
=
int8_row_col_name
,
int8_row_col_instances
=
int8_row_col_instances
,
include_header
=
include_header
)
def
get_device_gemm_softmax_gemm_file
(
op_name
,
instances_name
,
instances
,
include_header
):
return
out_file_no_quant
.
format
(
op_name
=
op_name
,
instances_name
=
instances_name
,
instances
=
instances
,
include_header
=
include_header
)
library/src/jit_library/util/make_instance_strings.py
View file @
e8b54cb3
import
argparse
,
re
,
json
,
os
,
sys
import
argparse
,
re
,
json
,
os
,
sys
,
file_templates
out_file
=
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
namespace ck {{
namespace host {{
namespace instance {{
struct {op_name}_instances
{{
static inline std::vector<std::string> {col_row_name} =
{{
{col_row_instances}
}};
static inline std::vector<std::string> {col_col_name} =
{{
{col_col_instances}
}};
static inline std::vector<std::string> {row_row_name} =
{{
{row_row_instances}
}};
static inline std::vector<std::string> {row_col_name} =
{{
{row_col_instances}
}};
static inline std::vector<std::string> {int8_col_row_name} =
{{
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
static auto get_col_col_instances(const bool quantize)
{{
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
static auto get_row_row_instances(const bool quantize)
{{
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
static auto get_row_col_instances(const bool quantize)
{{
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
static auto get_include_header()
{{
return "{include_header}";
}}
}};
}} // namespace instance
}} // namespace host
}} // namespace ck
"""
def
strip_sequences
(
str
):
def
strip_sequences
(
str
):
matches
=
re
.
findall
(
r
'S<\d+(?:,\s*\d+)*>'
,
str
)
matches
=
re
.
findall
(
r
'S<\d+(?:,\s*\d+)*>'
,
str
)
...
@@ -251,27 +161,206 @@ def parse_instances(source, out_dir):
...
@@ -251,27 +161,206 @@ def parse_instances(source, out_dir):
int8_file
=
"/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_file
=
"/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances
=
get_int8_instances
(
source
,
int8_file
,
"DeviceGemmMultipleD_Xdl_CShuffle"
)
int8_instances
=
get_int8_instances
(
source
,
int8_file
,
"DeviceGemmMultipleD_Xdl_CShuffle"
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
out_file
.
format
(
op_name
=
op_name
,
f
.
write
(
file_templates
.
get_device_gemm_multiple_d_file
(
col_row_name
=
col_row_name
,
op_name
,
col_row_instances
=
"
\n
"
.
join
(
col_row_instances
),
col_row_name
,
col_col_name
=
col_col_name
,
"
\n
"
.
join
(
col_row_instances
),
col_col_instances
=
"
\n
"
.
join
(
col_col_instances
),
col_col_name
,
row_row_name
=
row_row_name
,
"
\n
"
.
join
(
col_col_instances
),
row_row_instances
=
"
\n
"
.
join
(
row_row_instances
),
row_row_name
,
row_col_name
=
row_col_name
,
"
\n
"
.
join
(
row_row_instances
),
row_col_instances
=
"
\n
"
.
join
(
row_col_instances
),
row_col_name
,
int8_col_row_name
=
int8_instances
[
"col_row_name"
],
"
\n
"
.
join
(
row_col_instances
),
int8_col_row_instances
=
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_instances
[
"col_row_name"
],
int8_col_col_name
=
int8_instances
[
"col_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_col_col_instances
=
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_instances
[
"col_col_name"
],
int8_row_row_name
=
int8_instances
[
"row_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_row_row_instances
=
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_instances
[
"row_row_name"
],
int8_row_col_name
=
int8_instances
[
"row_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_row_col_instances
=
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
int8_instances
[
"row_col_name"
],
include_header
=
include_header
))
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
))
def
parse_device_gemm_multiple_d_instances
(
source
,
out_dir
):
aliases
=
{
"F16_F16_Tuple"
:
"ck::Tuple<F16,F16>"
,
"Row_Row_Tuple"
:
"ck::Tuple<Row,Row>"
,
"Empty_Tuple"
:
"ck::Tuple<>"
,
"LoopScheduler"
:
"ck::LoopScheduler"
,
"PipelineVersion"
:
"ck::PipelineVersion"
,
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"F16"
:
"ck::half_t"
,
"F32"
:
"float"
,
"OutElementOp"
:
"PassThrough"
}
device_ops
=
{
"gemm_add_add_fastgelu"
:
"DeviceGemmMultipleD_Xdl_CShuffle"
,
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
dir
in
dirs_
:
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
if
op_name
not
in
device_ops
:
continue
col_row_name
=
""
col_col_name
=
""
row_row_name
=
""
row_col_name
=
""
row_row_instances
=
[]
col_row_instances
=
[]
row_col_instances
=
[]
col_col_instances
=
[]
for
root
,
dirs
,
files
in
os
.
walk
(
os
.
path
.
join
(
root_
,
dir
)):
for
file
in
files
:
if
not
file
.
endswith
(
".cpp"
):
continue
;
file_name
=
os
.
path
.
split
(
file
)[
-
1
]
is_row_row
=
bool
(
re
.
search
(
".*mk.*kn.*"
,
file_name
))
is_col_row
=
bool
(
re
.
search
(
".*km.*kn.*"
,
file_name
))
is_row_col
=
bool
(
re
.
search
(
".*mk.*nk.*"
,
file_name
))
is_col_col
=
bool
(
re
.
search
(
".*km.*nk.*"
,
file_name
))
if
is_row_row
:
row_row_name
=
file_name
[:
-
4
]
if
is_col_row
:
col_row_name
=
file_name
[:
-
4
]
if
is_row_col
:
row_col_name
=
file_name
[:
-
4
]
if
is_col_col
:
col_col_name
=
file_name
[:
-
4
]
instances_list
=
[]
template_name
=
device_ops
[
op_name
]
include_header
=
""
with
open
(
os
.
path
.
join
(
root
,
file
))
as
f
:
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
instances_list
.
append
(
new_line
)
instances_list
[
-
1
]
=
instances_list
[
-
1
][:
-
1
]
if
is_row_row
:
row_row_instances
=
instances_list
if
is_col_row
:
col_row_instances
=
instances_list
if
is_row_col
:
row_col_instances
=
instances_list
if
is_col_col
:
col_col_instances
=
instances_list
out_file_name
=
op_name
+
"_instances.hpp"
if
not
os
.
path
.
exists
(
out_dir
):
os
.
mkdir
(
out_dir
)
int8_file
=
"/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances
=
get_int8_instances
(
source
,
int8_file
,
"DeviceGemmMultipleD_Xdl_CShuffle"
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
file_templates
.
get_device_gemm_multiple_d_file
(
op_name
,
col_row_name
,
"
\n
"
.
join
(
col_row_instances
),
col_col_name
,
"
\n
"
.
join
(
col_col_instances
),
row_row_name
,
"
\n
"
.
join
(
row_row_instances
),
row_col_name
,
"
\n
"
.
join
(
row_col_instances
),
int8_instances
[
"col_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_instances
[
"col_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_instances
[
"row_row_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_instances
[
"row_col_name"
],
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
))
def
parse_param_names
(
file
):
param_names
=
[]
for
line
in
file
:
if
bool
(
re
.
search
(
r
"\s*//#+"
,
line
)):
names
=
line
.
split
(
'|'
)
names
=
[
n
.
strip
()
for
n
in
names
]
if
not
param_names
:
param_names
=
[
""
]
*
len
(
names
)
param_names
=
[
a
+
b
for
a
,
b
in
zip
(
param_names
,
names
)]
elif
param_names
:
param_names
[
0
]
=
line
.
split
(
'<'
)[
0
].
strip
()
file
.
seek
(
0
)
return
param_names
[:
-
1
]
file
.
seek
(
0
)
return
param_names
[:
-
1
]
def
parse_device_batched_gemm_softmax_gemm_instances
(
source
,
out_dir
):
aliases
=
{
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"F16"
:
"ck::half_t"
,
"F32"
:
"float"
,
"PassThrough"
:
"ck::tensor_operation::element_wise::PassThrough"
,
"Scale"
:
"ck::tensor_operation::element_wise::Scale"
,
"GemmPadded"
:
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"GemmDefault"
:
"ck::tensor_operation::device::GemmSpecialization::Default"
}
device_ops
=
{
"batched_gemm_softmax_gemm"
:
"DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
dir
in
dirs_
:
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
if
"permute"
in
op_name
or
op_name
not
in
device_ops
:
continue
for
root
,
dirs
,
files
in
os
.
walk
(
os
.
path
.
join
(
root_
,
dir
)):
for
file
in
files
:
if
not
file
.
endswith
(
".cpp"
):
continue
;
file_name
=
os
.
path
.
split
(
file
)[
-
1
]
instances_name
=
file_name
[:
-
4
]
instances_list
=
[]
template_name
=
device_ops
[
op_name
]
include_header
=
""
with
open
(
os
.
path
.
join
(
root
,
file
))
as
f
:
param_names
=
parse_param_names
(
f
)
# for i in range(len(param_names)):
# print(f"{i}: {param_names[i]}")
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
masking
=
new_line
.
replace
(
"Masking"
,
"true"
)
no_masking
=
new_line
.
replace
(
"Masking"
,
"false"
)
instances_list
.
append
(
masking
)
instances_list
.
append
(
no_masking
)
out_file_name
=
op_name
+
"_instances.hpp"
if
not
os
.
path
.
exists
(
out_dir
):
os
.
mkdir
(
out_dir
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
file_templates
.
get_device_gemm_softmax_gemm_file
(
op_name
,
instances_name
,
"
\n
"
.
join
(
instances_list
),
include_header
))
def
run
(
args
):
def
run
(
args
):
parse_instances
(
args
[
0
],
args
[
1
])
parse_device_gemm_multiple_d_instances
(
args
[
0
],
args
[
1
])
parse_device_batched_gemm_softmax_gemm_instances
(
args
[
0
],
args
[
1
])
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
run
(
sys
.
argv
[
1
:])
run
(
sys
.
argv
[
1
:])
test/CMakeLists.txt
View file @
e8b54cb3
...
@@ -63,6 +63,7 @@ else()
...
@@ -63,6 +63,7 @@ else()
add_subdirectory
(
pool_fwd
)
add_subdirectory
(
pool_fwd
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
batched_gemm_multi_d
)
add_subdirectory
(
grouped_convnd_bwd_data
)
add_subdirectory
(
grouped_convnd_bwd_data
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
add_subdirectory
(
wmma_op
)
endif
()
endif
()
endif
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment