Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
b102083b
Commit
b102083b
authored
Jan 01, 2025
by
Po Yen Chen
Browse files
Only launch splitkv combine kernel when necessary
parent
ab1b16ac
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
206 additions
and
162 deletions
+206
-162
example/ck_tile/01_fmha/fmha_fwd.cpp
example/ck_tile/01_fmha/fmha_fwd.cpp
+84
-80
example/ck_tile/01_fmha/fmha_fwd.hpp
example/ck_tile/01_fmha/fmha_fwd.hpp
+122
-82
No files found.
example/ck_tile/01_fmha/fmha_fwd.cpp
View file @
b102083b
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include <array>
#include <array>
#include <cstring>
#include <cstring>
#include <functional>
#include <functional>
#include <map>
#include <numeric>
#include <numeric>
#include <ostream>
#include <ostream>
#include <string>
#include <string>
...
@@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
...
@@ -176,61 +177,14 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}
}
}
int
num_splits_heuristic
(
int
batch_nhead_mblocks
,
int
num_SMs
,
int
num_n_blocks
,
int
max_splits
)
int
override_num_splits_if_necessary
(
int
batch
,
{
int
nhead
,
// If we have enough to almost fill the SMs, then just use 1 split
int
max_seqlen_q
,
if
(
batch_nhead_mblocks
>=
0.8
f
*
num_SMs
)
int
hdim_q
,
{
int
hdim_v
,
return
1
;
float
p_drop
,
}
bool
is_prefill
,
max_splits
=
std
::
min
({
max_splits
,
num_SMs
,
num_n_blocks
});
int
num_splits
)
float
max_efficiency
=
0.
f
;
std
::
vector
<
float
>
efficiency
;
efficiency
.
reserve
(
max_splits
);
auto
ceildiv
=
[](
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
};
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto
is_split_eligible
=
[
&
ceildiv
,
&
num_n_blocks
](
int
num_splits
)
{
return
num_splits
==
1
||
ceildiv
(
num_n_blocks
,
num_splits
)
!=
ceildiv
(
num_n_blocks
,
num_splits
-
1
);
};
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
efficiency
.
push_back
(
0.
f
);
}
else
{
float
n_waves
=
float
(
batch_nhead_mblocks
*
num_splits
)
/
num_SMs
;
float
eff
=
n_waves
/
ceil
(
n_waves
);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
.
push_back
(
eff
);
}
}
for
(
int
num_splits
=
1
;
num_splits
<=
max_splits
;
num_splits
++
)
{
if
(
!
is_split_eligible
(
num_splits
))
{
continue
;
}
if
(
efficiency
[
num_splits
-
1
]
>=
0.85
*
max_efficiency
)
{
// printf("num_splits chosen = %d\n", num_splits);
return
num_splits
;
}
}
return
1
;
}
int
override_num_splits_if_necessary
(
int
batch
,
int
nhead
,
int
max_seqlen_q
,
int
hdim_v
,
float
p_drop
,
int
num_splits
)
{
{
int
device
;
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
auto
status
=
hipGetDevice
(
&
device
);
...
@@ -246,17 +200,41 @@ int override_num_splits_if_necessary(
...
@@ -246,17 +200,41 @@ int override_num_splits_if_necessary(
return
num_splits
;
return
num_splits
;
}
}
// tile size should match the generate.py
const
int
kM0
=
[
&
]
{
const
int
kM0
=
64
;
// get kM0 for prefill phase
const
int
kN1
=
hdim_v
;
if
(
is_prefill
)
{
return
128
;
}
// get kM0 for decode phase
/// TODO: take dtype=fp8/bf8 into consideration
const
std
::
map
<
int
,
int
>
hdim_to_m0
=
{
{
32
,
32
},
{
64
,
64
},
// {96, 64},
{
128
,
64
},
{
256
,
64
},
};
for
(
auto
[
hdim
,
m0
]
:
hdim_to_m0
)
{
if
(
hdim_q
<=
hdim
&&
hdim_v
<=
hdim
)
{
return
m0
;
}
}
return
64
;
// meet unsupported hdim_q/hdim_v
}();
// const int kN1 = hdim_v;
const
int
num_m_blocks
=
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
);
const
int
num_m_blocks
=
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
);
const
int
num_n_blocks
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
//
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
// always 1
if
(
num_splits
<
1
&&
p_drop
==
0.0
f
)
if
(
num_splits
<
1
&&
p_drop
==
0.0
f
)
{
{
return
num_splits_heuristic
(
return
num_splits_heuristic
(
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
8
);
batch
*
nhead
*
num_m_blocks
,
props
.
multiProcessorCount
*
2
,
num_n_blocks
,
128
);
}
}
return
num_splits
;
return
num_splits
;
...
@@ -556,8 +534,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -556,8 +534,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
// legalize num_splits according to other options
// legalize num_splits according to other options
if
(
num_splits
<
1
)
if
(
num_splits
<
1
)
{
{
num_splits
=
override_num_splits_if_necessary
(
num_splits
=
override_num_splits_if_necessary
(
batch
,
batch
,
nhead
,
max_seqlen_q
,
hdim_v
,
p_drop
,
num_splits
);
nhead
,
max_seqlen_q
,
hdim_q
,
hdim_v
,
p_drop
,
/*is_prefill=*/
mode
==
mode_enum
::
group
&&
0
<
page_block_size
,
num_splits
);
}
}
if
(
128
<
num_splits
)
if
(
128
<
num_splits
)
{
{
...
@@ -632,17 +617,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -632,17 +617,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto
[
rotary_cos_host
,
rotary_sin_host
]
=
generate_rotary_cos_sin
<
KDataType
>
(
auto
[
rotary_cos_host
,
rotary_sin_host
]
=
generate_rotary_cos_sin
<
KDataType
>
(
std
::
max
(
shape_seqlen_q
,
shape_seqlen_k
),
rotary_dim
,
seed
);
std
::
max
(
shape_seqlen_q
,
shape_seqlen_k
),
rotary_dim
,
seed
);
// lse_acc_host & o_acc_host are only used when 1 < num_spilts
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
ck_tile
::
HostTensor
<
LSEDataType
>
lse_acc_host
(
1
<
num_splits
||
use_kvcache
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
shape_batch
,
nhead
,
num_splits
,
shape_seqlen_q
}
?
std
::
array
<
ck_tile
::
index_t
,
4
>
{
shape_batch
,
nhead
,
num_splits
,
shape_seqlen_q
}
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
4
>
{
1
,
1
,
1
,
1
});
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
ck_tile
::
HostTensor
<
OaccDataType
>
o_acc_host
(
1
<
num_splits
||
use_kvcache
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
shape_batch
,
1
<
num_splits
?
std
::
array
<
ck_tile
::
index_t
,
5
>
{
shape_batch
,
nhead
,
nhead
,
num_splits
,
num_splits
,
shape_seqlen_q
,
shape_seqlen_q
,
hdim_v
}
hdim_v
}
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
:
std
::
array
<
ck_tile
::
index_t
,
5
>
{
1
,
1
,
1
,
1
,
1
});
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// batch mode of lse data layout is [batch, nhead, seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
// group mode of lse data layout is [nhead, total_seqlen_q]
...
@@ -1043,9 +1029,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1043,9 +1029,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
}
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
else
if
constexpr
(
std
::
is_same_v
<
fmha_fwd_splitkv_args
,
std
::
decay_t
<
decltype
(
args
)
>>
)
{
{
args
.
lse_acc_ptr
=
lse_acc_buf
.
GetDeviceBuffer
();
// lse_acc_buf & o_acc_buf are only used when 1 < num_spilts
args
.
o_acc_ptr
=
o_acc_buf
.
GetDeviceBuffer
();
args
.
block_table_ptr
=
args
.
block_table_ptr
=
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
(
0
<
page_block_size
?
block_table_buf
.
GetDeviceBuffer
()
:
nullptr
);
args
.
batch_stride_block_table
=
batch_stride_block_table
;
args
.
batch_stride_block_table
=
batch_stride_block_table
;
...
@@ -1057,13 +1041,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
...
@@ -1057,13 +1041,33 @@ bool run(const ck_tile::ArgParser& arg_parser)
args
.
num_splits
=
num_splits
;
args
.
num_splits
=
num_splits
;
args
.
stride_o_acc
=
stride_o_acc
;
if
(
1
<
num_splits
)
args
.
nhead_stride_lse_acc
=
nhead_stride_lse_acc
;
{
args
.
nhead_stride_o_acc
=
nhead_stride_o_acc
;
args
.
lse_acc_ptr
=
lse_acc_buf
.
GetDeviceBuffer
();
args
.
batch_stride_lse_acc
=
batch_stride_lse_acc
;
args
.
o_acc_ptr
=
o_acc_buf
.
GetDeviceBuffer
();
args
.
batch_stride_o_acc
=
batch_stride_o_acc
;
args
.
split_stride_lse_acc
=
split_stride_lse_acc
;
args
.
stride_o_acc
=
stride_o_acc
;
args
.
split_stride_o_acc
=
split_stride_o_acc
;
args
.
nhead_stride_lse_acc
=
nhead_stride_lse_acc
;
args
.
nhead_stride_o_acc
=
nhead_stride_o_acc
;
args
.
batch_stride_lse_acc
=
batch_stride_lse_acc
;
args
.
batch_stride_o_acc
=
batch_stride_o_acc
;
args
.
split_stride_lse_acc
=
split_stride_lse_acc
;
args
.
split_stride_o_acc
=
split_stride_o_acc
;
}
else
{
// following attribues are ignored by fmha_fwd_splitkv()
args
.
lse_acc_ptr
=
nullptr
;
args
.
o_acc_ptr
=
nullptr
;
args
.
stride_o_acc
=
0
;
args
.
nhead_stride_lse_acc
=
0
;
args
.
nhead_stride_o_acc
=
0
;
args
.
batch_stride_lse_acc
=
0
;
args
.
batch_stride_o_acc
=
0
;
args
.
split_stride_lse_acc
=
0
;
args
.
split_stride_o_acc
=
0
;
}
}
}
}
}
};
};
...
...
example/ck_tile/01_fmha/fmha_fwd.hpp
View file @
b102083b
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "mask.hpp"
#include "mask.hpp"
#include "rotary.hpp"
#include "rotary.hpp"
#include <array>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <variant>
#include <variant>
...
@@ -422,91 +423,93 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
...
@@ -422,91 +423,93 @@ auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
// create group mode kernel arguments
// create group mode kernel arguments
if
constexpr
(
Kernel
::
kIsGroupMode
)
if
constexpr
(
Kernel
::
kIsGroupMode
)
{
{
return
Kernel
::
MakeKargs
(
args
.
q_ptr
,
return
Kernel
::
MakeKargs
(
args
.
k_ptr
,
args
.
q_ptr
,
args
.
v_ptr
,
args
.
k_ptr
,
args
.
bias_ptr
,
args
.
v_ptr
,
args
.
lse_acc_ptr
,
args
.
bias_ptr
,
args
.
o_acc_ptr
,
(
1
<
args
.
num_splits
?
args
.
lse_acc_ptr
:
args
.
lse_ptr
),
args
.
batch
,
(
1
<
args
.
num_splits
?
args
.
o_acc_ptr
:
args
.
o_ptr
),
args
.
seqstart_q_ptr
,
args
.
batch
,
args
.
seqstart_k_ptr
,
args
.
seqstart_q_ptr
,
args
.
seqlen_k_ptr
,
args
.
seqstart_k_ptr
,
args
.
hdim_q
,
args
.
seqlen_k_ptr
,
args
.
hdim_v
,
args
.
hdim_q
,
args
.
nhead_q
,
args
.
hdim_v
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
,
args
.
num_splits
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
block_table_ptr
,
args
.
num_splits
,
args
.
batch_stride_block_table
,
args
.
block_table_ptr
,
args
.
page_block_size
,
args
.
batch_stride_block_table
,
args
.
is_gappy
,
args
.
page_block_size
,
args
.
scale_s
,
args
.
is_gappy
,
args
.
scale_p
,
args
.
scale_s
,
args
.
stride_q
,
args
.
scale_p
,
args
.
stride_k
,
args
.
stride_q
,
args
.
stride_v
,
args
.
stride_k
,
args
.
stride_bias
,
args
.
stride_v
,
args
.
stride_o_acc
,
args
.
stride_bias
,
args
.
nhead_stride_q
,
(
1
<
args
.
num_splits
?
args
.
stride_o_acc
:
args
.
stride_o
),
args
.
nhead_stride_k
,
args
.
nhead_stride_q
,
args
.
nhead_stride_v
,
args
.
nhead_stride_k
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_v
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
nhead_stride_lse_acc
:
args
.
nhead_stride_lse
),
args
.
batch_stride_k
,
// only used for paged-kvcache
(
1
<
args
.
num_splits
?
args
.
nhead_stride_o_acc
:
args
.
nhead_stride_o
),
args
.
batch_stride_v
,
// only used for paged-kvcache
args
.
batch_stride_k
,
// only used for paged-kvcache
args
.
split_stride_lse_acc
,
args
.
batch_stride_v
,
// only used for paged-kvcache
args
.
split_stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
split_stride_lse_acc
:
0
),
args
.
window_size_left
,
(
1
<
args
.
num_splits
?
args
.
split_stride_o_acc
:
0
),
args
.
window_size_right
,
args
.
window_size_left
,
args
.
mask_type
);
args
.
window_size_right
,
args
.
mask_type
);
}
}
else
else
{
// create batch mode kernel arguments
{
// create batch mode kernel arguments
return
Kernel
::
MakeKargs
(
args
.
q_ptr
,
return
Kernel
::
MakeKargs
(
args
.
k_ptr
,
args
.
q_ptr
,
args
.
v_ptr
,
args
.
k_ptr
,
args
.
bias_ptr
,
args
.
v_ptr
,
args
.
lse_acc_ptr
,
args
.
bias_ptr
,
args
.
o_acc_ptr
,
(
1
<
args
.
num_splits
?
args
.
lse_acc_ptr
:
args
.
lse_ptr
),
args
.
batch
,
(
1
<
args
.
num_splits
?
args
.
o_acc_ptr
:
args
.
o_ptr
),
args
.
seqlen_q
,
args
.
batch
,
args
.
seqlen_k
,
args
.
seqlen_q
,
args
.
seqlen_k_ptr
,
args
.
seqlen_k
,
args
.
hdim_q
,
args
.
seqlen_k_ptr
,
args
.
hdim_v
,
args
.
hdim_q
,
args
.
nhead_q
,
args
.
hdim_v
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
nhead_q
,
args
.
num_splits
,
args
.
nhead_q
/
args
.
nhead_k
,
args
.
block_table_ptr
,
args
.
num_splits
,
args
.
batch_stride_block_table
,
args
.
block_table_ptr
,
args
.
page_block_size
,
args
.
batch_stride_block_table
,
args
.
cache_batch_idx
,
args
.
page_block_size
,
args
.
scale_s
,
args
.
cache_batch_idx
,
args
.
scale_p
,
args
.
scale_s
,
args
.
stride_q
,
args
.
scale_p
,
args
.
stride_k
,
args
.
stride_q
,
args
.
stride_v
,
args
.
stride_k
,
args
.
stride_bias
,
args
.
stride_v
,
args
.
stride_o_acc
,
args
.
stride_bias
,
args
.
nhead_stride_q
,
(
1
<
args
.
num_splits
?
args
.
stride_o_acc
:
args
.
stride_o
),
args
.
nhead_stride_k
,
args
.
nhead_stride_q
,
args
.
nhead_stride_v
,
args
.
nhead_stride_k
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_v
,
args
.
nhead_stride_lse_acc
,
args
.
nhead_stride_bias
,
args
.
nhead_stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
nhead_stride_lse_acc
:
args
.
nhead_stride_lse
),
args
.
batch_stride_q
,
(
1
<
args
.
num_splits
?
args
.
nhead_stride_o_acc
:
args
.
nhead_stride_o
),
args
.
batch_stride_k
,
args
.
batch_stride_q
,
args
.
batch_stride_v
,
args
.
batch_stride_k
,
args
.
batch_stride_bias
,
args
.
batch_stride_v
,
args
.
batch_stride_lse_acc
,
args
.
batch_stride_bias
,
args
.
batch_stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
batch_stride_lse_acc
:
args
.
batch_stride_lse
),
args
.
split_stride_lse_acc
,
(
1
<
args
.
num_splits
?
args
.
batch_stride_o_acc
:
args
.
batch_stride_o
),
args
.
split_stride_o_acc
,
(
1
<
args
.
num_splits
?
args
.
split_stride_lse_acc
:
0
),
args
.
window_size_left
,
(
1
<
args
.
num_splits
?
args
.
split_stride_o_acc
:
0
),
args
.
window_size_right
,
args
.
window_size_left
,
args
.
mask_type
);
args
.
window_size_right
,
args
.
mask_type
);
}
}
}();
}();
...
@@ -821,3 +824,40 @@ struct fmha_fwd_appendkv_traits
...
@@ -821,3 +824,40 @@ struct fmha_fwd_appendkv_traits
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
float
fmha_fwd_appendkv
(
fmha_fwd_appendkv_traits
,
fmha_fwd_appendkv_args
,
fmha_fwd_appendkv_args
,
const
ck_tile
::
stream_config
&
);
const
ck_tile
::
stream_config
&
);
template
<
typename
Int
=
int
>
Int
num_splits_heuristic
(
Int
batch_nhead_mblocks
,
Int
num_SMs
,
Int
max_splits
)
{
// If we have enough to almost fill the SMs, then just use 1 split
if
(
batch_nhead_mblocks
>=
0.8
f
*
num_SMs
)
{
return
1
;
}
max_splits
=
std
::
min
({
max_splits
,
num_SMs
});
constexpr
std
::
array
<
Int
,
5
>
num_splits_array
=
{
1
,
2
,
4
,
8
,
16
};
float
max_efficiency
=
0.
f
;
std
::
array
<
float
,
num_splits_array
.
size
()
>
efficiency
;
for
(
size_t
idx
=
0
;
idx
<
num_splits_array
.
size
()
&&
num_splits_array
[
idx
]
<=
max_splits
;
++
idx
)
{
float
n_blocks
=
float
(
batch_nhead_mblocks
*
num_splits_array
[
idx
])
/
num_SMs
;
float
eff
=
n_blocks
/
std
::
ceil
(
n_blocks
);
if
(
eff
>
max_efficiency
)
{
max_efficiency
=
eff
;
}
efficiency
[
idx
]
=
eff
;
}
for
(
size_t
idx
=
0
;
idx
<
num_splits_array
.
size
()
&&
num_splits_array
[
idx
]
<=
max_splits
;
++
idx
)
{
if
(
efficiency
[
idx
]
>=
0.85
*
max_efficiency
)
{
return
num_splits_array
[
idx
];
}
}
return
1
;
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment