Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
71eea17c
"...composable_kernel_rocm.git" did not exist on "4b616aad52807740908071e90e06e184d3177357"
Commit
71eea17c
authored
Oct 29, 2024
by
Aleksander Dudek
Browse files
Batched gemm - counting strides
parent
5ab76075
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
15 deletions
+19
-15
example/ck_tile/05_batched_gemm/batched_gemm_basic.cpp
example/ck_tile/05_batched_gemm/batched_gemm_basic.cpp
+1
-1
example/ck_tile/05_batched_gemm/batched_gemm_basic.hpp
example/ck_tile/05_batched_gemm/batched_gemm_basic.hpp
+10
-10
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+8
-4
No files found.
example/ck_tile/05_batched_gemm/batched_gemm_basic.cpp
View file @
71eea17c
...
@@ -96,7 +96,7 @@ float gemm_calc(const batched_gemm_basic_args& args, const ck_tile::stream_confi
...
@@ -96,7 +96,7 @@ float gemm_calc(const batched_gemm_basic_args& args, const ck_tile::stream_confi
args
.
batch_stride_C
,
args
.
batch_stride_C
,
args
.
batch_count
);
args
.
batch_count
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
k
batch
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
M
,
args
.
N
,
args
.
batch
_count
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
if
(
s
.
log_level_
>
0
)
if
(
s
.
log_level_
>
0
)
...
...
example/ck_tile/05_batched_gemm/batched_gemm_basic.hpp
View file @
71eea17c
...
@@ -73,16 +73,16 @@ auto create_args(int argc, char* argv[])
...
@@ -73,16 +73,16 @@ auto create_args(int argc, char* argv[])
{
{
ck_tile
::
ArgParser
arg_parser
;
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
arg_parser
.
insert
(
"b"
,
"1"
,
"batch size"
)
.
insert
(
"m"
,
"
3840
"
,
"m dimension"
)
.
insert
(
"m"
,
"
256
"
,
"m dimension"
)
.
insert
(
"n"
,
"
4096
"
,
"n dimension"
)
.
insert
(
"n"
,
"
128
"
,
"n dimension"
)
.
insert
(
"k"
,
"
4096
"
,
"k dimension"
)
.
insert
(
"k"
,
"
128
"
,
"k dimension"
)
.
insert
(
"stride_a"
,
"
0
"
,
"Tensor A stride"
)
.
insert
(
"stride_a"
,
"
128
"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"
0
"
,
"Tensor B stride"
)
.
insert
(
"stride_b"
,
"
128
"
,
"Tensor B stride"
)
.
insert
(
"stride_c"
,
"
0
"
,
"Tensor C stride"
)
.
insert
(
"stride_c"
,
"
128
"
,
"Tensor C stride"
)
.
insert
(
"batch_stride_a"
,
"
0
"
,
"Batch A stride"
)
.
insert
(
"batch_stride_a"
,
"
32768
"
,
"Batch A stride"
)
.
insert
(
"batch_stride_b"
,
"
0
"
,
"Batch B stride"
)
.
insert
(
"batch_stride_b"
,
"
16384
"
,
"Batch B stride"
)
.
insert
(
"batch_stride_c"
,
"
0
"
,
"Batch C stride"
)
.
insert
(
"batch_stride_c"
,
"
32768
"
,
"Batch C stride"
)
.
insert
(
"batch_count"
,
"1"
,
"Batch count"
)
.
insert
(
"batch_count"
,
"1
6
"
,
"Batch count"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"v"
,
"2"
,
"0. No validation, 1. Validation on CPU, 2. Validation on GPU"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"prec"
,
"fp16"
,
"data type. fp16/bf16/fp8/bf8"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
.
insert
(
"warmup"
,
"50"
,
"number of iterations before benchmark the kernel"
)
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
71eea17c
...
@@ -89,9 +89,12 @@ struct BatchedGemmKernel
...
@@ -89,9 +89,12 @@ struct BatchedGemmKernel
CK_TILE_DEVICE
void
operator
()(
BatchedGemmCommonKargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
BatchedGemmCommonKargs
kargs
)
const
{
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
// const auto i_k = blockIdx.z;
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
// options
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
//+ __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
//+ __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B);
// Convert pointers to tensor views
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
...
@@ -169,7 +172,8 @@ struct BatchedGemmKernel
...
@@ -169,7 +172,8 @@ struct BatchedGemmKernel
auto
c_block_tile
=
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
//; + __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_C);
auto
c_tensor_view
=
[
&
]()
{
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment