Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
61386bf9
Commit
61386bf9
authored
May 25, 2023
by
Alan Turner
Browse files
Add edatatype and scalars_per_vector workaround
parent
6289e36f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
3 deletions
+18
-3
library/src/jit_library/include/device_gemm_multiple_d.hpp
library/src/jit_library/include/device_gemm_multiple_d.hpp
+15
-0
library/src/jit_library/util/make_instance_strings.py
library/src/jit_library/util/make_instance_strings.py
+3
-3
No files found.
library/src/jit_library/include/device_gemm_multiple_d.hpp
View file @
61386bf9
...
@@ -82,6 +82,7 @@ struct Problem
...
@@ -82,6 +82,7 @@ struct Problem
static
const
index_t
ds_layout_idx
=
3
;
static
const
index_t
ds_layout_idx
=
3
;
static
const
index_t
ds_data_type_idx
=
9
;
static
const
index_t
ds_data_type_idx
=
9
;
static
const
index_t
e_data_type_idx
=
10
;
static
const
index_t
a_elementwise_op_idx
=
11
;
static
const
index_t
a_elementwise_op_idx
=
11
;
static
const
index_t
b_elementwise_op_idx
=
12
;
static
const
index_t
b_elementwise_op_idx
=
12
;
static
const
index_t
ds_elementwise_op_idx
=
13
;
static
const
index_t
ds_elementwise_op_idx
=
13
;
...
@@ -146,12 +147,26 @@ private:
...
@@ -146,12 +147,26 @@ private:
std
::
istringstream
iss
(
template_str
);
std
::
istringstream
iss
(
template_str
);
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
std
::
istream_iterator
<
std
::
string
>
());
if
(
ADataType
==
"int8_t"
and
BDataType
==
"int8_t"
)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
"ck::half_t"
;
}))
{
params
[
params
.
size
()
-
3
]
=
"8"
;
}
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
"float"
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
}
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsLayout
);
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
DsLayout
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
e_data_type_idx
]
=
EDataType
;
auto
block_size_str
=
params
[
block_size_idx
];
auto
block_size_str
=
params
[
block_size_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
...
...
library/src/jit_library/util/make_instance_strings.py
View file @
61386bf9
...
@@ -154,9 +154,9 @@ def get_int8_instances(src, file, template_name):
...
@@ -154,9 +154,9 @@ def get_int8_instances(src, file, template_name):
for
key
in
aliases
:
for
key
in
aliases
:
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
new_line
=
new_line
.
replace
(
key
,
aliases
[
key
])
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:PipelineVersion::v1"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Default"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:
:
PipelineVersion::v1"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Default"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:PipelineVersion::v1"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Interwave"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:
:
PipelineVersion::v1"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Interwave"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:PipelineVersion::v2"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Default"
))
versions
.
append
(
new_line
.
replace
(
"GemmPipeline"
,
"ck:
:
PipelineVersion::v2"
).
replace
(
"GemmLoopScheduler"
,
"ck::LoopScheduler::Default"
))
if
"ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::RowMajor"
in
new_line
:
if
"ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::RowMajor"
in
new_line
:
instances
[
"row_row"
].
extend
(
versions
)
instances
[
"row_row"
].
extend
(
versions
)
elif
"ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::ColumnMajor"
in
new_line
:
elif
"ck::tensor_layout::gemm::RowMajor ck::tensor_layout::gemm::ColumnMajor"
in
new_line
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment