Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f3e5a74e
Commit
f3e5a74e
authored
Dec 05, 2024
by
Aleksander Dudek
Browse files
Gemm Kernel Refactor part1
parent
feb9a2bd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
46 additions
and
28 deletions
+46
-28
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+46
-28
No files found.
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
f3e5a74e
...
@@ -28,6 +28,10 @@ struct GemmKernel
...
@@ -28,6 +28,10 @@ struct GemmKernel
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
static
constexpr
auto
I0
=
number
<
0
>
();
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
);
...
@@ -139,13 +143,11 @@ struct GemmKernel
...
@@ -139,13 +143,11 @@ struct GemmKernel
return
true
;
return
true
;
}
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
CK_TILE_DEVICE
auto
make_gemm_tensor_views
(
const
ADataType
*
a_start
,
const
BDataType
*
b_start
,
CDataType
*
c_start
,
const
GemmCommonKargs
&
kargs
)
const
{
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -188,7 +190,43 @@ struct GemmKernel
...
@@ -188,7 +190,43 @@ struct GemmKernel
}
}
}();
}();
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
return
make_tuple
(
a_tensor_view
,
b_tensor_view
,
c_tensor_view
);
}
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
// Convert pointers to tensor views
const
auto
gemm_tensor_views_tuple
=
make_gemm_tensor_views
(
a_start
,
b_start
,
c_start
,
kargs
);
auto
a_pad_view
=
[
&
]()
{
auto
a_pad_view
=
[
&
]()
{
auto
a_tensor_view
=
gemm_tensor_views_tuple
.
at
(
I0
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
...
@@ -212,6 +250,7 @@ struct GemmKernel
...
@@ -212,6 +250,7 @@ struct GemmKernel
{
i_m
,
0
});
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
auto
b_pad_view
=
[
&
]()
{
auto
b_tensor_view
=
gemm_tensor_views_tuple
.
at
(
I1
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
...
@@ -242,29 +281,8 @@ struct GemmKernel
...
@@ -242,29 +281,8 @@ struct GemmKernel
auto
c_block_tile
=
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
auto
c_pad_view
=
[
&
]()
{
auto
c_tensor_view
=
gemm_tensor_views_tuple
.
at
(
I2
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment