Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7572a691
Commit
7572a691
authored
Feb 15, 2025
by
coderfeli
Browse files
merge develop
parents
7796fc73
6b6fcd37
Changes
452
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2524 additions
and
979 deletions
+2524
-979
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+44
-9
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+326
-108
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+307
-29
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
+81
-190
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+39
-21
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+165
-34
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
+559
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp
...ipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp
+92
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+259
-62
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+2
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+28
-18
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+11
-120
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+11
-1
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+54
-21
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+458
-344
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
+12
-1
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+26
-0
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+2
-1
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+2
-1
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+46
-18
No files found.
Too many changes to show.
To preserve performance only
452 of 452+
files are displayed.
Plain diff
Email patch
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -57,6 +59,18 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
BLayout
=
typename
Base
::
BLayout
;
using
CLayout
=
typename
Base
::
CLayout
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_batched"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
struct
BatchedGemmKernelArgs
:
GemmKernelArgs
{
index_t
batch_stride_A
;
...
...
@@ -67,9 +81,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
KernelArgs
=
BatchedGemmKernelArgs
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_count
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
batch_count
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
batch_count
);
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
)
,
batch_count
,
KBatch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
...
...
@@ -85,7 +100,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
},
hostArgs
.
stride_C
,
hostArgs
.
k_batch
},
hostArgs
.
batch_stride_A
,
hostArgs
.
batch_stride_B
,
hostArgs
.
batch_stride_C
,
...
...
@@ -99,23 +115,42 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
kargs
.
M
,
kargs
.
N
}.
GetOutputTileIndex
(
blockIdx
.
x
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
auto
i_splitk
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_splitk
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
;
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
+
splitk_batch_offset
.
a_k_split_offset
;
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
;
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
+
splitk_batch_offset
.
b_k_split_offset
;
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
if
(
kargs
.
k_batch
==
1
)
{
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
this
->
template
RunGemm
<
memory_operation_enum
::
atomic_add
>(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -8,7 +8,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/
ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler
.hpp"
#include "ck_tile/
host/concat
.hpp"
namespace
ck_tile
{
...
...
@@ -69,18 +69,26 @@ struct GemmKernel
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
// Below type is actually accumulation data type - the output of block GEMM.
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
static
constexpr
auto
I0
=
number
<
0
>
();
static
constexpr
auto
I1
=
number
<
1
>
();
static
constexpr
auto
I2
=
number
<
2
>
();
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
(
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
);
// clang-format off
return
concat
(
'_'
,
"gemm"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
GemmPipeline
::
GetName
());
// clang-format on
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
dim3
(
TilePartitioner
::
GridSize
(
M
,
N
),
1
,
KBatch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
struct
GemmKernelArgs
{
...
...
@@ -93,6 +101,7 @@ struct GemmKernel
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
k_batch
;
};
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
...
...
@@ -105,121 +114,188 @@ struct GemmKernel
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
};
hostArgs
.
stride_C
,
hostArgs
.
k_batch
};
}
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
const
GemmKernelArgs
&
kargs
,
const
std
::
size_t
k_id
=
blockIdx
.
z
)
{
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
const
index_t
K_t
=
kargs
.
k_batch
*
K1
;
const
index_t
KRead
=
(
kargs
.
K
+
K_t
-
1
)
/
K_t
*
K1
;
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
k_id
*
KRead
;
}
else
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
k_id
*
KRead
*
kargs
.
stride_A
;
}
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
k_id
*
KRead
*
kargs
.
stride_B
;
}
else
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
k_id
*
KRead
;
}
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
k_batch
-
1
))
{
splitted_k
=
KRead
;
}
else
{
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
k_batch
-
1
);
}
}
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
index_t
splitted_k
;
};
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
if
constexpr
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
)
{
if
(
kargs
.
k_batch
!=
1
)
{
std
::
cerr
<<
"Conditions not met for Kbatch >1 !"
<<
std
::
endl
;
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
K
%
TilePartitioner
::
k
K
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
if
(
kargs
.
K
%
TilePartitioner
::
K
PerBlock
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
std
::
cerr
<<
"Can't support K that is not a multiple of KPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
std
::
cerr
<<
"K is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
}
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
k
M
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
if
(
kargs
.
M
%
TilePartitioner
::
M
PerBlock
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
std
::
cerr
<<
"Can't support M that is not a multiple of MPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
GemmPipeline
::
VectorSizeA
!=
0
)
if
(
kargs
.
M
%
GemmPipeline
::
Get
VectorSizeA
()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for A tensor!"
<<
std
::
endl
;
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
k
N
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
if
(
kargs
.
N
%
TilePartitioner
::
N
PerBlock
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
std
::
cerr
<<
"Can't support N that is not a multiple of NPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
N
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
}
}
else
{
if
(
kargs
.
K
%
TilePartitioner
::
k
K
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
if
(
kargs
.
K
%
TilePartitioner
::
K
PerBlock
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
{
std
::
cerr
<<
"Can't support K that is not a multiple of KPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
K
%
GemmPipeline
::
VectorSizeB
!=
0
)
if
(
kargs
.
K
%
GemmPipeline
::
Get
VectorSizeB
()
!=
0
)
{
std
::
cerr
<<
"K is not a multiple of vector load size for B tensor!"
<<
std
::
endl
;
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
N
%
TilePartitioner
::
k
N
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
if
(
kargs
.
N
%
TilePartitioner
::
N
PerBlock
!=
0
&&
GemmPipeline
::
kPadN
==
false
)
{
std
::
cerr
<<
"Can't support N that is not a multiple of NPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
N
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
N
%
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
!=
0
)
{
std
::
cerr
<<
"N is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
}
}
else
{
if
(
kargs
.
M
%
TilePartitioner
::
k
M
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
if
(
kargs
.
M
%
TilePartitioner
::
M
PerBlock
!=
0
&&
GemmPipeline
::
kPadM
==
false
)
{
std
::
cerr
<<
"Can't support M that is not a multiple of MPerBlock"
" without padding!"
<<
std
::
endl
;
return
false
;
}
if
(
kargs
.
M
%
Gemm
Pipeline
::
VectorSizeC
!=
0
)
if
(
kargs
.
M
%
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
!=
0
)
{
std
::
cerr
<<
"M is not a multiple of vector load size for C tensor!"
<<
std
::
endl
;
return
false
;
}
}
return
true
;
}
CK_TILE_DEVICE
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
)
const
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
)
{
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
GemmPipeline
::
Get
VectorSizeA
()
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
M
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
GetVectorSizeA
()
>
{},
number
<
1
>
{});
}
}();
...
...
@@ -229,35 +305,36 @@ struct GemmKernel
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
make_tuple
(
splitk_batch_offset
.
splitted_k
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
GetVectorSizeB
()
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
GemmPipeline
::
Get
VectorSizeB
()
>
{},
number
<
1
>
{});
}
}();
// TODO: enable vector write for C in ColMajor
const
auto
&
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
,
DstInMemOp
>
(
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
Gemm
Pipeline
::
VectorSizeC
>
{},
number
<
Epilogue
Pipeline
::
template
GetVectorSizeC
<
CDataType
>()
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
,
DstInMemOp
>
(
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
...
...
@@ -270,23 +347,23 @@ struct GemmKernel
}
template
<
typename
TensorView
>
CK_TILE_DEVICE
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
const
CK_TILE_DEVICE
static
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
{
const
auto
&
a_pad_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{}
,
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{}
,
number
<
TilePartitioner
::
MPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadM
>
{});
}
}();
...
...
@@ -294,35 +371,36 @@ struct GemmKernel
const
auto
&
b_tensor_view
=
views
.
at
(
I1
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{}
,
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{}
,
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
}();
// TODO vector write in for C in ColMajor
const
auto
&
c_pad_view
=
[
&
]()
{
const
auto
&
c_tensor_view
=
views
.
at
(
I2
);
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{}
,
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{}
,
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
...
...
@@ -330,25 +408,50 @@ struct GemmKernel
}
template
<
typename
PadView
>
CK_TILE_DEVICE
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
CK_TILE_DEVICE
static
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
{
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
const
auto
&
b_pad_view
=
views
.
at
(
I1
);
const
auto
&
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
b_pad_view
=
views
.
at
(
I1
);
const
auto
&
c_pad_view
=
views
.
at
(
I2
);
auto
c_block_window
=
make_tile_window
(
const
auto
&
a_block_window
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
}
else
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{},
number
<
TilePartitioner
::
MPerBlock
>
{}),
{
0
,
i_m
});
}
}();
const
auto
&
b_block_window
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
}
else
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
KPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
0
,
i_n
});
}
}();
auto
c_block_window
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
k
M
>
{},
number
<
TilePartitioner
::
k
N
>
{}),
make_tuple
(
number
<
TilePartitioner
::
M
PerBlock
>
{},
number
<
TilePartitioner
::
N
PerBlock
>
{}),
{
i_m
,
i_n
});
return
make_tuple
(
a_block_window
,
b_block_window
,
c_block_window
);
...
...
@@ -360,47 +463,162 @@ struct GemmKernel
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The start memory pointer of the shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
const
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
smem_ptr_0
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
{
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
);
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr_0
);
}
/**
* @brief Runs single GEMM problem cooperatively by whole workgroup.
*
* @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism.
*
* @param a_ptr input A pointer
* @param b_ptr input B pointer
* @param c_ptr output C pointer
* @param smem_ptr_0 The starting pointer of 1st shared memory block.
* @param smem_ptr_1 The starting pointer of 2nd shared memory block.
* @param kargs GEMM kernel arguments
* @param splitk_batch_offset Utility structure used to calculate k batch.
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
void
RunGemm2LDS
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
__restrict__
smem_ptr_0
,
void
*
__restrict__
smem_ptr_1
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
{
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
const
auto
&
b_block_window
=
gemm_tile_windows
.
at
(
I1
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
const
auto
&
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr_0
,
smem_ptr_1
);
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
,
smem_ptr_0
);
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
kargs
.
M
,
kargs
.
N
}.
GetOutputTileIndex
(
blockIdx
.
x
);
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
const
SplitKBatchOffset
splitk_batch_offset
(
kargs
);
// options
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
splitk_batch_offset
.
a_k_split_offset
;
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
splitk_batch_offset
.
b_k_split_offset
;
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
// allocate LDS
__shared__
char
smem_ptr_0
[
GetSmemSize
()];
__shared__
char
smem_ptr_1
[
GetSmemSize
()];
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
if
(
kargs
.
k_batch
==
1
)
{
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
{
RunGemm2LDS
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_1
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
else
{
// Do not compile in case where we have unsupported
// VectorSizeC & data type configuration.
if
constexpr
(
!
(
EpiloguePipeline
::
template
GetVectorSizeC
<
CDataType
>()
%
2
!=
0
&&
is_any_of
<
CDataType
,
fp16_t
,
bf16_t
>::
value
))
{
if
constexpr
(
GemmPipeline
::
DoubleSmemBuffer
==
true
)
{
RunGemm2LDS
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
smem_ptr_1
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr_0
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
}
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* @file
* GemmTilePartitioner allows customized mapping between a workgroup and the C-tile it computes.
*/
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockGemmShape_
>
struct
GemmTilePartitioner
/**
* @brief Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
*
*/
template
<
typename
BlockGemmShapeType
>
struct
GemmTile2DPartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShapeType
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kK
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
GemmTile2DPartitioner
()
noexcept
=
delete
;
CK_TILE_HOST_DEVICE
GemmTile2DPartitioner
([[
maybe_unused
]]
index_t
M
,
[[
maybe_unused
]]
index_t
N
)
noexcept
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
dim3
{
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
,
GridDimY
,
1
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
k
K
);
return
integer_divide_ceil
(
K
,
K
PerBlock
);
}
CK_TILE_DEVICE
auto
operator
()()
/**
* @brief The function returns 2D output tile space.
* @param [in] blockIdx is blockIdx.x
* @param [in] blockIdy is blockIdx.y
* @return Returns the output tile indexes.
*/
/**
* @brief Calculate workgroup 2D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's X index.
* @param blockIdy WGP's Y index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
static
auto
GetOutputTileIndex
(
index_t
blockIdx
,
index_t
blockIdy
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockId
x
.
y
*
kN
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockId
y
);
return
make_tuple
(
iM
,
iN
);
}
};
/**
* @brief Class providing 1D WGP index mapping into 2D output C-tile space.
*
* @tparam BlockGemmShape_ A class providing basic GEMM parameters. \link TileGemmShape
*/
template
<
typename
BlockGemmShape_
>
struct
GemmTile1DPartitioner
{
...
...
@@ -45,30 +92,261 @@ struct GemmTile1DPartitioner
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
)
CK_TILE_HOST_DEVICE
GemmTile1DPartitioner
()
noexcept
=
delete
;
/**
* @brief Construct a new GemmTile1DPartitioner object.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
*/
CK_TILE_HOST_DEVICE
GemmTile1DPartitioner
([[
maybe_unused
]]
index_t
M
,
index_t
N
)
noexcept
{
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
dim3
(
GridDimX
*
GridDimY
,
1
,
1
);
N_
=
N
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetNBlock
(
index_t
N
)
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return dim3 Structure holding grid's X,Y and Z dimensions.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
index_t
{
return
integer_divide_ceil
(
N
,
NPerBlock
);
const
index_t
GridDimX
=
(
M
+
MPerBlock
-
1
)
/
MPerBlock
;
const
index_t
GridDimY
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
return
GridDimX
*
GridDimY
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
CK_TILE_DEVICE
auto
operator
()(
index_t
blockOffset
,
index_t
NBlockSize
)
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param blockIdx WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
static
auto
GetOutputTileIndex
(
index_t
blockIdx
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
index_t
NBlocks
=
integer_divide_ceil
(
N_
,
NPerBlock
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
/
NBlocks
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
-
iM
*
NBlocks
);
return
make_tuple
(
iM
,
iN
);
}
private:
CK_TILE_DEVICE
static
index_t
N_
;
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::false specialization,
* checking expression validity in-place for ill-formed.
*/
template
<
typename
,
typename
=
void
>
struct
HasFnOneArgImpl
:
std
::
false_type
{
};
/**
* @brief `GemmTile1DPartitioner::GetOutputTileIndex`'s std::true specialization,
* checking expression validity in-place for well-formed.
* @note: `1` - a constant value indicating the number of parameters in the function.
*/
template
<
typename
T
>
struct
HasFnOneArgImpl
<
T
,
std
::
void_t
<
decltype
(
std
::
declval
<
T
>
().
GetOutputTileIndex
(
1
))
>>
:
std
::
true_type
{
};
/**
* @brief Struct used to calculate offseted tile indexes.
* @note: The struct supports the 1D-Partitioner mechanism,
* enable-if `GetOutputTileIndex`-fn is std::true_type when `GetOutputTileIndex`-fn is well-formed,
* otherwise std::false_type.
*/
template
<
typename
TilePartitioner
,
typename
=
typename
std
::
enable_if_t
<
HasFnOneArgImpl
<
TilePartitioner
>{}
>>
struct
OffsettedTile1DPartitioner
{
/**
* @brief The function subtracts the block's start (offset) from 1D raw-indexes.
* @param [in] block_start Workgroup offset.
* @param [in] M Gemm's M dimension.
* @param [in] N Gemm's N dimension.
* @return Returns a `tuple` [Im, In] with shifted index.
*/
[[
nodiscard
]]
CK_TILE_DEVICE
static
auto
GetOffsetedTileIndex
(
index_t
block_start
,
index_t
M
,
index_t
N
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
index_t
iM
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
/
GetNBlock
(
NBlockSize
)
*
MPerBlock
);
index_t
iN
=
__builtin_amdgcn_readfirstlane
((
blockIdx
.
x
-
blockOffset
)
%
GetNBlock
(
NBlockSize
)
*
NPerBlock
);
const
auto
[
iM
,
iN
]
=
TilePartitioner
{
M
,
N
}.
GetOutputTileIndex
(
blockIdx
.
x
-
block_start
);
return
make_tuple
(
iM
,
iN
);
}
};
/**
* @brief Class mapping 1D block index into 2D output tile space.
*
* @note It groups spatially workgroups in order to better utilize caches.
* It is using grouped Rows of column-vectors WGP pattern. It's optimized
* for gfx94x-like multiple-die chip.
*
* @tparam GroupNum - The number of big groups.
* @tparam M01 - The number of groups in M dim within spatially local WGPs,
*
*/
template
<
typename
BlockGemmShapeType
,
index_t
GroupNum
,
index_t
M01
>
struct
GemmSpatiallyLocalTilePartitioner
{
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShapeType
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
GemmSpatiallyLocalTilePartitioner
()
noexcept
=
delete
;
CK_TILE_HOST_DEVICE
GemmSpatiallyLocalTilePartitioner
(
index_t
M_
,
index_t
N_
)
noexcept
:
M
(
M_
),
N
(
N_
)
{
}
/**
* @brief Calculates GEMM kernel grid size.
*
* @param M GEMM's M dimension.
* @param N GEMM's N dimension.
* @return index_t A total number of workgroups.
*/
CK_TILE_HOST
static
auto
GridSize
(
index_t
M
,
index_t
N
)
noexcept
(
noexcept
(
MPerBlock
!=
0
&&
NPerBlock
!=
0
))
->
index_t
{
const
index_t
GridDimX
=
integer_divide_ceil
(
M
,
MPerBlock
);
const
index_t
GridDimY
=
integer_divide_ceil
(
N
,
NPerBlock
);
return
GridDimX
*
GridDimY
;
}
/**
* @brief Calculate number of loop iterations over GEMM's K dimension.
*
* @param K GEMM's K dimension.
* @return index_t The number of loop iterations over K dimension.
*/
CK_TILE_HOST_DEVICE
static
auto
GetLoopNum
(
index_t
K
)
noexcept
->
index_t
{
return
integer_divide_ceil
(
K
,
KPerBlock
);
}
/**
* @brief Calculate workgroup 1D index mapping into 2D output C-tile space.
*
* @param [in] block_1d_id WGP's index.
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
*/
CK_TILE_DEVICE
auto
GetOutputTileIndex
(
index_t
block_1d_id
)
noexcept
->
const
tuple
<
index_t
,
index_t
>
{
const
auto
M0
=
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
integer_divide_ceil
(
N
,
NPerBlock
);
if
(
M0
==
1
)
{
return
make_tuple
(
0
,
block_1d_id
);
}
else
if
(
N0
==
1
)
{
return
make_tuple
(
block_1d_id
,
0
);
}
// block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
else
{
const
auto
group_size
=
integer_divide_ceil
(
M0
*
N0
,
GroupNum
);
const
auto
big_group_num
=
GroupNum
-
(
group_size
*
GroupNum
-
M0
*
N0
);
const
auto
group_id_y
=
block_1d_id
/
GroupNum
;
const
auto
group_id_x
=
block_1d_id
-
group_id_y
*
GroupNum
;
const
auto
remap_block_1d_id
=
group_id_x
<=
big_group_num
?
group_id_x
*
group_size
+
group_id_y
:
group_id_x
*
group_size
+
big_group_num
-
group_id_x
+
group_id_y
;
const
index_t
idx_M0
=
remap_block_1d_id
/
N0
;
const
index_t
idx_N0
=
remap_block_1d_id
-
idx_M0
*
N0
;
const
index_t
M0_tmp
=
M0
/
M01
;
const
index_t
M0_mod_M01
=
M0
-
M0_tmp
*
M01
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0_mod_M01
)
?
M01
:
M0_mod_M01
;
const
index_t
idx_M00
=
idx_M0
/
M01
;
const
index_t
idx_M01
=
idx_M0
-
idx_M00
*
M01
;
const
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
/**
* idxN0
*
* |< mtx N >|
*
* NPerBlock NPerBlock NPerBlock NPerBlock
* N_0 N_1 N_2 N_3
* - |-----------|-----------|-----------|-----|-----|-
* ^ | - - 0 |/----> 2 | | | |
* | | | / | | | | | M_0 MPerBlock
* | M | /| | | | | |
* |-0---|---/-|-----|-----|-----------|-----|-----|-
* | 1 | / | | | blockid | | |
* idxM0 | | | / | V | 5 | | | M_1 MPerBlock
* | - V 1 | - 3 | | | |
* |-----------|-----------|-----------|-----|-----|-
* mtx M | | | | | |
* | | | | | | M_2 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* | | | | | |
* | | | | | | M_3 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* V | | | | | |
* - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock
* | | | | | |
* |-----------|-----------|-----------|-----|-----|-
* Example:
* assume:
* M0 = 5
* N0 = 4
* block_1d_id = 5
* M01 = 2
*
* idx_N0 = 1
* idx_M0 = 1
* M01_adapt = 2
* idx_M00 = 0
* idx_M01 = 1
* idx_N0_M01_local = 5
* output {1, 2}
*/
const
index_t
N_out
=
idx_N0_M01_local
/
M01_adapt
;
const
index_t
idx_loc_mod_M01
=
idx_N0_M01_local
-
N_out
*
M01_adapt
;
return
make_tuple
(
idx_loc_mod_M01
+
idx_M00
*
M01
,
N_out
);
}
}
private:
index_t
M
;
index_t
N
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/utility/literals.hpp"
#include "ck_tile/core/utility/amd_address_space.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/host.hpp"
namespace
ck_tile
{
struct
GroupedGemmHostArgs
struct
GroupedGemmHostArgs
:
public
ck_tile
::
GemmHostArgs
{
const
void
*
a_ptr
;
const
void
*
b_ptr
;
void
*
c_ptr
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
CK_TILE_HOST
GroupedGemmHostArgs
()
noexcept
=
default
;
CK_TILE_HOST
GroupedGemmHostArgs
(
const
void
*
a_ptr_
,
const
void
*
b_ptr_
,
void
*
c_ptr_
,
ck_tile
::
index_t
M_
,
ck_tile
::
index_t
N_
,
ck_tile
::
index_t
K_
,
ck_tile
::
index_t
stride_A_
,
ck_tile
::
index_t
stride_B_
,
ck_tile
::
index_t
stride_C_
)
:
GemmHostArgs
(
a_ptr_
,
b_ptr_
,
c_ptr_
,
KBatch
,
M_
,
N_
,
K_
,
stride_A_
,
stride_B_
,
stride_C_
)
{
}
private:
static
constexpr
index_t
KBatch
=
1
;
};
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GroupedGemmKernel
struct
GroupedGemmKernel
:
public
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmPipeline
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmPipeline
::
CLayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
OffsetTile1DPartitioner
=
OffsettedTile1DPartitioner
<
TilePartitioner
>
;
using
Base
=
GemmKernel
<
TilePartitioner_
,
GemmPipeline_
,
EpiloguePipeline_
>
;
using
GemmKernelArgs
=
typename
Base
::
GemmKernelArgs
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
BlockSize
;
struct
GemmTransKernelArg
{
G
roupedGemmHost
Args
group_karg
;
G
emmKernel
Args
group_karg
;
ck_tile
::
index_t
block_start
;
ck_tile
::
index_t
block_end
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
G
roupedGemmHost
Args
&&
karg
,
index_t
bl_start
,
index_t
bl_end
)
GemmTransKernelArg
(
G
emmKernel
Args
&&
karg
,
index_t
bl_start
,
index_t
bl_end
)
:
group_karg
{
karg
},
block_start
{
bl_start
},
block_end
{
bl_end
}
{
}
};
__host__
static
size_t
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
(
)
{
return
gemm_descs
.
size
()
*
sizeof
(
GemmTransKernelArg
);
// clang-format off
using
P_
=
GemmPipeline
;
return
concat
(
'_'
,
"gemm_grouped"
,
gemm_prec_str
<
ADataType
,
BDataType
>
,
concat
(
'x'
,
P_
::
kMPerBlock
,
P_
::
kNPerBlock
,
P_
::
kKPerBlock
),
concat
(
'x'
,
P_
::
GetVectorSizeA
(),
P_
::
GetVectorSizeB
(),
P_
::
GetVectorSizeC
()),
concat
(
'x'
,
P_
::
kPadM
,
P_
::
kPadN
,
P_
::
kPadK
));
// clang-format on
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
KernelBlockSize
);
}
__host__
static
auto
GetWorkSpaceSize
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
->
std
::
size_t
{
return
gemm_descs
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
using
Hargs
=
GroupedGemmHostArgs
;
__host__
static
constexpr
auto
BlockSize
()
->
dim3
{
return
dim3
(
KernelBlockSize
);
}
__host__
static
constexpr
auto
GridSize
(
const
std
::
vector
<
Ha
rgs
>&
gemm_descs
)
__host__
static
constexpr
auto
GridSize
(
const
std
::
vector
<
GroupedGemmHostA
rgs
>&
gemm_descs
)
{
index_t
grid_size
=
0
;
for
(
const
auto
&
it_desc
:
gemm_descs
)
{
const
auto
dim3
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
dim3
.
x
*
dim3
.
y
*
1
;
const
auto
local_grid_size
=
TilePartitioner
::
GridSize
(
it_desc
.
M
,
it_desc
.
N
);
grid_size
+=
local_grid_size
*
it_desc
.
k_batch
;
}
return
dim3
(
grid_size
,
1
,
1
);
}
CK_TILE_HOST
static
auto
MakeKargs
(
const
std
::
vector
<
Hargs
>&
gemm_descs
)
CK_TILE_HOST
static
auto
MakeKargs
(
const
std
::
vector
<
GroupedGemmHostArgs
>&
gemm_descs
)
->
std
::
vector
<
GemmTransKernelArg
>
{
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
group_count
=
ck_tile
::
type_convert
<
ck_tile
::
index_t
>
(
gemm_descs
.
size
());
...
...
@@ -99,23 +118,23 @@ struct GroupedGemmKernel
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C
;
const
auto
dim3
=
TilePartitioner
::
GridSize
(
M
,
N
);
const
index_t
grid_size_grp
=
dim3
.
x
*
1
*
1
;
const
index_t
grid_size_grp
=
TilePartitioner
::
GridSize
(
M
,
N
)
*
gemm_descs
[
i
].
k_batch
;
const
index_t
block_start
=
grid_size
;
const
index_t
block_end
=
grid_size
+
grid_size_grp
;
grid_size
+=
grid_size_grp
;
auto
karg
=
GroupedGemmHostArgs
{
type_convert
<
const
ADataType
*>
(
gemm_descs
[
i
].
a_ptr
),
type_convert
<
const
BDataType
*>
(
gemm_descs
[
i
].
b_ptr
),
type_convert
<
CDataType
*>
(
gemm_descs
[
i
].
c_ptr
),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
};
auto
karg
=
GemmKernelArgs
{
type_convert
<
const
ADataType
*>
(
gemm_descs
[
i
].
a_ptr
),
type_convert
<
const
BDataType
*>
(
gemm_descs
[
i
].
b_ptr
),
type_convert
<
CDataType
*>
(
gemm_descs
[
i
].
c_ptr
),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
,
gemm_descs
[
i
].
k_batch
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
}
...
...
@@ -123,162 +142,34 @@ struct GroupedGemmKernel
return
gemm_kernel_args_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemSize
()
->
index_t
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
CK_TILE_DEVICE
void
Run
(
const
Hargs
&
kargs
,
const
index_t
block_start
)
const
CK_TILE_DEVICE
void
Run
(
const
GemmTransKernelArg
&
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}(
block_start
,
kargs
.
N
);
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
// Convert pointers to tensor views
auto
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
const
auto
[
iM
,
iN
]
=
OffsetTile1DPartitioner
::
GetOffsetedTileIndex
(
kargs
.
block_start
,
kargs
.
group_karg
.
M
,
kargs
.
group_karg
.
N
);
auto
b_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
}
}();
const
index_t
i_m
=
__builtin_amdgcn_readfirstlane
(
iM
*
TilePartitioner
::
MPerBlock
);
const
index_t
i_n
=
__builtin_amdgcn_readfirstlane
(
iN
*
TilePartitioner
::
NPerBlock
);
auto
a_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
.
group_karg
,
blockIdx
.
z
);
auto
a_block_window
=
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_m
,
0
});
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
NPerBlock
>
{},
number
<
TilePartitioner
::
KPerBlock
>
{}),
{
i_n
,
0
});
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
group_karg
.
a_ptr
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
group_karg
.
b_ptr
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
group_karg
.
c_ptr
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
// Run GEMM cooperatively by whole wokrgroup.
auto
c_block_tile
=
GemmPipeline
{}.
template
operator
()(
a_block_window
,
b_block_window
,
num_loop
,
smem_ptr
);
CDataType
*
c_start
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
auto
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
number
<
GemmPipeline
::
VectorSizeC
>
{},
number
<
1
>
{});
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
c_start
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
auto
c_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
MPerBlock
>
{},
number
<
TilePartitioner
::
NPerBlock
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
c_block_tile
);
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
.
group_karg
,
splitk_batch_offset
,
i_m
,
i_n
);
}
CK_TILE_DEVICE
void
operator
()(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
int
group_count
)
const
in
dex_
t
group_count
)
const
{
const
index_t
block_id
=
ck_tile
::
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmTransKernelArg
*>
(
...
...
@@ -286,7 +177,7 @@ struct GroupedGemmKernel
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
index_t
group_id
=
index_t
((
left
+
right
)
>>
1
);
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end
))
&&
...
...
@@ -300,10 +191,10 @@ struct GroupedGemmKernel
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
group_id
=
index_t
((
left
+
right
)
>>
1
);
}
Run
(
gemm_desc_ptr
[
group_id
]
.
group_karg
,
gemm_desc_ptr
[
group_id
].
block_start
);
Run
(
gemm_desc_ptr
[
group_id
]);
}
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
...
...
@@ -12,18 +13,23 @@ struct GemmPipelineAgBgCrImplBase
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
,
typename
DramTileWindowStep
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
SrcTileWindow
&
dram_tile_window
,
const
DramTileWindowStep
&
dram_tile_window_step
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
}
);
move_tile_window
(
dram_tile_window
,
dram_tile_window_step
);
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
...
...
@@ -35,20 +41,26 @@ struct GemmPipelineAgBgCrImplBase
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
LocalPrefetch
(
DstBlockTile
&
dst_block_tile
,
const
SrcTileWindow
&
lds_tile_window
)
const
{
load_tile
(
dst_block_tile
,
lds_tile_window
);
}
CK_TILE_DEVICE
auto
GetABLdsTensorViews
(
void
*
p_smem
)
const
{
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
ADataType
*
__restrict__
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_least_multiple
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
);
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
BDataType
*
__restrict__
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
...
...
@@ -60,19 +72,21 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE
auto
GetAWindows
(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
ALdsTensorView
&
a_lds_block_view
)
const
{
constexpr
bool
is_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
YPerTile
=
std
::
conditional_t
<
is_col_major
,
number
<
KPerBlock
>
,
number
<
MPerBlock
>>
;
using
XPerTile
=
std
::
conditional_t
<
is_col_major
,
number
<
MPerBlock
>
,
number
<
KPerBlock
>>
;
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
make_tuple
(
YPerTile
{},
XPerTile
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
...
...
@@ -86,18 +100,22 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE
auto
GetBWindows
(
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BLdsTensorView
&
b_lds_block_view
)
const
{
constexpr
bool
is_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
using
YPerTile
=
std
::
conditional_t
<
is_row_major
,
number
<
KPerBlock
>
,
number
<
NPerBlock
>>
;
using
XPerTile
=
std
::
conditional_t
<
is_row_major
,
number
<
NPerBlock
>
,
number
<
KPerBlock
>>
;
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
make_tuple
(
YPerTile
{},
XPerTile
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag
mem_bgmem_creg_v1_default
_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_
universal_
pipeline_ag
_bg_cr
_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -20,6 +24,8 @@ struct BaseGemmPipelineAgBgCrCompV3
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
...
...
@@ -37,7 +43,7 @@ struct BaseGemmPipelineAgBgCrCompV3
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrCompV3
:
public
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV3
<
Problem
>
;
...
...
@@ -62,26 +68,86 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
DoubleSmemBuffer
=
Problem
::
DoubleSmemBuffer
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrCompV3"
,
BlockSize
,
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST
static
std
::
string
Print
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
::
WarpGemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
WarpGemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
auto
str
=
std
::
stringstream
{};
str
<<
"A/B vector size: "
<<
GetVectorSizeA
()
<<
", "
<<
GetVectorSizeB
()
<<
"
\n
"
<<
"A/B LDS read/write width: "
<<
A_LDS_Read_Width
<<
", "
<<
B_LDS_Read_Width
<<
"
\n
"
<<
"A/B buffer load inst: "
<<
A_Buffer_Load_Inst_Num
<<
", "
<<
B_Buffer_Load_Inst_Num
<<
"
\n
"
<<
"A/B LDS write inst: "
<<
A_LDS_Write_Inst_Num
<<
", "
<<
B_LDS_Write_Inst_Num
<<
"
\n
"
<<
"A/B LDS read inst: "
<<
A_LDS_Read_Inst_Num
<<
", "
<<
B_LDS_Read_Inst_Num
<<
"
\n
"
<<
"C MFMA inst: "
<<
C_MFMA_Inst_Num
<<
"
\n
"
<<
"KPack: "
<<
BlockGemm
::
Traits
::
KPack
<<
"
\n
"
<<
"PrefetchStages: "
<<
PrefetchStages
<<
"
\n
"
;
return
str
.
str
();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
...
...
@@ -94,29 +160,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I0
{})
;
constexpr
index_t
NPerXDL
=
BlockGemm
Shape
::
Warp
Tile
::
at
(
I1
{})
;
constexpr
index_t
KPerXDL
=
BlockGemm
Shape
::
WarpTile
::
at
(
I2
{})
;
constexpr
index_t
MPerXDL
=
BlockGemm
::
Warp
Gemm
::
kM
;
constexpr
index_t
NPerXDL
=
BlockGemm
::
Warp
Gemm
::
kN
;
constexpr
index_t
KPerXDL
=
BlockGemm
::
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kK
;
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
// Below should be equal to AK1|BK1
constexpr
index_t
A_LDS_Read_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Read_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_LDS_Write_Width
=
Policy
::
template
GetSmemPackA
<
Problem
>();
constexpr
index_t
B_LDS_Write_Width
=
Policy
::
template
GetSmemPackB
<
Problem
>();
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeA
);
MPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeA
()
);
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
VectorSizeB
);
NPerBlock
*
KPerBlock
/
(
BlockSize
*
Get
VectorSizeB
()
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Write_Width
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Write_Width
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
A_LDS_Read_Width
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
B_LDS_Read_Width
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
...
...
@@ -246,11 +318,22 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -285,23 +368,51 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ABlockTile
a_block_tile
;
BBlockTile
b_block_tile
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
block_sync_lds
();
block_gemm
.
LocalPrefetch
(
a_lds_gemm_window
,
b_lds_gemm_window
);
...
...
@@ -316,11 +427,31 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrCompV4
{
static
constexpr
index_t
PrefetchStages
=
2
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
Three
;
}
else
{
return
TailNumber
::
Two
;
}
}
};
/**
* @brief Compute optimized pipeline version 4
*
* This version introduces a dual LDS window mechanism using a ping-pong buffer approach
* for more efficient data handling from global memory. Unlike compute version 3, this method
* allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
* matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
* thereby significantly reducing memory load times and enhancing overall performance.
*
* @note This version shows improved performance over Compute Version 3 with the same block tile.
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
* even when Compute Version 3's block size is twice that of Compute Version 4.
*/
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAgBgCrCompV4DefaultPolicy
>
struct
GemmPipelineAgBgCrCompV4
:
public
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrCompV4
<
Problem
>
;
using
PipelineImplBase
=
GemmPipelineAgBgCrImplBase
<
Problem
,
Policy
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
GetVectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
GetVectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
GetVectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Problem
::
DoubleSmemBuffer
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
template
IsTransposeC
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
:
public
PipelineImplBase
{
using
Base
=
PipelineImplBase
;
CK_TILE_DEVICE
static
constexpr
auto
HotLoopScheduler
()
{
constexpr
index_t
MPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I0
{});
constexpr
index_t
NPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I1
{});
constexpr
index_t
KPerXDL
=
BlockGemmShape
::
WarpTile
::
at
(
I2
{});
constexpr
index_t
WaveSize
=
64
;
constexpr
index_t
WaveNumM
=
BlockGemmShape
::
BlockWarps
::
at
(
I0
{});
constexpr
index_t
WaveNumN
=
BlockGemmShape
::
BlockWarps
::
at
(
I1
{});
constexpr
index_t
A_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
B_LDS_Read_Width
=
KPerXDL
;
constexpr
index_t
A_Buffer_Load_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeA
());
constexpr
index_t
B_Buffer_Load_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
GetVectorSizeB
());
constexpr
index_t
A_LDS_Write_Inst_Num
=
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Write_Inst_Num
=
NPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
A_LDS_Read_Inst_Num
=
WaveNumN
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
B_LDS_Read_Inst_Num
=
WaveNumM
*
MPerBlock
*
KPerBlock
/
(
BlockSize
*
KPerXDL
);
constexpr
index_t
C_MFMA_Inst_Num
=
MPerBlock
*
NPerBlock
*
KPerBlock
/
(
BlockSize
/
WaveSize
)
/
(
MPerXDL
*
NPerXDL
*
KPerXDL
);
constexpr
auto
num_ds_read_inst_a
=
A_LDS_Read_Width
*
sizeof
(
ADataType
)
==
16
?
A_LDS_Read_Inst_Num
:
A_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst_b
=
B_LDS_Read_Width
*
sizeof
(
BDataType
)
==
16
?
B_LDS_Read_Inst_Num
:
B_LDS_Read_Inst_Num
/
2
;
constexpr
auto
num_ds_read_inst
=
num_ds_read_inst_a
+
num_ds_read_inst_b
;
constexpr
auto
num_ds_write_inst
=
A_LDS_Write_Inst_Num
+
B_LDS_Write_Inst_Num
;
constexpr
auto
num_buffer_load_inst
=
A_Buffer_Load_Inst_Num
+
B_Buffer_Load_Inst_Num
;
constexpr
auto
num_issue
=
num_buffer_load_inst
;
static_for
<
0
,
num_buffer_load_inst
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
/
num_issue
,
0
);
// DS read : 2
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA: 1
__builtin_amdgcn_sched_group_barrier
(
0x200
,
num_ds_write_inst
/
num_issue
,
0
);
// DS write : 1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA : 1
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read :1
__builtin_amdgcn_sched_group_barrier
(
0x008
,
C_MFMA_Inst_Num
/
num_issue
-
3
,
0
);
// MFMA : 5
});
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"Data Type conflict on A and B matrix input data type."
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
////////////// global window & register /////////////////
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window_linear
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window_linear
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// A register tile for global load
constexpr
auto
ABlockTileDistr
=
a_copy_dram_window
.
get_tile_distribution
();
constexpr
auto
BBlockTileDistr
=
b_copy_dram_window
.
get_tile_distribution
();
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
));
ABlockTile
a_global_load_tile
;
BBlockTile
b_global_load_tile
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// global prefetch 0
// global read 0
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
////////////// LDS desc, window & register /////////////////
auto
&&
[
a_lds_block0
,
b_lds_block0
]
=
Base
::
GetABLdsTensorViews
(
p_smem_0
);
auto
&&
[
a_lds_block1
,
b_lds_block1
]
=
Base
::
GetABLdsTensorViews
(
p_smem_1
);
auto
a_copy_lds_window0
=
make_tile_window
(
a_lds_block0
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
a_copy_lds_window1
=
make_tile_window
(
a_lds_block1
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_copy_lds_window0
=
make_tile_window
(
b_lds_block0
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
auto
b_copy_lds_window1
=
make_tile_window
(
b_lds_block1
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
// global read 1
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
block_sync_lds
();
constexpr
auto
ALdsTileDistr
=
decltype
(
make_static_tile_distribution
(
BlockGemm
::
MakeABlockDistributionEncode
())){};
constexpr
auto
BLdsTileDistr
=
decltype
(
make_static_tile_distribution
(
BlockGemm
::
MakeBBlockDistributionEncode
())){};
using
ALdsTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ALdsTileDistr
));
using
BLdsTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BLdsTileDistr
));
ALdsTile
a_block_tile0
;
ALdsTile
a_block_tile1
;
BLdsTile
b_block_tile0
;
BLdsTile
b_block_tile1
;
auto
a_lds_ld_window0
=
make_tile_window_linear
(
a_lds_block0
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
a_lds_ld_window1
=
make_tile_window_linear
(
a_lds_block1
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
ALdsTileDistr
);
auto
b_lds_ld_window0
=
make_tile_window_linear
(
b_lds_block0
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
auto
b_lds_ld_window1
=
make_tile_window_linear
(
b_lds_block1
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
BLdsTileDistr
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
if
(
HasHotLoop
)
{
// minus 2 because we have ping-pong double buffer.
index_t
iCounter
=
__builtin_amdgcn_readfirstlane
(
num_loop
-
2
);
do
{
// ping
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// pong
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
b_block_tile0
,
b_lds_ld_window0
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window1
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window1
,
b_global_load_tile
,
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_global_load_tile
,
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_global_load_tile
,
b_copy_dram_window
,
b_dram_tile_window_step
);
// gemm
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
}
iCounter
-=
2
;
}
while
(
iCounter
>
1
);
}
// tail 3
if
(
TailNum
==
TailNumber
::
Three
)
{
// 3
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_global_load_tile
);
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window0
,
a_global_load_tile
,
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_global_load_tile
);
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window0
,
b_global_load_tile
,
b_element_func
);
}
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
}
// 2
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
Base
::
LocalPrefetch
(
a_block_tile0
,
a_lds_ld_window0
);
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
}
// 1
{
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
// 2
{
block_sync_lds
();
Base
::
LocalPrefetch
(
a_block_tile1
,
a_lds_ld_window1
);
Base
::
LocalPrefetch
(
b_block_tile1
,
b_lds_ld_window1
);
block_gemm
(
c_block_tile
,
a_block_tile0
,
b_block_tile0
);
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
8
,
0
);
// MFMA
});
__builtin_amdgcn_sched_barrier
(
0
);
}
// 1
{
block_gemm
(
c_block_tile
,
a_block_tile1
,
b_block_tile1
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem_0
,
void
*
p_smem_1
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem_0
,
p_smem_1
);
}
public:
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
index_t
num_loop
,
void
*
__restrict__
p_smem_0
,
void
*
__restrict__
p_smem_1
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem_0
,
p_smem_1
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp
0 → 100644
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
namespace
ck_tile
{
// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
// UniversalGemm Pipeline Policy.
// Default policy class should not be templated, put template on
// member functions instead.
struct
GemmPipelineAgBgCrCompV4DefaultPolicy
:
public
UniversalGemmBasePolicy
<
GemmPipelineAgBgCrCompV4DefaultPolicy
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
kMPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
kMPerBlock
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kMPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
>
{}
/
KPack
,
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
kNPerBlock
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
(
kNPerBlock
)
*
KPack
>
{},
number
<
KPack
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
kNPerBlock
>
{}),
make_merge_transform
(
make_tuple
(
number
<
kKPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmARegBRegCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,6 +7,7 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -20,6 +21,8 @@ struct BaseGemmPipelineAgBgCrMem
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
@@ -88,7 +91,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineA
GmemBGmemCRegV1Default
Policy
>
template
<
typename
Problem
,
typename
Policy
=
Universal
GemmPipelineA
gBgCr
Policy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
...
...
@@ -104,27 +107,40 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
using
I0
=
number
<
0
>
;
using
I1
=
number
<
1
>
;
using
I2
=
number
<
2
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Policy
::
template
GetVectorSizeA
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Policy
::
template
GetVectorSizeB
<
Problem
>();
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Policy
::
template
GetVectorSizeC
<
Problem
>();
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Problem
::
DoubleSmemBuffer
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AgBgCrMe"
,
concat
(
'x'
,
MPerBlock
,
NPerBlock
,
KPerBlock
),
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -162,11 +178,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -210,25 +237,59 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
// main body
...
...
@@ -244,19 +305,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
i
+=
PrefetchStages
;
...
...
@@ -272,12 +359,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_sync_lds
();
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
block_sync_lds
();
...
...
@@ -349,11 +456,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
constexpr
bool
is_a_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
constexpr
bool
is_b_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
static_assert
(
is_a_col_major
?
(
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"A block window has incorrect lengths for defined ALayout!"
);
static_assert
(
is_b_row_major
?
(
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}])
:
(
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I0
{}]
&&
KPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
I1
{}]),
"B block window has incorrect lengths for defined BLayout!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
...
...
@@ -397,25 +515,58 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
using
ADramTileWindowStep
=
typename
ADramBlockWindowTmp
::
BottomTensorIndex
;
using
BDramTileWindowStep
=
typename
BDramBlockWindowTmp
::
BottomTensorIndex
;
constexpr
ADramTileWindowStep
a_dram_tile_window_step
=
is_a_col_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
constexpr
BDramTileWindowStep
b_dram_tile_window_step
=
is_b_row_major
?
make_array
(
KPerBlock
,
0
)
:
make_array
(
0
,
KPerBlock
);
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
I0
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
}
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
,
b_dram_tile_window_step
);
});
// main body
...
...
@@ -429,19 +580,45 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
}
Base
::
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
a_copy_dram_window
,
a_dram_tile_window_step
);
Base
::
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
b_copy_dram_window
,
b_dram_tile_window_step
);
});
i
+=
PrefetchStages
;
...
...
@@ -454,12 +631,32 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
// no second block_sync_lds because it's interwave
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
if
constexpr
(
is_a_col_major
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegTileDistribution
<
Problem
>());
transpose_tile2d
(
a_shuffle_tmp
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_shuffle_tmp
,
a_element_func
);
}
else
{
Base
::
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
}
if
constexpr
(
is_b_row_major
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegTileDistribution
<
Problem
>());
transpose_tile2d
(
b_shuffle_tmp
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}));
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_shuffle_tmp
,
b_element_func
);
}
else
{
Base
::
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
}
});
block_sync_lds
();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <sstream>
#include "ck_tile/core.hpp"
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -23,31 +24,39 @@ struct GemmPipelineAGmemBGmemCRegV1
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
Get
VectorSizeA
()
{
return
Problem
::
VectorSizeA
;
}
static
constexpr
index_t
Get
VectorSizeB
()
{
return
Problem
::
VectorSizeB
;
}
static
constexpr
index_t
Get
VectorSizeC
()
{
return
Problem
::
VectorSizeC
;
}
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
static
constexpr
index_t
kLdsAlignmentInBytes
=
16
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
// clang-format off
return
concat
(
'_'
,
"pipeline_AGmemBGmemCRegV1"
,
concat
(
'x'
,
kMPerBlock
,
kNPerBlock
,
kKPerBlock
,
BlockSize
),
concat
(
'x'
,
GetVectorSizeA
(),
GetVectorSizeB
(),
GetVectorSizeC
()),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
));
// clang-format on
}
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
static
constexpr
bool
DoubleSmemBuffer
=
false
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
...
@@ -82,8 +91,9 @@ struct GemmPipelineAGmemBGmemCRegV1
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
kLdsAlignmentInBytes
)
*
kLdsAlignmentInBytes
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
...
...
@@ -124,7 +134,7 @@ struct GemmPipelineAGmemBGmemCRegV1
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
auto
block_gemm
=
Policy
::
template
GetBlockGemm
<
Problem
>
();
auto
block_gemm
=
BlockGemm
();
// Acc register tile
auto
c_block_tile
=
decltype
(
block_gemm
(
a_lds_gemm_window
,
b_lds_gemm_window
)){};
...
...
@@ -146,7 +156,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledARegBlockD
istribution
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
...
...
@@ -160,7 +170,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledBRegBlockD
istribution
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
...
...
@@ -197,7 +207,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledBRegBlockD
istribution
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -12,38 +12,10 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif
1
// 3d + padding
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
...
...
@@ -53,7 +25,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
...
...
@@ -114,8 +85,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
constexpr
index_t
smem_size
=
smem_size_a
+
smem_size_b
;
return
smem_size
;
}
...
...
@@ -123,88 +93,15 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
return
Problem
::
VectorLoadSize
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
#elif 1
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
a_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc_m_k
;
return
Problem
::
VectorLoadSize
;
}
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
...
...
@@ -269,7 +166,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
...
...
@@ -390,7 +286,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockD
escriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockD
istribution
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
@@ -438,11 +334,11 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockD
escriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockD
istribution
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
Row
Major
>
);
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
Column
Major
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -488,11 +384,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
constexpr
bool
TransposeC
=
false
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I2
=
number
<
2
>
{};
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
...
...
@@ -502,7 +393,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -25,6 +26,15 @@ struct GemmPipelineAGmemBGmemCRegV2
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"pipeline_AGmemBGmemCRegV2"
,
concat
(
'x'
,
kMPerBlock
,
kNPerBlock
,
kKPerBlock
,
kBlockSize
));
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
TransposeC
()
{
return
Problem
::
TransposeC
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -11,10 +13,10 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
>
typename
Traits_
>
struct
GemmPipelineProblemBase
{
using
Gemm
Traits
=
remove_cvref_t
<
TileGemm
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
...
...
@@ -22,18 +24,32 @@ struct GemmPipelineProblemBase
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
CLayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Traits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
static
constexpr
index_t
VectorLoadSize
=
GemmTraits
::
_VectorSize
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
TransposeC
=
Traits
::
TransposeC
;
static
constexpr
bool
kPadM
=
GemmTraits
::
kPadM
;
static
constexpr
bool
kPadN
=
GemmTraits
::
kPadN
;
static
constexpr
bool
kPadK
=
GemmTraits
::
kPadK
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Traits
::
DoubleSmemBuffer
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
index_t
VectorLoadSize
=
Traits
::
_VectorSize
;
[[
nodiscard
]]
CK_TILE_HOST
static
const
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"gemm_problem"
,
concat
(
'x'
,
VectorLoadSize
,
kBlockSize
),
concat
(
'x'
,
kPadM
,
kPadN
,
kPadK
),
Scheduler
);
// clang-format on
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
...
...
@@ -110,7 +126,6 @@ struct GemmPipelineProblemBase
return
kPadK
?
1
:
GetAlignmentB
();
}
}();
static
constexpr
index_t
VectorSizeC
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
...
...
@@ -128,27 +143,45 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
>
typename
Traits_
>
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemm
Traits_
>
;
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
Traits_
>
;
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
,
typename
Traits_
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
:
public
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
struct
UniversalGemmPipelineProblem
{
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
typename
Traits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
DoubleSmemBuffer
=
Traits
::
DoubleSmemBuffer
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
bool
TransposeC
=
Traits
::
TransposeC
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
// UniversalGemm Policy
struct
UniversalGemm
PipelineAgBgCr
Policy
template
<
typename
Derived
>
struct
UniversalGemm
Base
Policy
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorLoadSize
()
static
constexpr
auto
ATileAccessPattern
=
tile_distribution_pattern
::
thread_raked
;
static
constexpr
auto
BTileAccessPattern
=
tile_distribution_pattern
::
thread_raked
;
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
,
index_t
XPerTile
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetGlobalVectorLoadSize
()
{
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
if
constexpr
(
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
// Assume DataType is even!
if
constexpr
(
XPerTile
%
(
16
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
{
return
(
16
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
8
/
sizeof
(
DataType
))
==
0
)
else
if
constexpr
(
XPerTile
%
(
8
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
8
/
sizeof
(
DataType
))
==
0
)
{
return
(
8
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
4
/
sizeof
(
DataType
))
==
0
&&
sizeof
(
DataType
)
>
=
4
)
else
if
constexpr
(
sizeof
(
DataType
)
>=
4
&&
XPerTile
%
(
4
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
4
/
sizeof
(
DataType
)
)
=
=
0
)
{
return
(
4
/
sizeof
(
DataType
));
}
else
if
constexpr
(
elements_per_thread
%
(
2
/
sizeof
(
DataType
))
==
0
&&
sizeof
(
DataType
)
>
=
2
)
else
if
constexpr
(
sizeof
(
DataType
)
>=
2
&&
XPerTile
%
(
2
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
2
/
sizeof
(
DataType
)
)
=
=
0
)
{
return
(
2
/
sizeof
(
DataType
));
}
...
...
@@ -50,119 +63,246 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
constexpr
auto
DataTypeSize
=
sizeof
(
ADataType
);
constexpr
auto
MLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
GetGlobalVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
,
KPerBlock
>
();
}
else
{
return
GetGlobalVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
,
MPerBlock
>
();
}
}
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
*
MLdsLayer
>
{},
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{})
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeB
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
GetGlobalVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
,
NPerBlock
>
();
}
else
{
return
GetGlobalVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
,
KPerBlock
>
();
}
}
constexpr
auto
a_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
using
WG
=
typename
BlockGemm
::
WarpGemm
;
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
constexpr
bool
TransposeC
=
Problem
::
TransposeC
;
using
CLayout
=
typename
Problem
::
CLayout
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
return
a_lds_block_desc
;
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
TransposeC
)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
// In this case each thread has just a single item in Ndim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
TransposeC
)
{
// In this case each thread has just a single item in Mdim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Problem
::
TransposeC
;
}
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
constexpr
auto
DataTypeSize
=
sizeof
(
BDataType
);
constexpr
auto
NLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
// Tile: MPerBlock X KPerBlock
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
MPerBlock
,
KPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
// Tile: KPerBlock X MPerBlock
else
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
MPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
}
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
*
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
NLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
constexpr
auto
b_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
// Tile: KPerBlock X NPerBlock
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
NPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
// Tile: NPerBlock X KPerBlock
else
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
NPerBlock
,
KPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
}
}
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegTileDistribution
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
MPerBlock
,
VecLoadSize
,
ATileAccessPattern
>
;
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegTileDistribution
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
NPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
return
KPack
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Derived
::
template
GetBlockGemm
<
Problem
>())
>
;
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
return
KPack
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
constexpr
auto
a_lds_desc
=
Derived
::
template
MakeALdsBlockDescriptor
<
Problem
>();
constexpr
index_t
smem_size_a
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
ADataType
)
*
a_lds_desc
.
get_element_space_size
(),
16
);
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
constexpr
auto
b_lds_desc
=
Derived
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
constexpr
index_t
smem_size_b
=
integer_least_multiple
(
sizeof
(
typename
Problem
::
BDataType
)
*
b_lds_desc
.
get_element_space_size
(),
16
);
return
smem_size_b
;
}
...
...
@@ -171,298 +311,272 @@ struct UniversalGemmPipelineAgBgCrPolicy
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
return
smem_size
_a
+
smem_size_b
;
}
};
// UniversalGemm Policy
struct
UniversalGemmPipelineAgBgCrPolicy
:
public
UniversalGemmBasePolicy
<
UniversalGemmPipelineAgBgCrPolicy
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeA
DramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeA
LdsBlockDescriptor
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
auto
DataTypeSize
=
sizeof
(
ADataType
);
constexpr
auto
MLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
KPerBlock
/
KPack
*
MLdsLayer
>
{},
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
MLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
KPerBlock
/
KPack
*
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
constexpr
auto
a_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc
=
transform_tensor_descriptor
(
a_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
return
a_lds_block_desc
;
}
/**
* @brief Create LDS block descriptor for B tensor.
*
* @tparam Problem Gemm pipeline problem.
* @return B tensor LDS block descriptor.
*/
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
ShuffledBReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
BLds
BlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
//
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
#if 1
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
constexpr
auto
BK0
=
number
<
KPerBlock
/
KPack
>
{};
constexpr
auto
DataTypeSize
=
sizeof
(
BDataType
);
constexpr
auto
NLdsLayer
=
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
BK0
*
number
<
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
KPack
>
{}),
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
NLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
number
<
1
>
{});
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
BK0
*
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_nldslayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_nldslayer_n_bk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
BK0
,
number
<
KPack
>
{}))),
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
else
#else
else
// B is Row Major
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
NPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
constexpr
auto
BK0
=
number
<
TileEncodingPattern
::
X1
>
{};
constexpr
auto
BK1
=
number
<
TileEncodingPattern
::
Y0
>
{};
// constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr
auto
N0
=
TileEncodingPattern
::
X0
;
constexpr
auto
N1
=
NPerBlock
/
N0
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
constexpr
auto
NPerXdl
=
number
<
WarpTile
::
at
(
I1
)
>
{};
// constexpr auto KThreadWrite =
// BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr
auto
KThreadWrite
=
TileEncodingPattern
::
Y2
;
constexpr
auto
K0PerThreadWrite
=
BK0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
NPerXdl
;
constexpr
auto
K0PerThreadRead
=
BK0
/
KThreadRead
;
constexpr
auto
kfold
=
(
BK1
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
BK1
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=n0
constexpr
auto
npair
=
(
BK1
*
NPerXdl
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
BK1
*
NPerXdl
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
BK1
*
NPerXdl
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{},
number
<
npair
>
{},
BK1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
BK1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
BK1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
// constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_lds_block_desc_unmerged,
// make_tuple(make_merge_transform_v3_division_mod(
// make_tuple(number<KThreadReadPerm>{},
// number<KThreadWrite / kfold / KThreadReadPerm>{},
// number<kfold>{},
// number<K0PerThreadWrite>{})),
// make_merge_transform_v3_division_mod(
// make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{})),
// make_pass_through_transform(BK1)),
// make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}),
// make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
constexpr
auto
b_lds_block_desc_kn
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
BK1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
N0
/
npair
>
{},
number
<
npair
>
{},
number
<
N1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
// return b_lds_block_desc_bk0_n_bk1;
return
b_lds_block_desc_kn
;
// constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor(
// make_tuple(BK0, number<NPerBlock>{}, number<KPack>{}),
// make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
// number<KPack>{},
// number<1>{});
// constexpr auto b_lds_block_desc = transform_tensor_descriptor(
// b_lds_block_desc_bk0_n_bk1,
// make_tuple(make_pass_through_transform(number<NPerBlock>{}),
// make_merge_transform_v3_division_mod(make_tuple(BK0,
// number<KPack>{}))),
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return b_lds_block_desc;
}
#endif
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
Acc
DataType
,
typename
Problem
::
C
DataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
Block
GemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
return
Block
UniversalGemmAsBsCr
<
Problem
,
BlockGemmPolicy
>
{};
}
};
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/concat.hpp"
namespace
ck_tile
{
...
...
@@ -19,6 +20,16 @@ struct TileGemmShape
static
constexpr
index_t
kM
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kN
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kK
=
BlockTile
::
at
(
number
<
2
>
{});
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
return
concat
(
'_'
,
"tile_gemm_shape"
,
concat
(
'x'
,
kM
,
kN
,
kK
,
NumWarps
),
concat
(
'x'
,
BlockWarps
::
at
(
number
<
0
>
{}),
BlockWarps
::
at
(
number
<
1
>
{}),
BlockWarps
::
at
(
number
<
2
>
{})),
concat
(
'x'
,
(
WarpTile
::
at
(
number
<
0
>
{})),
WarpTile
::
at
(
number
<
1
>
{}),
WarpTile
::
at
(
number
<
2
>
{})));
// clang-format on
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
7572a691
...
...
@@ -19,11 +19,37 @@ struct TileGemmTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
// TODO this can't be hardcoded here! Should be in policy!
static
constexpr
int
_VectorSize
=
16
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
static
constexpr
bool
TransposeC
=
false
;
};
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
kPadK_
,
bool
DoubleSmemBuffer_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
bool
TransposeC_
=
false
>
struct
TileGemmUniversalTraits
{
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
bool
DoubleSmemBuffer
=
DoubleSmemBuffer_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
static
constexpr
bool
TransposeC
=
TransposeC_
;
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/layernorm2d.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -11,3 +11,4 @@
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
7572a691
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,7 +14,8 @@ struct Layernorm2dFwdHostArgs
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_sm_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_bias
;
// [1, n], bias, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
...
...
@@ -42,15 +43,16 @@ struct Layernorm2dFwd
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
XScaleDataType
=
remove_cvref_t
<
typename
Problem
::
XScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XBiasDataType
=
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
SmoothScaleDataType
=
remove_cvref_t
<
typename
Problem
::
SmoothScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
// for simplicity, shortcut input/output type is same as X
using
XResidualDataType
=
XDataType
;
...
...
@@ -67,6 +69,7 @@ struct Layernorm2dFwd
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
Traits
::
kTwoPass
;
static
constexpr
auto
kXbias
=
Problem
::
Traits
::
kXbias
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -81,7 +84,8 @@ struct Layernorm2dFwd
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_sm_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_bias
;
// [1, n], bias, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
...
...
@@ -107,7 +111,8 @@ struct Layernorm2dFwd
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x_residual
,
hargs
.
p_x_scale
,
hargs
.
p_sm_scale
,
hargs
.
p_x_bias
,
hargs
.
p_gamma
,
hargs
.
p_beta
,
hargs
.
p_y
,
...
...
@@ -152,6 +157,7 @@ struct Layernorm2dFwd
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kXbias
!=
Layernorm2dXBiasEnum
::
NO_BIAS
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dXBiasEnumName
<
kXbias
>::
name
;
if
(
kFusedAdd
!=
Layernorm2dFusedAddEnum
::
NO_ADD
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedAddEnumName
<
kFusedAdd
>::
name
;
if
(
kFusedQuant
!=
Layernorm2dFusedQuantEnum
::
NO_SWEEP
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedQuantEnumName
<
kFusedQuant
>::
name
;
if
(
kPadN
)
n
+=
"_pn"
;
...
...
@@ -165,7 +171,7 @@ struct Layernorm2dFwd
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
YDataType
>::
name
);
}
if
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sx"
)
+
_SS_
(
t2s
<
X
ScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sx"
)
+
_SS_
(
t2s
<
Smooth
ScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
if
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
...
...
@@ -228,6 +234,27 @@ struct Layernorm2dFwd
}
}();
const
auto
x_bias_window
=
[
&
]()
{
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XBiasDataType
*>
(
kargs
.
p_x_bias
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_N
>
{}));
}
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
...
...
@@ -329,18 +356,18 @@ struct Layernorm2dFwd
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
auto
x
_scale_window
=
[
&
]()
{
auto
sm
_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
const
X
ScaleDataType
*>
(
kargs
.
p_
x
_scale
),
static_cast
<
const
Smooth
ScaleDataType
*>
(
kargs
.
p_
sm
_scale
),
make_tuple
(
kargs
.
n
),
number
<
Vector_N
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
//
x
_scale no need pad
sequence
<
false
>
{});
//
sm
_scale no need pad
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
...
...
@@ -371,13 +398,14 @@ struct Layernorm2dFwd
Pipeline
{}(
x_window
,
x_residual_window
,
x_bias_window
,
gamma_window
,
beta_window
,
y_window
,
y_residual_window
,
mean_window
,
inv_std_window
,
x
_scale_window
,
sm
_scale_window
,
y_scale_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
...
...
Prev
1
…
17
18
19
20
21
22
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment