Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
853a5183
Unverified
Commit
853a5183
authored
Jan 30, 2025
by
arai713
Committed by
GitHub
Jan 30, 2025
Browse files
Merge branch 'develop' into codegen_hiprtc
parents
5574422a
e6d41804
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
386 additions
and
341 deletions
+386
-341
Jenkinsfile
Jenkinsfile
+1
-1
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+13
-24
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+14
-4
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+16
-26
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+20
-26
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+146
-151
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+97
-4
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+17
-32
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+7
-10
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+6
-9
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+5
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+1
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+3
-2
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+0
-6
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
+17
-26
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
+16
-5
No files found.
Jenkinsfile
View file @
853a5183
...
@@ -796,7 +796,7 @@ pipeline {
...
@@ -796,7 +796,7 @@ pipeline {
booleanParam
(
booleanParam
(
name:
"RUN_CK_TILE_GEMM_TESTS"
,
name:
"RUN_CK_TILE_GEMM_TESTS"
,
defaultValue:
false
,
defaultValue:
false
,
description:
"Run the ck_tile GEMM tests (default: O
FF
)"
)
description:
"Run the ck_tile GEMM tests (default: O
N
)"
)
booleanParam
(
booleanParam
(
name:
"BUILD_INSTANCES_ONLY"
,
name:
"BUILD_INSTANCES_ONLY"
,
defaultValue:
false
,
defaultValue:
false
,
...
...
docs/sphinx/requirements.in
View file @
853a5183
rocm-docs-core==1.1
4.1
rocm-docs-core==1.1
5.0
sphinxcontrib-bibtex==2.6.3
sphinxcontrib-bibtex==2.6.3
docs/sphinx/requirements.txt
View file @
853a5183
...
@@ -199,7 +199,7 @@ requests==2.32.3
...
@@ -199,7 +199,7 @@ requests==2.32.3
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==1.1
4.1
rocm-docs-core==1.1
5.0
# via -r requirements.in
# via -r requirements.in
rpds-py==0.22.3
rpds-py==0.22.3
# via
# via
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
853a5183
...
@@ -20,10 +20,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -20,10 +20,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
// This part comes from the Codegen
...
@@ -39,11 +35,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -39,11 +35,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
...
@@ -51,26 +42,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -51,26 +42,24 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -60,9 +60,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -60,9 +60,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
...
@@ -95,6 +92,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -95,6 +92,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using
GemmPipeline
=
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
GEMM_PIPELINE
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
GemmPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
UniversalGemmProblem
::
TransposeC
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
853a5183
...
@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
...
@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
...
@@ -41,11 +38,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -41,11 +38,6 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
...
@@ -53,26 +45,24 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -53,26 +45,24 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -20,12 +20,9 @@ namespace {
...
@@ -20,12 +20,9 @@ namespace {
struct
GroupedGemmKernelParam
struct
GroupedGemmKernelParam
{
{
static
const
bool
kPadM
=
false
;
static
const
bool
kPadM
=
false
;
static
const
bool
kPadN
=
false
;
static
const
bool
kPadN
=
false
;
static
const
bool
kPadK
=
false
;
static
const
bool
kPadK
=
false
;
static
const
bool
kTilePermute
=
false
;
static
const
ck_tile
::
index_t
kOutputRank
=
2
;
static
const
int
kBlockPerCu
=
1
;
static
const
int
kBlockPerCu
=
1
;
static
const
ck_tile
::
index_t
M_Tile
=
128
;
static
const
ck_tile
::
index_t
M_Tile
=
128
;
...
@@ -54,24 +51,6 @@ using CodegenGemmShape =
...
@@ -54,24 +51,6 @@ using CodegenGemmShape =
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
template
<
typename
CLayout
>
using
GemmEpilogue
=
std
::
conditional_t
<
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
GroupedGemmKernelParam
::
kPadM
,
GroupedGemmKernelParam
::
kPadN
,
GroupedGemmKernelParam
::
kTilePermute
,
GroupedGemmKernelParam
::
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
GroupedGemmKernelParam
::
kPadM
,
GroupedGemmKernelParam
::
kPadN
>>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
GroupedGemmKernelParam
::
kPadM
,
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
GroupedGemmKernelParam
::
kPadM
,
GroupedGemmKernelParam
::
kPadN
,
GroupedGemmKernelParam
::
kPadN
,
...
@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout>
...
@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout>
using
CodegenGemmPipeline
=
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>>
;
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
GroupedGemmKernelParam
::
M_Warp
,
GroupedGemmKernelParam
::
N_Warp
,
GroupedGemmKernelParam
::
M_Warp_Tile
,
GroupedGemmKernelParam
::
N_Warp_Tile
,
GroupedGemmKernelParam
::
K_Warp_Tile
,
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>::
TransposeC
>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
<
ALayout
,
BLayout
,
CLayout
>
,
CodegenGemmPipeline
<
ALayout
,
BLayout
,
CLayout
>
,
GemmEpilogue
<
CLayout
>>
;
GemmEpilogue
<
ALayout
,
BLayout
,
CLayout
>>
;
};
// namespace
};
// namespace
std
::
size_t
get_workspace_size
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
std
::
size_t
get_workspace_size
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#
define CK_TILE_MAX_RANK 5
#
include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template
<
typename
AccDataType_
,
template
<
typename
AccDataType_
,
typename
ODataType_
,
typename
ODataType_
,
bool
kPadM_
,
typename
CLayout_
,
bool
kPadN_
,
index_t
kBlockSize_
,
bool
kTilePermute_
,
index_t
kM_
,
index_t
kRank_
,
index_t
kN_
,
index_t
kPerm0
,
index_t
kMWave_
,
index_t
kPerm1
,
index_t
kNWave_
,
index_t
TileSize0
,
index_t
kMPerXdl_
,
index_t
TileSize1
,
index_t
kNPerXdl_
,
index_t
kPerm2
=
0
,
index_t
kKPerXdl_
,
index_t
kPerm3
=
0
,
bool
isCTransposed_
>
index_t
kPerm4
=
0
,
index_t
TileSize2
=
0
,
index_t
TileSize3
=
0
,
index_t
TileSize4
=
0
>
struct
CShuffleEpilogueProblem
struct
CShuffleEpilogueProblem
{
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
bool
kTilePermute
=
kTilePermute_
;
static
constexpr
index_t
kMPerBlock
=
kM_
;
static
constexpr
index_t
kRank
=
kRank_
;
static
constexpr
index_t
kNPerBlock
=
kN_
;
static
constexpr
index_t
kPerm
[
CK_TILE_MAX_RANK
]
=
{
kPerm0
,
kPerm1
,
kPerm2
,
kPerm3
,
kPerm4
};
static
constexpr
index_t
kMWave
=
kMWave_
;
static
constexpr
index_t
tile_sizes
[
CK_TILE_MAX_RANK
]
=
{
static
constexpr
index_t
kNWave
=
kNWave_
;
TileSize0
,
TileSize1
,
TileSize2
,
TileSize3
,
TileSize4
};
static
constexpr
index_t
kMPerXdl
=
kMPerXdl_
;
static
constexpr
index_t
kNPerXdl
=
kNPerXdl_
;
static
constexpr
index_t
kKPerXdl
=
kKPerXdl_
;
static
constexpr
index_t
isCTransposed
=
isCTransposed_
;
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
CShuffleEpilogue
struct
CShuffleEpilogue
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
const
index_t
*
kPerm
=
Problem
::
kPerm
;
static
constexpr
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
bool
kTilePermute
=
Problem
::
kTilePermute
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
static
constexpr
index_t
kRank
=
Problem
::
kRank
;
static
constexpr
index_t
kMWave
=
Problem
::
kMWave
;
const
index_t
*
tile_sizes
=
Problem
::
tile_sizes
;
static
constexpr
index_t
kNWave
=
Problem
::
kNWave
;
static
constexpr
index_t
kMPerXdl
=
Problem
::
kMPerXdl
;
// No additional shared memory needed
static
constexpr
index_t
kNPerXdl
=
Problem
::
kNPerXdl
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
static
constexpr
index_t
kKPerXdl
=
Problem
::
kKPerXdl
;
static
constexpr
index_t
isCTransposed
=
Problem
::
isCTransposed
;
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
static
constexpr
index_t
kMPerIteration
=
kMPerXdl
*
kMWave
;
static
constexpr
index_t
kNPerIteration
=
kNPerXdl
*
kNWave
;
using
WG
=
WarpGemmMfmaDispatcher
<
ODataType
,
ODataType
,
AccDataType
,
kMPerXdl
,
kNPerXdl
,
kKPerXdl
,
isCTransposed
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
constexpr
index_t
MaxVectorStoreSize
=
16
;
// It should be fixed and this function should return true.
return
MaxVectorStoreSize
/
sizeof
(
ODataType
);
return
false
;
}
}
template
<
typename
OAccTi
le
>
template
<
typename
Prob
le
m
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
MakeLdsBlockDescriptor
(
)
{
{
using
DataType
=
typename
OAccTile
::
DataType
;
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
// Get thread buffer
auto
&
thread_buf
=
o_acc_tile
.
get_thread_buffer
();
// Create a temporary buffer to hold the permuted data
thread_buffer
<
DataType
,
OAccTile
::
kThreadElementSpaceSize
>
permuted_thread_buf
;
// Get the lengths of each dimension
auto
thread_tensor_lengths
=
o_acc_tile
.
get_lengths
();
// Total number of elements
index_t
total_elements
=
OAccTile
::
kThreadElementSpaceSize
;
// Iterate over all elements
for
(
index_t
linear_idx
=
0
;
linear_idx
<
total_elements
;
++
linear_idx
)
{
{
// Convert linear index to multi-dimensional indices
return
make_naive_tensor_descriptor
(
array
<
index_t
,
kRank
>
indices
;
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
index_t
remaining
=
linear_idx
;
make_tuple
(
number
<
kNWave
*
kNPerXdl
>
{},
number
<
1
>
{}));
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
indices
(
rev_i
)
=
remaining
%
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
remaining
/=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Apply the permutation
array
<
index_t
,
kRank
>
permuted_indices
;
static_for
<
0
,
kRank
,
1
>
{}(
[
&
](
auto
i
)
{
permuted_indices
(
i
)
=
indices
.
get
(
number
<
Problem
::
kPerm
[
i
]
>
{});
});
// Compute offsets
index_t
dst_offset
=
0
;
index_t
stride
=
1
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
dst_offset
+=
permuted_indices
[
rev_i
]
*
stride
;
stride
*=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Move the data
permuted_thread_buf
(
dst_offset
)
=
thread_buf
[
linear_idx
];
}
}
// M is contiguous dimension
// Copy the permuted data back to the original thread buffer
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
for
(
index_t
i
=
0
;
i
<
total_elements
;
++
i
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
make_tuple
(
number
<
1
>
{},
number
<
kMWave
*
kMPerXdl
>
{}));
}
else
{
{
thread_buf
.
set_as
(
i
,
permuted_thread_buf
.
get
(
i
)
);
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
}
}
template
<
typename
ODramWindowTmp
,
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
kMWave
*
kNWave
*
kMPerXdl
*
kNPerXdl
*
sizeof
(
ODataType
);
}
template
<
typename
ODramWindow
,
typename
OAccTile
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
CK_TILE_DEVICE
auto
operator
()(
ODramWindow
&
out_dram_window
,
const
OAccTile
&
o_acc_tile
,
void
*
p_smem
)
{
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t
tile_coords
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
tile_coords
[
i
]
=
current_window_origin
[
i
]
/
tile_sizes
[
i
];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t
permuted_tile_coords
[
CK_TILE_MAX_RANK
];
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_tile_coords
[
i
]
=
tile_coords
[
kPerm
[
i
]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
const
index_t
iMWarp
=
get_warp_id
()
/
kNWave
;
index_t
permuted_window_origin
[
CK_TILE_MAX_RANK
]
=
{
0
};
const
index_t
iNWarp
=
get_warp_id
()
-
iMWarp
*
kNWave
;
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
constexpr
auto
lds_block_desc
=
MakeLdsBlockDescriptor
<
Problem
>
();
permuted_window_origin
[
i
]
=
permuted_tile_coords
[
i
]
*
tile_sizes
[
i
];
auto
o_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
static_cast
<
ODataType
*>
(
p_smem
),
lds_block_desc
);
}
auto
in_lds_window
=
make_tile_window
(
o_lds_block
,
typename
ODramWindowTmp
::
BottomTensorIndex
step
=
{};
make_tuple
(
number
<
kMPerXdl
>
{},
number
<
kNPerXdl
>
{}),
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
number
<
kMPerXdl
>
{}
*
iMWarp
,
number
<
kNPerXdl
>
{}
*
iNWarp
});
{
auto
out_lds_window
=
step
[
i
]
=
permuted_window_origin
[
i
]
-
current_window_origin
[
i
];
make_tile_window
(
o_lds_block
,
}
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
{
0
,
0
});
using
SFC
=
space_filling_curve
<
sequence
<
kMPerBlock
,
kNPerBlock
>
,
sequence
<
0
,
1
>
,
sequence
<
kMPerXdl
*
kMWave
,
kNPerXdl
*
kNWave
>>
;
constexpr
index_t
num_access
=
SFC
::
get_num_of_access
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
kBlockSize
,
kMPerIteration
,
kNPerIteration
,
GetVectorSizeC
(),
tile_distribution_pattern
::
thread_raked
>
;
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
CWarpTensor
c_warp_in_tensor
;
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
constexpr
auto
idx_y_start
=
SFC
::
get_index
(
iAccess
);
constexpr
auto
mIter
=
number
<
idx_y_start
.
at
(
number
<
0
>
{})
/
(
kMPerXdl
*
kMWave
)
>
{};
constexpr
auto
nIter
=
number
<
idx_y_start
.
at
(
number
<
1
>
{})
/
(
kNPerXdl
*
kNWave
)
>
{};
c_warp_in_tensor
.
get_thread_buffer
()
=
o_acc_tile
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
const
auto
c_warp_in_tensor_casted
=
cast_tile
<
ODataType
>
(
c_warp_in_tensor
);
block_sync_lds
();
store_tile
(
in_lds_window
,
c_warp_in_tensor_casted
);
block_sync_lds
();
const
auto
c_out_tensor
=
load_tile
(
make_tile_window
(
out_lds_window
,
dram_tile_distribution
));
// Move the window
move_tile_window
(
o_dram_window_tmp
,
step
);
// Permute the data within the tile if necessary
if
constexpr
(
kTilePermute
)
{
permute_tile_data
(
o_acc_tile
);
}
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
{
store_tile
_raw
(
o_dram_window
_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
)
);
store_tile
(
o
ut
_dram_window
,
c_out_tensor
);
}
}
else
else
{
{
update_tile
_raw
(
o_dram_window
_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
)
);
update_tile
(
o
ut
_dram_window
,
c_out_tensor
);
}
}
buffer_store_fence
();
if
constexpr
(
iAccess
!=
num_access
-
1
)
}
else
{
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
constexpr
auto
step
=
SFC
::
get_forward_step
(
iAccess
);
move_tile_window
(
out_dram_window
,
{
step
.
at
(
number
<
0
>
{}),
step
.
at
(
number
<
1
>
{})});
}
}
else
});
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
...
@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
};
};
template
<
typename
AccDataType_
,
typename
ODataType_
,
typename
CLayout_
,
bool
kPadM_
,
bool
kPadN_
,
index_t
kMPerXdl_
,
index_t
kNPerXdl_
,
index_t
kKPerXdl_
,
bool
isCTransposed_
,
bool
UseRawStore_
=
true
>
struct
DefaultGemm2DEpilogueProblem
:
public
Default2DEpilogueProblem
<
AccDataType_
,
ODataType_
,
kPadM_
,
kPadN_
,
UseRawStore_
>
{
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
index_t
kMPerXdl
=
kMPerXdl_
;
static
constexpr
index_t
kNPerXdl
=
kNPerXdl_
;
static
constexpr
index_t
kKPerXdl
=
kKPerXdl_
;
static
constexpr
index_t
isCTransposed
=
isCTransposed_
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Default2DEpilogue
struct
Default2DEpilogue
{
{
...
@@ -35,14 +57,13 @@ struct Default2DEpilogue
...
@@ -35,14 +57,13 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
return
false
;
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
)
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
,
void
*
=
nullptr
)
{
{
// TODO: this is ugly
// TODO: this is ugly
...
@@ -71,4 +92,76 @@ struct Default2DEpilogue
...
@@ -71,4 +92,76 @@ struct Default2DEpilogue
}
}
}
}
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
DefaultGemm2DEpilogue
:
public
Default2DEpilogue
<
Problem_
,
Policy_
>
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
index_t
kMPerXdl
=
Problem
::
kMPerXdl
;
static
constexpr
index_t
kNPerXdl
=
Problem
::
kNPerXdl
;
static
constexpr
index_t
kKPerXdl
=
Problem
::
kKPerXdl
;
static
constexpr
index_t
isCTransposed
=
Problem
::
isCTransposed
;
using
WG
=
WarpGemmMfmaDispatcher
<
ODataType
,
ODataType
,
AccDataType
,
kMPerXdl
,
kNPerXdl
,
kKPerXdl
,
isCTransposed
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
isCTransposed
)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
// In this case each thread has just a single item in Ndim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
isCTransposed
)
{
// In this case each thread has just a single item in Mdim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
853a5183
...
@@ -159,12 +159,8 @@ struct GemmKernel
...
@@ -159,12 +159,8 @@ struct GemmKernel
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
{
constexpr
bool
is_output_c_reg_transposed
=
if
constexpr
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
if
constexpr
(
!
((
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
)
||
!
(
std
::
is_same_v
<
CDataType
,
fp16_t
>
||
std
::
is_same_v
<
CDataType
,
bf16_t
>
)))
{
{
if
(
kargs
.
KBatch
!=
1
)
if
(
kargs
.
KBatch
!=
1
)
{
{
...
@@ -182,7 +178,7 @@ struct GemmKernel
...
@@ -182,7 +178,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
{
std
::
cerr
<<
"K is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"K is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -197,7 +193,7 @@ struct GemmKernel
...
@@ -197,7 +193,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
M
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
{
std
::
cerr
<<
"M is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"M is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -213,7 +209,7 @@ struct GemmKernel
...
@@ -213,7 +209,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
N
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
{
std
::
cerr
<<
"N is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"N is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -228,7 +224,7 @@ struct GemmKernel
...
@@ -228,7 +224,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
{
std
::
cerr
<<
"K is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"K is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -244,7 +240,7 @@ struct GemmKernel
...
@@ -244,7 +240,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
N
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
N
%
Epilogue
Pipeline
::
Get
VectorSizeC
()
!=
0
)
{
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -259,7 +255,7 @@ struct GemmKernel
...
@@ -259,7 +255,7 @@ struct GemmKernel
<<
std
::
endl
;
<<
std
::
endl
;
return
false
;
return
false
;
}
}
if
(
kargs
.
M
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
M
%
Epilogue
Pipeline
::
Get
VectorSizeC
()
!=
0
)
{
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
return
false
;
...
@@ -275,14 +271,6 @@ struct GemmKernel
...
@@ -275,14 +271,6 @@ struct GemmKernel
const
GemmKernelArgs
&
kargs
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
)
const
SplitKBatchOffset
&
splitk_batch_offset
)
{
{
// const auto idxs = TilePartitioner{}();
// const auto i_m = idxs.at(number<0>{});
// const auto i_n = idxs.at(number<1>{});
// // options
// const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
// const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// // Convert pointers to tensor views
// auto a_tensor_view = [&]() {
const
auto
&
a_tensor_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
@@ -290,7 +278,7 @@ struct GemmKernel
...
@@ -290,7 +278,7 @@ struct GemmKernel
a_ptr
,
a_ptr
,
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_A
,
1
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
Get
VectorSizeA
()
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
...
@@ -299,7 +287,7 @@ struct GemmKernel
...
@@ -299,7 +287,7 @@ struct GemmKernel
a_ptr
,
a_ptr
,
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
M
),
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
M
),
make_tuple
(
kargs
.
stride_A
,
1
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
Get
VectorSizeA
()
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
}();
}();
...
@@ -311,7 +299,7 @@ struct GemmKernel
...
@@ -311,7 +299,7 @@ struct GemmKernel
b_ptr
,
b_ptr
,
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
N
),
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_B
,
1
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
Get
VectorSizeB
()
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
...
@@ -320,7 +308,7 @@ struct GemmKernel
...
@@ -320,7 +308,7 @@ struct GemmKernel
b_ptr
,
b_ptr
,
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
Get
VectorSizeB
()
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
}();
}();
...
@@ -333,7 +321,7 @@ struct GemmKernel
...
@@ -333,7 +321,7 @@ struct GemmKernel
c_ptr
,
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
Gemm
Pipeline
::
VectorSizeC
>
{},
number
<
Epilogue
Pipeline
::
Get
VectorSizeC
()
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
...
@@ -501,16 +489,13 @@ struct GemmKernel
...
@@ -501,16 +489,13 @@ struct GemmKernel
// Run Epilogue Pipeline
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
constexpr
bool
is_output_c_reg_transposed
=
if
constexpr
(
DstInMemOp
==
memory_operation_enum
::
set
||
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
!
(
EpiloguePipeline
::
GetVectorSizeC
()
%
2
!=
0
&&
if
constexpr
((
DstInMemOp
==
memory_operation_enum
::
set
)
||
(
sizeof
(
CDataType
)
>
2
)
||
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
(
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
))
{
{
EpiloguePipeline
{}
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
);
c_block_window
,
c_block_tile
,
smem_ptr
);
}
}
}
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
853a5183
...
@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrImplBase
...
@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrImplBase
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
,
typename
DramTileWindowStep
>
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
,
typename
DramTileWindowStep
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
,
SrcTileWindow
&
dram_tile_window
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
853a5183
...
@@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrCompV3
...
@@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrCompV3
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
{
return
num_loop
>
PrefetchStages
;
return
num_loop
>
PrefetchStages
;
...
@@ -62,9 +64,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -62,9 +64,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Policy
::
template
GetVectorSizeA
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
VectorSizeB
=
Policy
::
template
GetVectorSizeB
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
VectorSizeC
=
Policy
::
template
GetVectorSizeC
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
...
@@ -81,11 +83,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -81,11 +83,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
template
IsTransposeC
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
struct
PipelineImpl
:
public
PipelineImplBase
{
{
...
@@ -110,9 +107,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
...
@@ -110,9 +107,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_Buffer_Load_Inst_Num
=
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
MPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeA
()
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
NPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeB
()
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrMem
...
@@ -20,6 +20,8 @@ struct BaseGemmPipelineAgBgCrMem
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
@@ -113,9 +115,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -113,9 +115,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Policy
::
template
GetVectorSizeA
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
VectorSizeB
=
Policy
::
template
GetVectorSizeB
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
VectorSizeC
=
Policy
::
template
GetVectorSizeC
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
...
@@ -133,11 +135,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -133,11 +135,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
template
IsTransposeC
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
struct
PipelineImpl
:
public
PipelineImplBase
{
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
853a5183
...
@@ -31,21 +31,21 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -31,21 +31,21 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Problem
::
VectorSizeA
;
}
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Problem
::
VectorSizeB
;
}
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Problem
::
VectorSizeC
;
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
853a5183
...
@@ -16,8 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -16,8 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
// 3d + padding
// 3d + padding
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
...
@@ -383,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -383,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
{
...
@@ -397,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -397,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
typename
Problem
::
CDataType
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -25,6 +25,8 @@ struct GemmPipelineAGmemBGmemCRegV2
...
@@ -25,6 +25,8 @@ struct GemmPipelineAGmemBGmemCRegV2
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
{
return
integer_divide_ceil
(
return
integer_divide_ceil
(
...
@@ -36,8 +38,6 @@ struct GemmPipelineAGmemBGmemCRegV2
...
@@ -36,8 +38,6 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
AElementFunction
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -27,6 +27,8 @@ struct GemmPipelineProblemBase
...
@@ -27,6 +27,8 @@ struct GemmPipelineProblemBase
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
static
constexpr
bool
TransposeC
=
Traits
::
TransposeC
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
...
@@ -111,7 +113,6 @@ struct GemmPipelineProblemBase
...
@@ -111,7 +113,6 @@ struct GemmPipelineProblemBase
return
kPadK
?
1
:
GetAlignmentB
();
return
kPadK
?
1
:
GetAlignmentB
();
}
}
}();
}();
static
constexpr
index_t
VectorSizeC
=
[]()
{
static
constexpr
index_t
VectorSizeC
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
853a5183
...
@@ -549,12 +549,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -549,12 +549,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Problem
::
TransposeC
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
{
...
...
test/ck_tile/batched_gemm/test_batched_gemm_util.hpp
View file @
853a5183
...
@@ -29,12 +29,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -29,12 +29,9 @@ class TestCkTileBatchedGemm : public ::testing::Test
const
ck_tile
::
stream_config
&
s
)
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
...
@@ -51,11 +48,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -51,11 +48,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
...
@@ -63,21 +55,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -63,21 +55,6 @@ class TestCkTileBatchedGemm : public ::testing::Test
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
...
@@ -88,6 +65,20 @@ class TestCkTileBatchedGemm : public ::testing::Test
...
@@ -88,6 +65,20 @@ class TestCkTileBatchedGemm : public ::testing::Test
CodegenGemmTraits
>
;
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenGemmPipeline
::
BlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
using
Kernel
=
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
test/ck_tile/gemm/test_gemm_pipeline_util.hpp
View file @
853a5183
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <sstream>
#include <sstream>
...
@@ -65,9 +65,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -65,9 +65,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
GemmShape
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
using
GemmUniversalTraits
=
ck_tile
::
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
TileGemmUniversalTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
,
TransposeC
>
;
...
@@ -106,6 +103,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -106,6 +103,20 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
ck_tile
::
GemmPipelineAgBgCrCompV3
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>>
;
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
GemmPipeline
::
BlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
UniversalGemmProblem
::
TransposeC
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
...
@@ -244,7 +255,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
...
@@ -244,7 +255,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
public:
public:
std
::
vector
<
int
>
k_batches_
;
std
::
vector
<
int
>
k_batches_
;
void
SetUp
()
override
{
k_batches_
=
{
1
};
}
void
SetUp
()
override
{
k_batches_
=
{
1
,
2
};
}
template
<
bool
PadM
=
true
,
bool
PadN
=
true
,
bool
PadK
=
true
>
template
<
bool
PadM
=
true
,
bool
PadN
=
true
,
bool
PadK
=
true
>
void
Run
(
const
int
M
,
void
Run
(
const
int
M
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment