Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6289e36f
Commit
6289e36f
authored
May 25, 2023
by
Alan Turner
Browse files
Add int8 instances
parent
6dd246a6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
112 additions
and
13 deletions
+112
-13
library/src/jit_library/include/device_gemm_multiple_d.hpp
library/src/jit_library/include/device_gemm_multiple_d.hpp
+5
-4
library/src/jit_library/util/make_instance_strings.py
library/src/jit_library/util/make_instance_strings.py
+107
-9
No files found.
library/src/jit_library/include/device_gemm_multiple_d.hpp
View file @
6289e36f
...
@@ -95,17 +95,18 @@ private:
...
@@ -95,17 +95,18 @@ private:
auto
GetInstances
(
const
std
::
string
&
arch
)
const
auto
GetInstances
(
const
std
::
string
&
arch
)
const
{
{
std
::
vector
<
std
::
string
>
instances
;
std
::
vector
<
std
::
string
>
instances
;
const
bool
quantize
=
ADataType
==
"int8_t"
and
BDataType
==
"int8_t"
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
{
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
if
(
TransA
and
TransB
)
if
(
TransA
and
TransB
)
instances
=
all_instances
.
get_col_col_instances
();
instances
=
all_instances
.
get_col_col_instances
(
quantize
);
else
if
(
TransA
and
not
TransB
)
else
if
(
TransA
and
not
TransB
)
instances
=
all_instances
.
get_col_row_instances
();
instances
=
all_instances
.
get_col_row_instances
(
quantize
);
else
if
(
not
TransA
and
not
TransB
)
else
if
(
not
TransA
and
not
TransB
)
instances
=
all_instances
.
get_row_row_instances
();
instances
=
all_instances
.
get_row_row_instances
(
quantize
);
else
else
instances
=
all_instances
.
get_row_col_instances
();
instances
=
all_instances
.
get_row_col_instances
(
quantize
);
}
}
return
instances
;
return
instances
;
}
}
...
...
library/src/jit_library/util/make_instance_strings.py
View file @
6289e36f
...
@@ -36,24 +36,48 @@ struct {op_name}_instances
...
@@ -36,24 +36,48 @@ struct {op_name}_instances
{row_col_instances}
{row_col_instances}
}};
}};
static
auto get_col_row_instances()
static
inline std::vector<std::string> {int8_col_row_name} =
{{
{{
return {col_row_name};
{int8_col_row_instances}
}};
static inline std::vector<std::string> {int8_col_col_name} =
{{
{int8_col_col_instances}
}};
static inline std::vector<std::string> {int8_row_row_name} =
{{
{int8_row_row_instances}
}};
static inline std::vector<std::string> {int8_row_col_name} =
{{
{int8_row_col_instances}
}};
static auto get_col_row_instances(const bool quantize)
{{
return quantize ? {int8_col_row_name} :
{col_row_name};
}}
}}
static auto get_col_col_instances()
static auto get_col_col_instances(
const bool quantize
)
{{
{{
return {col_col_name};
return quantize ? {int8_col_col_name} :
{col_col_name};
}}
}}
static auto get_row_row_instances()
static auto get_row_row_instances(
const bool quantize
)
{{
{{
return {row_row_name};
return quantize ? {int8_row_row_name} :
{row_row_name};
}}
}}
static auto get_row_col_instances()
static auto get_row_col_instances(
const bool quantize
)
{{
{{
return {row_col_name};
return quantize ? {int8_row_col_name} :
{row_col_name};
}}
}}
static auto get_include_header()
static auto get_include_header()
...
@@ -87,19 +111,83 @@ def remove_commas_and_brackets(string):
...
@@ -87,19 +111,83 @@ def remove_commas_and_brackets(string):
return
string
return
string
def
get_int8_instances
(
src
,
file
,
template_name
):
aliases
=
{
"Empty_Tuple"
:
"ck::Tuple<>"
,
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"OutElementOp"
:
"PassThrough"
}
instances
=
{
"row_row"
:
[],
"row_col"
:
[],
"col_col"
:
[],
"col_row"
:
[],
"row_row_name"
:
[],
"row_col_name"
:
[],
"col_col_name"
:
[],
"col_row_name"
:
[]}
path
=
src
+
file
with
open
(
path
)
as
f
:
for
line
in
f
:
if
"impl"
in
line
:
include_header
=
line
.
replace
(
"#include
\"
"
,
""
).
replace
(
"
\"
"
,
""
).
replace
(
"
\n
"
,
""
)
elif
"using"
in
line
:
if
bool
(
re
.
search
(
".*mk.*kn.*"
,
line
)):
instances
[
"row_row_name"
]
=
re
.
search
(
"device_gemm.*instance"
,
line
).
group
()
elif
bool
(
re
.
search
(
".*mk.*nk.*"
,
line
)):
instances
[
"row_col_name"
]
=
re
.
search
(
"device_gemm.*instance"
,
line
).
group
()
elif
bool
(
re
.
search
(
".*km.*nk.*"
,
line
)):
instances
[
"col_col_name"
]
=
re
.
search
(
"device_gemm.*instance"
,
line
).
group
()
elif
bool
(
re
.
search
(
".*km.*kn.*"
,
line
)):
instances
[
"col_row_name"
]
=
re
.
search
(
"device_gemm.*instance"
,
line
).
group
()
elif
template_name
in
line
:
# Turn all whitespace into single spaces
new_line
=
" "
.
join
(
line
.
split
())
# Remove whitespace from S<*>
new_line
=
strip_sequences
(
new_line
)
new_line
=
remove_commas_and_brackets
(
new_line
)
last_char
=
"
\n
"
if
new_line
[
-
1
]
==
","
:
last_char
=
",
\n
"
new_line
=
new_line
[:
-
1
]
new_line
=
' "ck::tensor_operation::device::'
+
new_line
+
'",'
versions
=
[]
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:PipelineVersion::v1"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Default"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:PipelineVersion::v1"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Interwave"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:PipelineVersion::v2"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Default"
))
if
"ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::RowMajor"
in
new_line
:
instances
[
"row_row"
].
extend
(
versions
)
elif
"ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::ColumnMajor"
in
new_line
:
instances
[
"row_col"
].
extend
(
versions
)
elif
"ck::tensor_layout::gemm::ColumnMajor ck::tensor_layout::gemm::ColumnMajor"
in
new_line
:
instances
[
"col_col"
].
extend
(
versions
)
elif
"ck::tensor_layout::gemm::ColumnMajor ck::tensor_layout::gemm::RowMajor"
in
new_line
:
instances
[
"col_row"
].
extend
(
versions
)
instances
[
"row_row"
][
-
1
]
=
instances
[
"row_row"
][
-
1
][:
-
1
]
instances
[
"row_col"
][
-
1
]
=
instances
[
"row_col"
][
-
1
][:
-
1
]
instances
[
"col_col"
][
-
1
]
=
instances
[
"col_col"
][
-
1
][:
-
1
]
instances
[
"col_row"
][
-
1
]
=
instances
[
"col_row"
][
-
1
][:
-
1
]
return
instances
def
parse_instances
(
source
):
def
parse_instances
(
source
):
out_dir
=
os
.
path
.
join
(
source
,
"../../../src/jit_library/solution_instances"
)
out_dir
=
os
.
path
.
join
(
source
,
"../../../src/jit_library/solution_instances"
)
aliases
=
{
"F16_F16_Tuple"
:
"ck::Tuple<F16,F16>"
,
aliases
=
{
"F16_F16_Tuple"
:
"ck::Tuple<F16,F16>"
,
"Row_Row_Tuple"
:
"ck::Tuple<Row,Row>"
,
"Row_Row_Tuple"
:
"ck::Tuple<Row,Row>"
,
"Empty_Tuple"
:
"ck::Tuple<>"
,
"LoopScheduler"
:
"ck::LoopScheduler"
,
"LoopScheduler"
:
"ck::LoopScheduler"
,
"PipelineVersion"
:
"ck::PipelineVersion"
,
"PipelineVersion"
:
"ck::PipelineVersion"
,
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Row"
:
"ck::tensor_layout::gemm::RowMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"Col"
:
"ck::tensor_layout::gemm::ColumnMajor"
,
"F16"
:
"ck::half_t"
,
"F16"
:
"ck::half_t"
,
"F32"
:
"float"
}
"F32"
:
"float"
,
"OutElementOp"
:
"PassThrough"
}
device_ops
=
{
"gemm_add_add_fastgelu"
:
"DeviceGemmMultipleD_Xdl_CShuffle"
,
device_ops
=
{
"gemm_add_add_fastgelu"
:
"DeviceGemmMultipleD_Xdl_CShuffle"
,
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
#"batched_gemm_softmax_gemm": "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
}
}
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
root_
,
dirs_
,
files_
in
os
.
walk
(
source
):
for
dir
in
dirs_
:
for
dir
in
dirs_
:
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
op_name
=
os
.
path
.
split
(
dir
)[
-
1
]
...
@@ -163,6 +251,8 @@ def parse_instances(source):
...
@@ -163,6 +251,8 @@ def parse_instances(source):
out_file_name
=
op_name
+
"_instances.hpp"
out_file_name
=
op_name
+
"_instances.hpp"
if
not
os
.
path
.
exists
(
out_dir
):
if
not
os
.
path
.
exists
(
out_dir
):
os
.
mkdir
(
out_dir
)
os
.
mkdir
(
out_dir
)
int8_file
=
"/quantization/gemm/device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_instance.hpp"
int8_instances
=
get_int8_instances
(
source
,
int8_file
,
"DeviceGemmMultipleD_Xdl_CShuffle"
)
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
with
open
(
os
.
path
.
join
(
out_dir
,
out_file_name
),
"w+"
)
as
f
:
f
.
write
(
out_file
.
format
(
op_name
=
op_name
,
f
.
write
(
out_file
.
format
(
op_name
=
op_name
,
col_row_name
=
col_row_name
,
col_row_name
=
col_row_name
,
...
@@ -173,6 +263,14 @@ def parse_instances(source):
...
@@ -173,6 +263,14 @@ def parse_instances(source):
row_row_instances
=
"
\n
"
.
join
(
row_row_instances
),
row_row_instances
=
"
\n
"
.
join
(
row_row_instances
),
row_col_name
=
row_col_name
,
row_col_name
=
row_col_name
,
row_col_instances
=
"
\n
"
.
join
(
row_col_instances
),
row_col_instances
=
"
\n
"
.
join
(
row_col_instances
),
int8_col_row_name
=
int8_instances
[
"col_row_name"
],
int8_col_row_instances
=
"
\n
"
.
join
(
int8_instances
[
"col_row"
]),
int8_col_col_name
=
int8_instances
[
"col_col_name"
],
int8_col_col_instances
=
"
\n
"
.
join
(
int8_instances
[
"col_col"
]),
int8_row_row_name
=
int8_instances
[
"row_row_name"
],
int8_row_row_instances
=
"
\n
"
.
join
(
int8_instances
[
"row_row"
]),
int8_row_col_name
=
int8_instances
[
"row_col_name"
],
int8_row_col_instances
=
"
\n
"
.
join
(
int8_instances
[
"row_col"
]),
include_header
=
include_header
))
include_header
=
include_header
))
def
run
():
def
run
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment