Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5ab76075
Commit
5ab76075
authored
Oct 24, 2024
by
Aleksander Dudek
Browse files
Batched gemm - passed batch args
parent
533204d6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
12 deletions
+36
-12
example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc
example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc
+36
-12
No files found.
example/ck_tile/05_batched_gemm/run_batched_gemm_example.inc
View file @
5ab76075
...
...
@@ -14,6 +14,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile
::
index_t
stride_B
,
ck_tile
::
index_t
stride_C
,
ck_tile
::
index_t
kbatch
,
ck_tile
::
index_t
batch_stride_A
,
ck_tile
::
index_t
batch_stride_B
,
ck_tile
::
index_t
batch_stride_C
,
ck_tile
::
index_t
batch_count
,
int
n_warmup
,
int
n_repeat
)
{
...
...
@@ -28,6 +32,10 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args
.
stride_A
=
stride_A
;
args
.
stride_B
=
stride_B
;
args
.
stride_C
=
stride_C
;
args
.
batch_stride_A
=
batch_stride_A
;
args
.
batch_stride_B
=
batch_stride_B
;
args
.
batch_stride_C
=
batch_stride_C
;
args
.
batch_count
=
batch_count
;
float
ave_time
=
gemm_calc
<
ALayout
,
BLayout
,
CLayout
>
(
args
,
ck_tile
::
stream_config
{
nullptr
,
true
,
1
,
n_warmup
,
n_repeat
});
...
...
@@ -63,6 +71,18 @@ int run_batched_gemm_example(int argc, char* argv[])
ck_tile::index_t stride_C = arg_parser.get_int("
stride_c
");
ck_tile::index_t batch_size = arg_parser.get_int("
b
");
ck_tile::index_t batch_stride_A = arg_parser.get_int("
batch_stride_a
");
ck_tile::index_t batch_stride_B = arg_parser.get_int("
batch_stride_b
");
ck_tile::index_t batch_stride_C = arg_parser.get_int("
batch_stride_c
");
ck_tile::index_t batch_count = arg_parser.get_int("
batch_count
");
std::cout << "
Received
args
:
" << std::endl;
std::cout << "
batch_stride_A
:
" << batch_stride_A << '
\n
'
<< "
batch_stride_B
:
" << batch_stride_B << '
\n
'
<< "
batch_stride_C
:
" << batch_stride_C << '
\n
'
<< "
batch_count
:
" << batch_count << std::endl;
int n_warmup = arg_parser.get_int("
warmup
");
int n_repeat = arg_parser.get_int("
repeat
");
...
...
@@ -137,6 +157,10 @@ int run_batched_gemm_example(int argc, char* argv[])
stride_B,
stride_C,
batch_size,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count,
n_warmup,
n_repeat);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment