Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c881136b
Commit
c881136b
authored
Jan 01, 2025
by
Po Yen Chen
Browse files
Merge branch 'develop' into ck_tile/support-vllm-kcache-layout
parents
c5e8e14f
4e076909
Changes
75
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1255 additions
and
284 deletions
+1255
-284
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+1
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-1
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+44
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-1
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-1
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-1
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-1
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+27
-4
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+22
-4
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-1
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+3
-3
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
+20
-8
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+75
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+66
-27
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+0
-48
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+6
-3
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
+0
-105
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+65
-18
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+126
-47
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
...ock_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
+794
-0
No files found.
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
c881136b
...
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
...
...
include/ck_tile/host.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/host/arg_parser.hpp
View file @
c881136b
...
...
@@ -15,11 +15,14 @@
namespace
ck_tile
{
/*
* a host side utility, arg parser for
* -[key0]=[value0] -[key1]=[value1] ...
* a host side utility, arg parser for, either
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
class
ArgParser
{
public:
class
Arg
{
...
...
@@ -187,6 +190,45 @@ class ArgParser
return
value
;
}
std
::
vector
<
std
::
string
>
get_string_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
std
::
string
s
=
get_str
(
name
);
std
::
vector
<
std
::
string
>
tokens
;
size_t
pos
=
0
;
std
::
string
token
;
while
((
pos
=
s
.
find
(
delimiter
))
!=
std
::
string
::
npos
)
{
token
=
s
.
substr
(
0
,
pos
);
tokens
.
push_back
(
token
);
s
.
erase
(
0
,
pos
+
delimiter
.
length
());
}
tokens
.
push_back
(
s
);
return
tokens
;
}
std
::
vector
<
int
>
get_int_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
const
std
::
vector
<
std
::
string
>
args
=
get_string_vec
(
name
,
delimiter
);
std
::
vector
<
int
>
tokens
;
tokens
.
reserve
(
static_cast
<
int
>
(
args
.
size
()));
for
(
const
std
::
string
&
token
:
args
)
{
int
value
=
atoi
(
token
.
c_str
());
tokens
.
push_back
(
value
);
}
return
tokens
;
}
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/common.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/elementwise.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/epilogue.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -56,6 +56,13 @@ struct CShuffleEpilogue
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return
false
;
}
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
...
...
@@ -111,7 +118,9 @@ struct CShuffleEpilogue
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
...
...
@@ -158,12 +167,26 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
}
};
...
...
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -35,21 +35,39 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
return
false
;
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
)
{
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
}
};
...
...
include/ck_tile/ops/flatmm.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/fmha.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,9 +14,7 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
...
...
@@ -28,6 +26,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
View file @
c881136b
...
...
@@ -10,10 +10,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
>
template
<
typename
FmhaPipeline_
>
struct
FmhaFwdAppendKVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
...
...
@@ -234,12 +233,25 @@ struct FmhaFwdAppendKVKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
seqlen_knew
);
// TODO: this may need tuning
return
dim3
(
std
::
max
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
FmhaPipeline
::
kM0
),
ck_tile
::
integer_divide_ceil
(
seqlen_knew
,
FmhaPipeline
::
kN0
)),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
/* kargs */
)
{
const
index_t
i_tile
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile
,
i_nhead
,
i_batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -247,7 +259,7 @@ struct FmhaFwdAppendKVKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
);
const
auto
[
i_tile
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile
*
FmhaPipeline
::
kM0
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile
*
FmhaPipeline
::
kN0
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
c881136b
...
...
@@ -20,10 +20,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
...
...
@@ -71,7 +70,8 @@ struct FmhaFwdKernel
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
g0wt
=
typename
bfs
::
Gemm0WarpTile
;
using
g1wt
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
...
...
@@ -83,12 +83,13 @@ struct FmhaFwdKernel
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
...
...
@@ -865,9 +866,75 @@ struct FmhaFwdKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
ck_tile
::
index_t
hdim_v_
,
bool
has_padded_seqlen_k
=
false
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
,
hdim_v_
);
// has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
if
(
has_padded_seqlen_k
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
FmhaPipeline
::
kN1
));
}
else
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
FmhaPipeline
::
kN1
),
nhead_
,
batch_size_
);
}
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
bool
has_padded_seqlen_k
=
false
;
if
constexpr
(
kIsGroupMode
)
has_padded_seqlen_k
=
(
kargs
.
seqlen_k_ptr
!=
nullptr
);
if
(
has_padded_seqlen_k
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
else
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -883,8 +950,7 @@ struct FmhaFwdKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
c881136b
...
...
@@ -5,12 +5,13 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVCombineKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
index_t
kNumWarps
=
FmhaPipeline
::
kNumWarps
;
static
constexpr
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
static_assert
(
kBlockPerCu
>
0
);
...
...
@@ -50,8 +51,7 @@ struct FmhaFwdSplitKVCombineKernel
return
_SS_
(
"fmha_fwd_splitkv_combine_d"
)
+
_TS_
(
FmhaPipeline
::
kHeadDimV
)
+
"_"
+
_SS_
(
t2s
<
ODataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
FmhaPipeline
::
kM0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
"b"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
...
...
@@ -234,12 +234,35 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_seqlen_q
,
hdim_v
);
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -255,8 +278,7 @@ struct FmhaFwdSplitKVCombineKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
@@ -339,37 +361,56 @@ struct FmhaFwdSplitKVCombineKernel
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
// read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
const
auto
o_acc_dram_view
=
pad_tensor_view
(
o_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
make_tuple
(
number
<
kNumWarps
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
true
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_num_splits
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
0
>
{}];
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
const
index_t
num_m_tiles
=
integer_divide_floor
(
padded_seqlen_q
,
FmhaPipeline
::
kM0
);
// transform tensor view by following steps, given shape: (padded_num_splits,
// padded_seqlen_q, padded_hdim_v)
// 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
// 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
// 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
auto
transposed
=
transform_tensor_view
(
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_tuple
(
make_pass_through_transform
(
padded_num_splits
),
make_unmerge_transform
(
make_tuple
(
num_m_tiles
,
FmhaPipeline
::
kM0
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{},
sequence
<
3
>
{}));
return
transform_tensor_view
(
transposed
,
make_tuple
(
make_merge_transform
(
make_tuple
(
num_m_tiles
,
padded_num_splits
,
FmhaPipeline
::
kM0
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}();
auto
lse_acc_dram_window
=
make_tile_window
(
lse_acc_dram
,
[
&
]()
{
return
make_tuple
(
number
<
FmhaPipeline
::
kMaxSplits
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kMaxSplits
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
,
i_m0
});
const
index_t
padded_num_splits
=
integer_divide_ceil
(
kargs
.
num_splits
,
kNumWarps
)
*
kNumWarps
;
auto
o_acc_dram_window
=
make_tile_window
(
o_acc_dram
,
[
&
]()
{
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{});
}(),
{
i_m0
,
i_n1
});
make_tuple
(
number
<
kNumWarps
*
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
{
i_tile_m
*
padded_num_splits
*
FmhaPipeline
::
kM0
,
i_n1
});
// LSE DRAM window
auto
lse_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
...
...
@@ -410,7 +451,6 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
seqlen_q
,
smem_ptr
);
}
else
...
...
@@ -419,7 +459,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
seqlen_q
,
smem_ptr
);
}
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
deleted
100644 → 0
View file @
c5e8e14f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN1_
>
struct
FmhaFwdSplitKVCombineTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
c881136b
...
...
@@ -44,6 +44,7 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
...
...
@@ -66,7 +67,8 @@ struct FmhaFwdSplitKVKernel
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
g0wt
=
typename
bfs
::
Gemm0WarpTile
;
using
g1wt
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
...
...
@@ -83,11 +85,12 @@ struct FmhaFwdSplitKVKernel
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
)
+
(
kIsPagedKV
?
"_pagedkv"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
)
+
(
kIsPagedKV
?
"_pagedkv"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
c5e8e14f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"shb"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
template
<
typename
BlockFmhaShape_
>
using
FmhaFwdTilePartitioner_SHB
=
FmhaFwdTilePartitioner
<
BlockFmhaShape_
>
;
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner_HBS
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"hbs"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
c881136b
...
...
@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kHeadDimV
=
Problem
::
kHeadDimV
;
...
...
@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
// lse_acc tile in LDS
...
...
@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto
lse_acc_tile
=
load_tile
(
lse_acc_dram_window
);
store_tile
(
lse_acc_lds_write_window
,
lse_acc_tile
);
block_sync_lds
();
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
...
...
@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}
});
}
block_sync_lds
();
if
constexpr
(
kStoreLSE
)
{
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse_logsum
));
}
auto
o_acc_dist
=
Policy
::
template
MakeOaccDramTileDistribution
<
Problem
>();
auto
o_acc_dram_window
=
auto
o_acc_
4_
dist
=
Policy
::
template
MakeOacc
4
DramTileDistribution
<
Problem
>();
auto
o_acc_
4_
dram_window
=
make_tile_window
(
o_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
o_acc_dram_block_window_tmp
.
get_window_lengths
(),
o_acc_dram_block_window_tmp
.
get_window_origin
(),
o_acc_dist
);
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
o_acc_4_dist
);
const
index_t
padded_seqlen_q
=
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
kM0
;
// shape=[4 * KM0, kN1]
auto
o_acc_4
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_4_dist
);
clear_tile
(
o_acc_4
);
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
const
index_t
padded_num_splits
=
integer_divide_ceil
(
num_splits
,
kNumWarps
)
*
kNumWarps
;
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
// each warp handles a [KM0, kN1] tile
for
(
index_t
split_start
=
0
;
split_start
<
padded_num_splits
;
split_start
+=
kNumWarps
)
{
auto
o_tile
=
load_tile
(
o_acc_dram_window
);
auto
o_tile
=
load_tile
(
o_acc_4_dram_window
);
const
index_t
i_split
=
split_start
+
get_warp_id
();
const
index_t
row_start
=
kM0
*
get_warp_id
();
{
constexpr
auto
spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
constexpr
auto
spans
=
decltype
(
o_acc
_4
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
o_acc
.
get_tile_distribution
(),
i_j_idx
);
o_acc
_4
.
get_tile_distribution
(),
i_j_idx
);
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
const
LSEDataType
lse_scale
=
lse_acc_lds
(
row
,
i_split
);
o_acc
(
i_j_idx
)
+=
lse_scale
*
o_tile
(
i_j_idx
);
const
LSEDataType
lse_scale
=
lse_acc_lds
(
row
-
row_start
,
i_split
);
o_acc
_4
(
i_j_idx
)
+=
lse_scale
*
o_tile
(
i_j_idx
);
});
});
}
move_tile_window
(
o_acc_dram_window
,
{
padded_seqlen_q
,
0
});
move_tile_window
(
o_acc_4_dram_window
,
{
kNumWarps
*
kM0
,
0
});
}
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
OaccDataType
*
o_acc_4_lds_ptr
=
static_cast
<
OaccDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeLSEacc
<
Problem
>()));
{
auto
o_acc_4_lds_window
=
[
&
]()
{
auto
desc
=
Policy
::
template
MakeOacc4LdsBlockDescriptor
<
Problem
>();
auto
view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
o_acc_4_lds_ptr
,
desc
);
return
make_tile_window
(
view
,
desc
.
get_lengths
(),
{
0
,
0
});
}();
store_tile
(
o_acc_4_lds_window
,
o_acc_4
);
}
auto
o_acc_dist
=
Policy
::
template
MakeOaccDramTileDistribution
<
Problem
>();
auto
o_acc_4_lds_window
=
[
&
]()
{
auto
desc
=
Policy
::
template
MakeOacc4LdsBlockDescriptor
<
Problem
>();
auto
view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
o_acc_4_lds_ptr
,
desc
);
return
make_tile_window
(
view
,
desc
.
get_lengths
(),
{
0
,
0
},
o_acc_dist
);
}();
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
static_for
<
0
,
kNumWarps
,
1
>
{}([
&
](
auto
)
{
auto
o_acc_in
=
load_tile
(
o_acc_4_lds_window
);
{
constexpr
auto
spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
+=
o_acc_in
(
i_j_idx
);
});
});
}
move_tile_window
(
o_acc_4_lds_window
,
{
kM0
,
0
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
...
...
@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
return
operator
()(
lse_acc_dram_block_window
,
...
...
@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity
{},
identity
{},
num_splits
,
seqlen_q
,
smem_ptr
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
c881136b
...
...
@@ -10,23 +10,38 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template
<
index_t
BlockSize
,
index_t
M
,
index_t
N
,
typename
DataType
>
template
<
index_t
NumWarps
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMaxNumWarpsForTile
()
{
static_assert
(
NumWarps
==
1
||
NumWarps
==
2
||
NumWarps
==
4
);
constexpr
index_t
ElemPerThread
=
(
M
*
N
)
/
(
NumWarps
*
get_warp_size
());
if
constexpr
(
0
<
ElemPerThread
)
{
return
NumWarps
;
}
else
{
// try dividing tile by smaller # of warps
return
GetMaxNumWarpsForTile
<
NumWarps
/
2
,
M
,
N
,
DataType
>
();
}
}
template
<
index_t
NumWarps
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeForTile
()
{
constexpr
index_t
PixelsPerThread
=
(
M
*
N
)
/
BlockSize
;
static_assert
(
0
<
PixelsPerThread
);
constexpr
index_t
MaxNumWarps
=
GetMaxNumWarpsForTile
<
NumWarps
,
M
,
N
,
DataType
>
();
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
constexpr
index_t
NPerThread
=
min
(
MaxNPerThread
,
PixelsPerThread
);
constexpr
index_t
ElemPerThread
=
(
M
*
N
)
/
(
MaxNumWarps
*
get_warp_size
());
return
NPerThread
;
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
return
min
(
MaxNPerThread
,
ElemPerThread
);
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
{
return
GetVectorSizeForTile
<
Problem
::
k
BlockSize
,
return
GetVectorSizeForTile
<
Problem
::
k
NumWarps
,
Problem
::
kMaxSplits
,
Problem
::
kM0
,
typename
Problem
::
LSEDataType
>
();
...
...
@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
LSEacc
()
{
return
sizeof
(
typename
Problem
::
LSEDataType
)
*
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeOacc4
()
{
return
sizeof
(
typename
Problem
::
OaccDataType
)
*
MakeOacc4LdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
GetSmemSizeLSEacc
<
Problem
>
()
+
GetSmemSizeOacc4
<
Problem
>
();
}
// shape=[kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
MaxNumWarps
=
GetMaxNumWarpsForTile
<
Problem
::
kNumWarps
,
kNPerBlock
,
kMPerBlock
,
LSEDataType
>
();
constexpr
index_t
Replicate
=
Problem
::
kNumWarps
/
MaxNumWarps
;
constexpr
index_t
NPerThread
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
MaxNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
k
NumWarps
*
MThreadsPerWarp
);
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
Max
NumWarps
*
MThreadsPerWarp
);
static_assert
(
MPerThread
*
MaxNumWarps
*
MThreadsPerWarp
==
kMPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MPerThread
*
kNumWarps
*
MThreadsPerWarp
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
k
NumWarps
,
MThreadsPerWarp
>
,
tile_distribution_encoding
<
sequence
<
Replicate
>
,
tuple
<
sequence
<
MPerThread
,
Max
NumWarps
,
MThreadsPerWarp
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
...
...
@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
NPack
>
{},
number
<
1
>
{});
constexpr
auto
lse_acc_lds_block_desc
=
transform_tensor_descriptor
(
...
...
@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
NPack
>
{},
number
<
1
>
{});
constexpr
auto
lse_acc_t_lds_block_desc
=
transform_tensor_descriptor
(
...
...
@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return
lse_acc_t_lds_block_desc
;
}
// 3d + padding, shape=[4 * kM0, kN1]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LSEaccRegTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
Oacc4LdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
4
*
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
o_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
o_acc_t_lds_block_desc
=
transform_tensor_descriptor
(
o_acc_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kNPerBlock
/
NPack
,
NPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
o_acc_t_lds_block_desc
;
}
// shape=[kM0, kMaxSplits]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccRegTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NThreads
=
4
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MaxNThreads
=
8
;
constexpr
index_t
NThreads
=
min
(
kNPerBlock
,
MaxNThreads
);
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
1
;
constexpr
index_t
MThreads
=
kMPerBlock
/
MPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MaxNumWarps
=
(
MThreads
*
NThreads
)
/
get_warp_size
();
constexpr
index_t
Replicate
=
Problem
::
kNumWarps
/
MaxNumWarps
;
static_assert
(
MaxNumWarps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MWarps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MWarps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
tile_distribution_encoding
<
sequence
<
Replicate
>
,
tuple
<
sequence
<
MaxNumWarps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
// direction
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOacc4DramTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
// real kMPerBlock we want is (4 * kM0)
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
static_assert
(
get_warp_size
()
<=
kMPerBlock
*
kNPerBlock
);
constexpr
index_t
M1
=
1
;
// compose encoding base on 1 warp
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
kNPerBlock
/
N0
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
4
,
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
2
>
,
sequence
<
3
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
template
<
typename
Problem
>
...
...
@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
static_assert
(
kBlockSize
<=
kMPerBlock
*
kNPerBlock
);
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
0 → 100644
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace
ck_tile
{
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVSDefaultPolicy
>
struct
BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
QDataType
=
remove_cvref_t
<
typename
Problem
::
QDataType
>
;
using
KDataType
=
remove_cvref_t
<
typename
Problem
::
KDataType
>
;
using
VDataType
=
remove_cvref_t
<
typename
Problem
::
VDataType
>
;
using
SaccDataType
=
remove_cvref_t
<
typename
Problem
::
SaccDataType
>
;
using
SMPLComputeDataType
=
remove_cvref_t
<
typename
Problem
::
SMPLComputeDataType
>
;
using
BiasDataType
=
remove_cvref_t
<
typename
Problem
::
BiasDataType
>
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
using
PDataType
=
remove_cvref_t
<
typename
Problem
::
PDataType
>
;
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
FmhaMask
=
remove_cvref_t
<
typename
Problem
::
FmhaMask
>
;
using
BlockFmhaShape
=
remove_cvref_t
<
typename
Problem
::
BlockFmhaShape
>
;
using
VLayout
=
remove_cvref_t
<
typename
BlockFmhaShape
::
VLayout
>
;
static
constexpr
bool
kQLoadOnce
=
true
;
// if q_tile load whole block length (hdim) at once
static_assert
(
kQLoadOnce
==
Policy
::
QLoadOnce
);
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kQKHeaddim
=
BlockFmhaShape
::
kQKHeaddim
;
static
constexpr
index_t
kSubQKHeaddim
=
BlockFmhaShape
::
kSubQKHeaddim
;
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
static
constexpr
bool
kPadSeqLenQ
=
Problem
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Problem
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Problem
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kIsPagedKV
=
Problem
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
Problem
::
kHasUnevenSplits
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentOacc
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentOacc
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
return
Problem
::
kBlockPerCu
;
else
{
if
constexpr
(
kQKHeaddim
<=
32
)
{
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<=
64
)
{
return
3
;
}
else
if
constexpr
(
kQKHeaddim
<=
128
)
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
return
1
;
else
return
2
;
}
else
if
constexpr
(
kQKHeaddim
<=
256
)
{
return
1
;
}
}
}();
static
constexpr
const
char
*
name
=
"qr_nwarp_sshuffle"
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
VElementFunction
,
typename
BiasElementFunction
,
typename
LSEaccElementFunction
,
typename
SAccElementFunction
,
typename
PComputeElementFunction
,
typename
OAccElementFunction
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
QElementFunction
&
q_element_func
,
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
KElementFunction
&
k_element_func
,
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
LSEaccDramBlockWindowTmp
&
lse_acc_dram_window_tmp
,
// M0*1 tile
const
LSEaccElementFunction
&
lse_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
PComputeElementFunction
&
p_compute_element_func
,
const
OAccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
void
*
smem_ptr
)
const
{
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
KDataType
,
remove_cvref_t
<
typename
KPageBlockNavigator
::
DataType
>>
&&
std
::
is_same_v
<
VDataType
,
remove_cvref_t
<
typename
VPageBlockNavigator
::
DataType
>>
,
"wrong!"
);
static_assert
(
kM0
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kSubQKHeaddim
==
QDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}]
&&
kN0
==
KDramBlockWindowLengths
{}[
number
<
0
>
{}]
&&
kK0
==
KDramBlockWindowLengths
{}[
number
<
1
>
{}]
&&
kN1
==
VDramBlockWindowLengths
{}[
number
<
0
>
{}]
&&
kK1
==
VDramBlockWindowLengths
{}[
number
<
1
>
{}]
&&
kM0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
kN0
==
BiasDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"wrong!"
);
// Q tile in LDS
QDataType
*
q_lds_ptr
=
static_cast
<
QDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
q_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
q_lds_ptr
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>());
// K tile in LDS
KDataType
*
k_lds_ptr
=
static_cast
<
KDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)));
auto
k_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeKLdsBlockDescriptor
<
Problem
>());
auto
k_lds_window
=
make_tile_window
(
k_lds
,
make_tuple
(
number
<
kN0
>
{},
number
<
kK0
>
{}),
{
0
,
0
});
// V tile in LDS
auto
v_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
VDataType
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
max
(
Policy
::
template
GetSmemSizeQ
<
Problem
>(),
Policy
::
template
GetSmemSizeK
<
Problem
>())),
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
Policy
::
template
MakeVLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
// S tile in LDS
auto
s_lds
=
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
SaccDataType
*>
(
reinterpret_cast
<
char
*>
(
smem_ptr
)
+
max
(
Policy
::
template
GetSmemSizeQ
<
Problem
>(),
Policy
::
template
GetSmemSizeK
<
Problem
>())),
Policy
::
template
MakeSLdsBlockDescriptor
<
Problem
>());
auto
s_write_lds_window
=
make_tile_window
(
s_lds
,
Policy
::
template
MakeSLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
auto
s_read_lds_window
=
make_tile_window
(
s_lds
,
Policy
::
template
MakeSLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
template
MakeSRegTileDistribution
<
Problem
>());
// Block GEMM
constexpr
auto
gemm_0
=
Policy
::
template
GetQKBlockGemm
<
Problem
>();
constexpr
auto
gemm_1
=
Policy
::
template
GetKVBlockGemm
<
Problem
>();
auto
q_dram_window
=
make_tile_window
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
>());
// load Q here, will store Q into LDS to maximize throughput
auto
origin_q
=
load_tile
(
q_dram_window
);
using
SaccBlockTileType
=
decltype
(
gemm_0
.
MakeCBlockTile
());
auto
s_acc
=
SaccBlockTileType
{};
// reduction function for softmax
const
auto
f_max
=
[](
auto
e0
,
auto
e1
)
{
return
max
(
e0
,
e1
);
};
const
auto
f_sum
=
[](
auto
e0
,
auto
e1
)
{
return
e0
+
e1
;
};
using
OaccBlockTileType
=
decltype
(
gemm_1
.
MakeCBlockTile
());
auto
o_acc
=
OaccBlockTileType
{};
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
o_acc
));
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
// init M, L
auto
m
=
MLBlockTileType
{};
auto
l
=
MLBlockTileType
{};
clear_tile
(
o_acc
);
set_tile
(
m
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
clear_tile
(
l
);
const
auto
q_origin
=
q_dram_window
.
get_window_origin
();
const
auto
[
logical_seqlen_k_start
,
logical_seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
const
index_t
logical_num_total_loop
=
integer_divide_ceil
(
logical_seqlen_k_end
-
logical_seqlen_k_start
,
kN0
);
if
(
logical_num_total_loop
<=
0
)
{
if
constexpr
(
kStoreLSE
)
{
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
set_tile
(
lse_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
if
(
get_thread_local_1d_id
()
<
kM0
)
{
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
}
// Note: here occ are all cleard, return it
// Note: q loaded but no fence, ignore it.
return
o_acc
;
}
}
const
index_t
physical_seqlen_k_start
=
logical_seqlen_k_start
+
kv_l2p_offset
;
const
index_t
physical_seqlen_k_end
=
logical_seqlen_k_end
+
kv_l2p_offset
;
// make sure the first tile is completely located in page-block (page-block size should be
// divisible by kN0)
// relationship between each *_start variables: aligned_physical_seqlen_k_start <=
// physical_seqlen_k_start, logical_seqlen_k_start <= physical_seqlen_k_start
const
index_t
aligned_physical_seqlen_k_start
=
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
]
{
if
constexpr
(
kIsPagedKV
)
{
return
kN0
*
integer_divide_floor
(
physical_seqlen_k_start_
,
kN0
);
}
else
{
return
physical_seqlen_k_start_
;
}
}();
const
index_t
num_total_loop
=
integer_divide_ceil
(
physical_seqlen_k_end
-
aligned_physical_seqlen_k_start
,
kN0
);
auto
[
i_page_block_k
,
k_dram_block_window
]
=
k_page_block_navigator
.
make_tile_window
(
k_dram_block_window_lengths
,
{
aligned_physical_seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
logical_seqlen_k_start
-
(
physical_seqlen_k_start
-
aligned_physical_seqlen_k_start
)},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
[
i_page_block_v
,
v_dram_window
]
=
v_page_block_navigator
.
make_tile_window
(
v_dram_block_window_lengths
,
{
0
,
aligned_physical_seqlen_k_start
},
// TODO: hdim split?
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// store Q into LDS
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_lds_window_for_store
=
make_tile_window
(
q_lds
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
});
store_tile
(
q_lds_window_for_store
,
origin_q
);
__builtin_amdgcn_sched_barrier
(
0
);
// load Q from LDS
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_lds_window_for_load
=
make_tile_window
(
q_lds
,
Policy
::
template
MakeQLdsBlockDescriptor
<
Problem
>().
get_lengths
(),
{
0
,
0
},
Policy
::
template
MakeQRegTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
block_sync_lds
();
auto
q
=
load_tile
(
q_lds_window_for_load
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
q_tile
=
tile_elementwise_in
(
q_element_func
,
q
);
// prefetch K tile
index_t
i_total_loops
=
0
;
constexpr
index_t
k0_loops
=
kQKHeaddim
/
kK0
;
constexpr
index_t
k1_loops
=
kN0
/
kK1
;
static_assert
(
2
<=
k0_loops
);
static_assert
(
1
<=
k1_loops
);
auto
k_dram_window
=
make_tile_window
(
k_dram_block_window
,
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load the first tile of the first iteration and store to LDS
auto
k_block_tile
=
load_tile
(
k_dram_window
);
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
do
{
// STAGE 1, QK gemm
clear_tile
(
s_acc
);
// initialize C
// load the second tile of the first iteration
k_block_tile
=
load_tile
(
k_dram_window
);
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
const
auto
bias_tile
=
load_tile
(
bias_dram_window
);
// load bias tile
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
__builtin_amdgcn_sched_barrier
(
0
);
// prevent from messing up the order of global loads
}
if
constexpr
(
k0_loops
>
2
)
{
static_for
<
0
,
k0_loops
-
2
,
1
>
{}([
&
](
auto
i_k0
)
{
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
i_k0
*
kK0
>
{},
sequence
<
kM0
,
(
i_k0
+
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
// LDS write i + 1
k_block_tile
=
load_tile
(
k_dram_window
);
// global read i + 2
});
}
const
auto
v_prefetch
=
load_tile
(
v_dram_window
);
// prefetch load v tile
{
// tail
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
2
)
*
kK0
>
{},
sequence
<
kM0
,
(
k0_loops
-
1
)
*
kK0
>
{}),
k_lds_window
);
block_sync_lds
();
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
block_sync_lds
();
gemm_0
(
s_acc
,
get_slice_tile
(
q_tile
,
sequence
<
0
,
(
k0_loops
-
1
)
*
kK0
>
{},
sequence
<
kM0
,
k0_loops
*
kK0
>
{}),
k_lds_window
);
}
// STAGE 2, scale_s, add bias, mask, softmax
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
tile_elementwise_inout
(
[
&
](
auto
&
x
,
const
auto
&
y
)
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x
+=
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#else
x
+=
log2e_v
<
SaccDataType
>
*
type_convert
<
SaccDataType
>
(
bias_element_func
(
y
));
#endif
},
s_acc
,
bias_tile
);
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
constexpr
auto
s_spans
=
decltype
(
s_acc
)
::
get_distributed_spans
();
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
sweep_tile_span
(
s_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
s_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
s_acc
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
s_acc
(
i_j_idx
)
*=
scale_s
;
// position_encoding accept only logical coordinates, do conversion here
position_encoding
.
update
(
s_acc
(
i_j_idx
),
row
,
col
-
kv_l2p_offset
);
});
});
}
else
{
s_acc
=
tile_elementwise_in
(
s_acc_element_func
,
s_acc
);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout
([
&
scale_s
](
auto
&
x
)
{
x
=
x
*
scale_s
;
},
s_acc
);
#endif
}
move_tile_window
(
bias_dram_window
,
{
0
,
kN0
});
/// TODO: only check in first/last iteration without increasing code size
if
constexpr
(
kHasUnevenSplits
)
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
,
physical_seqlen_k_start_
=
physical_seqlen_k_start
,
physical_seqlen_k_end_
=
physical_seqlen_k_end
](
auto
tile_idx
)
{
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
if
constexpr
(
kIsPagedKV
)
{
return
col
<
physical_seqlen_k_start_
||
physical_seqlen_k_end_
<=
col
;
}
else
{
return
physical_seqlen_k_end_
<=
col
;
}
});
}
if
constexpr
(
kPadSeqLenK
||
FmhaMask
::
IsMasking
)
{
const
auto
k_origin
=
k_page_block_navigator
.
to_global_window_origin
(
i_page_block_k
,
k_dram_block_window
.
get_window_origin
());
// mask accept only logical coordinates, do conversion here
bool
need_perpixel_check
=
mask
.
IsEdgeTile
(
q_origin
.
at
(
number
<
0
>
{}),
k_origin
.
at
(
number
<
0
>
{})
-
kv_l2p_offset
,
number
<
kM0
>
{},
number
<
kN0
>
{});
if
(
need_perpixel_check
)
{
set_tile_if
(
s_acc
,
-
numeric
<
SMPLComputeDataType
>::
infinity
(),
[
&
](
auto
tile_idx
)
{
const
auto
row
=
q_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
0
>
{});
const
auto
col
=
k_origin
.
at
(
number
<
0
>
{})
+
tile_idx
.
at
(
number
<
1
>
{});
return
mask
.
IsOutOfBound
(
row
,
col
-
kv_l2p_offset
);
});
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// load the first tile for next iteration
if
(
i_total_loops
<
num_total_loop
-
1
)
{
// move K tile windows
i_page_block_k
=
k_page_block_navigator
.
move_tile_window
(
i_page_block_k
,
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
make_tile_window
(
k_dram_block_window
,
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window
// laod the first tile of the first iteration and store to LDS
k_block_tile
=
load_tile
(
k_dram_window
);
}
__builtin_amdgcn_sched_barrier
(
0
);
const
auto
s
=
cast_tile
<
SMPLComputeDataType
>
(
s_acc
);
// S{j}
// shuffle through LDS so that the tile layout is consistent with required by Gemm1
store_tile
(
s_write_lds_window
,
s
);
block_sync_lds
();
auto
s_new
=
load_tile
(
s_read_lds_window
);
auto
m_local
=
block_tile_reduce
<
SMPLComputeDataType
>
(
s_new
,
sequence
<
1
>
{},
f_max
,
-
numeric
<
SMPLComputeDataType
>::
infinity
());
// m_local = rowmax(S{j})
block_tile_reduce_sync
(
m_local
,
f_max
,
bool_constant
<
false
>
{});
const
auto
m_old
=
m
;
// m{j-1}
tile_elementwise_inout
(
[](
auto
&
e0
,
auto
e1
,
auto
e2
)
{
e0
=
max
(
e1
,
e2
);
},
m
,
m_old
,
m_local
);
// m{j}
auto
p_compute
=
make_static_distributed_tensor
<
SMPLComputeDataType
>
(
s_new
.
get_tile_distribution
());
// Pcompute{j}
static
const
auto
get_validated_m
=
[](
SMPLComputeDataType
raw_m
)
{
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
raw_m
==
-
numeric
<
SMPLComputeDataType
>::
infinity
()
?
type_convert
<
SMPLComputeDataType
>
(
0.
f
)
:
raw_m
;
}
else
{
return
raw_m
;
}
};
constexpr
auto
p_spans
=
decltype
(
p_compute
)
::
get_distributed_spans
();
sweep_tile_span
(
p_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
#endif
sweep_tile_span
(
p_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
p_compute
(
i_j_idx
)
=
exp2
(
s_new
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
p_compute
(
i_j_idx
)
=
exp2
(
scale_s
*
s_new
[
i_j_idx
]
-
row_max
);
}
#else
p_compute
(
i_j_idx
)
=
exp
(
s_new
[
i_j_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
});
});
auto
rowsum_p
=
block_tile_reduce
<
SMPLComputeDataType
>
(
p_compute
,
sequence
<
1
>
{},
f_sum
,
SMPLComputeDataType
{
0
});
// rowsum(Pcompute{j})
block_tile_reduce_sync
(
rowsum_p
,
f_sum
,
bool_constant
<
false
>
{});
const
auto
p
=
cast_tile
<
PDataType
>
(
tile_elementwise_in
(
p_compute_element_func
,
p_compute
));
// l{j}, Oacc{j}
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
return
exp2
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
}
else
{
auto
row_max
=
scale_s
*
get_validated_m
(
m
[
i_idx
]);
return
exp2
(
scale_s
*
m_old
[
i_idx
]
-
row_max
);
}
}();
#else
const
auto
tmp
=
exp
(
m_old
[
i_idx
]
-
get_validated_m
(
m
[
i_idx
]));
#endif
l
(
i_idx
)
=
tmp
*
l
[
i_idx
]
+
rowsum_p
[
i_idx
];
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v_prefetch
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_prefetch
));
// store the prefetch
}
i_page_block_v
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v
,
v_dram_window
,
{
0
,
kK1
});
// STAGE 3, KV gemm
if
constexpr
(
k1_loops
>
1
)
{
static_for
<
0
,
k1_loops
-
1
,
1
>
{}([
&
,
&
i_page_block_v_
=
i_page_block_v
,
&
v_dram_window_
=
v_dram_window
](
auto
i_k1
)
{
const
auto
v
=
load_tile
(
v_dram_window_
);
// load next v
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
i_k1
*
kK1
>
{},
sequence
<
kM0
,
(
i_k1
+
1
)
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
auto
v_shuffle_tmp
=
make_static_distributed_tensor
<
VDataType
>
(
Policy
::
template
MakeShuffledVRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
v_shuffle_tmp
,
v
);
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v_shuffle_tmp
));
// store the prefetch
}
else
{
store_tile
(
v_lds_window
,
tile_elementwise_in
(
v_element_func
,
v
));
// store next v
}
i_page_block_v_
=
v_page_block_navigator
.
move_tile_window
(
i_page_block_v_
,
v_dram_window_
,
{
0
,
kK1
});
});
}
// tail
{
block_sync_lds
();
gemm_1
(
o_acc
,
get_slice_tile
(
p
,
sequence
<
0
,
(
k1_loops
-
1
)
*
kK1
>
{},
sequence
<
kM0
,
k1_loops
*
kK1
>
{}),
v_lds_window
);
block_sync_lds
();
}
__builtin_amdgcn_sched_barrier
(
0
);
// load the first tile for next iteration
if
(
i_total_loops
<
num_total_loop
-
1
)
{
// store the first tile for next iteration to LDS
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
store_tile
(
k_lds_window
,
tile_elementwise_in
(
k_element_func
,
k_block_tile
));
}
}
while
(
++
i_total_loops
<
num_total_loop
);
if
constexpr
(
kStoreLSE
)
{
// store lse acc
auto
lse_acc
=
make_static_distributed_tensor
<
LSEDataType
>
(
m
.
get_tile_distribution
());
constexpr
auto
lse_acc_spans
=
decltype
(
lse_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
lse_acc_spans
[
number
<
0
>
{}],
[
&
,
m_
=
m
,
l_
=
l
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
else
{
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
*
scale_s
/
C_LOG2E
+
log
(
l_
[
i_idx
]);
}
#else
lse_acc
(
i_idx
)
=
m_
[
i_idx
]
+
log
(
l_
[
i_idx
]);
#endif
});
if
(
get_thread_local_1d_id
()
<
kM0
)
{
store_tile
(
lse_acc_dram_window_tmp
,
tile_elementwise_in
(
lse_acc_element_func
,
lse_acc
));
}
}
// finally, O
constexpr
auto
o_spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
else
return
1
/
l
[
i_idx
];
}();
sweep_tile_span
(
o_spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
}
template
<
typename
QDramBlockWindowTmp
,
typename
KDramBlockWindowLengths
,
typename
KPageBlockNavigator
,
typename
VDramBlockWindowLengths
,
typename
VPageBlockNavigator
,
typename
BiasDramBlockWindowTmp
,
typename
LSEaccDramBlockWindowTmp
,
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
operator
()(
const
QDramBlockWindowTmp
&
q_dram_block_window_tmp
,
// M0*K0 tile
const
KDramBlockWindowLengths
&
k_dram_block_window_lengths
,
// N0*K0 tile
const
KPageBlockNavigator
&
k_page_block_navigator
,
const
VDramBlockWindowLengths
&
v_dram_block_window_lengths
,
// N1*K1 tile
const
VPageBlockNavigator
&
v_page_block_navigator
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
LSEaccDramBlockWindowTmp
&
lse_acc_dram_block_window_tmp
,
// M0*1 tile
index_t
num_splits
,
index_t
i_split
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
float
scale_s
,
index_t
kv_l2p_offset
,
// logical-to-physical offset of seqlen_k coordinate
void
*
smem_ptr
)
const
{
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
k_dram_block_window_lengths
,
k_page_block_navigator
,
identity
{},
v_dram_block_window_lengths
,
v_page_block_navigator
,
identity
{},
bias_dram_block_window_tmp
,
identity
{},
lse_acc_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
i_split
,
mask
,
position_encoding
,
scale_s
,
kv_l2p_offset
,
smem_ptr
);
}
};
}
// namespace ck_tile
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment