Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dc65f4c6
Commit
dc65f4c6
authored
May 24, 2023
by
Alan Turner
Browse files
Use vectors for Ds types and layouts params
parent
e2878e25
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
70 deletions
+51
-70
library/src/jit_library/include/device_gemm_multiple_d.hpp
library/src/jit_library/include/device_gemm_multiple_d.hpp
+51
-70
No files found.
library/src/jit_library/include/device_gemm_multiple_d.hpp
View file @
dc65f4c6
...
...
@@ -26,25 +26,6 @@ struct Solution
std
::
string
template_str
;
index_t
block_size
;
index_t
grid_size
;
Solution
(
std
::
string
s
,
index_t
b
,
index_t
g
)
:
template_str
(
s
),
block_size
(
b
),
grid_size
(
g
)
{}
auto
GetStr
()
const
{
return
template_str
;
}
auto
GetBlockSize
()
const
{
return
block_size
;
}
auto
GetGridSize
()
const
{
return
grid_size
;
}
};
std
::
string
GetGemmSpec
(
const
index_t
m
,
...
...
@@ -84,20 +65,20 @@ const std::unordered_set<std::string>& get_xdlop_archs()
struct
Problem
{
index_t
M
;
index_t
N
;
index_t
K
;
index_t
NumDTensors
;
bool
Trans
A
;
bool
Trans
B
;
bool
TransCDE
;
std
::
string
ADataType
;
std
::
string
BDataType
;
std
::
string
CD
EDataType
;
std
::
string
AElementOp
;
std
::
string
B
ElementOp
;
std
::
string
CDE
ElementOp
;
std
::
string
CDE
Layout
;
index_t
M
=
0
;
index_t
N
=
0
;
index_t
K
=
0
;
bool
TransA
=
false
;
bool
Trans
B
=
false
;
bool
Trans
E
=
false
;
std
::
vector
<
bool
>
DsLayout
=
{}
;
std
::
string
ADataType
=
"ck::half_t"
;
std
::
string
BDataType
=
"ck::half_t"
;
std
::
string
EDataType
=
"ck::half_t"
;
std
::
vector
<
std
::
string
>
DsDataType
=
{}
;
std
::
string
A
ElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
B
ElementOp
=
"ck::tensor_operation::element_wise::PassThrough"
;
std
::
string
CDE
ElementOp
=
"ck::Tuple<>"
;
static
const
index_t
ds_layout_idx
=
3
;
static
const
index_t
ds_data_type_idx
=
9
;
...
...
@@ -110,6 +91,7 @@ struct Problem
static
const
index_t
n_per_block_idx
=
18
;
static
const
index_t
k_per_block_idx
=
19
;
private:
auto
GetInstances
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
std
::
string
>
instances
;
...
...
@@ -128,45 +110,33 @@ struct Problem
return
instances
;
}
auto
GetHeaders
()
const
{
return
ck_headers
();
}
auto
GetIncludeHeader
()
const
auto
MakeLayoutTuple
(
const
std
::
vector
<
bool
>&
layouts
)
const
{
return
instance
::
gemm_add_add_fastgelu_instances
{}.
get_include_header
();
std
::
string
layout_tuple
=
"ck::Tuple<"
;
auto
it
=
layouts
.
begin
();
while
(
it
!=
layouts
.
end
())
{
layout_tuple
+=
*
it
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
it
=
std
::
next
(
it
);
if
(
it
!=
layouts
.
end
())
layout_tuple
+=
", "
;
}
return
layout_tuple
+
">"
;
}
Problem
(
index_t
m
,
index_t
n
,
index_t
k
,
index_t
numDTensors
,
bool
transA
,
bool
transB
,
bool
transCDE
,
std
::
string
aDataType
,
std
::
string
bDataType
,
std
::
string
cdeDataType
,
std
::
string
aElementOp
,
std
::
string
bElementOp
,
std
::
string
cdeElementOp
,
std
::
string
cdeLayout
)
:
M
(
m
),
N
(
n
),
K
(
k
),
NumDTensors
(
numDTensors
),
TransA
(
transA
),
TransB
(
transB
),
TransCDE
(
transCDE
),
ADataType
(
aDataType
),
BDataType
(
bDataType
),
CDEDataType
(
cdeDataType
),
AElementOp
(
aElementOp
),
BElementOp
(
bElementOp
),
CDEElementOp
(
cdeElementOp
),
CDELayout
(
cdeLayout
)
auto
MakeTypeTuple
(
const
std
::
vector
<
std
::
string
>&
types
)
const
{
std
::
string
type_tuple
=
"ck::Tuple<"
;
auto
it
=
types
.
begin
();
while
(
it
!=
types
.
end
())
{
type_tuple
+=
*
it
;
it
=
std
::
next
(
it
);
if
(
it
!=
types
.
end
())
type_tuple
+=
", "
;
}
return
type_tuple
+
">"
;
}
auto
MakeSolution
(
index_t
idx
,
const
std
::
string
&
arch
)
const
...
...
@@ -178,8 +148,8 @@ struct Problem
params
[
a_elementwise_op_idx
]
=
AElementOp
;
params
[
b_elementwise_op_idx
]
=
BElementOp
;
params
[
ds_layout_idx
]
=
CDE
Layout
;
params
[
ds_data_type_idx
]
=
CDE
DataType
;
params
[
ds_layout_idx
]
=
MakeLayoutTuple
(
Ds
Layout
)
;
params
[
ds_data_type_idx
]
=
MakeTypeTuple
(
Ds
DataType
)
;
params
[
ds_elementwise_op_idx
]
=
CDEElementOp
;
auto
block_size_str
=
params
[
block_size_idx
];
auto
m_per_block_str
=
params
[
m_per_block_idx
];
...
...
@@ -201,6 +171,17 @@ struct Problem
return
Solution
{
str
,
block_size
,
grid_size
};
}
public:
auto
GetHeaders
()
const
{
return
ck_headers
();
}
auto
GetIncludeHeader
()
const
{
return
instance
::
gemm_add_add_fastgelu_instances
{}.
get_include_header
();
}
auto
GetSolutions
(
const
std
::
string
&
arch
)
const
{
std
::
vector
<
Solution
>
solutions
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment