Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
799cde32
"Src/vscode:/vscode.git/clone" did not exist on "ae781c5430a1a4aeb531421f727102997a2c0d3b"
Commit
799cde32
authored
Dec 11, 2024
by
Aleksander Dudek
Browse files
[CK TILE] Refactor GemmKernel - review changes part1
parent
ed528d76
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
59 additions
and
56 deletions
+59
-56
example/ck_tile/16_batched_gemm/batched_gemm.hpp
example/ck_tile/16_batched_gemm/batched_gemm.hpp
+1
-1
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+5
-14
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+52
-40
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+1
-1
No files found.
example/ck_tile/16_batched_gemm/batched_gemm.hpp
View file @
799cde32
...
...
@@ -29,7 +29,7 @@ using BDataType = Types::BDataType;
using
AccDataType
=
Types
::
AccDataType
;
using
CDataType
=
Types
::
CDataType
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmH
ostA
rgs
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmH
a
rgs
{
};
...
...
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
799cde32
...
...
@@ -7,17 +7,8 @@
namespace
ck_tile
{
struct
BatchedGemmH
ostA
rgs
struct
BatchedGemmH
args
:
GemmHa
rgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
batch_stride_A
;
index_t
batch_stride_B
;
index_t
batch_stride_C
;
...
...
@@ -29,7 +20,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
{
using
Base
=
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
;
using
Gemm
Common
Kargs
=
typename
Base
::
Gemm
Common
Kargs
;
using
GemmKargs
=
typename
Base
::
GemmKargs
;
using
ADataType
=
typename
Base
::
ADataType
;
using
BDataType
=
typename
Base
::
BDataType
;
...
...
@@ -42,7 +33,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
struct
BatchedGemmKargs
:
Gemm
Common
Kargs
struct
BatchedGemmKargs
:
GemmKargs
{
index_t
batch_stride_A
;
index_t
batch_stride_B
;
...
...
@@ -51,7 +42,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
};
using
Kargs
=
BatchedGemmKargs
;
using
Hargs
=
BatchedGemmH
ostA
rgs
;
using
Hargs
=
BatchedGemmH
a
rgs
;
__host__
static
constexpr
auto
GridSize
(
const
Hargs
&
k
)
{
...
...
@@ -102,7 +93,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
this
->
r
un
_common_gemm_pipeline
(
a_start
,
b_start
,
c_start
,
kargs
,
i_m
,
i_n
);
this
->
R
un
Gemm
(
a_start
,
b_start
,
c_start
,
kargs
,
i_m
,
i_n
);
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
799cde32
...
...
@@ -12,6 +12,19 @@
namespace
ck_tile
{
struct
GemmHargs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
...
...
@@ -25,7 +38,6 @@ struct GemmKernel
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
static
constexpr
auto
I0
=
number
<
0
>
();
...
...
@@ -39,7 +51,7 @@ struct GemmKernel
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
Gemm
Common
Kargs
struct
GemmKargs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
...
...
@@ -52,7 +64,7 @@ struct GemmKernel
index_t
stride_C
;
};
CK_TILE_HOST
static
constexpr
Gemm
Common
Kargs
MakeKargs
(
const
void
*
a_ptr
,
CK_TILE_HOST
static
constexpr
GemmKargs
MakeKargs
(
const
void
*
a_ptr
,
const
void
*
b_ptr
,
void
*
c_ptr
,
index_t
M
,
...
...
@@ -62,7 +74,7 @@ struct GemmKernel
index_t
stride_B
,
index_t
stride_C
)
{
return
Gemm
Common
Kargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
return
GemmKargs
{
a_ptr
,
b_ptr
,
c_ptr
,
M
,
N
,
K
,
stride_A
,
stride_B
,
stride_C
};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -70,7 +82,7 @@ struct GemmKernel
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
Gemm
Common
Kargs
&
kargs
)
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKargs
&
kargs
)
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -143,16 +155,16 @@ struct GemmKernel
return
true
;
}
CK_TILE_DEVICE
auto
m
ake
_g
emm
_t
ensor
_v
iews
(
const
ADataType
*
a_
start
,
const
BDataType
*
b_
start
,
CDataType
*
c_
start
,
const
Gemm
Common
Kargs
&
kargs
)
const
CK_TILE_DEVICE
auto
M
ake
G
emm
T
ensor
V
iews
(
const
ADataType
*
a_
ptr
,
const
BDataType
*
b_
ptr
,
CDataType
*
c_
ptr
,
const
GemmKargs
&
kargs
)
const
{
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_
start
,
a_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
...
...
@@ -161,7 +173,7 @@ struct GemmKernel
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_
start
,
a_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
...
...
@@ -173,7 +185,7 @@ struct GemmKernel
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_
start
,
b_
ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
...
...
@@ -182,7 +194,7 @@ struct GemmKernel
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_
start
,
b_
ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
...
...
@@ -194,7 +206,7 @@ struct GemmKernel
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_
start
,
c_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
...
...
@@ -203,7 +215,7 @@ struct GemmKernel
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_
start
,
c_
ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
...
...
@@ -215,7 +227,7 @@ struct GemmKernel
}
template
<
typename
TensorView
>
CK_TILE_DEVICE
auto
m
ake
_g
emm
_pad_v
iews
(
TensorView
&
&
views
)
const
CK_TILE_DEVICE
auto
M
ake
G
emm
PadV
iews
(
TensorView
&
views
)
const
{
auto
a_pad_view
=
[
&
]()
{
auto
a_tensor_view
=
views
.
at
(
I0
);
...
...
@@ -276,7 +288,7 @@ struct GemmKernel
template
<
typename
PadView
>
CK_TILE_DEVICE
auto
m
ake
_g
emm
_t
ile
_w
indows
(
PadView
&
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
M
ake
G
emm
T
ile
W
indows
(
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
{
auto
a_pad_view
=
views
.
at
(
I0
);
auto
a_block_window
=
make_tile_window
(
...
...
@@ -299,18 +311,18 @@ struct GemmKernel
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
}
CK_TILE_DEVICE
void
r
un
_common_gemm_pipeline
(
const
ADataType
*
a_
start
,
const
BDataType
*
b_
start
,
CDataType
*
c_
start
,
const
Gemm
Common
Kargs
&
kargs
,
const
index_t
i
_m
,
const
index_t
i
_n
)
const
CK_TILE_DEVICE
void
R
un
Gemm
(
const
ADataType
*
a_
ptr
,
const
BDataType
*
b_
ptr
,
CDataType
*
c_
ptr
,
const
GemmKargs
&
kargs
,
const
index_t
block_idx
_m
,
const
index_t
block_idx
_n
)
const
{
// Convert pointers to tensor views
const
auto
gemm_tensor_views_tuple
=
make_gemm_tensor_views
(
a_start
,
b_start
,
c_start
,
kargs
);
const
auto
gemm_
pad_views
=
make_gemm_pad_views
(
gemm_tensor_views_tuple
);
const
auto
gemm_tile_windows
=
m
ake
_g
emm
_t
ile
_w
indows
(
gemm_pad_views
,
i_m
,
i
_n
);
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
const
auto
&
gemm_
tile_windows
=
M
ake
G
emm
T
ile
W
indows
(
gemm_pad_views
,
block_idx_m
,
block_idx
_n
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
...
...
@@ -329,15 +341,15 @@ struct GemmKernel
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
}
CK_TILE_DEVICE
void
operator
()(
Gemm
Common
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
GemmKargs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
const
ADataType
*
a_
start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_
start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
CDataType
*
c_
start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
const
ADataType
*
a_
ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_
ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
CDataType
*
c_
ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
r
un
_common_gemm_pipeline
(
a_start
,
b_start
,
c_start
,
kargs
,
i_m
,
i_n
);
R
un
Gemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
}
};
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
799cde32
...
...
@@ -24,7 +24,7 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
AccDataType
=
std
::
tuple_element_t
<
5
,
Tuple
>
;
using
CDataType
=
std
::
tuple_element_t
<
6
,
Tuple
>
;
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmH
ostA
rgs
struct
batched_gemm_kargs
:
public
ck_tile
::
BatchedGemmH
a
rgs
{
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment