Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c881136b
Commit
c881136b
authored
Jan 01, 2025
by
Po Yen Chen
Browse files
Merge branch 'develop' into ck_tile/support-vllm-kcache-layout
parents
c5e8e14f
4e076909
Changes
75
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1255 additions
and
284 deletions
+1255
-284
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+1
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-1
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+44
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-1
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-1
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-1
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-1
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+27
-4
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+22
-4
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-1
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+3
-3
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
+20
-8
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+75
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+66
-27
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+0
-48
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+6
-3
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
+0
-105
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+65
-18
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
...lock_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
+126
-47
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
...ock_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
+794
-0
No files found.
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
c881136b
...
...
@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t
<
decltype
(
StaticTileDistribution
{}.
get_ys_to_d_descriptor
())
>
;
static
constexpr
index_t
kThreadElementSpaceSize
=
ThreadTensorDesc
{}.
get_element_space_size
();
static_assert
(
0
<
kThreadElementSpaceSize
,
"Make sure tile distribution is valid"
);
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_num_of_dimension
()
{
...
...
include/ck_tile/host.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/host/arg_parser.hpp
View file @
c881136b
...
...
@@ -15,11 +15,14 @@
namespace
ck_tile
{
/*
* a host side utility, arg parser for
* -[key0]=[value0] -[key1]=[value1] ...
* a host side utility, arg parser for, either
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ...
*/
class
ArgParser
{
public:
class
Arg
{
...
...
@@ -187,6 +190,45 @@ class ArgParser
return
value
;
}
std
::
vector
<
std
::
string
>
get_string_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
std
::
string
s
=
get_str
(
name
);
std
::
vector
<
std
::
string
>
tokens
;
size_t
pos
=
0
;
std
::
string
token
;
while
((
pos
=
s
.
find
(
delimiter
))
!=
std
::
string
::
npos
)
{
token
=
s
.
substr
(
0
,
pos
);
tokens
.
push_back
(
token
);
s
.
erase
(
0
,
pos
+
delimiter
.
length
());
}
tokens
.
push_back
(
s
);
return
tokens
;
}
std
::
vector
<
int
>
get_int_vec
(
const
std
::
string
&
name
,
const
std
::
string
&
delimiter
=
","
)
const
{
if
(
get_str
(
name
).
empty
())
{
return
{};
}
const
std
::
vector
<
std
::
string
>
args
=
get_string_vec
(
name
,
delimiter
);
std
::
vector
<
int
>
tokens
;
tokens
.
reserve
(
static_cast
<
int
>
(
args
.
size
()));
for
(
const
std
::
string
&
token
:
args
)
{
int
value
=
atoi
(
token
.
c_str
());
tokens
.
push_back
(
value
);
}
return
tokens
;
}
private:
std
::
unordered_map
<
std
::
string
,
Arg
>
input_map
;
std
::
vector
<
std
::
string
>
keys
;
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/common.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/elementwise.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/epilogue.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -56,6 +56,13 @@ struct CShuffleEpilogue
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return
false
;
}
template
<
typename
OAccTile
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
{
...
...
@@ -111,7 +118,9 @@ struct CShuffleEpilogue
}
}
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
...
...
@@ -158,12 +167,26 @@ struct CShuffleEpilogue
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
}
};
...
...
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -35,21 +35,39 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
return
false
;
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
OAccTile
>
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
)
{
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
}
};
...
...
include/ck_tile/ops/flatmm.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
include/ck_tile/ops/fmha.hpp
View file @
c881136b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,9 +14,7 @@
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
...
...
@@ -28,6 +26,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
View file @
c881136b
...
...
@@ -10,10 +10,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
>
template
<
typename
FmhaPipeline_
>
struct
FmhaFwdAppendKVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
ck_tile
::
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
...
...
@@ -234,12 +233,25 @@ struct FmhaFwdAppendKVKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
seqlen_knew
);
// TODO: this may need tuning
return
dim3
(
std
::
max
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
FmhaPipeline
::
kM0
),
ck_tile
::
integer_divide_ceil
(
seqlen_knew
,
FmhaPipeline
::
kN0
)),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
/* kargs */
)
{
const
index_t
i_tile
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile
,
i_nhead
,
i_batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -247,7 +259,7 @@ struct FmhaFwdAppendKVKernel
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
// divide problem
const
auto
[
i_tile
,
i_nhead
,
i_batch
]
=
Tile
Partitioner
{}(
);
const
auto
[
i_tile
,
i_nhead
,
i_batch
]
=
Get
Tile
Index
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile
*
FmhaPipeline
::
kM0
);
const
index_t
i_n0
=
__builtin_amdgcn_readfirstlane
(
i_tile
*
FmhaPipeline
::
kN0
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
c881136b
...
...
@@ -20,10 +20,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
...
...
@@ -71,7 +70,8 @@ struct FmhaFwdKernel
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
g0wt
=
typename
bfs
::
Gemm0WarpTile
;
using
g1wt
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
...
...
@@ -83,12 +83,13 @@ struct FmhaFwdKernel
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
...
...
@@ -865,9 +866,75 @@ struct FmhaFwdKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
ck_tile
::
index_t
hdim_v_
,
bool
has_padded_seqlen_k
=
false
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
,
hdim_v_
);
// has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
if
(
has_padded_seqlen_k
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
FmhaPipeline
::
kN1
));
}
else
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
FmhaPipeline
::
kN1
),
nhead_
,
batch_size_
);
}
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
bool
has_padded_seqlen_k
=
false
;
if
constexpr
(
kIsGroupMode
)
has_padded_seqlen_k
=
(
kargs
.
seqlen_k_ptr
!=
nullptr
);
if
(
has_padded_seqlen_k
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
else
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -883,8 +950,7 @@ struct FmhaFwdKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
c881136b
...
...
@@ -5,12 +5,13 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVCombineKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
index_t
kNumWarps
=
FmhaPipeline
::
kNumWarps
;
static
constexpr
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
static
constexpr
index_t
kBlockPerCu
=
FmhaPipeline
::
kBlockPerCu
;
static_assert
(
kBlockPerCu
>
0
);
...
...
@@ -50,8 +51,7 @@ struct FmhaFwdSplitKVCombineKernel
return
_SS_
(
"fmha_fwd_splitkv_combine_d"
)
+
_TS_
(
FmhaPipeline
::
kHeadDimV
)
+
"_"
+
_SS_
(
t2s
<
ODataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
FmhaPipeline
::
kM0
)
+
"x"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
"b"
+
_TS_
(
FmhaPipeline
::
kN1
)
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
...
...
@@ -234,12 +234,35 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_seqlen_q
,
hdim_v
);
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -255,8 +278,7 @@ struct FmhaFwdSplitKVCombineKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
@@ -339,37 +361,56 @@ struct FmhaFwdSplitKVCombineKernel
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
// read 4 * (kM0, kN1) o_acc tiles simultaneously by 4 warps
const
auto
o_acc_dram_view
=
pad_tensor_view
(
o_acc_dram_naive
,
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
make_tuple
(
number
<
kNumWarps
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
true
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_num_splits
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
0
>
{}];
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
const
index_t
num_m_tiles
=
integer_divide_floor
(
padded_seqlen_q
,
FmhaPipeline
::
kM0
);
// transform tensor view by following steps, given shape: (padded_num_splits,
// padded_seqlen_q, padded_hdim_v)
// 1. unmerge to (padded_num_splits, num_m_tiles, kM0, padded_hdim_v)
// 2. transpose to (num_m_tiles, padded_num_splits, kM0, padded_hdim_v)
// 3. merge to (num_m_tiles * padded_num_splits * kM0, padded_hdim_v)
auto
transposed
=
transform_tensor_view
(
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_tuple
(
make_pass_through_transform
(
padded_num_splits
),
make_unmerge_transform
(
make_tuple
(
num_m_tiles
,
FmhaPipeline
::
kM0
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{},
sequence
<
3
>
{}));
return
transform_tensor_view
(
transposed
,
make_tuple
(
make_merge_transform
(
make_tuple
(
num_m_tiles
,
padded_num_splits
,
FmhaPipeline
::
kM0
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
,
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}();
auto
lse_acc_dram_window
=
make_tile_window
(
lse_acc_dram
,
[
&
]()
{
return
make_tuple
(
number
<
FmhaPipeline
::
kMaxSplits
>
{},
number
<
FmhaPipeline
::
kM0
>
{});
}(),
make_tuple
(
number
<
FmhaPipeline
::
kMaxSplits
>
{},
number
<
FmhaPipeline
::
kM0
>
{}),
{
0
,
i_m0
});
const
index_t
padded_num_splits
=
integer_divide_ceil
(
kargs
.
num_splits
,
kNumWarps
)
*
kNumWarps
;
auto
o_acc_dram_window
=
make_tile_window
(
o_acc_dram
,
[
&
]()
{
return
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{});
}(),
{
i_m0
,
i_n1
});
make_tuple
(
number
<
kNumWarps
*
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
{
i_tile_m
*
padded_num_splits
*
FmhaPipeline
::
kM0
,
i_n1
});
// LSE DRAM window
auto
lse_dram_window
=
[
&
,
i_nhead_
=
i_nhead
]()
{
...
...
@@ -410,7 +451,6 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
seqlen_q
,
smem_ptr
);
}
else
...
...
@@ -419,7 +459,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
seqlen_q
,
smem_ptr
);
}
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
deleted
100644 → 0
View file @
c5e8e14f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN1_
>
struct
FmhaFwdSplitKVCombineTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
c881136b
...
...
@@ -44,6 +44,7 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kPadHeadDimQ
=
FmhaPipeline
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
FmhaPipeline
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
FmhaPipeline
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
...
...
@@ -66,7 +67,8 @@ struct FmhaFwdSplitKVKernel
using
bfs
=
typename
FmhaPipeline
::
BlockFmhaShape
;
using
g0br
=
typename
bfs
::
Gemm0BlockWarps
;
using
g1br
=
typename
bfs
::
Gemm1BlockWarps
;
using
gwt
=
typename
bfs
::
Gemm0WarpTile
;
using
g0wt
=
typename
bfs
::
Gemm0WarpTile
;
using
g1wt
=
typename
bfs
::
Gemm1WarpTile
;
#define _SS_ std::string
#define _TS_ std::to_string
auto
pn
=
[
&
]
()
{
...
...
@@ -83,11 +85,12 @@ struct FmhaFwdSplitKVKernel
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"r"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
gwt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
"w"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g1wt
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
(
kBlockPerCuInput
==
-
1
?
""
:
(
"o"
+
_TS_
(
kBlockPerCu
)
+
"_"
))
+
_SS_
(
FmhaPipeline
::
name
)
+
"_"
+
"v"
+
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
?
"r"
:
"c"
)
+
(
pn
.
empty
()
?
""
:
"_"
+
pn
)
+
(
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
?
_SS_
(
""
)
:
(
_SS_
(
"_"
)
+
BlockAttentionBiasEnumToStr
<
BiasEnum
>::
name
))
+
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
)
+
(
kIsPagedKV
?
"_pagedkv"
:
""
);
(
kHasMask
?
"_"
+
_SS_
(
FmhaMask
::
name
)
:
""
)
+
(
kStoreLSE
?
"_lse"
:
""
)
+
(
kDoFp8StaticQuant
?
"_squant"
:
""
)
+
(
kIsPagedKV
?
"_pagedkv"
:
""
);
#undef _SS_
#undef _TS_
// clang-format on
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
c5e8e14f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"shb"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
template
<
typename
BlockFmhaShape_
>
using
FmhaFwdTilePartitioner_SHB
=
FmhaFwdTilePartitioner
<
BlockFmhaShape_
>
;
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner_HBS
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"hbs"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
c881136b
...
...
@@ -53,6 +53,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
using
OaccDataType
=
remove_cvref_t
<
typename
Problem
::
OaccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kHeadDimV
=
Problem
::
kHeadDimV
;
...
...
@@ -117,7 +118,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
// lse_acc tile in LDS
...
...
@@ -143,11 +143,12 @@ struct BlockFmhaFwdSplitKVCombinePipeline
// copy lse_acc tile (shape=[kMaxSplits, kM0]) to LDS (shape=[kMaxSplits, kM0]).
auto
lse_acc_tile
=
load_tile
(
lse_acc_dram_window
);
store_tile
(
lse_acc_lds_write_window
,
lse_acc_tile
);
block_sync_lds
();
auto
lse_accum
=
make_static_distributed_tensor
<
LSEDataType
>
(
Policy
::
template
MakeLSEaccRegTileDistribution
<
Problem
>());
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
// copy LDS (shape=[kM0, kMaxSplits]) to lse_accum (shape=[kM0, kMaxSplits])
// and fill up -INF values outside the [kM0, num_splits] region.
{
...
...
@@ -264,46 +265,94 @@ struct BlockFmhaFwdSplitKVCombinePipeline
}
});
}
block_sync_lds
();
if
constexpr
(
kStoreLSE
)
{
store_tile
(
lse_dram_window_tmp
,
tile_elementwise_in
(
lse_element_func
,
lse_logsum
));
}
auto
o_acc_dist
=
Policy
::
template
MakeOaccDramTileDistribution
<
Problem
>();
auto
o_acc_dram_window
=
auto
o_acc_
4_
dist
=
Policy
::
template
MakeOacc
4
DramTileDistribution
<
Problem
>();
auto
o_acc_
4_
dram_window
=
make_tile_window
(
o_acc_dram_block_window_tmp
.
get_bottom_tensor_view
(),
o_acc_dram_block_window_tmp
.
get_window_lengths
(),
o_acc_dram_block_window_tmp
.
get_window_origin
(),
o_acc_dist
);
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
o_acc_4_dist
);
const
index_t
padded_seqlen_q
=
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
kM0
;
// shape=[4 * KM0, kN1]
auto
o_acc_4
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_4_dist
);
clear_tile
(
o_acc_4
);
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
const
index_t
padded_num_splits
=
integer_divide_ceil
(
num_splits
,
kNumWarps
)
*
kNumWarps
;
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
// each warp handles a [KM0, kN1] tile
for
(
index_t
split_start
=
0
;
split_start
<
padded_num_splits
;
split_start
+=
kNumWarps
)
{
auto
o_tile
=
load_tile
(
o_acc_dram_window
);
auto
o_tile
=
load_tile
(
o_acc_4_dram_window
);
const
index_t
i_split
=
split_start
+
get_warp_id
();
const
index_t
row_start
=
kM0
*
get_warp_id
();
{
constexpr
auto
spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
constexpr
auto
spans
=
decltype
(
o_acc
_4
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
o_acc
.
get_tile_distribution
(),
i_j_idx
);
o_acc
_4
.
get_tile_distribution
(),
i_j_idx
);
const
auto
row
=
x_indices
.
at
(
number
<
0
>
{});
const
LSEDataType
lse_scale
=
lse_acc_lds
(
row
,
i_split
);
o_acc
(
i_j_idx
)
+=
lse_scale
*
o_tile
(
i_j_idx
);
const
LSEDataType
lse_scale
=
lse_acc_lds
(
row
-
row_start
,
i_split
);
o_acc
_4
(
i_j_idx
)
+=
lse_scale
*
o_tile
(
i_j_idx
);
});
});
}
move_tile_window
(
o_acc_dram_window
,
{
padded_seqlen_q
,
0
});
move_tile_window
(
o_acc_4_dram_window
,
{
kNumWarps
*
kM0
,
0
});
}
// 4 o_acc tiles in LDS. shape=[4 * kM0, kN1]
OaccDataType
*
o_acc_4_lds_ptr
=
static_cast
<
OaccDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
smem_ptr
)
+
Policy
::
template
GetSmemSizeLSEacc
<
Problem
>()));
{
auto
o_acc_4_lds_window
=
[
&
]()
{
auto
desc
=
Policy
::
template
MakeOacc4LdsBlockDescriptor
<
Problem
>();
auto
view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
o_acc_4_lds_ptr
,
desc
);
return
make_tile_window
(
view
,
desc
.
get_lengths
(),
{
0
,
0
});
}();
store_tile
(
o_acc_4_lds_window
,
o_acc_4
);
}
auto
o_acc_dist
=
Policy
::
template
MakeOaccDramTileDistribution
<
Problem
>();
auto
o_acc_4_lds_window
=
[
&
]()
{
auto
desc
=
Policy
::
template
MakeOacc4LdsBlockDescriptor
<
Problem
>();
auto
view
=
make_tensor_view
<
address_space_enum
::
lds
>
(
o_acc_4_lds_ptr
,
desc
);
return
make_tile_window
(
view
,
desc
.
get_lengths
(),
{
0
,
0
},
o_acc_dist
);
}();
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
__builtin_amdgcn_sched_barrier
(
0
);
block_sync_lds
();
static_for
<
0
,
kNumWarps
,
1
>
{}([
&
](
auto
)
{
auto
o_acc_in
=
load_tile
(
o_acc_4_lds_window
);
{
constexpr
auto
spans
=
decltype
(
o_acc
)
::
get_distributed_spans
();
sweep_tile_span
(
spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
spans
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
o_acc
(
i_j_idx
)
+=
o_acc_in
(
i_j_idx
);
});
});
}
move_tile_window
(
o_acc_4_lds_window
,
{
kM0
,
0
});
});
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
return
o_acc
;
...
...
@@ -316,7 +365,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
{
return
operator
()(
lse_acc_dram_block_window
,
...
...
@@ -325,7 +373,6 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity
{},
identity
{},
num_splits
,
seqlen_q
,
smem_ptr
);
}
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp
View file @
c881136b
...
...
@@ -10,23 +10,38 @@ namespace ck_tile {
struct
BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
template
<
index_t
BlockSize
,
index_t
M
,
index_t
N
,
typename
DataType
>
template
<
index_t
NumWarps
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetMaxNumWarpsForTile
()
{
static_assert
(
NumWarps
==
1
||
NumWarps
==
2
||
NumWarps
==
4
);
constexpr
index_t
ElemPerThread
=
(
M
*
N
)
/
(
NumWarps
*
get_warp_size
());
if
constexpr
(
0
<
ElemPerThread
)
{
return
NumWarps
;
}
else
{
// try dividing tile by smaller # of warps
return
GetMaxNumWarpsForTile
<
NumWarps
/
2
,
M
,
N
,
DataType
>
();
}
}
template
<
index_t
NumWarps
,
index_t
M
,
index_t
N
,
typename
DataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeForTile
()
{
constexpr
index_t
PixelsPerThread
=
(
M
*
N
)
/
BlockSize
;
static_assert
(
0
<
PixelsPerThread
);
constexpr
index_t
MaxNumWarps
=
GetMaxNumWarpsForTile
<
NumWarps
,
M
,
N
,
DataType
>
();
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
constexpr
index_t
NPerThread
=
min
(
MaxNPerThread
,
PixelsPerThread
);
constexpr
index_t
ElemPerThread
=
(
M
*
N
)
/
(
MaxNumWarps
*
get_warp_size
());
return
NPerThread
;
constexpr
index_t
MaxNPerThread
=
16
/
sizeof
(
DataType
);
return
min
(
MaxNPerThread
,
ElemPerThread
);
}
// alignment for dram lse tile (shape=[kMaxSplits, kM0])
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentLSE
()
{
return
GetVectorSizeForTile
<
Problem
::
k
BlockSize
,
return
GetVectorSizeForTile
<
Problem
::
k
NumWarps
,
Problem
::
kMaxSplits
,
Problem
::
kM0
,
typename
Problem
::
LSEDataType
>
();
...
...
@@ -56,40 +71,54 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
LSEacc
()
{
return
sizeof
(
typename
Problem
::
LSEDataType
)
*
MakeLSEaccLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeOacc4
()
{
return
sizeof
(
typename
Problem
::
OaccDataType
)
*
MakeOacc4LdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
GetSmemSizeLSEacc
<
Problem
>
()
+
GetSmemSizeOacc4
<
Problem
>
();
}
// shape=[kMaxSplits, kM0]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccDramTileDistribution
()
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNumWarps
=
Problem
::
kNumWarps
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
MaxNumWarps
=
GetMaxNumWarpsForTile
<
Problem
::
kNumWarps
,
kNPerBlock
,
kMPerBlock
,
LSEDataType
>
();
constexpr
index_t
Replicate
=
Problem
::
kNumWarps
/
MaxNumWarps
;
constexpr
index_t
NPerThread
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
MaxNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
index_t
NThreads
=
kNPerBlock
/
NPerThread
;
constexpr
index_t
MThreadsPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
k
NumWarps
*
MThreadsPerWarp
);
constexpr
index_t
MPerThread
=
kMPerBlock
/
(
Max
NumWarps
*
MThreadsPerWarp
);
static_assert
(
MPerThread
*
MaxNumWarps
*
MThreadsPerWarp
==
kMPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MPerThread
*
kNumWarps
*
MThreadsPerWarp
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MPerThread
,
k
NumWarps
,
MThreadsPerWarp
>
,
tile_distribution_encoding
<
sequence
<
Replicate
>
,
tuple
<
sequence
<
MPerThread
,
Max
NumWarps
,
MThreadsPerWarp
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
...
...
@@ -100,17 +129,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
NPack
>
{},
number
<
1
>
{});
constexpr
auto
lse_acc_lds_block_desc
=
transform_tensor_descriptor
(
...
...
@@ -129,17 +156,15 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
{
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kNPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
kBlockSize
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
lse_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
NPack
>
{},
number
<
1
>
{});
constexpr
auto
lse_acc_t_lds_block_desc
=
transform_tensor_descriptor
(
...
...
@@ -152,33 +177,86 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
return
lse_acc_t_lds_block_desc
;
}
// 3d + padding, shape=[4 * kM0, kN1]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
LSEaccRegTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
Make
Oacc4LdsBlockDescriptor
()
{
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
LSEDataType
=
remove_cvref_t
<
typename
Problem
::
LSEDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
kMPerBlock
=
4
*
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
constexpr
index_t
NPack
=
GetVectorSizeForTile
<
Problem
::
kNumWarps
,
kMPerBlock
,
kNPerBlock
,
LSEDataType
>
();
constexpr
auto
o_acc_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kNPerBlock
/
NPack
>
{},
number
<
kMPerBlock
>
{},
number
<
NPack
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
NPack
>
{},
number
<
NPack
>
{},
number
<
1
>
{}),
number
<
8
>
{},
number
<
1
>
{});
constexpr
auto
o_acc_t_lds_block_desc
=
transform_tensor_descriptor
(
o_acc_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kMPerBlock
),
make_merge_transform
(
make_tuple
(
kNPerBlock
/
NPack
,
NPack
))),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
o_acc_t_lds_block_desc
;
}
// shape=[kM0, kMaxSplits]
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLSEaccRegTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kMaxSplits
;
constexpr
index_t
NThreads
=
4
;
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MaxNThreads
=
8
;
constexpr
index_t
NThreads
=
min
(
kNPerBlock
,
MaxNThreads
);
constexpr
index_t
NPerThread
=
kNPerBlock
/
NThreads
;
constexpr
index_t
MThreads
=
kBlockSize
/
NThreads
;
constexpr
index_t
MPerThread
=
kMPerBlock
/
MThreads
;
constexpr
index_t
MWarps
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
MPerThread
=
1
;
constexpr
index_t
MThreads
=
kMPerBlock
/
MPerThread
;
constexpr
index_t
MThreadPerWarp
=
get_warp_size
()
/
NThreads
;
constexpr
index_t
MaxNumWarps
=
(
MThreads
*
NThreads
)
/
get_warp_size
();
constexpr
index_t
Replicate
=
Problem
::
kNumWarps
/
MaxNumWarps
;
static_assert
(
MaxNumWarps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
static_assert
(
NThreads
*
NPerThread
==
kNPerBlock
);
static_assert
(
MWarps
*
MThreadPerWarp
*
MPerThread
==
kMPerBlock
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
MWarps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
tile_distribution_encoding
<
sequence
<
Replicate
>
,
tuple
<
sequence
<
MaxNumWarps
,
MThreadPerWarp
,
MPerThread
>
,
sequence
<
NThreads
,
NPerThread
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
0
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
{});
}
// similar to MakeOaccDramTileDistribution(), but duplicate same 1-warp encoding 4 times on M
// direction
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOacc4DramTileDistribution
()
{
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
// real kMPerBlock we want is (4 * kM0)
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
static_assert
(
get_warp_size
()
<=
kMPerBlock
*
kNPerBlock
);
constexpr
index_t
M1
=
1
;
// compose encoding base on 1 warp
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
kNPerBlock
/
N0
;
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
4
,
M0
,
M1
,
M2
>
,
sequence
<
N0
,
N1
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
2
>
,
sequence
<
3
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
template
<
typename
Problem
>
...
...
@@ -187,6 +265,7 @@ struct BlockFmhaFwdSplitKVCombinePipelineDefaultPolicy
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
kM0
;
constexpr
index_t
kNPerBlock
=
Problem
::
kN1
;
static_assert
(
kBlockSize
<=
kMPerBlock
*
kNPerBlock
);
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
min
(
kMPerBlock
/
M1
,
get_warp_size
());
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp
0 → 100644
View file @
c881136b
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment