Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
63b152d6
Commit
63b152d6
authored
Oct 17, 2024
by
danyao12
Browse files
Merge branch 'develop' into ck_tile/fa_bwd_v3
parents
ae2d7d2b
14c3cfb1
Changes
132
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1239 additions
and
366 deletions
+1239
-366
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+88
-148
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+55
-45
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+15
-8
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+4
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+35
-31
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+7
-5
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+6
-9
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+7
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+2
-2
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
+4
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+10
-7
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+424
-0
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+27
-0
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+9
-0
include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp
...ile/ops/image_to_column/kernel/image_to_column_kernel.hpp
+224
-0
include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp
...mage_to_column/pipeline/block_image_to_column_problem.hpp
+27
-0
include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp
...s/image_to_column/pipeline/tile_image_to_column_shape.hpp
+32
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+247
-86
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
...ps/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
+13
-9
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
63b152d6
...
...
@@ -5,7 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/
pipeline
/block_gemm_
pipeline_
problem.hpp"
#include "ck_tile/ops/gemm/
block
/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...
...
@@ -25,10 +25,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
...
...
@@ -52,16 +53,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
...
...
@@ -84,16 +86,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
...
...
@@ -117,16 +120,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
...
...
@@ -149,16 +153,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
...
...
@@ -181,7 +186,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBRegCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBRegCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
// these are for global load
...
...
@@ -191,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
...
...
@@ -210,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
...
...
@@ -229,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -249,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
...
...
@@ -310,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -322,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
...
...
@@ -333,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -371,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -394,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -417,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -440,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -811,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
k_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegBlockDescriptor
()
{
...
...
@@ -860,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
...
@@ -885,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsWriteBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegBlockDescriptor
()
{
...
...
@@ -935,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
...
@@ -961,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1043,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
...
...
@@ -1087,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1250,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
// Hold full block data
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
...
...
@@ -1294,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1727,7 +1668,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
0
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
0
>
()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
...
...
@@ -1759,7 +1700,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
1
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
1
>
()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
...
...
@@ -1777,7 +1718,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
2
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
2
>
()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
...
...
@@ -1796,7 +1737,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
3
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
3
>
()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
...
...
@@ -1830,7 +1771,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
4
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
4
>
()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
...
...
@@ -1854,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kK0
=
Problem
::
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK2
=
Problem
::
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
WarpGemmM
=
...
...
@@ -1868,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute
static
constexpr
index_t
Gemm0MFMA
=
kM0
*
kN0
*
kQKHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
kM0
*
kN0
*
kK0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm1MFMA
=
kM0
*
kN0
*
kVHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kN0
*
kVHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kM0
*
kN0
*
kK2
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm3MFMA
=
kN0
*
kQKHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
...
...
@@ -1898,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
SGradT_LDS_READ_P1
=
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kK0
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
SGradT_LDS_READ_P2
=
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
k
VHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
kM0
*
k
K2
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
63b152d6
...
...
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
// lse_acc tile in LDS
...
...
@@ -172,22 +172,27 @@ struct BlockFmhaFwdSplitKVCombinePipeline
lse_accum
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
LSEDataType
>::
infinity
());
block_tile_reduce_sync
(
lse_max
,
f_max
,
bool_constant
<
false
>
{});
static
const
auto
get_validated_m
=
[](
LSEDataType
raw_m
)
{
return
raw_m
==
-
numeric
<
LSEDataType
>::
infinity
()
?
type_convert
<
LSEDataType
>
(
0.
f
)
:
raw_m
;
};
decltype
(
lse_accum
)
lse_exp
;
{
constexpr
auto
spans
=
decltype
(
lse_exp
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
if
(
lse_max
[
i_idx
]
==
-
numeric
<
LSEDataType
>::
infinity
())
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
lse_exp
(
i_j_idx
)
=
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
get_validated_m
(
lse_max
(
i_idx
)));
lse_exp
(
i_j_idx
)
=
ck_tile
::
type_convert
<
LSEDataType
>
(
0.0
f
);
});
}
else
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
lse_exp
(
i_j_idx
)
=
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
lse_max
(
i_idx
));
});
}
});
}
...
...
@@ -201,15 +206,10 @@ struct BlockFmhaFwdSplitKVCombinePipeline
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
if
(
lse_sum
(
i_idx
)
==
0.
f
||
lse_sum
(
i_idx
)
!=
lse_sum
(
i_idx
))
{
lse_logsum
(
i_idx
)
=
numeric
<
LSEDataType
>::
infinity
();
}
if
(
lse_sum
[
i_idx
]
==
ck_tile
::
type_convert
<
LSEDataType
>
(
0.0
f
))
lse_logsum
(
i_idx
)
=
-
numeric
<
LSEDataType
>::
infinity
();
else
{
lse_logsum
(
i_idx
)
=
ck_tile
::
log
(
lse_sum
(
i_idx
))
+
get_validated_m
(
lse_max
(
i_idx
));
}
lse_logsum
(
i_idx
)
=
ck_tile
::
log
(
lse_sum
(
i_idx
))
+
lse_max
(
i_idx
);
});
}
...
...
@@ -218,6 +218,25 @@ struct BlockFmhaFwdSplitKVCombinePipeline
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
if
(
lse_logsum
(
i_idx
)
==
-
numeric
<
LSEDataType
>::
infinity
())
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
lse_accum
.
get_tile_distribution
(),
i_j_idx
);
const
auto
col
=
x_indices
.
at
(
number
<
1
>
{});
if
(
col
<
num_splits
)
{
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
lse_acc_lds
(
row
,
col
)
=
ck_tile
::
type_convert
<
LSEDataType
>
(
0.0
f
);
}
});
}
else
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
...
...
@@ -233,22 +252,13 @@ struct BlockFmhaFwdSplitKVCombinePipeline
ck_tile
::
exp
(
lse_accum
(
i_j_idx
)
-
lse_logsum
(
i_idx
));
}
});
}
});
}
block_sync_lds
();
if
constexpr
(
kStoreLSE
)
{
constexpr
auto
spans
=
decltype
(
lse_logsum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
if
(
lse_logsum
(
i_idx
)
==
numeric
<
LSEDataType
>::
infinity
())
{
lse_logsum
(
i_idx
)
=
-
numeric
<
LSEDataType
>::
infinity
();
}
});
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse_logsum
));
}
...
...
@@ -261,7 +271,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
const
index_t
padded_
max_
seqlen_q
=
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
kM0
;
const
index_t
padded_seqlen_q
=
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
kM0
;
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
{
...
...
@@ -282,7 +292,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
}
move_tile_window
(
o_acc_dram_window
,
{
padded_
max_
seqlen_q
,
0
});
move_tile_window
(
o_acc_dram_window
,
{
padded_seqlen_q
,
0
});
}
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
...
...
@@ -297,7 +307,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
return
operator
()(
lse_acc_dram_block_window
,
...
...
@@ -306,7 +316,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity
{},
identity
{},
num_splits
,
max_
seqlen_q
,
seqlen_q
,
smem_ptr
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
63b152d6
...
...
@@ -21,14 +21,23 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOacc
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
return
16
/
sizeof
(
OaccDataType
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
kNPerBlock
/
N0
;
return
min
(
N1
,
static_cast
<
index_t
>
(
16
/
sizeof
(
OaccDataType
)));
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentO
()
{
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
return
16
/
sizeof
(
ODataType
);
return
GetAlignmentOacc
<
Problem
>
();
}
template
<
typename
Problem
>
...
...
@@ -150,16 +159,14 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOaccDramTileDistribution
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
N1
=
16
/
sizeof
(
OaccDataType
);
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
M2
=
get_warp_size
()
/
N0
;
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
kNPerBlock
/
N0
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
63b152d6
...
...
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
...
...
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if
masked and
no work to do
.
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
const
index_t
original_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
...
...
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
63b152d6
...
...
@@ -5,7 +5,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
...
...
@@ -75,10 +76,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
...
...
@@ -116,7 +118,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
@@ -199,10 +201,11 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
...
...
@@ -240,7 +243,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmASmemBSmemCRegV1
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmASmemBSmemCRegV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
@@ -954,10 +957,11 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
using
Block
GemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
using
GemmProblem
=
BlockGemmProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
...
...
@@ -996,7 +1000,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV2
<
Block
GemmProblem
,
BlockGemmPolicy
>
{};
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
...
...
include/ck_tile/ops/gemm.hpp
View file @
63b152d6
...
...
@@ -23,12 +23,14 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
63b152d6
...
...
@@ -11,20 +11,12 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
template
<
typename
TilePartitioner_
,
typename
GemmPipeline_
,
typename
EpiloguePipeline_
>
struct
GemmKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
GemmPipeline
=
remove_cvref_t
<
GemmPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
index_t
KernelBlockSize
=
GemmPipeline
::
kBlockSize
;
using
ADataType
=
remove_cvref_t
<
typename
GemmPipeline
::
ADataType
>
;
...
...
@@ -32,6 +24,10 @@ struct GemmKernel
using
CAccDataType
=
remove_cvref_t
<
typename
GemmPipeline
::
CDataType
>
;
using
CODataType
=
remove_cvref_t
<
typename
EpiloguePipeline
::
ODataType
>
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmPipeline
::
LayoutC
>
;
__host__
static
constexpr
auto
GridSize
(
index_t
M_size
,
index_t
N_size
,
index_t
Batch_size
)
{
return
TilePartitioner
::
GridSize
(
M_size
,
N_size
,
Batch_size
);
...
...
@@ -184,6 +180,7 @@ struct GemmKernel
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
EpiloguePipeline
{}(
CBlockWindow_pad
,
acc
);
}
};
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
63b152d6
...
...
@@ -4,15 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAGmemBGmemCRegV1
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
@@ -33,6 +33,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
Problem
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
Problem
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
Problem
::
LayoutC
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
63b152d6
...
...
@@ -7,9 +7,9 @@
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV1
// Default policy for GemmPipelineAGmemBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
Block
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
#if 0
// 2d
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
63b152d6
...
...
@@ -4,15 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
,
typename
Policy
=
Block
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
struct
Block
GemmPipelineAGmemBGmemCRegV2
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
struct
GemmPipelineAGmemBGmemCRegV2
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
View file @
63b152d6
...
...
@@ -7,12 +7,11 @@
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV2
// Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
63b152d6
...
...
@@ -13,20 +13,23 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
>
struct
BlockGemmPipelineProblem
typename
TileGemmTraits_
>
struct
GemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutC
>
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
1
:
VectorLoadSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
AlignmentB
=
kPadB
?
1
:
VectorLoadSize
/
sizeof
(
BDataType
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace
ck_tile
{
// UniversalGemm Policy
template
<
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
UniversalGemmPipelineAgBgCrPolicy
{
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayoutA
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
*
number
<
MLdsLayer
>
{},
number
<
MPerBlock
/
MLdsLayer
>
{},
K1
),
make_tuple
(
K1
,
number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
K0
*
MLdsLayer
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I0
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
M0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kM
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
M0
*
sizeof
(
ADataType
)
>
128
)
?
1
:
128
/
(
K1
*
M0
*
sizeof
(
ADataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=kN0
constexpr
auto
mpair
=
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)))
>
M0
?
M0
:
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{},
number
<
mpair
>
{},
K1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
M0
/
mpair
>
{},
number
<
mpair
>
{},
number
<
M1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayoutB
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
*
number
<
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
K1
),
make_tuple
(
K1
,
number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
K0
*
NLdsLayer
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
N0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kN
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
K1
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=kN0
constexpr
auto
npair
=
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{},
number
<
npair
>
{},
K1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
N0
/
npair
>
{},
number
<
npair
>
{},
number
<
N1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockGemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
bool
kPadA_
,
bool
kPadB_
,
bool
kPadC_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
TileGemmTraits
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
using
LayoutA
=
LayoutA_
;
using
LayoutB
=
LayoutB_
;
using
LayoutC
=
LayoutC_
;
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
>
struct
ImageToColumn
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
auto
I3
=
number
<
3
>
{};
static
constexpr
auto
I4
=
number
<
4
>
{};
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
InDataType
=
remove_cvref_t
<
typename
Problem
::
InDataType
>
;
using
OutDataType
=
remove_cvref_t
<
typename
Problem
::
OutDataType
>
;
static
constexpr
index_t
NDimSpatial
=
Problem
::
NDimSpatial
;
static
constexpr
index_t
AligmentIn
=
Problem
::
AligmentIn
;
static
constexpr
index_t
AligmentOut
=
Problem
::
AligmentOut
;
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
static
constexpr
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
index_t
kKPerBlock
=
Problem
::
BlockShape
::
kKPerBlock
;
struct
Kargs
{
const
void
*
p_in
;
void
*
p_out
;
const
long_index_t
G
;
const
long_index_t
N
;
const
long_index_t
C
;
const
array
<
long_index_t
,
NDimSpatial
>
input_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
>
output_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
;
const
array
<
long_index_t
,
3
>
gemm_g_m_k_strides
;
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides
;
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations
;
const
array
<
long_index_t
,
NDimSpatial
>
input_left_pads
;
const
array
<
long_index_t
,
NDimSpatial
>
input_right_pads
;
};
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
void
*
p_in
,
void
*
p_out
,
const
long_index_t
G
,
const
long_index_t
N
,
const
long_index_t
C
,
const
array
<
long_index_t
,
NDimSpatial
>
input_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
>
output_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
,
const
array
<
long_index_t
,
3
>
gemm_g_m_k_strides
,
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides
,
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations
,
const
array
<
long_index_t
,
NDimSpatial
>
input_left_pads
,
const
array
<
long_index_t
,
NDimSpatial
>
input_right_pads
)
{
return
Kargs
{
p_in
,
p_out
,
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_g_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
GemmM
,
index_t
GemmK
,
index_t
Batch
)
{
return
dim3
(
integer_divide_ceil
(
GemmM
,
kMPerBlock
),
integer_divide_ceil
(
GemmK
,
kKPerBlock
),
Batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
kBlockSize
;
}
CK_TILE_DEVICE
auto
MakeImageMKDesc
(
const
Kargs
&
kargs
)
const
{
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
kargs
.
N
,
kargs
.
input_spatial_lengths
[
I0
],
kargs
.
input_spatial_lengths
[
I1
],
kargs
.
C
),
make_tuple
(
kargs
.
image_g_n_c_wis_strides
[
I1
],
kargs
.
image_g_n_c_wis_strides
[
I3
],
kargs
.
image_g_n_c_wis_strides
[
I4
],
kargs
.
image_g_n_c_wis_strides
[
I2
]),
number
<
AligmentIn
>
{},
I1
);
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
kargs
.
N
),
make_pad_transform
(
kargs
.
input_spatial_lengths
[
I0
],
kargs
.
input_left_pads
[
I0
],
kargs
.
input_right_pads
[
I0
]),
make_pad_transform
(
kargs
.
input_spatial_lengths
[
I1
],
kargs
.
input_left_pads
[
I1
],
kargs
.
input_right_pads
[
I1
]),
make_pass_through_transform
(
kargs
.
C
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
kargs
.
N
),
make_embed_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I0
],
kargs
.
output_spatial_lengths
[
I0
]),
make_tuple
(
kargs
.
conv_filter_dilations
[
I0
],
kargs
.
conv_filter_strides
[
I0
])),
make_embed_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I1
],
kargs
.
output_spatial_lengths
[
I1
]),
make_tuple
(
kargs
.
conv_filter_dilations
[
I1
],
kargs
.
conv_filter_strides
[
I1
])),
make_pass_through_transform
(
kargs
.
C
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{},
sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
N
,
kargs
.
output_spatial_lengths
[
I0
],
kargs
.
output_spatial_lengths
[
I1
])),
make_merge_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I0
],
kargs
.
filter_spatial_lengths
[
I1
],
kargs
.
C
))),
make_tuple
(
sequence
<
0
,
2
,
4
>
{},
sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
CK_TILE_DEVICE
auto
CalculateMKDims
(
const
Kargs
&
kargs
)
const
{
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
const
index_t
M
=
kargs
.
N
*
static_cast
<
index_t
>
(
kargs
.
output_spatial_lengths
[
I0
]
*
kargs
.
output_spatial_lengths
[
I1
]);
const
index_t
K
=
kargs
.
C
*
static_cast
<
index_t
>
(
kargs
.
filter_spatial_lengths
[
I0
]
*
kargs
.
filter_spatial_lengths
[
I1
]);
return
make_tuple
(
M
,
K
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBlockTileDistribution
()
{
using
P
=
typename
Problem
::
BlockShape
;
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
// Y: {kMPerThread, kKPerThread}
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
P
::
kMWarpPerBlock
,
P
::
kMThreadPerWarp
,
P
::
kMPerThread
>
,
sequence
<
P
::
kKWarpPerBlock
,
P
::
kKThreadPerWarp
,
P
::
kKPerThread
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
2
>>
{});
}
CK_TILE_DEVICE
void
ConvTensorRearrange
(
const
Kargs
&
kargs
)
const
{
const
auto
[
M
,
K
]
=
CalculateMKDims
(
kargs
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kMPerBlock
);
const
index_t
iK
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kKPerBlock
);
const
index_t
iBatch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
auto
in_offset
=
iBatch
*
kargs
.
image_g_n_c_wis_strides
[
I0
];
const
auto
out_offset
=
iBatch
*
kargs
.
gemm_g_m_k_strides
[
I0
];
const
auto
image_m_k
=
make_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
InDataType
*>
(
kargs
.
p_in
)
+
in_offset
,
MakeImageMKDesc
(
kargs
));
const
auto
gemm_m_k
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
OutDataType
*>
(
kargs
.
p_out
)
+
out_offset
,
make_tuple
(
M
,
K
),
make_tuple
(
kargs
.
gemm_g_m_k_strides
[
I1
],
kargs
.
gemm_g_m_k_strides
[
I2
]),
number
<
AligmentOut
>
{},
I1
);
const
auto
image_m_k_padded
=
pad_tensor_view
(
image_m_k
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
sequence
<
false
,
true
>
{});
const
auto
gemm_m_k_padded
=
pad_tensor_view
(
gemm_m_k
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
sequence
<
false
,
true
>
{});
constexpr
auto
dstr
=
MakeBlockTileDistribution
();
const
auto
image_tile
=
make_tile_window
(
image_m_k_padded
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
iM
,
iK
},
dstr
);
auto
gemm_tile
=
make_tile_window
(
gemm_m_k_padded
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
iM
,
iK
},
dstr
);
// load from Global
const
auto
loaded_tile
=
load_tile
(
image_tile
);
// save to Global
store_tile
(
gemm_tile
,
loaded_tile
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
&
kargs
)
const
{
ConvTensorRearrange
(
kargs
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
InDataType_
,
typename
OutDataType_
,
typename
BlockShape_
,
index_t
NDimSpatial_
,
index_t
AligmentIn_
,
index_t
AligmentOut_
>
struct
BlockImageToColumnProblem
{
using
InDataType
=
remove_cvref_t
<
InDataType_
>
;
using
OutDataType
=
remove_cvref_t
<
OutDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
index_t
NDimSpatial
=
NDimSpatial_
;
static
constexpr
index_t
AligmentIn
=
AligmentIn_
;
static
constexpr
index_t
AligmentOut
=
AligmentOut_
;
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp
0 → 100644
View file @
63b152d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ThreadTile
,
// Sequence<...
typename
WarpTile
,
// Sequence<...
typename
BlockTile
>
// Sequence<...
struct
TileImageToColumnShape
{
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerThread
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kKThreadPerWarp
=
kKPerWarp
/
kKPerThread
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMWarpPerBlock
=
kMPerBlock
/
kMPerWarp
;
static
constexpr
index_t
kKWarpPerBlock
=
kKPerBlock
/
kKPerWarp
;
static
constexpr
index_t
kBlockSize
=
warpSize
*
kMWarpPerBlock
*
kKWarpPerBlock
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
63b152d6
...
...
@@ -31,8 +31,14 @@ struct Layernorm2dFwd
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
BlockShape
::
kNPerThread
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
{
...
...
@@ -96,19 +102,25 @@ struct Layernorm2dFwd
sequence
<
2
>>
{});
}
template
<
typename
Dstr
>
CK_TILE_DEVICE
static
constexpr
auto
GetNPerThread
(
Dstr
)
CK_TILE_DEVICE
static
int
GetWelfordMaxCount
(
int
N
)
{
constexpr
auto
nDstrSpan
=
Dstr
::
get_distributed_spans
().
template
at
<
1
>();
using
Lengths
=
decltype
(
nDstrSpan
.
impl_
);
constexpr
ck_tile
::
index_t
kNThreadPerBlock
=
kNPerBlock
/
kNPerThread
;
ck_tile
::
index_t
ret
=
1
;
int
thread_id_n
=
get_thread_id
()
%
kNThreadPerBlock
;
int
max_count
=
__builtin_amdgcn_readfirstlane
(
N
<
kNPerBlock
?
0
:
kNPerThread
*
(
N
/
kNPerBlock
));
int
n_per_block_tail_loop
=
__builtin_amdgcn_readfirstlane
(
N
-
max_count
*
kNThreadPerBlock
);
ck_tile
::
static_for
<
0
,
Lengths
::
size
(),
1
>
{}(
[
&
](
auto
idx
)
{
ret
*=
Lengths
::
template
at
(
idx
);
});
if
(
n_per_block_tail_loop
>
0
)
{
int
thread_max_n
=
(
thread_id_n
+
1
)
*
kNPerThread
;
int
delta
=
thread_max_n
-
n_per_block_tail_loop
;
delta
=
clamp
(
thread_max_n
-
n_per_block_tail_loop
,
0
,
kNPerThread
);
max_count
+=
kNPerThread
-
delta
;
}
return
re
t
;
return
max_coun
t
;
}
template
<
typename
DistributedTensor
>
...
...
@@ -129,42 +141,29 @@ struct Layernorm2dFwd
return
out_dstr_tensor
;
}
template
<
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
,
MeanDataType
*
p_mean
,
InvStdDataType
*
p_invStd
,
const
ComputeDataType
epsilon
,
ck_tile
::
index_t
M
,
template
<
typename
XBlockWindow
,
typename
GammaBlockWindow
,
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
// TODO - Optimize tail loop to reduce move_tile_window()
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
N
,
kNPerBlock
));
const
auto
x_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_x
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
32
>
{},
number
<
1
>
{});
const
auto
gamma_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_gamma
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
32
>
{},
number
<
1
>
{});
const
auto
beta_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_beta
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
32
>
{},
number
<
1
>
{});
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
N
/
kNPerBlock
);
// TODO: padding - handle max_count if N % kNPerBlock != 0
constexpr
auto
NPerThread
=
GetNPerThread
(
xDstr
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
type_convert
<
int
>
(
NPerThread
*
N
/
kNPerBlock
)};
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
...
...
@@ -190,44 +189,14 @@ struct Layernorm2dFwd
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
{
const
auto
mean_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_mean
,
make_tuple
(
M
),
number
<
32
>
{});
auto
mean_block_window
=
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
}
if
constexpr
(
kSaveInvStd
)
{
const
auto
inv_std_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_invStd
,
make_tuple
(
M
),
number
<
32
>
{});
auto
inv_std_block_window
=
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
store_tile
(
inv_std_block_window
,
cast_tile
<
MeanDataType
>
(
inv_std_compute_block_tensor
));
}
// TODO: Extract normalize pipeline
const
auto
y_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_y
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
32
>
{},
number
<
1
>
{});
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
N
-
kNPerBlock
;
ck_tile
::
index_t
stride_to_right_most_window
=
N
%
kNPerBlock
==
0
?
N
-
kNPerBlock
:
N
-
N
%
kNPerBlock
;
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
...
...
@@ -274,16 +243,208 @@ struct Layernorm2dFwd
}
}
template
<
typename
XBlockWindow
,
typename
GammaBlockWindow
,
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
OnePassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
auto
var_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
clear_tile
(
mean_compute_block_tensor
);
clear_tile
(
var_compute_block_tensor
);
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// normalize
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
YDataType
>
(
x_block_tensor
.
get_tile_distribution
());
sweep_tile_span
(
x_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx1
);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_block_tensor
[
j_idx
]);
const
auto
beta
=
type_convert
<
ComputeDataType
>
(
beta_block_tensor
[
j_idx
]);
sweep_tile_span
(
x_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
mean
=
mean_compute_block_tensor
[
i_idx
];
const
auto
inv_std
=
inv_std_compute_block_tensor
[
i_idx
];
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_block_tensor
[
i_j_idx
]);
auto
y
=
(
x
-
mean
)
*
inv_std
*
gamma
+
beta
;
y_block_tensor
(
i_j_idx
)
=
type_convert
<
YDataType
>
(
y
);
});
});
store_tile
(
y_block_window
,
y_block_tensor
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
TwoPassLayernorm2dFwd
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
const
auto
gamma_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
make_tuple
(
kargs
.
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
beta_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_beta
),
make_tuple
(
kargs
.
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
const
auto
y_m_n
=
[
&
]()
{
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
auto
mean_block_window
=
[
&
]()
{
if
constexpr
(
kSaveMean
)
{
const
auto
mean_m
=
[
&
]()
{
const
auto
mean_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
M
),
number
<
1
>
{});
return
pad_tensor_view
(
mean_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
auto
inv_std_block_window
=
[
&
]()
{
if
constexpr
(
kSaveInvStd
)
{
const
auto
inv_std_m
=
[
&
]()
{
const
auto
inv_std_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
InvStdDataType
*>
(
kargs
.
p_invStd
),
make_tuple
(
kargs
.
M
),
number
<
1
>
{});
return
pad_tensor_view
(
inv_std_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
if
(
kargs
.
N
<=
kNPerBlock
)
OnePassLayernorm2dFwd
(
x_block_window
,
gamma_block_window
,
beta_block_window
,
y_block_window
,
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
N
);
else
TwoPassLayernorm2dFwd
(
x_block_window
,
gamma_block_window
,
beta_block_window
,
y_block_window
,
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
M
,
kargs
.
N
);
}
};
...
...
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
View file @
63b152d6
...
...
@@ -14,7 +14,9 @@ template <typename XDataType_,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
>
typename
BlockShape_
,
bool
kPadM_
,
bool
kPadN_
>
struct
BlockLayernorm2dFwdProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
...
...
@@ -25,6 +27,8 @@ struct BlockLayernorm2dFwdProblem
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
}
// namespace ck_tile
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment