Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
420c0312
Commit
420c0312
authored
May 25, 2023
by
Paul
Browse files
Use enum for data types
parent
3905f4a2
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
31 additions
and
23 deletions
+31
-23
library/src/jit_library/include/ck/host/common.hpp
library/src/jit_library/include/ck/host/common.hpp
+2
-0
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
...rc/jit_library/include/ck/host/device_gemm_multiple_d.hpp
+9
-14
library/src/jit_library/src/common.cpp
library/src/jit_library/src/common.cpp
+11
-0
library/src/jit_library/src/device_gemm_multiple_d.cpp
library/src/jit_library/src/device_gemm_multiple_d.cpp
+9
-9
No files found.
library/src/jit_library/include/ck/host/common.hpp
View file @
420c0312
...
...
@@ -24,6 +24,8 @@ enum class DataType {
Int32
};
std
::
string
ToString
(
DataType
dt
);
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
();
std
::
size_t
integer_divide_ceil
(
std
::
size_t
x
,
std
::
size_t
y
);
...
...
library/src/jit_library/include/ck/host/device_gemm_multiple_d.hpp
View file @
420c0312
...
...
@@ -24,11 +24,11 @@ struct Problem
bool
TransA
=
false
;
bool
TransB
=
false
;
bool
TransE
=
false
;
std
::
vector
<
bool
>
Ds
Layout
=
{};
std
::
string
ADataType
=
"ck
::
h
alf
_t"
;
std
::
string
BDataType
=
"ck
::
h
alf
_t"
;
std
::
string
EDataType
=
"ck
::
h
alf
_t"
;
std
::
vector
<
std
::
string
>
DsDataType
=
{};
std
::
vector
<
bool
>
Ds
Trans
=
{};
DataType
ADataType
=
DataType
::
H
alf
;
DataType
BDataType
=
DataType
::
H
alf
;
DataType
EDataType
=
DataType
::
H
alf
;
std
::
vector
<
DataType
>
DsDataType
=
{};
std
::
string
AElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
BElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDEElementOp
=
"ck::Tuple<>"
;
...
...
@@ -45,19 +45,14 @@ struct Problem
static
const
std
::
size_t
n_per_block_idx
=
18
;
static
const
std
::
size_t
k_per_block_idx
=
19
;
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
std
::
string
GetIncludeHeader
()
const
;
std
::
string
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
std
::
string
MakeTypeTuple
(
const
std
::
vector
<
std
::
string
>&
types
)
const
;
private:
std
::
vector
<
std
::
string
>
GetInstances
(
const
std
::
string
&
arch
)
const
;
Solution
MakeSolution
(
std
::
size_t
idx
,
const
std
::
string
&
arch
)
const
;
public:
std
::
string
GetIncludeHeader
()
const
;
std
::
vector
<
Solution
>
GetSolutions
(
const
std
::
string
&
arch
)
const
;
};
}
// namespace device_gemm_multiple_d
...
...
library/src/jit_library/src/common.cpp
View file @
420c0312
...
...
@@ -5,6 +5,17 @@
namespace
ck
{
namespace
host
{
std
::
string
ToString
(
DataType
dt
)
{
switch
(
dt
)
{
case
DataType
::
Float
:
return
"float"
;
case
DataType
::
Half
:
return
"ck::half_t"
;
case
DataType
::
Int8
:
return
"int8_t"
;
case
DataType
::
Int32
:
return
"int32_t"
;
}
throw
std
::
runtime_error
(
"Incorrect data type"
);
}
std
::
unordered_map
<
std
::
string
,
std
::
pair
<
const
char
*
,
const
char
*>>
GetHeaders
()
{
return
ck_headers
();
...
...
library/src/jit_library/src/device_gemm_multiple_d.cpp
View file @
420c0312
...
...
@@ -46,7 +46,7 @@ const std::unordered_set<std::string>& get_xdlop_archs()
std
::
vector
<
std
::
string
>
Problem
::
GetInstances
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
std
::
string
>
instances
;
const
bool
quantize
=
ADataType
==
"int8_t"
and
BDataType
==
"int8_t"
;
const
bool
quantize
=
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
;
if
(
get_xdlop_archs
().
find
(
arch
)
!=
get_xdlop_archs
().
end
())
{
instance
::
gemm_add_add_fastgelu_instances
all_instances
{};
...
...
@@ -62,7 +62,7 @@ std::vector<std::string> Problem::GetInstances(const std::string& arch) const
return
instances
;
}
std
::
string
Problem
::
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
const
std
::
string
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
{
std
::
string
layout_tuple
=
"ck::Tuple<"
;
auto
it
=
layouts
.
begin
();
...
...
@@ -77,13 +77,13 @@ std::string Problem::MakeLayoutTuple(const std::vector<bool>& layouts) const
return
layout_tuple
+
">"
;
}
std
::
string
Problem
::
MakeTypeTuple
(
const
std
::
vector
<
std
::
string
>&
types
)
const
std
::
string
MakeTypeTuple
(
const
std
::
vector
<
DataType
>&
types
)
{
std
::
string
type_tuple
=
"ck::Tuple<"
;
auto
it
=
types
.
begin
();
while
(
it
!=
types
.
end
())
{
type_tuple
+=
*
it
;
type_tuple
+=
ToString
(
*
it
)
;
it
=
std
::
next
(
it
);
if
(
it
!=
types
.
end
())
type_tuple
+=
", "
;
...
...
@@ -98,14 +98,14 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
std
::
vector
<
std
::
string
>
params
(
std
::
istream_iterator
<
std
::
string
>
{
iss
},
std
::
istream_iterator
<
std
::
string
>
());
if
(
ADataType
==
"int8_t"
and
BDataType
==
"int8_t"
)
if
(
ADataType
==
DataType
::
Int8
and
BDataType
==
DataType
::
Int8
)
{
// Change CBlockTransfer ScalarPerVector if Ds contains other types
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
"ck
::
h
alf
_t"
;
}))
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
H
alf
;
}))
{
params
[
params
.
size
()
-
3
]
=
"8"
;
}
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
"f
loat
"
;
}))
if
(
std
::
any_of
(
DsDataType
.
begin
(),
DsDataType
.
end
(),
[](
auto
t
)
{
return
t
==
DataType
::
F
loat
;
}))
{
params
[
params
.
size
()
-
3
]
=
"4"
;
}
...
...
@@ -113,10 +113,10 @@ Solution Problem::MakeSolution(std::size_t idx, const std::string& arch) const
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
Ds
Layout
);
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
Ds
Trans
);
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
DsDataType
);
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
params
[
e_data_type_idx
]
=
EDataType
;
params
[
e_data_type_idx
]
=
ToString
(
EDataType
)
;
auto
block_size_str
=
params
[
block_size_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
auto
n_per_block_str
=
params
[
n_per_block_idx
];
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment