Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e941f59f
Commit
e941f59f
authored
Nov 01, 2024
by
Andriy Roshchenko
Browse files
Merge branch gfx950 into andriy/lwpck-2413
parents
fe9d9812
7da48908
Changes
353
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
639 additions
and
454 deletions
+639
-454
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+5
-4
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
...a/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
+12
-15
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+12
-15
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+73
-164
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+18
-9
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+46
-21
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+20
-16
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
...ock_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
+15
-6
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+12
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+17
-16
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+19
-18
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
...le/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
+11
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+14
-12
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+94
-134
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+26
-5
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+5
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
...e/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
+1
-1
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
...ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
+237
-0
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
.../ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
+1
-1
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
View file @
e941f59f
...
...
@@ -26,8 +26,8 @@ struct FmhaFwdSplitKVTilePartitioner
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
*
num_splits
,
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
)
*
num_splits
,
nhead
,
batch_size
);
}
...
...
@@ -42,8 +42,9 @@ struct FmhaFwdSplitKVTilePartitioner
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
blockIdx
.
x
,
num_tile_n1
);
const
auto
[
i_nhead
,
i_split
]
=
f
(
blockIdx
.
y
,
num_splits
);
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp
View file @
e941f59f
...
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKReg
Slice
BlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
...
...
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
...
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
auto
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
block_sync_lds
();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
...
...
@@ -276,7 +273,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
...
...
@@ -297,7 +294,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -322,7 +319,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
...
...
@@ -341,7 +338,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -483,9 +480,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
=
=
kK0
,
"kQKHeaddim should equal
t
o kK0"
);
static_assert
(
kQKHeaddim
>
=
kK0
,
"kQKHeaddim should
be
equal o
r greater than
kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
=
=
kK2
,
"kVHeaddim should equal
t
o kK2"
);
static_assert
(
kVHeaddim
>
=
kK2
,
"kVHeaddim should
be
equal o
r greater than
kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
e941f59f
...
...
@@ -178,13 +178,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
k_lds_ptr
,
Policy
::
template
MakeKLdsWriteBlockDescriptor
<
Problem
>());
auto
k_lds_write_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
k_lds_read_window
=
make_tile_window
(
k_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
k_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeKReg
Slice
BlockDescriptor
<
Problem
>());
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
auto
k_reg_tensor
=
make_static_distributed_tensor
<
KDataType
>
(
Policy
::
template
MakeKRegBlockDescriptor
<
Problem
>());
...
...
@@ -204,16 +204,13 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr
,
Policy
::
template
MakeVLdsWriteBlockDescriptor
<
Problem
>());
auto
v_lds_write_window
=
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
v_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
v_lds_read_window
=
make_tile_window
(
v_lds_write_window
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
kN0
>
{},
number
<
kK2
>
{}),
v_lds_write_window
.
get_window_origin
(),
Policy
::
template
MakeVRegSliceBlockDescriptor
<
Problem
>());
auto
v_reg_tensor
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
Policy
::
template
MakeVRegBlockDescriptor
<
Problem
>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
...
...
@@ -227,7 +224,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr
,
Policy
::
template
MakeShuffledKLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_k_lds_write_window
=
make_tile_window
(
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_k_lds_write
,
make_tuple
(
number
<
kN0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
kt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
kt_lds_ptr
,
Policy
::
template
MakeKTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -257,7 +254,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
block_sync_lds
();
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
auto
v_reg_tensor
=
load_tile
(
v_lds_read_window
);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto
q_dram_window
=
...
...
@@ -275,7 +272,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
auto
q_lds_window
=
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
make_tile_window
(
q_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
q_lds_read_window
=
make_tile_window
(
q_lds_window
.
get_bottom_tensor_view
(),
...
...
@@ -296,7 +293,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr
,
Policy
::
template
MakeShuffledQLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_q_lds_write_window
=
make_tile_window
(
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K0
>
{}),
{
0
,
0
});
shuffled_q_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
QKHeaddim
>
{}),
{
0
,
0
});
auto
qt_lds_read
=
make_tensor_view
<
address_space_enum
::
lds
>
(
qt_lds_ptr
,
Policy
::
template
MakeQTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -321,7 +318,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
do_lds_ptr
,
Policy
::
template
MakeOGradLdsBlockDescriptor
<
Problem
>());
auto
do_lds_window
=
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
make_tile_window
(
do_lds
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
do_lds_read_window
=
make_tile_window
(
do_lds_window
.
get_bottom_tensor_view
(),
...
...
@@ -340,7 +337,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr
,
Policy
::
template
MakeShuffledOGradLdsWriteBlockDescriptor
<
Problem
>());
auto
shuffled_do_lds_write_window
=
make_tile_window
(
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
K2
>
{}),
{
0
,
0
});
shuffled_do_lds_write
,
make_tuple
(
number
<
kM0
>
{},
number
<
k
VHeaddim
>
{}),
{
0
,
0
});
auto
dot_read_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
dot_lds_ptr
,
Policy
::
template
MakeOGradTLdsReadBlockDescriptor
<
Problem
>());
...
...
@@ -482,9 +479,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t
i_total_loops
=
0
;
index_t
seqlen_q_step
=
seqlen_q_start
;
static_assert
(
kQKHeaddim
=
=
kK0
,
"kQKHeaddim should equal
t
o kK0"
);
static_assert
(
kQKHeaddim
>
=
kK0
,
"kQKHeaddim should
be
equal o
r greater than
kK0"
);
static_assert
(
kM0
==
kK1
,
"kM0 should equal to kK1"
);
static_assert
(
kVHeaddim
=
=
kK2
,
"kVHeaddim should equal
t
o kK2"
);
static_assert
(
kVHeaddim
>
=
kK2
,
"kVHeaddim should
be
equal o
r greater than
kK2"
);
static_assert
(
kM0
==
kK3
,
"kM0 should equal to kK3"
);
constexpr
index_t
k4_loops
=
kN0
/
kK4
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
e941f59f
...
...
@@ -5,9 +5,8 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/
pipeline/gemm_pipeline
_problem.hpp"
#include "ck_tile/ops/gemm/
block/block_gemm
_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
QDataType
,
...
...
@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetPTOGradTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimV
,
Problem
::
kPadHeadDimV
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
OGradDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kVHeaddim
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetOGradVBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
OGradDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK2
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm2WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
OGradDataType
,
...
...
@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradTQTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
QDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK3
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm3BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm3WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSGradKTBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadHeadDimQ
,
Problem
::
kPadSeqLenK
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
GemmDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
AccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kQKHeaddim
,
Problem
::
BlockFmhaShape
::
kK4
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm4BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm4WarpTile
>>
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
GemmDataType
,
...
...
@@ -222,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
QDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
QDataType
);
...
...
@@ -241,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
KDataType
);
...
...
@@ -260,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
total_pixels
=
kMNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -280,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using
OGradDataType
=
remove_cvref_t
<
typename
Problem
::
OGradDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kMaxVecLoad
=
16
/
sizeof
(
OGradDataType
);
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
OGradDataType
);
...
...
@@ -341,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -353,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
return
total_pixels
/
GetAlignmentK
<
Problem
>
();
...
...
@@ -364,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
...
...
@@ -402,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -425,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentV
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -448,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -471,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -842,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsWriteBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kKPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetQKBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
k_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
k_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
k_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
k_block_dstr
=
make_static_tile_distribution
(
k_block_dstr_encode
);
return
k_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKRegBlockDescriptor
()
{
...
...
@@ -891,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
...
@@ -916,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVLdsWriteBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kVPack
=
GetSmemKPackV
<
Problem
>
();
return
MakeXLdsBlockDescriptor
<
kNPerBlock
,
kKPerBlock
,
kVPack
>
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegSliceBlockDescriptor
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetOGradVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WarpGemm
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
0
>
{});
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
constexpr
auto
v_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
MWarp
>
,
tuple
<
sequence
<
NIterPerWarp
,
NWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
v_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
v_block_outer_dstr_encoding
,
typename
WarpGemm
::
BWarpDstrEncoding
{});
constexpr
auto
v_block_dstr
=
make_static_tile_distribution
(
v_block_dstr_encode
);
return
v_block_dstr
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeVRegBlockDescriptor
()
{
...
...
@@ -966,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr
index_t
NWarp
=
Problem
::
BlockFmhaShape
::
Gemm2BlockWarps
::
at
(
number
<
1
>
{});
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
NIterPerWarp
=
kNPerBlock
/
(
NWarp
*
WarpGemm
::
kN
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WarpGemm
::
kK
;
...
...
@@ -992,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentK
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1074,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQLdsBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackQ
<
Problem
>
();
...
...
@@ -1118,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
K1
=
GetAlignmentQ
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1281,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
// Hold full block data
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
kKPack
=
GetSmemKPackOGrad
<
Problem
>
();
...
...
@@ -1325,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K2
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
VHeaddim
;
constexpr
index_t
K1
=
GetAlignmentOGrad
<
Problem
>
();
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
...
...
@@ -1885,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static
constexpr
index_t
kN0
=
Problem
::
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kQKHeaddim
=
Problem
::
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kVHeaddim
=
Problem
::
BlockFmhaShape
::
kVHeaddim
;
static
constexpr
index_t
kK0
=
Problem
::
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kK2
=
Problem
::
BlockFmhaShape
::
kK2
;
static
constexpr
index_t
kK4
=
Problem
::
BlockFmhaShape
::
kK4
;
static
constexpr
index_t
WarpGemmM
=
...
...
@@ -1899,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute
static
constexpr
index_t
Gemm0MFMA
=
kM0
*
kN0
*
kQKHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
kM0
*
kN0
*
kK0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm1MFMA
=
kM0
*
kN0
*
kVHeaddim
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kN0
*
kVHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm2MFMA
=
kM0
*
kN0
*
kK2
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
static
constexpr
index_t
Gemm3MFMA
=
kN0
*
kQKHeaddim
*
kM0
/
(
kBlockSize
/
get_warp_size
()
*
WarpGemmM
*
WarpGemmN
*
WarpGemmK
);
...
...
@@ -1929,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0
*
kQKHeaddim
/
get_warp_size
()
/
GetTransposedAlignmentQ
<
Problem
>
();
static
constexpr
index_t
SGradT_LDS_READ_P1
=
kM0
*
kK4
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kQKHeaddim
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
Q_LDS_READ
=
kM0
*
kK0
/
kBlockSize
/
GetAlignmentQ
<
Problem
>
();
static
constexpr
index_t
LSE_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
static
constexpr
index_t
SGradT_LDS_READ_P2
=
kM0
*
(
kN0
-
kK4
)
/
(
get_warp_size
()
*
Gemm4MWarp
)
/
GetSmemKPackSGrad
<
Problem
>
();
static
constexpr
index_t
OGrad_LDS_READ
=
kM0
*
k
VHeaddim
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
kM0
*
k
K2
/
kBlockSize
/
GetAlignmentOGrad
<
Problem
>
();
static
constexpr
index_t
D_LDS_READ
=
WarpGemmM
==
16
?
kM0
/
(
4
*
4
)
:
kM0
/
(
2
*
4
);
// LDS Write
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
e941f59f
...
...
@@ -12,6 +12,16 @@ namespace detail {
template
<
index_t
N
>
struct
log2
;
template
<
>
struct
log2
<
4
>
:
std
::
integral_constant
<
index_t
,
2
>
{
};
template
<
>
struct
log2
<
8
>
:
std
::
integral_constant
<
index_t
,
3
>
{
};
template
<
>
struct
log2
<
16
>
:
std
::
integral_constant
<
index_t
,
4
>
{
...
...
@@ -72,18 +82,18 @@ struct BlockFmhaFwdSplitKVCombinePipeline
{
if
constexpr
(
kHeadDimV
<=
32
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
3
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
3
,
3
,
3
,
3
,
3
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
128
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
3
,
3
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
3
,
3
,
3
,
3
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
else
if
constexpr
(
kHeadDimV
<=
256
)
{
constexpr
std
::
array
<
int
,
4
>
occupancy
{
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
4
];
constexpr
std
::
array
occupancy
{
2
,
2
,
2
,
2
,
2
,
1
};
return
occupancy
[
detail
::
log2
<
kMaxSplits
>::
value
-
2
];
}
}
}();
...
...
@@ -138,9 +148,8 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, max(kMaxSplits, warp_size)])
// this will extend the distributed tensor width so that each thread in wave have data to
// reduce.
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
constexpr
auto
spans
=
decltype
(
lse_accum
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
e941f59f
...
...
@@ -10,11 +10,26 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template
<
index_t
BlockSize
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeForTile
()
{
constexpr
index_t
PixelsPerThread
=
(
M
*
N
)
/
BlockSize
;
static_assert
(
0
<
PixelsPerThread
);
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
constexpr
index_t
NPerThread
=
min
(
MaxNPerThread
,
PixelsPerThread
);
return
NPerThread
;
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
return
16
/
sizeof
(
LSEDataType
);
return
GetVectorSizeForTile
<
Problem
::
kBlockSize
,
Problem
::
kMaxSplits
,
Problem
::
kM0
,
typename
Problem
::
LSEDataType
>
();
}
template
<
typename
Problem
>
...
...
@@ -47,29 +62,31 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
// shape=[kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPerThread
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
NPerThread
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
TotalWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
TotalWarps
*
MThreadsPerWarp
);
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
kNumWarps
*
MThreadsPerWarp
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MPerThread
*
Total
Warps
*
MThreadsPerWarp
==
kMPerBlock
);
static_assert
(
MPerThread
*
kNum
Warps
*
MThreadsPerWarp
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
Total
Warps
,
MThreadsPerWarp
>
,
tuple
<
sequence
<
MPerThread
,
kNum
Warps
,
MThreadsPerWarp
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
...
...
@@ -77,15 +94,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
sequence
<
0
,
1
>>
{});
}
// 3d + padding, [kMaxSplits, kM0]
// 3d + padding,
shape=
[kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsStoreBlockDescriptor
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
...
...
@@ -103,15 +123,18 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return
lse_acc_lds_block_desc
;
}
// 3d + padding, [kM0, kMaxSplits]
// 3d + padding,
shape=
[kM0, kMaxSplits]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccLdsBlockDescriptor
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NPack
=
16
/
sizeof
(
LSEDataType
);
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
...
...
@@ -134,26 +157,28 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
max
(
Problem
::
kMaxSplits
,
get_warp_size
())
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
NThreads
=
get_warp_size
()
;
constexpr
index_t
NThreads
=
4
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
NThreads
;
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
M
Threads
*
MPerThread
==
kMPerBlock
);
static_assert
(
M
Warps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M
Threads
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
>>
,
tuple
<
sequence
<
M
Warps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
sequence
<
2
,
1
>>
{});
}
template
<
typename
Problem
>
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
e941f59f
...
...
@@ -34,12 +34,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
@@ -64,6 +65,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentOacc
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOacc
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
...
...
@@ -72,22 +76,22 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -252,11 +256,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
k_dram_block_window_lengths
,
{
adjusted_seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
adjusted_seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
adjusted_seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
v_dram_block_window_lengths
,
...
...
@@ -267,7 +271,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp
View file @
e941f59f
...
...
@@ -9,11 +9,20 @@
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
using
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
=
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
;
struct
BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
:
BlockFmhaPipelineQXKSVSCustomPolicy
<
/* QLoadOnce = */
true
,
/* AsyncCopyK = */
false
,
/* AsyncCopyV = */
false
,
/* NumPrefetchK = */
1
,
/* NumPrefetchV = */
1
>
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentOacc
()
{
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
return
static_cast
<
index_t
>
(
16
/
sizeof
(
OaccDataType
));
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
e941f59f
...
...
@@ -39,8 +39,11 @@ struct BlockFmhaPipelineProblem
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kNumGemm0Warps
=
BlockFmhaShape
::
NumGemm0Warps
;
static
constexpr
index_t
kNumGemm1Warps
=
BlockFmhaShape
::
NumGemm1Warps
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
...
@@ -84,8 +87,11 @@ struct BlockFmhaFwdSplitKVPipelineProblem
using
FmhaMask
=
remove_cvref_t
<
FmhaMask_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kNumGemm0Warps
=
BlockFmhaShape
::
NumGemm0Warps
;
static
constexpr
index_t
kNumGemm1Warps
=
BlockFmhaShape
::
NumGemm1Warps
;
static
constexpr
index_t
kBlockSize
=
BlockFmhaShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
...
...
@@ -115,7 +121,8 @@ struct BlockFmhaSplitKVCombinePipelineProblem
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
index_t
kBlockSize
=
256
;
static
constexpr
index_t
kNumWarps
=
kM0_
/
(
get_warp_size
()
/
4
);
static
constexpr
index_t
kBlockSize
=
kNumWarps
*
get_warp_size
();
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
index_t
kHeadDimV
=
HeadDimV_
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
View file @
e941f59f
...
...
@@ -37,12 +37,13 @@ struct BlockFmhaPipelineQRKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
@@ -76,22 +77,22 @@ struct BlockFmhaPipelineQRKSVS
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -242,11 +243,11 @@ struct BlockFmhaPipelineQRKSVS
{
seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
...
...
@@ -261,7 +262,7 @@ struct BlockFmhaPipelineQRKSVS
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
e941f59f
...
...
@@ -38,12 +38,13 @@ struct BlockFmhaPipelineQRKSVSAsync
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
...
...
@@ -87,7 +88,7 @@ struct BlockFmhaPipelineQRKSVSAsync
return
1
;
}
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
FmhaMask
::
IsMasking
)
...
...
@@ -95,21 +96,21 @@ struct BlockFmhaPipelineQRKSVSAsync
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
2
;
else
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -314,11 +315,11 @@ struct BlockFmhaPipelineQRKSVSAsync
}();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
randval_dram_window
=
dropout
.
template
MakeRandvalDramWindow
<
decltype
(
gemm_0
)>(
randval_dram_block_window_tmp
,
seqlen_k_start
);
...
...
@@ -334,12 +335,12 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_fence
(
k_dram_window
.
get_num_access
(),
q
.
get_thread_buffer
());
buffer_load_fence
(
k_dram_window
.
get_num_
of_
access
(),
q
.
get_thread_buffer
());
(
void
)
q_element_func
;
// ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
1
<=
k0_loops
);
...
...
@@ -359,7 +360,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
async_load_fence
(
k_dram_window
.
get_num_access
());
async_load_fence
(
k_dram_window
.
get_num_
of_
access
());
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_sched_barrier
(
0
);
gemm_0
(
s_acc
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_fp8.hpp
View file @
e941f59f
...
...
@@ -36,12 +36,12 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
k
K0BlockLength
=
BlockFmhaShape
::
k
K0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
k
QKHeaddim
=
BlockFmhaShape
::
k
QKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
@@ -75,22 +75,22 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -232,7 +232,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
e941f59f
...
...
@@ -9,9 +9,10 @@
namespace
ck_tile
{
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
struct
BlockFmhaPipelineQSKSVS
struct
[[
deprecated
]]
BlockFmhaPipelineQSKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
...
@@ -35,12 +36,13 @@ struct BlockFmhaPipelineQSKSVS
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
...
...
@@ -55,22 +57,22 @@ struct BlockFmhaPipelineQSKSVS
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
k
K0BlockLength
<=
32
)
if
constexpr
(
k
QKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
64
)
else
if
constexpr
(
k
QKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
k
K0BlockLength
<=
128
)
else
if
constexpr
(
k
QKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
k
K0BlockLength
<=
256
)
else
if
constexpr
(
k
QKHeaddim
<=
256
)
{
return
1
;
}
...
...
@@ -234,7 +236,7 @@ struct BlockFmhaPipelineQSKSVS
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
e941f59f
...
...
@@ -5,9 +5,9 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...
...
@@ -15,6 +15,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
// TODO: remove this
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
...
...
@@ -54,7 +55,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
K0BlockLength
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
k
SubQKHeaddim
;
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
...
...
@@ -64,51 +65,72 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
else
{
static_assert
(
MWarp
==
1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
,
K2
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
1
,
2
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kNumGemm0Warps
*
get_warp_size
(),
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
constexpr
index_t
WarpGemmM
=
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static_assert
(
WarpGemmM
==
16
||
WarpGemmM
==
32
);
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
half_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
// WarpGemmM == 16
return
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
bf16_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
if
constexpr
(
WarpGemmM
==
32
)
return
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution
{};
else
// WarpGemmM == 16
return
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
{};
}
else
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
std
::
is_same_v
<
typename
Problem
::
SaccDataType
,
float
>
)
{
static_assert
(
WarpGemmM
==
32
);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr
index_t
swizzle_factor
=
4
;
return
WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution
<
...
...
@@ -123,12 +145,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
decltype
(
warp_gemm
)
>
;
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
if
constexpr
(
1
<
Problem
::
kNumGemm0Warps
)
return
BlockGemmARegBSmemCRegV2
<
GemmProblem
,
BlockGemmPolicy
>
{};
else
return
BlockGemmARegBSmemCRegOneWarpV1
<
GemmProblem
,
BlockGemmPolicy
>
{};
}
};
/// NOTICE: we no-longer use this policy.
template
<
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
struct
[[
deprecated
]]
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
{
static
constexpr
bool
QLoadOnce
=
false
;
...
...
@@ -207,20 +233,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetQKBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm0WarpTile
>>
;
constexpr
auto
warp_gemm
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
QDataType
,
half_t
>
&&
...
...
@@ -302,6 +323,9 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
3
>
{
using
type
=
sequence
<
1
,
2
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
3
,
4
>
{
using
type
=
sequence
<
1
,
2
,
0
,
0
,
1
,
2
,
0
>
;
};
template
<
>
struct
LdsBufferSequence
<
3
,
3
,
2
,
2
>
{
using
type
=
sequence
<
1
,
2
,
1
,
0
>
;};
// clang-format on
...
...
@@ -311,12 +335,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
k
K0BlockLength
=
BlockFmhaShape
::
k
K0BlockLength
;
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
constexpr
index_t
k
QKHeaddim
=
BlockFmhaShape
::
k
QKHeaddim
;
constexpr
index_t
k0_loops
=
k
K0BlockLength
/
kK0
;
constexpr
index_t
k0_loops
=
k
QKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
return
typename
LdsBufferSequence
<
NumPrefetchK
,
NumPrefetchV
,
k0_loops
,
k1_loops
>::
type
{};
...
...
@@ -363,12 +387,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
constexpr
index_t
kMaxVecLoad
=
min
(
total_pixels
,
static_cast
<
index_t
>
(
16
/
sizeof
(
VDataType
)));
constexpr
index_t
kMinVecLoad
=
4
/
sizeof
(
VDataType
);
// TODO: not correct!
if
constexpr
(
total_pixels
>
4
)
return
4
;
else
return
2
;
constexpr
index_t
kVecLoad
=
((
total_pixels
/
kMaxVecLoad
)
>=
kMinVecLoad
)
?
kMaxVecLoad
:
(
total_pixels
/
kMinVecLoad
)
;
return
kVecLoad
;
}
else
{
...
...
@@ -382,10 +409,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using
BlockGemm
=
remove_cvref_t
<
decltype
(
QXPolicy
::
template
GetQKBlockGemm
<
Problem
>())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
}
template
<
typename
Problem
>
...
...
@@ -394,10 +419,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetKVBlockGemm
<
Problem
>
())
>
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
constexpr
auto
vec
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
CWarpDstr
::
NDimY
-
1
>
{});
return
vec
;
return
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
;
}
template
<
typename
Problem
>
...
...
@@ -448,44 +471,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
return
max
(
SingleKSize
,
SingleVSize
);
}
template
<
typename
Problem
,
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeQRegBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
auto
q_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
q_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
q_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
q_block_dstr
=
make_static_tile_distribution
(
q_block_dstr_encode
);
return
q_block_dstr
;
}
// TODO: this is used for non async copy desc. unify in the future
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK
1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK
0
;
constexpr
index_t
kKPack
=
GetSmemKPackK
<
Problem
>
();
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
...
...
@@ -885,36 +876,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
}
}
template
<
typename
Problem
,
typename
BlockGemm
>
template
<
typename
BlockGemm
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBiasDramTileDistribution
()
{
constexpr
index_t
MPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// Construct C-Block-HostTensor
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
NIterPerWarp
,
NWarp
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
return
BlockGemm
::
MakeCBlockTile
().
get_tile_distribution
();
}
template
<
typename
Problem
>
...
...
@@ -968,20 +933,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetKVBlockGemm
()
{
using
GemmProblem
=
GemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>
,
TileGemmTraits
<
Problem
::
kPadSeqLenQ
,
Problem
::
kPadSeqLenK
,
Problem
::
kPadHeadDimQ
,
typename
tensor_layout
::
gemm
::
RowMajor
,
typename
tensor_layout
::
gemm
::
ColumnMajor
,
typename
tensor_layout
::
gemm
::
RowMajor
>>
;
BlockGemmProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kNumGemm1Warps
*
get_warp_size
(),
TileGemmShape
<
sequence
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
typename
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
>>
;
auto
warp_gemm
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
typename
Problem
::
KDataType
,
fp8_t
>
&&
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
e941f59f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -7,6 +7,20 @@
namespace
ck_tile
{
static
CK_TILE_HOST_DEVICE
constexpr
index_t
ceil_to_qualified_tile_length
(
index_t
len
)
{
if
(
len
==
96
)
return
128
;
if
(
len
==
160
)
return
256
;
// only length of 96, 160 and power-of-two is supported
if
(
!
(
len
&
(
len
-
1
)))
return
len
;
return
0
;
};
template
<
typename
BlockTile_
,
// sequence<...
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
...
...
@@ -21,20 +35,27 @@ struct TileFmhaShape
using
Gemm1BlockWarps
=
remove_cvref_t
<
Gemm1BlockWarps_
>
;
using
Gemm1WarpTile
=
remove_cvref_t
<
Gemm1WarpTile_
>
;
static
constexpr
index_t
NumWarps
=
static
constexpr
index_t
Num
Gemm0
Warps
=
reduce_on_sequence
(
Gemm0BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static
constexpr
index_t
NumGemm1Warps
=
reduce_on_sequence
(
Gemm1BlockWarps
{},
multiplies
{},
number
<
1
>
{});
static_assert
(
NumGemm1Warps
%
NumGemm0Warps
==
0
);
static
constexpr
index_t
NumWarps
=
max
(
NumGemm0Warps
,
NumGemm1Warps
);
static_assert
(
NumWarps
==
reduce_on_sequence
(
Gemm1BlockWarps
{},
multiplies
{},
number
<
1
>
{})
);
static_assert
(
std
::
is_same_v
<
Gemm0WarpTile
,
Gemm1WarpTile
>
);
static
constexpr
index_t
kM0
=
BlockTile
::
at
(
number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
at
(
number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kK0
=
BlockTile
::
at
(
number
<
2
>
{});
// tile size along qk gemm unroll
static
constexpr
index_t
kN1
=
BlockTile
::
at
(
number
<
3
>
{});
// tile size along v head_dim
static
constexpr
index_t
kK1
=
BlockTile
::
at
(
number
<
4
>
{});
// tile size along kv gemm unroll
static
constexpr
index_t
k
K0BlockLength
=
static
constexpr
index_t
k
QKHeaddim
=
BlockTile
::
at
(
number
<
5
>
{});
// total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
static_assert
(
kK0BlockLength
%
kK0
==
0
,
"kK0BlockLength should be divisible by kK0"
);
static_assert
(
kQKHeaddim
%
kK0
==
0
,
"kQKHeaddim should be divisible by kK0"
);
static
constexpr
index_t
kSubQKHeaddim
=
ceil_to_qualified_tile_length
(
kQKHeaddim
);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static
constexpr
bool
IsVLayoutRowMajor
=
IsVLayoutRowMajor_
;
...
...
include/ck_tile/ops/gemm.hpp
View file @
e941f59f
...
...
@@ -8,6 +8,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
...
...
@@ -23,11 +24,14 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
...
...
@@ -35,4 +39,5 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp
View file @
e941f59f
...
...
@@ -32,7 +32,7 @@ struct BlockGemmARegBGmemCRegV1
BlockGemmProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBGmemCRegV1DefaultPolicy
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
sizeof
(
BDataType
)
*
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>().
get_element_space_size
();
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp
View file @
e941f59f
...
...
@@ -157,7 +157,7 @@ struct BlockGemmARegBRegCRegV1
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp
0 → 100644
View file @
e941f59f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace
ck_tile
{
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem_
,
typename
Policy_
=
BlockGemmARegBSmemCRegV1DefaultPolicy
>
struct
BlockGemmARegBSmemCRegOneWarpV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static_assert
(
kBlockSize
==
get_warp_size
(),
"Check failed!"
);
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensorTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK,
// "wrong!");
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static_assert
(
MWarp
==
1
&&
NWarp
==
1
,
"Check failed!"
);
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
constexpr
index_t
KIterPerWarp
=
KPerBlock
/
WG
::
kK
;
constexpr
index_t
NPerBlockPerIter
=
NPerBlock
/
NIterPerWarp
;
constexpr
index_t
KPerBlockPerIter
=
KPerBlock
/
KIterPerWarp
;
const
index_t
iNWarp
=
0
;
constexpr
auto
a_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<
NWarp
>
,
tuple
<
sequence
<
MIterPerWarp
,
MWarp
>
,
sequence
<
KIterPerWarp
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
>
,
sequence
<
NIterPerWarp
>>
,
tuple
<>
,
tuple
<>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
a_block_dstr
=
make_static_tile_distribution
(
a_block_dstr_encode
);
// constrcut from A-block-tensor from A-Block-tensor-tmp
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
// distribution
auto
a_block_tensor
=
make_static_distributed_tensor
<
typename
ABlockTensorTmp
::
DataType
>
(
a_block_dstr
);
a_block_tensor
.
get_thread_buffer
()
=
a_block_tensor_tmp
.
get_thread_buffer
();
// construct B-warp-window
auto
b_warp_window_tmp
=
make_tile_window
(
b_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
WG
::
kN
>
{},
number
<
WG
::
kK
>
{}),
b_block_window_tmp
.
get_window_origin
()
+
multi_index
<
2
>
{
iNWarp
*
WG
::
kN
,
0
},
make_static_tile_distribution
(
typename
WG
::
BWarpDstrEncoding
{}));
#if 0 // FIXME: using array will cause register spill
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
{b_warp_window_tmp}};
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
{
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
{
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
}
}
#else
statically_indexed_array
<
statically_indexed_array
<
decltype
(
b_warp_window_tmp
),
KIterPerWarp
>
,
NIterPerWarp
>
b_warp_windows
;
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
b_warp_windows
(
nIter
)(
kIter
)
=
b_warp_window_tmp
;
move_tile_window
(
b_warp_windows
(
nIter
)(
kIter
),
{
nIter
*
NPerBlockPerIter
,
kIter
*
KPerBlockPerIter
});
});
});
#endif
// check C-block-distribution
static_assert
(
std
::
is_same_v
<
remove_cvref_t
<
decltype
(
c_block_dstr_encode
)
>
,
remove_cvref_t
<
decltype
(
CBlockTensor
::
get_tile_distribution
()
.
get_static_tile_distribution_encoding
())
>>
,
"wrong!"
);
using
AWarpDstr
=
typename
WG
::
AWarpDstr
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
AWarpTensor
=
typename
WG
::
AWarpTensor
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
constexpr
auto
a_warp_y_lengths
=
to_sequence
(
AWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
a_warp_y_index_zeros
=
uniform_sequence_gen_t
<
AWarpDstr
::
NDimY
,
0
>
{};
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
// hot loop:
static_for
<
0
,
KIterPerWarp
,
1
>
{}([
&
](
auto
kIter
)
{
static_for
<
0
,
MIterPerWarp
,
1
>
{}([
&
](
auto
mIter
)
{
// read A warp tensor from A block tensor
AWarpTensor
a_warp_tensor
;
a_warp_tensor
.
get_thread_buffer
()
=
a_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
kIter
>
{},
a_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
a_warp_y_lengths
));
static_for
<
0
,
NIterPerWarp
,
1
>
{}([
&
](
auto
nIter
)
{
// read B warp tensor from B Block window
const
auto
b_warp_tensor
=
load_tile
(
b_warp_windows
(
nIter
)(
kIter
));
// read C warp tensor from C block tensor
CWarpTensor
c_warp_tensor
;
c_warp_tensor
.
get_thread_buffer
()
=
c_block_tensor
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
// warp GEMM
WG
{}(
c_warp_tensor
,
a_warp_tensor
,
b_warp_tensor
);
// write C warp tensor into C block tensor
c_block_tensor
.
set_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
),
c_warp_tensor
.
get_thread_buffer
());
});
});
});
}
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
constexpr
auto
config
=
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
at
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
at
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
at
<
2
>();
static_assert
(
MWarp
==
1
&&
NWarp
==
1
,
"Check failed!"
);
constexpr
index_t
MIterPerWarp
=
MPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
NIterPerWarp
=
NPerBlock
/
(
NWarp
*
WG
::
kN
);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
MIterPerWarp
>
,
sequence
<
NIterPerWarp
>>
,
tuple
<>
,
tuple
<>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
static_assert
(
decltype
(
c_block_dstr_encode
)
::
NDimP
==
1
,
"Check failed!"
);
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
// C = A * B
template
<
typename
ABlockTensorTmp
,
typename
BBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ABlockTensorTmp
&
a_block_tensor_tmp
,
const
BBlockWindowTmp
&
b_block_window_tmp
)
const
{
auto
c_block_tensor
=
MakeCBlockTile
();
operator
()(
c_block_tensor
,
a_block_tensor_tmp
,
b_block_window_tmp
);
return
c_block_tensor
;
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp
View file @
e941f59f
...
...
@@ -181,7 +181,7 @@ struct BlockGemmARegBSmemCRegV1
});
}
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
const
CK_TILE_DEVICE
static
constexpr
auto
MakeCBlockTile
()
{
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
...
...
Prev
1
…
7
8
9
10
11
12
13
14
15
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment