Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e71aa1d6
Unverified
Commit
e71aa1d6
authored
Nov 03, 2023
by
carlushuang
Committed by
GitHub
Nov 03, 2023
Browse files
unify q persistent in register (#24)
* unify q persistent in register * add refactor warp_gemm dispatcher
parent
02d69525
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
928 additions
and
31 deletions
+928
-31
example/91_tile_program/fmha_fwd.cpp
example/91_tile_program/fmha_fwd.cpp
+12
-4
example/91_tile_program/fmha_fwd_kernel.hpp
example/91_tile_program/fmha_fwd_kernel.hpp
+11
-5
include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp
...tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp
+6
-8
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp
...tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp
+3
-1
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
...lock_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
+49
-0
include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp
...ile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp
+3
-1
include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
...ock_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
+59
-0
include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp
..._program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp
+2
-1
include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp
...gram/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+347
-0
include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
..._pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
+280
-0
include/ck/tile_program/tile/tile_fmha_shape.hpp
include/ck/tile_program/tile/tile_fmha_shape.hpp
+19
-11
include/ck/tile_program/warp_tile/warp_gemm.hpp
include/ck/tile_program/warp_tile/warp_gemm.hpp
+6
-0
include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp
...de/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp
+84
-0
include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp
include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp
+47
-0
No files found.
example/91_tile_program/fmha_fwd.cpp
View file @
e71aa1d6
#include <cstring>
#include <ostream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
...
@@ -15,6 +16,8 @@
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs_default_policy.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_problem.hpp"
#include "ck/tile_program/tile/tile_fmha_shape.hpp"
...
...
@@ -33,10 +36,14 @@ using PDataType = ck::half_t; // data type for A matrix of second gemm
using
OaccDataType
=
float
;
// data type for second gemm accumulation
using
ODataType
=
ck
::
half_t
;
// M0 N0 K0 N1 K1
// M0 N0 K0 N1 K1
K0L
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 64, 64, 128, 64>;
// using FmhaShape = ck::tile_program::TileFmhaShape<128, 256, 32, 128, 32>;
using
FmhaShape
=
ck
::
tile_program
::
TileFmhaShape
<
128
,
128
,
32
,
128
,
32
>
;
using
FmhaBlockTile
=
ck
::
Sequence
<
128
,
128
,
32
,
128
,
32
,
128
>
;
using
FmhaBlockWarps
=
ck
::
Sequence
<
4
,
1
,
1
>
;
using
FmhaWarpTile
=
ck
::
Sequence
<
32
,
32
,
16
>
;
using
FmhaShape
=
ck
::
tile_program
::
TileFmhaShape
<
FmhaBlockTile
,
FmhaBlockWarps
,
FmhaWarpTile
,
FmhaBlockWarps
,
FmhaWarpTile
>
;
using
FmhaTilePartitioner
=
FmhaFwdTilePartitioner
<
FmhaShape
>
;
using
FmhaPipelineProblem
=
ck
::
tile_program
::
block
::
BlockFmhaPipelineProblem
<
QDataType
,
...
...
@@ -49,7 +56,8 @@ using FmhaPipelineProblem = ck::tile_program::block::BlockFmhaPipelineProblem<QD
ODataType
,
256
,
// BlockSize
FmhaShape
>
;
using
FmhaPipeline
=
ck
::
tile_program
::
block
::
BlockFmhaPipelineQKVS
<
FmhaPipelineProblem
>
;
// using FmhaPipeline = ck::tile_program::block::BlockFmhaPipelineQKVS<FmhaPipelineProblem>;
using
FmhaPipeline
=
ck
::
tile_program
::
block
::
BlockFmhaPipelineQRKSVS
<
FmhaPipelineProblem
>
;
using
FmhaEpilogue
=
FmhaFwdEpilogue
<
FmhaFwdEpilogueProblem
<
OaccDataType
,
ODataType
>>
;
using
FmhaKernel
=
FmhaFwdKernel
<
FmhaTilePartitioner
,
FmhaPipeline
,
FmhaEpilogue
>
;
...
...
@@ -134,7 +142,7 @@ int main(int argc, char* argv[])
<<
", seqlen_k:"
<<
seqlen_k
<<
", hdim_q:"
<<
hdim_q
<<
", hdim_v:"
<<
hdim_v
<<
", scale:"
<<
scale
<<
", i_perm:"
<<
i_perm
<<
", o_perm:"
<<
o_perm
<<
", grid_size "
<<
kGridSize
.
x
<<
"x"
<<
kGridSize
.
y
<<
"x"
<<
kGridSize
.
z
<<
std
::
endl
;
<<
std
::
flush
<<
std
::
endl
;
constexpr
ck
::
index_t
kWarpPerCu
=
8
;
// 2 warps per SIMD
constexpr
ck
::
index_t
kWarpPerBlock
=
kBlockSize
.
x
/
warpSize
;
...
...
example/91_tile_program/fmha_fwd_kernel.hpp
View file @
e71aa1d6
...
...
@@ -11,7 +11,7 @@
// P[seqlen_q, seqlen_k] = Softmax(S[seqlen_q, seqlen_k])
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] * V[hdim_v, seqlen_k]
#define C_LOG2E
1.44269504088896340736
// log2(e)
#define C_LOG2E 1.44269504088896340736 // log2(e)
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdKernel
...
...
@@ -148,10 +148,16 @@ struct FmhaFwdKernel
Number
<
32
>
{},
Number
<
1
>
{});
auto
q_dram_window
=
make_tile_window
(
q_dram
,
make_tuple
(
Number
<
FmhaPipeline
::
kM0
>
{},
Number
<
FmhaPipeline
::
kK0
>
{}),
{
i_m0
,
0
});
auto
q_dram_window
=
make_tile_window
(
q_dram
,
[
&
]()
{
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
return
make_tuple
(
Number
<
FmhaPipeline
::
kM0
>
{},
Number
<
FmhaPipeline
::
kK0BlockLength
>
{});
else
return
make_tuple
(
Number
<
FmhaPipeline
::
kM0
>
{},
Number
<
FmhaPipeline
::
kK0
>
{});
}(),
{
i_m0
,
0
});
auto
k_dram_window
=
make_tile_window
(
k_dram
,
make_tuple
(
Number
<
FmhaPipeline
::
kN0
>
{},
Number
<
FmhaPipeline
::
kK0
>
{}),
{
0
,
0
});
...
...
include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp
View file @
e71aa1d6
...
...
@@ -26,9 +26,11 @@ namespace block {
// This will:
// 1. Load B from global memory into shared memory and then
// 2. Call BlockGemmARegSGmemCRegV1
template
<
typename
Problem
,
typename
Policy
=
BlockGemmARegBGmemCRegV1DefaultPolicy
>
template
<
typename
Problem
_
,
typename
Policy
_
=
BlockGemmARegBGmemCRegV1DefaultPolicy
>
struct
BlockGemmARegBGmemCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
...
...
@@ -37,13 +39,9 @@ struct BlockGemmARegBGmemCRegV1
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using
BlockGemmARegBSmemCRegImpl
=
BlockGemmARegBSmemCRegV1
<
BlockGemmARegBSmemCRegProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
using
BlockGemmARegBSmemCRegImpl
=
BlockGemmARegBSmemCRegV1
<
BlockGemmARegBSmemCRegProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
__host__
__device__
static
constexpr
ck
::
index_t
GetStaticLdsSize
()
{
...
...
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp
View file @
e71aa1d6
...
...
@@ -23,9 +23,11 @@ namespace block {
// A is block distributed tensor
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem
,
typename
Policy
=
BlockGemmARegBSmemCRegV1DefaultPolicy
>
template
<
typename
Problem
_
,
typename
Policy
_
=
BlockGemmARegBSmemCRegV1DefaultPolicy
>
struct
BlockGemmARegBSmemCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
...
...
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp
0 → 100644
View file @
e71aa1d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
template
<
typename
AType_
,
typename
BType_
,
typename
CType_
,
typename
BlockWarps_
,
typename
WarpGemm_
>
struct
BlockGemmARegBSmemCRegV1CustomPolicy
{
using
AType
=
remove_cvref_t
<
AType_
>
;
using
BType
=
remove_cvref_t
<
BType_
>
;
using
CType
=
remove_cvref_t
<
CType_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
static
constexpr
index_t
kMWarps
=
BlockWarps
::
At
(
Number
<
0
>
{});
static
constexpr
index_t
kNWarps
=
BlockWarps
::
At
(
Number
<
1
>
{});
static
constexpr
index_t
kKWarps
=
BlockWarps
::
At
(
Number
<
2
>
{});
using
WarpGemm
=
remove_cvref_t
<
WarpGemm_
>
;
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
using
namespace
ck
::
tile_program
::
warp
;
return
make_tuple
(
WarpGemm
{},
kMWarps
,
kNWarps
);
}
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
e71aa1d6
...
...
@@ -24,9 +24,11 @@ namespace block {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template
<
typename
Problem
,
typename
Policy
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
>
template
<
typename
Problem
_
,
typename
Policy
_
=
BlockGemmASmemBSmemCRegV1DefaultPolicy
>
struct
BlockGemmASmemBSmemCRegV1
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
...
...
include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp
0 → 100644
View file @
e71aa1d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
// Default policy for BlockGemmASmemBSmemCRegV1
// Default policy class should not be templated, put template on member functions instead
template
<
typename
AType_
,
typename
BType_
,
typename
CType_
,
typename
BlockWarps_
,
typename
WarpTile_
,
bool
TranposeC_
>
struct
BlockGemmASmemBSmemCRegV1CustomPolicy
{
using
AType
=
remove_cvref_t
<
AType_
>
;
using
BType
=
remove_cvref_t
<
BType_
>
;
using
CType
=
remove_cvref_t
<
CType_
>
;
using
BlockWarps
=
remove_cvref_t
<
BlockWarps_
>
;
using
WarpTile
=
remove_cvref_t
<
WarpTile_
>
;
static
constexpr
index_t
BlockMWarps
=
BlockWarps
::
At
(
Number
<
0
>
{});
static
constexpr
index_t
BlockNWarps
=
BlockWarps
::
At
(
Number
<
1
>
{});
static
constexpr
index_t
BlockKWarps
=
BlockWarps
::
At
(
Number
<
2
>
{});
static
constexpr
index_t
MPerWarp
=
WarpTile
::
At
(
Number
<
0
>
{});
static
constexpr
index_t
NPerWarp
=
WarpTile
::
At
(
Number
<
1
>
{});
static
constexpr
index_t
KPerWarp
=
WarpTile
::
At
(
Number
<
2
>
{});
static
constexpr
bool
TranposeC
=
TranposeC_
;
using
WarpGemm
=
ck
::
tile_program
::
warp
::
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWarp
,
NPerWarp
,
KPerWarp
,
TranposeC
>
;
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
GetWarpGemmMWarpNWarp
()
{
using
namespace
ck
::
tile_program
::
warp
;
return
make_tuple
(
WarpGemm
{},
BlockMWarps
,
BlockNWarps
);
}
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qkvs.hpp
View file @
e71aa1d6
...
...
@@ -35,7 +35,8 @@ struct BlockFmhaPipelineQKVS
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
static
constexpr
bool
kQLoadOnce
=
false
;
// if q load whole block length (hdim) at once
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
...
...
include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs.hpp
0 → 100644
View file @
e71aa1d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/load_tile.hpp"
#include "ck/tile_program/tile/store_tile.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/slice_tile.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
// This pipeline is qkv all located in LDS
template
<
typename
Problem
,
typename
Policy
=
BlockFmhaPipelineQRKSVSDefaultPolicy
>
struct
BlockFmhaPipelineQRKSVS
{
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q load whole block length (hdim) at once
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
__host__
__device__
static
constexpr
ck
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
>
__host__
__device__
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VElementFunction
&
v_element_func
,
float
scale
,
index_t
num_total_loop
,
index_t
/*num_sub_loop_qk*/
,
// in this pipeline, the 1st gemm loop must be static
void
*
smem_ptr
)
const
{
static_assert
(
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KDramBlockWindowTmp
::
DataType
>>
&&
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VDramBlockWindowTmp
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
GetWindowLengths
()[
Number
<
0
>
{}]
&&
kN0
==
KDramBlockWindowTmp
{}.
GetWindowLengths
()[
Number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowTmp
{}.
GetWindowLengths
()[
Number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowTmp
{}.
GetWindowLengths
()[
Number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowTmp
{}.
GetWindowLengths
()[
Number
<
1
>
{}],
"wrong!"
);
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeQ
<
Problem
>()));
auto
k_lds
=
make_tensor_view
<
AddressSpaceEnum
::
Lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
Number
<
kN0
>
{},
Number
<
kK0
>
{}),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
make_tuple
(
Number
<
kN1
>
{},
Number
<
kK1
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
GetBottomTensorView
(),
q_dram_block_window_tmp
.
GetWindowLengths
(),
q_dram_block_window_tmp
.
GetWindowOrigin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
auto
q
=
load_tile
(
q_dram_window
);
// persistent q register tile
auto
s_acc
=
decltype
(
gemm_0
(
get_slice_tile
(
tile_elementwise_in
(
q_element_func
,
q
),
Sequence
<
0
,
0
>
{},
Sequence
<
kM0
,
kK0
>
{}),
k_lds_window
)){};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
tile_elementwise_in
(
type_convert
<
SMPLComputeDataType
,
SaccDataType
>
,
s_acc
));
using
PBlockTileType
=
decltype
(
tile_elementwise_in
(
type_convert
<
PDataType
,
SaccDataType
>
,
s_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
Sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm_1
(
get_slice_tile
(
PBlockTileType
{},
Sequence
<
0
,
0
>
{},
Sequence
<
kM0
,
kK1
>
{}),
v_lds_window
));
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
tile_elementwise_inout
([](
auto
&
e
)
{
e
=
0
;
},
o_acc
);
tile_elementwise_inout
([](
auto
&
e
)
{
e
=
NumericLimits
<
SMPLComputeDataType
>::
Lowest
();
},
m
);
tile_elementwise_inout
([](
auto
&
e
)
{
e
=
0
;
},
l
);
auto
k_dram_block_window
=
k_dram_block_window_tmp
;
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
GetBottomTensorView
(),
v_dram_block_window_tmp
.
GetWindowLengths
(),
v_dram_block_window_tmp
.
GetWindowOrigin
(),
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
index_t
i_total_loops
=
0
;
do
{
// STAGE 1, QK gemm
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
.
GetBottomTensorView
(),
k_dram_block_window
.
GetWindowLengths
(),
k_dram_block_window
.
GetWindowOrigin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
auto
k_block_tile
=
load_tile
(
k_dram_window
);
{
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
s_acc
);
// Initialize C
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
// LDS write 0
k_block_tile
=
load_tile
(
k_dram_window
);
// global read 1
}
// index_t i_k0_loops = num_sub_loop_qk - 2;
constexpr
index_t
k0_loops
=
kK0BlockLength
/
kK0
;
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
Sequence
<
0
,
i_k0
*
kK0
>
{},
Sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
// LDS write i + 1
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
Sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
Sequence
<
kM0
,
(
k0_loops
-
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
Sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
Sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window
);
}
// STAGE 2, scale softmax
tile_elementwise_inout
([
&
scale
](
auto
&
x
)
{
x
=
x
*
scale
;
},
s_acc
);
const
auto
s
=
tile_elementwise_in
(
type_convert
<
SMPLComputeDataType
,
SaccDataType
>
,
s_acc
);
// S{j}
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s
,
Sequence
<
1
>
{},
f_max
,
NumericLimits
<
SMPLComputeDataType
>::
Lowest
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
);
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s
.
GetTileDistribution
());
// Pcompute{j}
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
GetDistributedSpans
();
sweep_tile_span
(
p_spans
[
Number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
p_spans
[
Number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
p_compute
(
i_j_idx
)
=
math
::
exp
(
s
[
i_j_idx
]
-
m
[
i_idx
]);
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
Sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
);
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
GetDistributedSpans
();
sweep_tile_span
(
o_spans
[
Number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
math
::
exp
(
m_old
[
i_idx
]
-
m
[
i_idx
]);
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
Number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
block_sync_lds
();
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
const
auto
p
=
tile_elementwise_in
(
type_convert
<
PDataType
,
SMPLComputeDataType
>
,
p_compute
);
// STAGE 3, KV gemm
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window
);
// load next v
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
Sequence
<
0
,
i_k1
*
kK1
>
{},
Sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
move_tile_window
(
v_dram_window
,
{
0
,
kK1
});
});
}
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
Sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
Sequence
<
kM0
,
kN0
>
{}),
v_lds_window
);
block_sync_lds
();
}
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
i_total_loops
++
;
}
while
(
i_total_loops
<
num_total_loop
);
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
GetDistributedSpans
();
sweep_tile_span
(
o_spans
[
Number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
1
/
l
[
i_idx
];
sweep_tile_span
(
o_spans
[
Number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
>
__host__
__device__
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
float
scale
,
index_t
num_total_loop
,
index_t
num_sub_loop_qk
,
void
*
smem_ptr
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
[](
const
QDataType
&
x
)
{
return
x
;
},
k_dram_block_window_tmp
,
[](
const
KDataType
&
x
)
{
return
x
;
},
v_dram_block_window_tmp
,
[](
const
VDataType
&
x
)
{
return
x
;
},
scale
,
num_total_loop
,
num_sub_loop_qk
,
smem_ptr
);
}
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/block_tile_pipeline/block_fmha_pipeline_qr_ks_vs_default_policy.hpp
0 → 100644
View file @
e71aa1d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
// This pipeline is qkv all located in LDS
struct
BlockFmhaPipelineQRKSVSDefaultPolicy
{
template
<
typename
Problem
,
typename
BlockGemm
>
__host__
__device__
static
constexpr
auto
MakeQRegBlockDescriptor
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
At
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
At
<
1
>();
constexpr
index_t
NWarp
=
config
.
template
At
<
2
>();
constexpr
index_t
MIterPerWarp
=
kMPerBlock
/
(
MWarp
*
WG
::
kM
);
constexpr
index_t
KIterPerWarp
=
kKPerBlock
/
WG
::
kK
;
constexpr
auto
q_block_outer_dstr_encoding
=
StaticTileDistributionEncoding
<
Sequence
<
NWarp
>
,
Tuple
<
Sequence
<
MIterPerWarp
,
MWarp
>
,
Sequence
<
KIterPerWarp
>>
,
Tuple
<
Sequence
<
1
,
0
>>
,
Tuple
<
Sequence
<
1
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
0
,
0
>>
{};
constexpr
auto
q_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
q_block_outer_dstr_encoding
,
typename
WG
::
AWarpDstrEncoding
{});
constexpr
auto
q_block_dstr
=
make_static_tile_distribution
(
q_block_dstr_encode
);
return
q_block_dstr
;
}
// 3d + padding
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
MakeKLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
auto
k_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
kKPerBlock
/
8
>
{},
Number
<
kNPerBlock
>
{},
Number
<
8
>
{}),
make_tuple
(
Number
<
(
kNPerBlock
+
1
)
*
8
>
{},
Number
<
8
>
{},
Number
<
1
>
{}),
Number
<
8
>
{},
Number
<
1
>
{});
constexpr
auto
k_lds_block_desc
=
transform_tensor_descriptor
(
k_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kNPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
k_lds_block_desc
;
}
// 3d + padding
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
MakeVLdsBlockDescriptor
()
{
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
kPad
=
1
;
constexpr
index_t
kK1
=
8
;
constexpr
auto
v_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
kKPerBlock
/
kK1
>
{},
Number
<
kNPerBlock
>
{},
Number
<
kK1
>
{}),
make_tuple
(
Number
<
(
kNPerBlock
+
kPad
)
*
kK1
>
{},
Number
<
kK1
>
{},
Number
<
1
>
{}),
Number
<
kK1
>
{},
Number
<
1
>
{});
constexpr
auto
v_lds_block_desc
=
transform_tensor_descriptor
(
v_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kNPerBlock
),
make_merge_transform
(
make_tuple
(
Number
<
kKPerBlock
/
kK1
>
{},
Number
<
kK1
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
v_lds_block_desc
;
}
template
<
typename
Problem
>
__host__
__device__
static
constexpr
ck
::
index_t
GetSmemSizeQ
()
{
return
0
;
}
template
<
typename
Problem
>
__host__
__device__
static
constexpr
ck
::
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_gemm_0
=
GetSmemSizeQ
<
Problem
>
()
+
sizeof
(
typename
Problem
::
KDataType
)
*
MakeKLdsBlockDescriptor
<
Problem
>
().
GetElementSpaceSize
();
constexpr
index_t
smem_size_gemm_1
=
MakeVLdsBlockDescriptor
<
Problem
>
().
GetElementSpaceSize
()
*
sizeof
(
typename
Problem
::
VDataType
);
// TODO: consider shuffle requirement
return
math
::
max
(
smem_size_gemm_0
,
smem_size_gemm_1
);
}
template
<
typename
Problem
,
typename
BlockGemm
>
__host__
__device__
static
constexpr
auto
MakeQDramTileDistribution
()
{
constexpr
auto
config
=
BlockGemm
::
Policy
::
template
GetWarpGemmMWarpNWarp
<
Problem
>();
using
WG
=
remove_cvref_t
<
decltype
(
config
.
template
At
<
0
>())
>
;
constexpr
index_t
MWarp
=
config
.
template
At
<
1
>();
constexpr
index_t
kMPerBlock
=
Problem
::
BlockFmhaShape
::
kM0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0BlockLength
;
constexpr
index_t
K2
=
WG
::
kK
/
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K1
=
WG
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
K0
=
kKPerBlock
/
(
K1
*
K2
);
constexpr
index_t
M2
=
WG
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
M1
=
MWarp
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
M0
,
M1
,
M2
>
,
Sequence
<
K0
,
K1
,
K2
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
2
,
1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Sequence
<
2
,
1
,
2
>
,
Sequence
<
0
,
0
,
2
>>
{});
}
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
MakeKDramTileDistribution
()
{
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN0
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK0
;
constexpr
index_t
K1
=
16
/
sizeof
(
KDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
#if 1 // coalesce reading for each blocks
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
N0
,
N1
,
N2
>
,
Sequence
<
K0
,
K1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
2
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
0
,
1
>>
{});
#else // coalesce reading for each warps
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
N0
,
N1
,
N2
>
,
Sequence
<
K0
,
K1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
0
>
,
Sequence
<
2
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
1
,
1
>>
{});
#endif
}
template
<
typename
Problem
>
__device__
static
constexpr
auto
MakeVDramTileDistribution
()
{
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockFmhaShape
::
kN1
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockFmhaShape
::
kK1
;
constexpr
index_t
K1
=
16
/
sizeof
(
VDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
N0
,
N1
,
N2
>
,
Sequence
<
K0
,
K1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
2
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
GetQKBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN0
,
Problem
::
BlockFmhaShape
::
kK0
>>
;
// using WarpGemm = ck::tile_program::warp::WarpGemmMfmaDispatcher<typename
// Problem::QDataType, typename Problem::KDataType, typename Problem::SaccDataType,
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<0>{}),
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<1>{}),
// Problem::BlockFmhaShape::Gemm0WarpTile::At(Number<2>{}), true>;
using
WarpGemm
=
warp
::
WarpGemmImpl
<
warp
::
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
<
warp
::
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
QDataType
,
typename
Problem
::
KDataType
,
typename
Problem
::
SaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm0BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
GetKVBlockGemm
()
{
using
BlockGemmProblem
=
BlockGemmPipelineProblem
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
kBlockSize
,
TileGemmShape
<
Problem
::
BlockFmhaShape
::
kM0
,
Problem
::
BlockFmhaShape
::
kN1
,
Problem
::
BlockFmhaShape
::
kK1
>>
;
// using BlockGemmPolicy = BlockGemmARegBSmemCRegV1DefaultPolicy;
using
WarpGemm
=
ck
::
tile_program
::
warp
::
WarpGemmMfmaDispatcher
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
At
(
Number
<
0
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
At
(
Number
<
1
>
{}),
Problem
::
BlockFmhaShape
::
Gemm1WarpTile
::
At
(
Number
<
2
>
{}),
true
>
;
using
BlockGemmPolicy
=
BlockGemmARegBSmemCRegV1CustomPolicy
<
typename
Problem
::
PDataType
,
typename
Problem
::
VDataType
,
typename
Problem
::
OaccDataType
,
typename
Problem
::
BlockFmhaShape
::
Gemm1BlockWarps
,
WarpGemm
>
;
return
BlockGemmARegBSmemCRegV1
<
BlockGemmProblem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/tile/tile_fmha_shape.hpp
View file @
e71aa1d6
...
...
@@ -8,19 +8,27 @@
namespace
ck
{
namespace
tile_program
{
template
<
index_t
kM0PerTile_
,
// tile size along q seqlen
index_t
kN0PerTile_
,
// tile size along k seqlen
index_t
kK0PerTile_
,
// tile size along qk gemm unroll
index_t
kN1PerTile_
,
// tile size along v head_dim
index_t
kK1PerTile_
// tile size along kv gemm unroll
>
template
<
typename
BlockTile_
,
// Sequence<...
typename
Gemm0BlockWarps_
,
typename
Gemm0WarpTile_
,
typename
Gemm1BlockWarps_
,
typename
Gemm1WarpTile_
>
struct
TileFmhaShape
{
static
constexpr
index_t
kM0
=
kM0PerTile_
;
static
constexpr
index_t
kN0
=
kN0PerTile_
;
static
constexpr
index_t
kK0
=
kK0PerTile_
;
static
constexpr
index_t
kN1
=
kN1PerTile_
;
static
constexpr
index_t
kK1
=
kK1PerTile_
;
using
BlockTile
=
remove_cvref_t
<
BlockTile_
>
;
using
Gemm0BlockWarps
=
remove_cvref_t
<
Gemm0BlockWarps_
>
;
using
Gemm0WarpTile
=
remove_cvref_t
<
Gemm0WarpTile_
>
;
using
Gemm1BlockWarps
=
remove_cvref_t
<
Gemm1BlockWarps_
>
;
using
Gemm1WarpTile
=
remove_cvref_t
<
Gemm1WarpTile_
>
;
static
constexpr
index_t
kM0
=
BlockTile
::
At
(
Number
<
0
>
{});
// tile size along q seqlen
static
constexpr
index_t
kN0
=
BlockTile
::
At
(
Number
<
1
>
{});
// tile size along k seqlen
static
constexpr
index_t
kK0
=
BlockTile
::
At
(
Number
<
2
>
{});
// tile size along qk gemm unroll
static
constexpr
index_t
kN1
=
BlockTile
::
At
(
Number
<
3
>
{});
// tile size along v head_dim
static
constexpr
index_t
kK1
=
BlockTile
::
At
(
Number
<
4
>
{});
// tile size along kv gemm unroll
static
constexpr
index_t
kK0BlockLength
=
BlockTile
::
At
(
Number
<
5
>
{});
// total length of K0, used for pipeline that need load Q at
// once (or repeately load Q as a whole tile)
};
}
// namespace tile_program
...
...
include/ck/tile_program/warp_tile/warp_gemm.hpp
View file @
e71aa1d6
...
...
@@ -22,9 +22,15 @@ using WarpGemmMfmaF16F16F32M16N16K16 =
using
WarpGemmMfmaF16F16F32M32N32K16
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
2
>>
;
using
WarpGemmMfmaF16F16F32M16N16K32
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateK
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
,
2
>>
;
using
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
>>
;
using
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16
>>
;
using
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
=
WarpGemmImpl
<
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8
,
...
...
include/ck/tile_program/warp_tile/warp_gemm_attribute_mfma.hpp
View file @
e71aa1d6
...
...
@@ -287,6 +287,90 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
}
};
template
<
typename
WarpGemmAttributeMfmaImpl_
,
index_t
kKIter
>
struct
WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
{
using
Impl
=
remove_cvref_t
<
WarpGemmAttributeMfmaImpl_
>
;
// swap A and B
using
ADataType
=
typename
Impl
::
BDataType
;
using
BDataType
=
typename
Impl
::
ADataType
;
using
CDataType
=
typename
Impl
::
CDataType
;
using
AVecType
=
typename
vector_type_maker
<
typename
Impl
::
BVecType
,
kKIter
>::
type
::
type
;
using
BVecType
=
typename
vector_type_maker
<
typename
Impl
::
AVecType
,
kKIter
>::
type
::
type
;
using
CVecType
=
typename
Impl
::
CVecType
;
static
constexpr
index_t
kM
=
Impl
::
kN
;
static
constexpr
index_t
kN
=
Impl
::
kM
;
static
constexpr
index_t
kK
=
Impl
::
kK
*
kKIter
;
using
AWarpDstrEncoding
=
StaticTileDistributionEncoding
<
Sequence
<>
,
Tuple
<
Sequence
<
Impl
::
kBNLane
>
,
Sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
Tuple
<
Sequence
<
2
,
1
>>
,
Tuple
<
Sequence
<
0
,
0
>>
,
Sequence
<
2
>
,
Sequence
<
1
>>
;
using
BWarpDstrEncoding
=
StaticTileDistributionEncoding
<
Sequence
<>
,
Tuple
<
Sequence
<
Impl
::
kAMLane
/
(
Impl
::
kABKPerLane
*
Impl
::
kABKLane
*
2
),
Impl
::
kABKLane
,
2
,
Impl
::
kABKPerLane
>
,
Sequence
<
Impl
::
kABKLane
,
Impl
::
kABKPerLane
*
kKIter
>>
,
Tuple
<
Sequence
<
2
,
1
,
1
,
1
,
1
>>
,
Tuple
<
Sequence
<
0
,
0
,
2
,
1
,
3
>>
,
Sequence
<
2
>
,
Sequence
<
1
>>
;
using
CWarpDstrEncoding
=
StaticTileDistributionEncoding
<
Sequence
<>
,
Tuple
<
Sequence
<
Impl
::
kCNLane
>
,
Sequence
<
Impl
::
kCM0PerLane
,
Impl
::
kCMLane
,
Impl
::
kCM1PerLane
>>
,
Tuple
<
Sequence
<
2
,
1
>>
,
Tuple
<
Sequence
<
1
,
0
>>
,
Sequence
<
2
,
2
>
,
Sequence
<
0
,
2
>>
;
// c_vec += a_vec * b_vec
__device__
void
operator
()(
CVecType
&
c_vec
,
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
const
auto
a_vector
=
typename
vector_type_maker
<
AVecType
,
1
>::
type
{
a_vec
};
const
auto
b_vector
=
typename
vector_type_maker
<
BVecType
,
1
>::
type
{
b_vec
};
// swap A and B, value and type
static_for
<
0
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
b_vector
.
template
AsType
<
typename
Impl
::
AVecType
>()[
iKIter
],
a_vector
.
template
AsType
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
}
// c_vec = a_vec * b_vec
__device__
CVecType
operator
()(
const
AVecType
&
a_vec
,
const
BVecType
&
b_vec
)
const
{
const
auto
a_vector
=
typename
vector_type_maker
<
AVecType
,
1
>::
type
{
a_vec
};
const
auto
b_vector
=
typename
vector_type_maker
<
BVecType
,
1
>::
type
{
b_vec
};
constexpr
auto
I0
=
Number
<
0
>
{};
// swap A and B, value and type
auto
c_vec
=
Impl
{}(
b_vector
.
template
AsType
<
typename
Impl
::
AVecType
>()[
I0
],
a_vector
.
template
AsType
<
typename
Impl
::
BVecType
>()[
I0
]);
static_for
<
1
,
kKIter
,
1
>
{}([
&
](
auto
iKIter
)
{
Impl
{}(
c_vec
,
b_vector
.
template
AsType
<
typename
Impl
::
AVecType
>()[
iKIter
],
a_vector
.
template
AsType
<
typename
Impl
::
BVecType
>()[
iKIter
]);
});
return
c_vec
;
}
};
}
// namespace warp
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/warp_tile/warp_gemm_dispatcher.hpp
0 → 100644
View file @
e71aa1d6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
warp
{
namespace
impl
{
template
<
typename
AType
,
typename
BType
,
typename
CType
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
struct
WarpGemmMfmaDispatcher
;
// clang-format off
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck
::
half_t
,
ck
::
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck
::
half_t
,
ck
::
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck
::
half_t
,
ck
::
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck
::
half_t
,
ck
::
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck
::
half_t
,
ck
::
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck
::
half_t
,
ck
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck
::
half_t
,
ck
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck
::
half_t
,
ck
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
// clang-format on
}
// namespace impl
template
<
typename
AType
,
typename
BType
,
typename
CType
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
KPerWave
,
bool
TransposeC
>
using
WarpGemmMfmaDispatcher
=
typename
impl
::
WarpGemmMfmaDispatcher
<
AType
,
BType
,
CType
,
MPerWave
,
NPerWave
,
KPerWave
,
TransposeC
>::
Type
;
}
// namespace warp
}
// namespace tile_program
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment