Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
84213e27
Commit
84213e27
authored
Sep 08, 2023
by
Harisankar Sadasivan
Browse files
modified for correctness pr#881
parent
1c03a65d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
28 deletions
+33
-28
example/53_gemv_splitk/README.md
example/53_gemv_splitk/README.md
+14
-14
example/53_gemv_splitk/common.hpp
example/53_gemv_splitk/common.hpp
+11
-8
example/53_gemv_splitk/run_gemv_splitk_example.inc
example/53_gemv_splitk/run_gemv_splitk_example.inc
+8
-6
No files found.
example/53_gemv_splitk/README.md
View file @
84213e27
# Instructions for ```example_gem
m_xdl
```
# Instructions for ```example_gem
v_splitk
```
## Run ```example_gem
m_xdl
```
## Run ```example_gem
v_splitk
```
```
bash
#arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1)
#arg4: number of split
K
batches
./bin/example_gem
m_xdl
0 1 5 151
#arg4: number of split
k
batches
./bin/example_gem
v_splitk
0 1 5 151
```
Result (MI
10
0 @
1087
Mhz, 1
33.
5TFlops peak FP16)
Result (MI
25
0 @
800
Mhz, 1
81.0
5TFlops peak FP16)
```
a_m_k: dim 2, lengths {
3840, 4096
}, strides {4
096
, 1}
b_k_n: dim 2, lengths {4
096, 4096
}, strides {1, 4
096
}
c_m_n: dim 2, lengths {
3840, 4096
}, strides {
4096
, 1}
arg.a_grid_desc_k0_m_k1_{
512, 3840
, 8}
arg.b_grid_desc_k0_n_k1_{
512, 4096
, 8}
arg.c_grid_desc_m_n_{
3840, 4096
}
launch_and_time_kernel: grid_dim {
480
, 1, 1}, block_dim {
25
6, 1, 1}
a_m_k: dim 2, lengths {
1, 4608
}, strides {4
608
, 1}
b_k_n: dim 2, lengths {4
608, 1104
}, strides {1, 4
608
}
c_m_n: dim 2, lengths {
1, 1104
}, strides {
1104
, 1}
arg.a_grid_desc_
kbatch_
k0_m_k1_{
1,4, 1
, 8}
arg.b_grid_desc_
kbatch_
k0_n_k1_{
1,4, 1104
, 8}
arg.c_grid_desc_m_n_{
1, 1104
}
launch_and_time_kernel: grid_dim {
1359
, 1, 1}, block_dim {6
4
, 1, 1}
Warm up
Start running
5
times...
Perf:
1.19685 ms, 107.657 TFlops, 78.8501
GB/s
Start running
10
times...
Perf:
0.0191358 ms, 0.531698 TFlops,532.295
GB/s
```
example/53_gemv_splitk/common.hpp
View file @
84213e27
...
...
@@ -55,25 +55,28 @@ parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfi
{
// use default case
}
else
if
(
argc
==
4
)
else
if
(
argc
==
5
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
}
else
if
(
argc
==
1
0
)
else
if
(
argc
==
1
1
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
problem_size
.
k_batch
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
4
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
M
=
std
::
stoi
(
argv
[
5
]);
problem_size
.
N
=
std
::
stoi
(
argv
[
6
]);
problem_size
.
K
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
S
trideA
=
std
::
stoi
(
argv
[
7
]);
problem_size
.
S
trideB
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
S
trideC
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
s
tride
_
A
=
std
::
stoi
(
argv
[
8
]);
problem_size
.
s
tride
_
B
=
std
::
stoi
(
argv
[
9
]);
problem_size
.
s
tride
_
C
=
std
::
stoi
(
argv
[
10
]);
}
else
{
...
...
example/53_gemv_splitk/run_gemv_splitk_example.inc
View file @
84213e27
...
...
@@ -103,12 +103,13 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
return
true
;
}
c_m_n_device_buf
.
Zero
();
c_m_n_device_buf
.
Set
Zero
();
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
// Run prior to verification
if
(
config
.
do_verification
)
{
auto
ref_gemv
=
ReferenceGemmInstance
{};
auto
ref_invoker
=
ref_gemv
.
MakeInvoker
();
...
...
@@ -124,11 +125,9 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
c_m_n_device_result
=
c_m_n_device_result_converted
.
CopyAsType
<
CDataType
>
();
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
#else
c_m_n_device_buf
.
FromDevice
(
c_m_n_device_result
.
mData
.
data
());
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#endif
}
...
...
@@ -146,13 +145,16 @@ bool run_gemv(const ProblemSize& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemv
.
GetTypeString
()
<<
std
::
endl
;
return
true
;
#ifdef BUILD_INT4_EXAMPLE
return
ck
::
utils
::
check_err
(
c_m_n_device_result_converted
,
c_m_n_host_result
);
#else
return
ck
::
utils
::
check_err
(
c_m_n_device_result
,
c_m_n_host_result
);
#endif
}
bool
run_gemv_example
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
// problem_size.M = 1;
ExecutionConfig
config
;
if
(
argc
==
1
)
{
...
...
@@ -185,7 +187,7 @@ bool run_gemv_example(int argc, char* argv[])
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4:
KBatch
\n
"
);
printf
(
"arg4:
splitk
\n
"
);
printf
(
"arg5 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC
\n
"
);
exit
(
0
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment