Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
228bc13a
Unverified
Commit
228bc13a
authored
Dec 19, 2024
by
M.Emin Ozturk
Committed by
GitHub
Dec 19, 2024
Browse files
Merge branch 'develop' into gemm_bf16_sk_muozturk
parents
39bc3b83
e758d006
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
78 additions
and
22 deletions
+78
-22
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
+13
-7
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
+21
-13
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+44
-2
No files found.
example/ck_tile/17_grouped_gemm/grouped_gemm.hpp
View file @
228bc13a
...
@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
...
@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto
create_args
(
int
argc
,
char
*
argv
[])
auto
create_args
(
int
argc
,
char
*
argv
[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
arg_parser
.
insert
(
"Ms"
,
""
,
"M dimensions - empty by default."
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default"
)
.
insert
(
"Ns"
,
""
,
"N dimensions - empty by default."
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"Ks"
,
""
,
"K dimensions - empty by default."
)
.
insert
(
"validate"
,
"1"
,
"0. No validation, 1. Validation on CPU"
)
.
insert
(
"stride_As"
,
""
,
"Tensor A strides - it is empty by default."
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"stride_Bs"
,
""
,
"Tensor B strides - it is empty by default."
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel"
)
.
insert
(
"stride_Cs"
,
""
,
"Tensor C strides - it is empty by default."
)
.
insert
(
"group_count"
,
"16"
,
"group count"
);
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default."
)
.
insert
(
"b_layout"
,
"R"
,
"B tensor data layout - Row by default."
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default."
)
.
insert
(
"validate"
,
"1"
,
"0. No validation, 1. Validation on CPU."
)
.
insert
(
"warmup"
,
"10"
,
"number of iterations before benchmark the kernel."
)
.
insert
(
"repeat"
,
"100"
,
"number of iterations to benchmark the kernel."
)
.
insert
(
"group_count"
,
"16"
,
"group count."
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
return
std
::
make_tuple
(
result
,
arg_parser
);
...
...
example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc
View file @
228bc13a
...
@@ -53,26 +53,34 @@ int run_grouped_gemm_example_with_layouts(int argc,
...
@@ -53,26 +53,34 @@ int run_grouped_gemm_example_with_layouts(int argc,
return
-
1
;
return
-
1
;
};
};
auto
valid_input_data
=
[
&
](
int
group_count
,
const
auto
&...
args
)
{
return
!
(
args
.
empty
()
||
...
)
&&
group_count
==
(
args
.
size
()
==
...
);
};
const
int
group_count
=
arg_parser
.
get_int
(
"group_count"
);
const
int
group_count
=
arg_parser
.
get_int
(
"group_count"
);
const
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
const
int
repeat
=
arg_parser
.
get_int
(
"repeat"
);
const
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
const
int
warmup
=
arg_parser
.
get_int
(
"warmup"
);
std
::
vector
<
ck_tile
::
index_t
>
Ms
;
std
::
vector
<
ck_tile
::
index_t
>
Ms
=
arg_parser
.
get_int_vec
(
"Ms"
)
;
std
::
vector
<
ck_tile
::
index_t
>
Ns
;
std
::
vector
<
ck_tile
::
index_t
>
Ns
=
arg_parser
.
get_int_vec
(
"Ns"
)
;
std
::
vector
<
ck_tile
::
index_t
>
Ks
;
std
::
vector
<
ck_tile
::
index_t
>
Ks
=
arg_parser
.
get_int_vec
(
"Ks"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_As
;
std
::
vector
<
ck_tile
::
index_t
>
stride_As
=
arg_parser
.
get_int_vec
(
"stride_As"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Bs
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Bs
=
arg_parser
.
get_int_vec
(
"stride_Bs"
)
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Cs
;
std
::
vector
<
ck_tile
::
index_t
>
stride_Cs
=
arg_parser
.
get_int_vec
(
"stride_Cs"
)
;
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
if
(
!
valid_input_data
(
group_count
,
Ms
,
Ns
,
Ks
,
stride_As
,
stride_Bs
,
stride_Cs
)
)
{
{
Ms
.
push_back
(
256
+
256
*
i
);
std
::
cout
<<
"Please check the input data. Default values will be used."
<<
std
::
endl
;
Ns
.
push_back
(
128
+
128
*
i
);
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
Ks
.
push_back
(
128
+
64
*
i
);
{
Ms
.
push_back
(
256
+
256
*
i
);
Ns
.
push_back
(
128
+
128
*
i
);
Ks
.
push_back
(
128
+
64
*
i
);
stride_As
.
push_back
(
Ks
[
i
]);
stride_As
.
push_back
(
Ks
[
i
]);
stride_Bs
.
push_back
(
Ks
[
i
]);
stride_Bs
.
push_back
(
Ks
[
i
]);
stride_Cs
.
push_back
(
Ns
[
i
]);
stride_Cs
.
push_back
(
Ns
[
i
]);
}
}
}
std
::
vector
<
ck_tile
::
HostTensor
<
ADataType
>>
a_m_k_tensors
;
std
::
vector
<
ck_tile
::
HostTensor
<
ADataType
>>
a_m_k_tensors
;
...
...
include/ck_tile/host/arg_parser.hpp
View file @
228bc13a
...
@@ -15,11 +15,14 @@
...
@@ -15,11 +15,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
/*
/*
* a host side utility, arg parser for
* a host side utility, arg parser for, either
* -[key0]=[value0] -[key1]=[value1] ...
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
*/
class
ArgParser
class
ArgParser
{
{
public:
public:
class
Arg
class
Arg
{
{
...
@@ -187,6 +190,45 @@ class ArgParser
...
@@ -187,6 +190,45 @@ class ArgParser
return
value
;
return
value
;
}
}
std
::
vector
<
std
::
string
>
get_string_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
std
::
string
s
=
get_str
(
name
);
std
::
vector
<
std
::
string
>
tokens
;
size_t
pos
=
0
;
std
::
string
token
;
while
((
pos
=
s
.
find
(
delimiter
))
!=
std
::
string
::
npos
)
{
token
=
s
.
substr
(
0
,
pos
);
tokens
.
push_back
(
token
);
s
.
erase
(
0
,
pos
+
delimiter
.
length
());
}
tokens
.
push_back
(
s
);
return
tokens
;
}
std
::
vector
<
int
>
get_int_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
const
std
::
vector
<
std
::
string
>
args
=
get_string_vec
(
name
,
delimiter
);
std
::
vector
<
int
>
tokens
;
tokens
.
reserve
(
static_cast
<
int
>
(
args
.
size
()));
for
(
const
std
::
string
&
token
:
args
)
{
int
value
=
atoi
(
token
.
c_str
());
tokens
.
push_back
(
value
);
}
return
tokens
;
}
private:
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
std
::
vector
<
std
::
string
>
keys
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment