Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f23a2e2a
Commit
f23a2e2a
authored
Feb 11, 2025
by
Jakub Piasecki
Browse files
resolved conflicts
parents
f3eb5a18
c0adab48
Changes
340
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
724 additions
and
216 deletions
+724
-216
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+21
-7
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+42
-48
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+237
-46
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+17
-6
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+82
-16
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+253
-64
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+2
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+21
-7
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+1
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+11
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+16
-4
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+1
-8
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
+12
-1
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+1
-0
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+1
-0
include/ck_tile/ops/norm_reduce.hpp
include/ck_tile/ops/norm_reduce.hpp
+1
-0
include/ck_tile/ops/permute.hpp
include/ck_tile/ops/permute.hpp
+1
-0
include/ck_tile/ops/reduce.hpp
include/ck_tile/ops/reduce.hpp
+1
-0
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+1
-0
No files found.
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_batched"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
{
index_t
batch_stride_A
;
...
...
@@ -70,7 +84,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
batch_count
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
*
batch_count
);
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
)
,
batch_count
,
KBatch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
...
...
@@ -101,14 +115,14 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
{
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
kargs
.
M
,
kargs
.
N
}.
GetOutputTileIndex
(
blockIdx
.
x
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
/
kargs
.
KBatch
);
const
auto
i_
k
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
-
i_batch
*
kargs
.
KBatch
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
auto
i_
splitk
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_k
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_
split
k
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
...
...
@@ -128,7 +142,7 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
if
(
kargs
.
KB
atch
==
1
)
if
(
kargs
.
k_b
atch
==
1
)
{
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
f23a2e2a
...
...
@@ -8,6 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -75,12 +76,19 @@ struct GemmKernel
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
(
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
);
// clang-format off
return
concat
(
'_'
,
"gemm"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
GemmPipeline
::
GetName
());
// clang-format on
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
),
1
,
KBatch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
GemmKernelArgs
{
...
...
@@ -93,7 +101,7 @@ struct GemmKernel
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
KB
atch
;
index_t
k_b
atch
;
};
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
...
...
@@ -121,7 +129,7 @@ struct GemmKernel
const
std
::
size_t
k_id
=
blockIdx
.
z
)
{
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
const
index_t
K_t
=
kargs
.
KB
atch
*
K1
;
const
index_t
K_t
=
kargs
.
k_b
atch
*
K1
;
const
index_t
KRead
=
(
kargs
.
K
+
K_t
-
1
)
/
K_t
*
K1
;
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
...
@@ -142,13 +150,13 @@ struct GemmKernel
b_k_split_offset
=
k_id
*
KRead
;
}
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
KB
atch
-
1
))
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
k_b
atch
-
1
))
{
splitted_k
=
KRead
;
}
else
{
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
KB
atch
-
1
);
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
k_b
atch
-
1
);
}
}
...
...
@@ -159,14 +167,10 @@ struct GemmKernel
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
(
!
((
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
)
||
!
(
std
::
is_same_v
<
CDataType
,
fp16_t
>
||
std
::
is_same_v
<
CDataType
,
bf16_t
>
)))
if
constexpr
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
{
if
(
kargs
.
KB
atch
!=
1
)
if
(
kargs
.
k_b
atch
!=
1
)
{
std
::
cerr
<<
"Conditions not met for Kbatch >1 !"
<<
std
::
endl
;
return
false
;
...
...
@@ -182,7 +186,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
std
::
cerr
<<
"K is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -197,7 +201,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
M
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -213,7 +217,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
N
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -228,7 +232,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
std
::
cerr
<<
"K is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -244,7 +248,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
N
%
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -259,7 +263,7 @@ struct GemmKernel
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
M
%
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
...
...
@@ -275,14 +279,6 @@ struct GemmKernel
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
)
{
// const auto idxs = TilePartitioner{}();
// const auto i_m = idxs.at(number<0>{});
// const auto i_n = idxs.at(number<1>{});
// // options
// const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
// const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// // Convert pointers to tensor views
// auto a_tensor_view = [&]() {
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -290,7 +286,7 @@ struct GemmKernel
a_ptr
,
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
Get
VectorSizeA
()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -299,7 +295,7 @@ struct GemmKernel
a_ptr
,
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
M
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
Get
VectorSizeA
()
>
{},
number
<
1
>
{});
}
}();
...
...
@@ -311,7 +307,7 @@ struct GemmKernel
b_ptr
,
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
Get
VectorSizeB
()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -320,7 +316,7 @@ struct GemmKernel
b_ptr
,
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
Get
VectorSizeB
()
>
{},
number
<
1
>
{});
}
}();
...
...
@@ -333,7 +329,7 @@ struct GemmKernel
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
Gemm
Pipeline
::
VectorSizeC
>
{},
number
<
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
>
{},
number
<
1
>
{});
}
else
...
...
@@ -501,22 +497,14 @@ struct GemmKernel
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
((
DstInMemOp
==
memory_operation_enum
::
set
)
||
(
sizeof
(
CDataType
)
>
2
)
||
(
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
))
{
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
);
}
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr
);
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
const
auto
[
iM
,
iN
]
=
TilePartitioner
::
GetOutputTileIndex
(
blockIdx
.
x
,
blockIdx
.
y
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
kargs
.
M
,
kargs
.
N
}.
GetOutputTileIndex
(
blockIdx
.
x
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
...
...
@@ -531,14 +519,20 @@ struct GemmKernel
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
if
(
kargs
.
KB
atch
==
1
)
if
(
kargs
.
k_b
atch
==
1
)
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if
constexpr
(
!
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes.
*/
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/** @brief Struct representing 2D block index mapping into 3D output tile space. */
/**
* @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
*
*/
template
<
typename
BlockGemmShapeType
>
struct
GemmTile2DPartitioner
{
...
...
@@ -17,21 +25,32 @@ struct GemmTile2DPartitioner
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
/** @brief Returns 3D grid size. */
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
CK_TILE_HOST_DEVICE
GemmTile2DPartitioner
()
noexcept
=
delete
;
CK_TILE_HOST_DEVICE
GemmTile2DPartitioner
([[
maybe_unused
]]
index_t
M
,
[[
maybe_unused
]]
index_t
N
)
noexcept
;
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
{
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
const
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
return
dim3
(
GridDimX
,
GridDimY
,
1
);
}
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
...
...
@@ -42,8 +61,15 @@ struct GemmTile2DPartitioner
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
CK_TILE_DEVICE
static
constexpr
auto
GetOutputTileIndex
(
index_t
blockIdx
,
index_t
blockIdy
)
noexcept
/**
* @brief Calculate workgroup 2D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's X index.
* @param blockIdy WGP's Y index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
static
auto
GetOutputTileIndex
(
index_t
blockIdx
,
index_t
blockIdy
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
);
...
...
@@ -53,61 +79,71 @@ struct GemmTile2DPartitioner
};
/**
* @brief Struct representing 1D block index mapping into 2D output tile space.
* @brief Class providing 1D WGP index mapping into 2D output C-tile space.
*
* @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape
*/
template
<
typename
BlockGemmShape
Type
>
template
<
typename
BlockGemmShape
_
>
struct
GemmTile1DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
Type
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape
_
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
/** @brief delete default ctr with no any object */
constexpr
GemmTile1DPartitioner
()
noexcept
=
delete
;
/** @brief constructs an object that does contain a N value. */
constexpr
GemmTile1DPartitioner
(
index_t
N
)
noexcept
{
N_
=
N
;
}
CK_TILE_HOST_DEVICE
GemmTile1DPartitioner
()
noexcept
=
delete
;
/** @brief Returns 1D grid size. */
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
/**
* @brief Construct a new GemmTile1DPartitioner object.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
*/
CK_TILE_HOST_DEVICE
GemmTile1DPartitioner
([[
maybe_unused
]]
index_t
M
,
index_t
N
)
noexcept
{
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
*
GridDimY
,
1
,
1
);
N_
=
N
;
}
/**
* @brief Returns the number of blocks in N.
* @param [in] N is dimension
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
noexcept
->
index_t
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
index_t
{
return
integer_divide_ceil
(
N
,
NPerBlock
);
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
GridDimX
*
GridDimY
;
}
/**
* @brief Returns the number of loops.
* @param [in] K is dimension
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x - block_start.
* */
CK_TILE_DEVICE
static
constexpr
auto
GetOutputTileIndex
(
index_t
blockIdx
)
noexcept
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
static
auto
GetOutputTileIndex
(
index_t
blockIdx
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
NBlock
=
GetN
Block
(
N_
);
const
index_t
NBlock
s
=
integer_divide_ceil
(
N_
,
NPer
Block
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
/
NBlock
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
-
(
iM
)
*
NBlock
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
/
NBlock
s
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
-
iM
*
NBlock
s
);
return
make_tuple
(
iM
,
iN
);
}
...
...
@@ -141,21 +177,176 @@ struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIn
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template
<
typename
Partitioner
Fn
,
typename
=
typename
std
::
enable_if_t
<
HasFnOneArgImpl
<
Partitioner
Fn
>{}
>>
template
<
typename
Tile
Partitioner
,
typename
=
typename
std
::
enable_if_t
<
HasFnOneArgImpl
<
Tile
Partitioner
>{}
>>
struct
OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start is `blockIdx.x - block_start`.
* @return Returns a `tuple` [Im, In] shifted index, used to shift 1d-tile index.
* @param [in] block_start Workgroup offset.
* @param [in] M Gemm's M dimension.
* @param [in] N Gemm's N dimension.
* @return Returns a `tuple` [Im, In] with shifted index.
*/
[[
nodiscard
]]
CK_TILE_DEVICE
static
constexpr
auto
GetOffsetedTileIndex
(
index_t
block_start
,
index_t
N
)
noexcept
[[
nodiscard
]]
CK_TILE_DEVICE
static
auto
GetOffsetedTileIndex
(
index_t
block_start
,
index_t
M
,
index_t
N
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
auto
[
iM
,
iN
]
=
Partitioner
Fn
(
N
)
.
GetOutputTileIndex
(
blockIdx
.
x
-
block_start
);
const
auto
[
iM
,
iN
]
=
Tile
Partitioner
{
M
,
N
}
.
GetOutputTileIndex
(
blockIdx
.
x
-
block_start
);
return
make_tuple
(
iM
,
iN
);
}
};
/**
* @brief Class mapping 1D block index into 2D output tile space.
*
* @note It groups spatially workgroups in order to better utilize caches.
* It is using grouped Rows of column-vectors WGP pattern. It's optimized
* for gfx94x-like multiple-die chip.
*
* @tparam GroupNum - The number of big groups.
* @tparam M01 - The number of groups in M dim within spatially local WGPs,
*
*/
template
<
typename
BlockGemmShapeType
,
index_t
GroupNum
,
index_t
M01
>
struct
GemmSpatiallyLocalTilePartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShapeType
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
GemmSpatiallyLocalTilePartitioner
()
noexcept
=
delete
;
CK_TILE_HOST_DEVICE
GemmSpatiallyLocalTilePartitioner
(
index_t
M_
,
index_t
N_
)
noexcept
:
M
(
M_
),
N
(
N_
)
{
}
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return index_t A total number of workgroups.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
index_t
{
const
index_t
GridDimX
=
integer_divide_ceil
(
M
,
MPerBlock
);
const
index_t
GridDimY
=
integer_divide_ceil
(
N
,
NPerBlock
);
return
GridDimX
*
GridDimY
;
}
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param [in] block_1d_id WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
auto
GetOutputTileIndex
(
index_t
block_1d_id
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
auto
M0
=
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
integer_divide_ceil
(
N
,
NPerBlock
);
if
(
M0
==
1
)
{
return
make_tuple
(
0
,
block_1d_id
);
}
else
if
(
N0
==
1
)
{
return
make_tuple
(
block_1d_id
,
0
);
}
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
else
{
const
auto
group_size
=
integer_divide_ceil
(
M0
*
N0
,
GroupNum
);
const
auto
big_group_num
=
GroupNum
-
(
group_size
*
GroupNum
-
M0
*
N0
);
const
auto
group_id_y
=
block_1d_id
/
GroupNum
;
const
auto
group_id_x
=
block_1d_id
-
group_id_y
*
GroupNum
;
const
auto
remap_block_1d_id
=
group_id_x
<=
big_group_num
?
group_id_x
*
group_size
+
group_id_y
:
group_id_x
*
group_size
+
big_group_num
-
group_id_x
+
group_id_y
;
const
index_t
idx_M0
=
remap_block_1d_id
/
N0
;
const
index_t
idx_N0
=
remap_block_1d_id
-
idx_M0
*
N0
;
const
index_t
M0_tmp
=
M0
/
M01
;
const
index_t
M0_mod_M01
=
M0
-
M0_tmp
*
M01
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0_mod_M01
)
?
M01
:
M0_mod_M01
;
const
index_t
idx_M00
=
idx_M0
/
M01
;
const
index_t
idx_M01
=
idx_M0
-
idx_M00
*
M01
;
const
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
const
index_t
N_out
=
idx_N0_M01_local
/
M01_adapt
;
const
index_t
idx_loc_mod_M01
=
idx_N0_M01_local
-
N_out
*
M01_adapt
;
return
make_tuple
(
idx_loc_mod_M01
+
idx_M00
*
M01
,
N_out
);
}
}
private:
index_t
M
;
index_t
N
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
View file @
f23a2e2a
...
...
@@ -64,6 +64,18 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
}
};
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_grouped"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
__host__
static
auto
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
->
std
::
size_t
{
...
...
@@ -77,8 +89,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
index_t
grid_size
=
0
;
for
(
const
auto
&
it_desc
:
gemm_descs
)
{
const
auto
dim3
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
dim3
.
x
*
dim3
.
y
*
1
;
const
auto
local_grid_size
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
local_grid_size
*
it_desc
.
k_batch
;
}
return
dim3
(
grid_size
,
1
,
1
);
}
...
...
@@ -106,8 +118,7 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C
;
const
auto
dim3
=
TilePartitioner
::
GridSize
(
M
,
N
);
const
index_t
grid_size_grp
=
dim3
.
x
;
const
index_t
grid_size_grp
=
TilePartitioner
::
GridSize
(
M
,
N
)
*
gemm_descs
[
i
].
k_batch
;
const
index_t
block_start
=
grid_size
;
const
index_t
block_end
=
grid_size
+
grid_size_grp
;
...
...
@@ -138,8 +149,8 @@ struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
Run
(
const
GemmTransKernelArg
&
kargs
)
const
{
const
auto
[
iM
,
iN
]
=
OffsetTile1DPartitioner
::
GetOffsetedTileIndex
(
kargs
.
block_start
,
kargs
.
group_karg
.
N
);
const
auto
[
iM
,
iN
]
=
OffsetTile1DPartitioner
::
GetOffsetedTileIndex
(
kargs
.
block_start
,
kargs
.
group_karg
.
M
,
kargs
.
group_karg
.
N
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
f23a2e2a
...
...
@@ -21,6 +21,8 @@ struct GemmPipelineAgBgCrImplBase
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
,
typename
DramTileWindowStep
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
f23a2e2a
...
...
@@ -3,10 +3,14 @@
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -20,6 +24,8 @@ struct BaseGemmPipelineAgBgCrCompV3
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
...
...
@@ -62,9 +68,9 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Policy
::
template
GetVectorSizeA
<
Problem
>();
static
constexpr
index_t
VectorSizeB
=
Policy
::
template
GetVectorSizeB
<
Problem
>();
static
constexpr
index_t
VectorSizeC
=
Policy
::
template
GetVectorSizeC
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
...
...
@@ -76,14 +82,68 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
using
Base
::
PrefetchStages
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrCompV3"
,
BlockSize
,
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST
_DEVICE
static
constexpr
auto
IsTransposeC
()
CK_TILE_HOST
static
std
::
string
Print
()
{
return
Policy
::
template
IsTransposeC
<
Problem
>();
constexpr
index_t
MPerXDL
=
BlockGemm
::
WarpGemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
WarpGemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
auto
str
=
std
::
stringstream
{};
str
<<
"A/B vector size: "
<<
GetVectorSizeA
()
<<
", "
<<
GetVectorSizeB
()
<<
"
\n
"
<<
"A/B LDS read/write width: "
<<
A_LDS_Read_Width
<<
", "
<<
B_LDS_Read_Width
<<
"
\n
"
<<
"A/B buffer load inst: "
<<
A_Buffer_Load_Inst_Num
<<
", "
<<
B_Buffer_Load_Inst_Num
<<
"
\n
"
<<
"A/B LDS write inst: "
<<
A_LDS_Write_Inst_Num
<<
", "
<<
B_LDS_Write_Inst_Num
<<
"
\n
"
<<
"A/B LDS read inst: "
<<
A_LDS_Read_Inst_Num
<<
", "
<<
B_LDS_Read_Inst_Num
<<
"
\n
"
<<
"C MFMA inst: "
<<
C_MFMA_Inst_Num
<<
"
\n
"
<<
"KPack: "
<<
BlockGemm
::
Traits
::
KPack
<<
"
\n
"
<<
"PrefetchStages: "
<<
PrefetchStages
<<
"
\n
"
;
return
str
.
str
();
}
template
<
GemmPipelineScheduler
Scheduler
>
...
...
@@ -98,29 +158,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I0
{})
;
constexpr
index_t
NPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I1
{})
;
constexpr
index_t
KPerXDL
=
BlockGemm
Shape
::
WarpTile
::
at
(
I2
{})
;
constexpr
index_t
MPerXDL
=
BlockGemm
::
Warp
Gemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
Warp
Gemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
MPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeA
()
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
NPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeB
()
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -20,6 +21,8 @@ struct BaseGemmPipelineAgBgCrMem
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
@@ -88,7 +91,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
...
...
@@ -113,9 +116,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Policy
::
template
GetVectorSizeA
<
Problem
>();
static
constexpr
index_t
VectorSizeB
=
Policy
::
template
GetVectorSizeB
<
Problem
>();
static
constexpr
index_t
VectorSizeC
=
Policy
::
template
GetVectorSizeC
<
Problem
>();
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
...
...
@@ -126,6 +129,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrMe"
,
concat
(
'x'
,
MPerBlock
,
NPerBlock
,
KPerBlock
),
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -133,11 +146,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
template
IsTransposeC
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
...
...
@@ -168,11 +176,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -216,25 +235,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
// main body
...
...
@@ -250,19 +303,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
i
+=
PrefetchStages
;
...
...
@@ -278,12 +357,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
block_sync_lds
();
...
...
@@ -355,11 +454,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -403,25 +513,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
// main body
...
...
@@ -435,19 +578,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
i
+=
PrefetchStages
;
...
...
@@ -460,12 +629,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
block_sync_lds
();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <sstream>
#include "ck_tile/core.hpp"
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
f23a2e2a
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -31,21 +32,33 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Problem
::
VectorSizeA
;
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Problem
::
VectorSizeB
;
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Problem
::
VectorSizeC
;
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
index_t
kLdsAlignmentInBytes
=
16
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AGmemBGmemCRegV1"
,
concat
(
'x'
,
kMPerBlock
,
kNPerBlock
,
kKPerBlock
,
BlockSize
),
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
...
...
@@ -75,8 +88,9 @@ struct GemmPipelineAGmemBGmemCRegV1
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
kLdsAlignmentInBytes
)
*
kLdsAlignmentInBytes
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
f23a2e2a
...
...
@@ -16,8 +16,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
...
...
@@ -383,8 +381,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
...
...
@@ -397,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -25,6 +26,15 @@ struct GemmPipelineAGmemBGmemCRegV2
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AGmemBGmemCRegV2"
,
concat
(
'x'
,
kMPerBlock
,
kNPerBlock
,
kKPerBlock
,
kBlockSize
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
...
...
@@ -36,8 +46,6 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -27,15 +28,27 @@ struct GemmPipelineProblemBase
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
static
constexpr
bool
TransposeC
=
Traits
::
TransposeC
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
index_t
VectorLoadSize
=
Traits
::
_VectorSize
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"gemm_problem"
,
concat
(
'x'
,
VectorLoadSize
,
kBlockSize
),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
),
Scheduler
);
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
...
...
@@ -111,7 +124,6 @@ struct GemmPipelineProblemBase
return
kPadK
?
1
:
GetAlignmentB
();
}
}();
static
constexpr
index_t
VectorSizeC
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
f23a2e2a
...
...
@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
...
...
@@ -519,7 +518,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
k
N
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
k
M
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
...
...
@@ -549,12 +548,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Problem
::
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -19,6 +20,16 @@ struct TileGemmShape
static
constexpr
index_t
kM
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kK
=
BlockTile
::
at
(
number
<
2
>
{});
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"tile_gemm_shape"
,
concat
(
'x'
,
kM
,
kN
,
kK
,
NumWarps
),
concat
(
'x'
,
BlockWarps
::
at
(
number
<
0
>
{}),
BlockWarps
::
at
(
number
<
1
>
{}),
BlockWarps
::
at
(
number
<
2
>
{})),
concat
(
'x'
,
(
WarpTile
::
at
(
number
<
0
>
{})),
WarpTile
::
at
(
number
<
1
>
{}),
WarpTile
::
at
(
number
<
2
>
{})));
// clang-format on
}
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column.hpp
View file @
f23a2e2a
...
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/layernorm2d.hpp
View file @
f23a2e2a
...
...
@@ -11,3 +11,4 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/norm_reduce.hpp
View file @
f23a2e2a
...
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/permute.hpp
View file @
f23a2e2a
...
...
@@ -7,3 +7,4 @@
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/reduce.hpp
View file @
f23a2e2a
...
...
@@ -9,3 +9,4 @@
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/rmsnorm2d.hpp
View file @
f23a2e2a
...
...
@@ -11,3 +11,4 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
Prev
1
…
8
9
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment