Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7a985548
Commit
7a985548
authored
May 22, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.9.0' into v0.9.0-ori
parents
45d3785c
dc1440cf
Changes
486
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1884 additions
and
992 deletions
+1884
-992
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+45
-52
benchmarks/kernels/benchmark_quant.py
benchmarks/kernels/benchmark_quant.py
+38
-33
benchmarks/kernels/benchmark_rmsnorm.py
benchmarks/kernels/benchmark_rmsnorm.py
+25
-34
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+46
-37
benchmarks/kernels/benchmark_w8a8_block_fp8.py
benchmarks/kernels/benchmark_w8a8_block_fp8.py
+53
-60
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
...hmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
+4
-2
benchmarks/kernels/graph_machete_bench.py
benchmarks/kernels/graph_machete_bench.py
+16
-17
benchmarks/kernels/utils.py
benchmarks/kernels/utils.py
+26
-26
benchmarks/overheads/benchmark_hashing.py
benchmarks/overheads/benchmark_hashing.py
+19
-17
benchmarks/pyproject.toml
benchmarks/pyproject.toml
+54
-0
benchmarks/run_structured_output_benchmark.sh
benchmarks/run_structured_output_benchmark.sh
+87
-23
cmake/cpu_extension.cmake
cmake/cpu_extension.cmake
+32
-1
cmake/utils.cmake
cmake/utils.cmake
+69
-20
csrc/activation_kernels.cu
csrc/activation_kernels.cu
+3
-0
csrc/attention/attention_kernels.cuh
csrc/attention/attention_kernels.cuh
+655
-655
csrc/attention/vertical_slash_index.cu
csrc/attention/vertical_slash_index.cu
+401
-0
csrc/core/math.hpp
csrc/core/math.hpp
+19
-0
csrc/core/scalar_type.hpp
csrc/core/scalar_type.hpp
+3
-0
csrc/cpu/cpu_types_vsx.hpp
csrc/cpu/cpu_types_vsx.hpp
+265
-0
csrc/cpu/pos_encoding.cpp
csrc/cpu/pos_encoding.cpp
+24
-15
No files found.
Too many changes to show.
To preserve performance only
486 of 486+
files are displayed.
Plain diff
Email patch
benchmarks/kernels/benchmark_paged_attention.py
View file @
7a985548
...
...
@@ -9,8 +9,11 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
)
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
FlexibleArgumentParser
,
create_kv_caches_with_random
,
)
logger
=
init_logger
(
__name__
)
...
...
@@ -38,19 +41,15 @@ def main(
current_platform
.
seed_everything
(
seed
)
scale
=
float
(
1.0
/
(
head_size
**
0.5
))
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
query
=
torch
.
empty
(
num_seqs
,
num_query_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
query
.
uniform_
(
-
scale
,
scale
)
assert
num_query_heads
%
num_kv_heads
==
0
alibi_slopes
=
None
if
use_alibi
:
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
device
)
alibi_slopes
=
torch
.
randn
(
num_query_heads
,
dtype
=
torch
.
float
,
device
=
device
)
seq_lens
=
[
seq_len
for
_
in
range
(
num_seqs
)]
max_seq_len
=
max
(
seq_lens
)
...
...
@@ -61,24 +60,23 @@ def main(
block_tables_lst
:
list
[
list
[
int
]]
=
[]
for
_
in
range
(
num_seqs
):
block_table
=
[
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
random
.
randint
(
0
,
NUM_BLOCKS
-
1
)
for
_
in
range
(
max_num_blocks_per_seq
)
]
block_tables_lst
.
append
(
block_table
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
,
device
=
device
)
block_tables
=
torch
.
tensor
(
block_tables_lst
,
dtype
=
torch
.
int
,
device
=
device
)
# Create the KV cache.
key_caches
,
value_caches
=
create_kv_caches_with_random
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
)
key_caches
,
value_caches
=
create_kv_caches_with_random
(
NUM_BLOCKS
,
block_size
,
1
,
num_kv_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
device
=
device
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Prepare for the paged attention kernel.
...
...
@@ -86,11 +84,8 @@ def main(
if
version
==
"v2"
:
if
current_platform
.
is_rocm
():
global
PARTITION_SIZE
if
not
args
.
custom_paged_attn
:
PARTITION_SIZE
=
1024
else
:
PARTITION_SIZE
=
PARTITION_SIZE_ROCM
num_partitions
=
((
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
)
PARTITION_SIZE
=
1024
if
not
args
.
custom_paged_attn
else
PARTITION_SIZE_ROCM
num_partitions
=
(
max_seq_len
+
PARTITION_SIZE
-
1
)
//
PARTITION_SIZE
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_query_heads
,
num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
...
...
@@ -110,9 +105,7 @@ def main(
start_time
=
time
.
perf_counter
()
# Using default kv_scale
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
k_scale
=
v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
,
device
=
device
)
for
_
in
range
(
num_iters
):
if
version
==
"v1"
:
...
...
@@ -195,30 +188,29 @@ def main(
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'__main__'
:
logger
.
warning
(
"This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
if
__name__
==
"__main__"
:
logger
.
warning
(
"This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the paged attention kernel."
)
parser
.
add_argument
(
"--version"
,
type
=
str
,
choices
=
[
"v1"
,
"v2"
],
default
=
"v2"
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--num-query-heads"
,
type
=
int
,
default
=
64
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
...
...
@@ -228,10 +220,11 @@ if __name__ == '__main__':
default
=
"auto"
,
help
=
"Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)"
)
parser
.
add_argument
(
"--custom-paged-attn"
,
action
=
"store_true"
,
help
=
"Use custom paged attention"
)
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)"
,
)
parser
.
add_argument
(
"--custom-paged-attn"
,
action
=
"store_true"
,
help
=
"Use custom paged attention"
)
args
=
parser
.
parse_args
()
print
(
args
)
...
...
benchmarks/kernels/benchmark_quant.py
View file @
7a985548
...
...
@@ -10,15 +10,17 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@
torch
.
inference_mode
()
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
static_scale
:
bool
,
quant_dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
)
->
None
:
def
main
(
num_tokens
:
int
,
hidden_size
:
int
,
static_scale
:
bool
,
quant_dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
seed
:
int
=
0
,
do_profile
:
bool
=
False
,
num_warmup_iters
:
int
=
5
,
num_iters
:
int
=
100
,
)
->
None
:
current_platform
.
seed_everything
(
seed
)
torch
.
set_default_device
(
"cuda"
)
...
...
@@ -56,7 +58,7 @@ def main(num_tokens: int,
print
(
f
"Kernel running time:
{
latency
*
1000000
:.
3
f
}
us"
)
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
def
to_torch_dtype
(
dt
):
if
dt
==
"int8"
:
...
...
@@ -66,37 +68,40 @@ if __name__ == '__main__':
raise
ValueError
(
f
"Unsupported dtype:
{
dt
}
"
)
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
description
=
"Benchmark the quantization (fp8 or int8) kernel."
)
parser
.
add_argument
(
"--num-tokens"
,
type
=
int
,
default
=
4096
)
parser
.
add_argument
(
"--hidden-size"
,
type
=
int
,
default
=
8192
)
parser
.
add_argument
(
"--static-scale"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--quant-dtype"
,
type
=
str
,
choices
=
[
"fp8"
,
"int8"
],
default
=
"int8"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--quant-dtype"
,
type
=
str
,
choices
=
[
"fp8"
,
"int8"
],
default
=
"int8"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"half"
,
"bfloat16"
,
"float"
],
default
=
"half"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--profile"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--num-warmup-iters"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
)
parser
.
add_argument
(
"--num-iters"
,
type
=
int
,
default
=
100
,
help
=
"Number of benchmark iterations. "
"If --profile is set, this number is ignored"
,
)
args
=
parser
.
parse_args
()
print
(
args
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
static_scale
=
args
.
static_scale
,
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
)
main
(
num_tokens
=
args
.
num_tokens
,
hidden_size
=
args
.
hidden_size
,
static_scale
=
args
.
static_scale
,
quant_dtype
=
to_torch_dtype
(
args
.
quant_dtype
),
dtype
=
STR_DTYPE_TO_TORCH_DTYPE
[
args
.
dtype
],
seed
=
args
.
seed
,
do_profile
=
args
.
profile
,
num_warmup_iters
=
args
.
num_warmup_iters
,
num_iters
=
args
.
num_iters
,
)
benchmarks/kernels/benchmark_rmsnorm.py
View file @
7a985548
...
...
@@ -4,15 +4,14 @@ import itertools
from
typing
import
Optional
,
Union
import
torch
import
triton
from
flashinfer.norm
import
fused_add_rmsnorm
,
rmsnorm
from
torch
import
nn
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm.triton_utils
import
triton
class
HuggingFaceRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
eps
:
float
=
1e-6
)
->
None
:
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
...
...
@@ -114,23 +113,19 @@ def rmsnorm_vllm(
def
calculate_diff
(
batch_size
,
seq_len
,
hidden_size
,
use_residual
=
True
):
dtype
=
torch
.
bfloat16
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
output_naive
=
rmsnorm_naive
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_flashinfer
=
rmsnorm_flashinfer
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
output_vllm
=
rmsnorm_vllm
(
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
x
.
clone
(),
weight
,
residual
.
clone
()
if
residual
is
not
None
else
None
)
if
use_residual
:
output_naive
=
output_naive
[
0
]
...
...
@@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
print
(
f
"FlashInfer output=
{
output_flashinfer
}
"
)
print
(
f
"vLLM output=
{
output_vllm
}
"
)
if
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
if
torch
.
allclose
(
output_naive
,
output_flashinfer
,
atol
=
1e-2
,
rtol
=
1e-2
)
and
torch
.
allclose
(
output_naive
,
output_vllm
,
atol
=
1e-2
,
rtol
=
1e-2
):
print
(
"✅ All implementations match"
)
else
:
print
(
"❌ Implementations differ"
)
...
...
@@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
7
,
2
)]
seq_length_range
=
[
2
**
i
for
i
in
range
(
6
,
11
,
1
)]
head_num_range
=
[
32
,
48
]
configs
=
list
(
itertools
.
product
(
head_num_range
,
batch_size_range
,
seq_length_range
))
configs
=
list
(
itertools
.
product
(
head_num_range
,
batch_size_range
,
seq_length_range
))
def
get_benchmark
(
use_residual
):
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"head_num"
,
"batch_size"
,
"seq_len"
],
...
...
@@ -167,19 +160,15 @@ def get_benchmark(use_residual):
line_names
=
[
"HuggingFace"
,
"FlashInfer"
,
"vLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"green"
,
"-"
),
(
"red"
,
"-"
)],
ylabel
=
"us"
,
plot_name
=
f
"rmsnorm-perf-
{
'with'
if
use_residual
else
'without'
}
-residual"
,
plot_name
=
f
"rmsnorm-perf-
{
'with'
if
use_residual
else
'without'
}
-residual"
,
args
=
{},
))
)
)
def
benchmark
(
head_num
,
batch_size
,
seq_len
,
provider
):
dtype
=
torch
.
bfloat16
hidden_size
=
head_num
*
128
# assuming head_dim = 128
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
torch
.
randn
(
batch_size
,
seq_len
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
weight
=
torch
.
ones
(
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
residual
=
torch
.
randn_like
(
x
)
if
use_residual
else
None
...
...
@@ -240,9 +229,9 @@ if __name__ == "__main__":
default
=
4096
,
help
=
"Hidden size (2nd dimension) of the sequence"
,
)
parser
.
add_argument
(
"--use-residual"
,
action
=
"store_true"
,
help
=
"Whether to use residual connection"
)
parser
.
add_argument
(
"--use-residual"
,
action
=
"store_true"
,
help
=
"Whether to use residual connection"
)
parser
.
add_argument
(
"--save-path"
,
type
=
str
,
...
...
@@ -253,10 +242,12 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
# Run correctness test
calculate_diff
(
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
hidden_size
=
args
.
hidden_size
,
use_residual
=
args
.
use_residual
)
calculate_diff
(
batch_size
=
args
.
batch_size
,
seq_len
=
args
.
seq_len
,
hidden_size
=
args
.
hidden_size
,
use_residual
=
args
.
use_residual
,
)
# Get the benchmark function with proper use_residual setting
benchmark
=
get_benchmark
(
args
.
use_residual
)
...
...
benchmarks/kernels/benchmark_rope.py
View file @
7a985548
...
...
@@ -6,8 +6,7 @@ from typing import Optional
import
nvtx
import
torch
from
vllm.model_executor.layers.rotary_embedding
import
(
RotaryEmbedding
,
get_rope
)
from
vllm.model_executor.layers.rotary_embedding
import
RotaryEmbedding
,
get_rope
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora(
# silulating serving 4 LoRAs
scaling_factors
=
[
1
,
2
,
4
,
8
]
# batched RoPE can take multiple scaling factors
batched_rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)
})
batched_rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
tuple
(
scaling_factors
)},
)
# non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior
non_batched_ropes
:
list
[
RotaryEmbedding
]
=
[]
for
scaling_factor
in
scaling_factors
:
non_batched_ropes
.
append
(
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
(
scaling_factor
,
)
}))
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
,
{
"rope_type"
:
"linear"
,
"factor"
:
(
scaling_factor
,)},
)
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
))
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
)
key
=
torch
.
randn_like
(
query
)
# create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type
offset_map
=
torch
.
tensor
(
list
(
accumulate
([
0
]
+
[
max_position
*
scaling_factor
*
2
for
scaling_factor
in
scaling_factors
[:
-
1
]
])))
query_types
=
torch
.
randint
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
)
accumulate
(
[
0
]
+
[
max_position
*
scaling_factor
*
2
for
scaling_factor
in
scaling_factors
[:
-
1
]
]
)
)
)
query_types
=
torch
.
randint
(
0
,
len
(
scaling_factors
),
(
batch_size
,
seq_len
),
device
=
device
)
# map query types to offsets
query_offsets
=
offset_map
[
query_types
]
# the kernel takes flattened offsets
...
...
@@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora(
torch
.
cuda
.
synchronize
()
if
__name__
==
'
__main__
'
:
if
__name__
==
"
__main__
"
:
parser
=
FlexibleArgumentParser
(
description
=
"Benchmark the rotary embedding kernels."
)
description
=
"Benchmark the rotary embedding kernels."
)
parser
.
add_argument
(
"--is-neox-style"
,
type
=
bool
,
default
=
True
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
16
)
parser
.
add_argument
(
"--seq-len"
,
type
=
int
,
default
=
512
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
)
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
],
default
=
128
,
)
parser
.
add_argument
(
"--rotary-dim"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
32
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bfloat16"
,
"float"
],
default
=
"float"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
choices
=
[
"bfloat16"
,
"float"
],
default
=
"float"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
choices
=
[
"cuda:0"
,
"cuda:1"
],
default
=
"cuda:0"
)
args
=
parser
.
parse_args
()
print
(
args
)
...
...
benchmarks/kernels/benchmark_w8a8_block_fp8.py
View file @
7a985548
...
...
@@ -14,14 +14,16 @@ import tqdm
import
triton
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
_w8a8_block_fp8_matmul
)
_w8a8_block_fp8_matmul
,
)
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
assert
current_platform
.
is_cuda
(
),
"Only support tune w8a8 block fp8 kernel on CUDA device."
assert
current_platform
.
is_cuda
(),
(
"Only support tune w8a8 block fp8 kernel on CUDA device."
)
DTYPE_MAP
=
{
"float32"
:
torch
.
float32
,
...
...
@@ -40,7 +42,7 @@ def w8a8_block_matmul(
config
:
dict
[
str
,
Any
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
"""This function performs matrix multiplication with
"""This function performs matrix multiplication with
block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
...
...
@@ -51,7 +53,7 @@ def w8a8_block_matmul(
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization.
block_size: The block size for per-block quantization.
It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
...
...
@@ -71,18 +73,18 @@ def w8a8_block_matmul(
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,
)
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
if
A
.
dtype
==
torch
.
float8_e4m3fn
:
kernel
=
_w8a8_block_fp8_matmul
else
:
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
kernel
[
grid
](
A
,
...
...
@@ -119,14 +121,16 @@ def get_configs_compute_bound():
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
num_warps
in
[
4
,
8
]:
for
group_size
in
[
1
,
16
,
32
,
64
]:
configs
.
append
({
"BLOCK_SIZE_M"
:
block_m
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
group_size
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
})
configs
.
append
(
{
"BLOCK_SIZE_M"
:
block_m
,
"BLOCK_SIZE_N"
:
block_n
,
"BLOCK_SIZE_K"
:
block_k
,
"GROUP_SIZE_M"
:
group_size
,
"num_warps"
:
num_warps
,
"num_stages"
:
num_stages
,
}
)
return
configs
...
...
@@ -165,15 +169,9 @@ def get_weight_shapes(tp_size):
return
weight_shapes
def
benchmark_config
(
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
=
torch
.
float16
,
num_iters
=
10
):
def
benchmark_config
(
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
=
torch
.
float16
,
num_iters
=
10
):
def
run
():
w8a8_block_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
config
,
out_dtype
)
...
...
@@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
A_fp32
=
(
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
)
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
)
A
=
A_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
B_fp32
=
(
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
)
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
-
0.5
)
*
2
*
fp8_max
)
B
=
B_fp32
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
else
:
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
raise
RuntimeError
(
"Currently, only support tune w8a8 block fp8 kernel."
)
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
As
=
torch
.
rand
(
M
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
factor_for_scale
Bs
=
(
torch
.
rand
(
n_tiles
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
factor_for_scale
)
As
=
torch
.
rand
(
M
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
factor_for_scale
Bs
=
(
torch
.
rand
(
n_tiles
,
k_tiles
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
factor_for_scale
)
best_config
=
None
best_time
=
float
(
"inf"
)
...
...
@@ -267,7 +265,8 @@ def save_configs(
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
json_file_name
=
(
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=
{
input_type
}
_w8a8,"
f
"block_shape=[
{
block_n
}
,
{
block_k
}
].json"
)
f
"block_shape=[
{
block_n
}
,
{
block_k
}
].json"
)
config_file_path
=
os
.
path
.
join
(
save_path
,
json_file_name
)
print
(
f
"Writing best config to
{
config_file_path
}
..."
)
...
...
@@ -295,8 +294,7 @@ def tune_on_gpu(args_dict):
search_space
=
get_configs_compute_bound
()
search_space
=
[
config
for
config
in
search_space
if
block_k
%
config
[
"BLOCK_SIZE_K"
]
==
0
config
for
config
in
search_space
if
block_k
%
config
[
"BLOCK_SIZE_K"
]
==
0
]
start
=
time
.
time
()
...
...
@@ -312,15 +310,11 @@ def tune_on_gpu(args_dict):
out_dtype
,
search_space
,
input_type
,
)
for
batch_size
in
tqdm
(
batch_sizes
,
desc
=
f
"GPU
{
gpu_id
}
- Batch sizes"
)
)
for
batch_size
in
tqdm
(
batch_sizes
,
desc
=
f
"GPU
{
gpu_id
}
- Batch sizes"
)
]
best_configs
=
{
M
:
config
for
M
,
config
in
zip
(
batch_sizes
,
benchmark_results
)
}
save_configs
(
N
,
K
,
block_n
,
block_k
,
best_configs
,
save_path
,
input_type
)
best_configs
=
{
M
:
config
for
M
,
config
in
zip
(
batch_sizes
,
benchmark_results
)}
save_configs
(
N
,
K
,
block_n
,
block_k
,
best_configs
,
save_path
,
input_type
)
end
=
time
.
time
()
print
(
f
"Tuning on GPU
{
gpu_id
}
took
{
end
-
start
:.
2
f
}
seconds"
)
...
...
@@ -376,13 +370,14 @@ def main(args):
process_args
=
[]
for
gpu_id
in
range
(
num_gpus
):
process_args
.
append
({
"gpu_id"
:
gpu_id
,
"batch_sizes"
:
batches_per_gpu
[
gpu_id
],
"weight_shapes"
:
weight_shapes
,
# Each GPU processes all weight shapes
"args"
:
args
,
})
process_args
.
append
(
{
"gpu_id"
:
gpu_id
,
"batch_sizes"
:
batches_per_gpu
[
gpu_id
],
"weight_shapes"
:
weight_shapes
,
# Each GPU processes all weight shapes
"args"
:
args
,
}
)
ctx
=
mp
.
get_context
(
"spawn"
)
with
ctx
.
Pool
(
num_gpus
)
as
pool
:
...
...
@@ -398,13 +393,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
Then copy to model_executor/layers/quantization/utils/configs
"""
,
formatter_class
=
argparse
.
RawTextHelpFormatter
)
formatter_class
=
argparse
.
RawTextHelpFormatter
,
)
parser
.
add_argument
(
"--tp-size"
,
"-tp"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--input-type"
,
type
=
str
,
choices
=
[
"fp8"
],
default
=
"fp8"
)
parser
.
add_argument
(
"--input-type"
,
type
=
str
,
choices
=
[
"fp8"
],
default
=
"fp8"
)
parser
.
add_argument
(
"--out-dtype"
,
type
=
str
,
...
...
benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py
View file @
7a985548
...
...
@@ -6,13 +6,15 @@ import time
# Import DeepGEMM functions
import
deep_gemm
import
torch
import
triton
from
deep_gemm
import
calc_diff
,
ceil_div
,
get_col_major_tma_aligned_tensor
# Import vLLM functions
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
)
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
,
)
from
vllm.triton_utils
import
triton
# Copied from
...
...
benchmarks/kernels/graph_machete_bench.py
View file @
7a985548
...
...
@@ -14,13 +14,14 @@ from vllm.utils import FlexibleArgumentParser
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
'Benchmark the latency of processing a single batch of '
'requests till completion.'
)
parser
.
add_argument
(
'filename'
,
type
=
str
)
description
=
"Benchmark the latency of processing a single batch of "
"requests till completion."
)
parser
.
add_argument
(
"filename"
,
type
=
str
)
args
=
parser
.
parse_args
()
with
open
(
args
.
filename
,
'
rb
'
)
as
f
:
with
open
(
args
.
filename
,
"
rb
"
)
as
f
:
data
=
pickle
.
load
(
f
)
raw_results
:
list
[
TMeasurement
]
=
data
[
"results"
]
...
...
@@ -38,11 +39,7 @@ if __name__ == "__main__":
raise
Exception
(
"MKN not found"
)
kernel
=
v
.
task_spec
.
description
results
[
KN
].
append
({
"kernel"
:
kernel
,
"batch_size"
:
M
,
"median"
:
v
.
median
})
results
[
KN
].
append
({
"kernel"
:
kernel
,
"batch_size"
:
M
,
"median"
:
v
.
median
})
rows
=
int
(
math
.
ceil
(
len
(
results
)
/
2
))
fig
,
axs
=
plt
.
subplots
(
rows
,
2
,
figsize
=
(
12
,
5
*
rows
))
...
...
@@ -50,14 +47,16 @@ if __name__ == "__main__":
for
axs_idx
,
(
shape
,
data
)
in
enumerate
(
results
.
items
()):
plt
.
sca
(
axs
[
axs_idx
])
df
=
pd
.
DataFrame
(
data
)
sns
.
lineplot
(
data
=
df
,
x
=
"batch_size"
,
y
=
"median"
,
hue
=
"kernel"
,
style
=
"kernel"
,
markers
=
True
,
dashes
=
False
,
palette
=
"Dark2"
)
sns
.
lineplot
(
data
=
df
,
x
=
"batch_size"
,
y
=
"median"
,
hue
=
"kernel"
,
style
=
"kernel"
,
markers
=
True
,
dashes
=
False
,
palette
=
"Dark2"
,
)
plt
.
title
(
f
"Shape:
{
shape
}
"
)
plt
.
ylabel
(
"time (median, s)"
)
plt
.
tight_layout
()
...
...
benchmarks/kernels/utils.py
View file @
7a985548
...
...
@@ -23,6 +23,7 @@ class ArgPool:
For every invocation during a benchmarking run, it will choose a
different value from the list.
"""
values
:
Iterable
[
Any
]
def
__getitem__
(
self
,
index
):
...
...
@@ -30,9 +31,7 @@ class ArgPool:
class
Bench
:
class
ArgsIterator
:
def
__init__
(
self
,
args_list
,
kwargs_list
):
assert
len
(
args_list
)
==
len
(
kwargs_list
)
self
.
args_list
=
args_list
...
...
@@ -53,10 +52,16 @@ class Bench:
def
n_args
(
self
):
return
self
.
n
def
__init__
(
self
,
cuda_graph_params
:
Optional
[
CudaGraphBenchParams
],
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
cuda_graph_params
:
Optional
[
CudaGraphBenchParams
],
label
:
str
,
sub_label
:
str
,
description
:
str
,
fn
:
Callable
,
*
args
,
**
kwargs
,
):
self
.
cuda_graph_params
=
cuda_graph_params
self
.
use_cuda_graph
=
self
.
cuda_graph_params
is
not
None
self
.
label
=
label
...
...
@@ -67,10 +72,8 @@ class Bench:
# Process args
self
.
_args
=
args
self
.
_kwargs
=
kwargs
self
.
args_list
,
self
.
kwargs_list
=
self
.
collapse_argpool
(
*
args
,
**
kwargs
)
self
.
args_iterator
=
self
.
ArgsIterator
(
self
.
args_list
,
self
.
kwargs_list
)
self
.
args_list
,
self
.
kwargs_list
=
self
.
collapse_argpool
(
*
args
,
**
kwargs
)
self
.
args_iterator
=
self
.
ArgsIterator
(
self
.
args_list
,
self
.
kwargs_list
)
# Cudagraph runner
self
.
g
=
None
...
...
@@ -100,16 +103,13 @@ class Bench:
for
i
in
range
(
argpool_size
):
# collapse args; Just pick the ith value
args_list
[
i
]
=
tuple
([
arg
[
i
]
if
isinstance
(
arg
,
ArgPool
)
else
arg
for
arg
in
args_list
[
i
]
])
args_list
[
i
]
=
tuple
(
[
arg
[
i
]
if
isinstance
(
arg
,
ArgPool
)
else
arg
for
arg
in
args_list
[
i
]]
)
# collapse kwargs
kwargs_i
=
kwargs_list
[
i
]
arg_pool_keys
=
[
k
for
k
,
v
in
kwargs_i
.
items
()
if
isinstance
(
v
,
ArgPool
)
]
arg_pool_keys
=
[
k
for
k
,
v
in
kwargs_i
.
items
()
if
isinstance
(
v
,
ArgPool
)]
for
k
in
arg_pool_keys
:
# again just pick the ith value
kwargs_i
[
k
]
=
kwargs_i
[
k
][
i
]
...
...
@@ -142,7 +142,7 @@ class Bench:
def
run_cudagrah
(
self
)
->
TMeasurement
:
assert
self
.
use_cuda_graph
globals
=
{
'g'
:
self
.
g
}
globals
=
{
"g"
:
self
.
g
}
return
TBenchmark
.
Timer
(
stmt
=
"g.replay()"
,
...
...
@@ -162,15 +162,15 @@ class Bench:
has_arg_pool
=
self
.
args_iterator
.
n_args
>
1
if
has_arg_pool
:
setup
=
'''
setup
=
"""
args_iterator.reset()
args_it = args_iterator.__next__()
'''
stmt
=
'''
"""
stmt
=
"""
args, kwargs = next(args_it)
fn(*args, **kwargs)
'''
globals
=
{
'
fn
'
:
self
.
fn
,
'
args_iterator
'
:
self
.
args_iterator
}
"""
globals
=
{
"
fn
"
:
self
.
fn
,
"
args_iterator
"
:
self
.
args_iterator
}
else
:
# no arg pool. Just use the args and kwargs directly
self
.
args_iterator
.
reset
()
...
...
@@ -178,10 +178,10 @@ class Bench:
args
,
kwargs
=
next
(
args_it
)
setup
=
""
stmt
=
'''
stmt
=
"""
fn(*args, **kwargs)
'''
globals
=
{
'
fn
'
:
self
.
fn
,
'
args
'
:
args
,
'
kwargs
'
:
kwargs
}
"""
globals
=
{
"
fn
"
:
self
.
fn
,
"
args
"
:
args
,
"
kwargs
"
:
kwargs
}
return
TBenchmark
.
Timer
(
stmt
=
stmt
,
...
...
benchmarks/overheads/benchmark_hashing.py
View file @
7a985548
...
...
@@ -7,9 +7,8 @@ from vllm import LLM, SamplingParams
from
vllm.utils
import
FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k.
LONG_PROMPT
=
[
"You are an expert in large language models, aren't you?"
]
*
1000
LONG_PROMPT
=
' '
.
join
(
LONG_PROMPT
)
LONG_PROMPT
=
[
"You are an expert in large language models, aren't you?"
]
*
1000
LONG_PROMPT
=
" "
.
join
(
LONG_PROMPT
)
def
main
(
args
):
...
...
@@ -30,32 +29,35 @@ def main(args):
print
(
"------start generating------"
)
for
i
in
range
(
3
):
profiler
.
runctx
(
'llm.generate(LONG_PROMPT, sampling_params)'
,
globals
(),
locals
())
profiler
.
runctx
(
"llm.generate(LONG_PROMPT, sampling_params)"
,
globals
(),
locals
()
)
# analyze the runtime of hashing function
stats
=
pstats
.
Stats
(
profiler
)
stats
.
sort_stats
(
'
cumulative
'
)
stats
.
sort_stats
(
"
cumulative
"
)
total_time
=
0
total_calls
=
0
for
func
in
stats
.
stats
:
if
'
hash_of_block
'
in
func
[
2
]:
if
"
hash_of_block
"
in
func
[
2
]:
total_time
=
stats
.
stats
[
func
][
3
]
total_calls
=
stats
.
stats
[
func
][
0
]
percentage
=
(
total_time
/
stats
.
total_tt
)
*
100
print
(
f
"Hashing took
{
total_time
:.
2
f
}
seconds,"
f
"
{
percentage
:.
2
f
}
% of the total runtime."
)
print
(
f
"Hashing took
{
total_time
:.
2
f
}
seconds,
{
percentage
:.
2
f
}
% of the total runtime."
)
if
__name__
==
"__main__"
:
parser
=
FlexibleArgumentParser
(
description
=
'Benchmark the performance of hashing function in'
'automatic prefix caching.'
)
parser
.
add_argument
(
'--model'
,
type
=
str
,
default
=
'lmsys/longchat-7b-16k'
)
parser
.
add_argument
(
'--tensor-parallel-size'
,
'-tp'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--output-len'
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
'--enable-prefix-caching'
,
action
=
'store_true'
,
help
=
'enable prefix caching'
)
description
=
"Benchmark the performance of hashing function in"
"automatic prefix caching."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"lmsys/longchat-7b-16k"
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
10
)
parser
.
add_argument
(
"--enable-prefix-caching"
,
action
=
"store_true"
,
help
=
"enable prefix caching"
)
args
=
parser
.
parse_args
()
main
(
args
)
benchmarks/pyproject.toml
0 → 100644
View file @
7a985548
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length
=
88
exclude
=
[
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py"
,
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**"
=
["ALL"]
"vllm/version.py"
=
["F401"]
"vllm/_version.py"
=
["ALL"]
[tool.ruff.lint]
select
=
[
# pycodestyle
"E"
,
# Pyflakes
"F"
,
# pyupgrade
"UP"
,
# flake8-bugbear
"B"
,
# flake8-simplify
"SIM"
,
# isort
"I"
,
# flake8-logging-format
"G"
,
]
ignore
=
[
# star imports
"F405"
,
"F403"
,
# lambda expression assignment
"E731"
,
# Loop control variable not used within loop body
"B007"
,
# f-string format
"UP032"
,
# Can remove once 3.10+ is the minimum Python version
"UP007"
,
]
[tool.ruff.lint.isort]
known-first-party
=
["vllm"]
[tool.ruff.format]
docstring-code-format
=
true
\ No newline at end of file
benchmarks/run_structured_output_benchmark.sh
View file @
7a985548
#!/bin/bash
# Define the model to use
MODEL
=
${
1
:-
"Qwen/Qwen2.5-7B-Instruct"
}
# Define the backend to use
BACKEND
=
${
2
:-
"vllm"
}
# Define the dataset to use
DATASET
=
${
3
:-
"xgrammar_bench"
}
# Define the guided decoding backend
GUIDED_BACKEND
=
${
4
:-
"xgrammar"
}
# default values
MODEL
=
${
MODEL
:-
"Qwen/Qwen2.5-7B-Instruct"
}
BACKEND
=
${
BACKEND
:-
"vllm"
}
DATASET
=
${
DATASET
:-
"xgrammar_bench"
}
SCRIPT_DIR
=
"
$(
cd
"
$(
dirname
"
${
BASH_SOURCE
[0]
}
"
)
"
&&
pwd
)
"
OUTPUT_DIR
=
${
5
:-
"
$SCRIPT_DIR
/structured_output_benchmark_results"
}
GUIDED_RATIO
=
${
6
:-
0
.5
}
OUTPUT_DIR
=
${
OUTPUT_DIR
:-
"
$SCRIPT_DIR
/structured_output_benchmark_results"
}
PORT
=
${
PORT
:-
8000
}
STRUCTURED_OUTPUT_RATIO
=
${
STRUCTURED_OUTPUT_RATIO
:-
1
}
TOTAL_SECONDS
=
${
TOTAL_SECONDS
:-
90
}
MAX_NEW_TOKENS
=
${
MAX_NEW_TOKENS
:-
300
}
TOKENIZER_MODE
=
${
TOKENIZER_MODE
:-
"auto"
}
usage
()
{
echo
"Usage:
$0
[options]"
echo
"Options:"
echo
" --model MODEL Model to benchmark (default:
$MODEL
)"
echo
" --backend BACKEND Backend to use (default:
$BACKEND
)"
echo
" --dataset DATASET Dataset to use (default:
$DATASET
)"
echo
" --max-new-tokens N Maximum number of tokens to generate (default:
$MAX_NEW_TOKENS
)"
echo
" --output-dir DIR Output directory for results (default:
$OUTPUT_DIR
)"
echo
" --port PORT Port to use (default:
$PORT
)"
echo
" --structured-output-ratio N Ratio of structured outputs (default:
$STRUCTURED_OUTPUT_RATIO
)"
echo
" --tokenizer-mode MODE Tokenizer mode to use (default:
$TOKENIZER_MODE
)"
echo
" --total-seconds N Total seconds to run the benchmark (default:
$TOTAL_SECONDS
)"
echo
" -h, --help Show this help message and exit"
exit
0
}
# parse command line arguments
while
[[
$#
-gt
0
]]
;
do
case
$1
in
--model
)
MODEL
=
"
$2
"
shift
2
;;
--backend
)
BACKEND
=
"
$2
"
shift
2
;;
--dataset
)
DATASET
=
"
$2
"
shift
2
;;
--max-new-tokens
)
MAX_NEW_TOKENS
=
"
$2
"
shift
2
;;
--output-dir
)
OUTPUT_DIR
=
"
$2
"
shift
2
;;
--port
)
PORT
=
"
$2
"
shift
2
;;
--structured-output-ratio
)
STRUCTURED_OUTPUT_RATIO
=
"
$2
"
shift
2
;;
--tokenizer-mode
)
TOKENIZER_MODE
=
"
$2
"
shift
2
;;
--total-seconds
)
TOTAL_SECONDS
=
"
$2
"
shift
2
;;
-h
|
--help
)
usage
;;
*
)
echo
"Unknown argument:
$1
\n
"
usage
;;
esac
done
# Create output directory if it doesn't exist
mkdir
-p
"
$OUTPUT_DIR
"
# Define QPS values to test
QPS_VALUES
=(
70 60 50
25 20 15 10
)
QPS_VALUES
=(
25 20 15 10
5 1
)
# Common parameters
COMMON_PARAMS
=
"--backend
$BACKEND
\
--model
$MODEL
\
--dataset
$DATASET
\
--structured-output-backend
$GUIDED_BACKEND
\
--structured-output-ratio
$GUIDED_RATIO
\
--structured-output-ratio
$STRUCTURED_OUTPUT_RATIO
\
--save-results
\
--result-dir
$OUTPUT_DIR
"
--result-dir
$OUTPUT_DIR
\
--output-len
$MAX_NEW_TOKENS
\
--port
$PORT
\
--tokenizer-mode
$TOKENIZER_MODE
"
echo
"Starting structured output benchmark with model:
$MODEL
"
echo
"Backend:
$BACKEND
"
echo
"Dataset:
$DATASET
"
echo
"Structured output backend:
$GUIDED_BACKEND
"
echo
"Results will be saved to:
$OUTPUT_DIR
"
echo
"----------------------------------------"
...
...
@@ -48,14 +109,17 @@ for qps in "${QPS_VALUES[@]}"; do
GIT_BRANCH
=
$(
git rev-parse
--abbrev-ref
HEAD 2>/dev/null
||
echo
"unknown"
)
# Construct filename for this run
FILENAME
=
"
${
GUIDED_BACKEND
}
_
${
BACKEND
}
_
${
qps
}
qps_
$(
basename
$MODEL
)
_
${
DATASET
}
_
${
GIT_HASH
}
.json"
FILENAME
=
"
${
BACKEND
}
_
${
qps
}
qps_
$(
basename
$MODEL
)
_
${
DATASET
}
_
${
GIT_HASH
}
.json"
NUM_PROMPTS
=
$(
echo
"
$TOTAL_SECONDS
*
$qps
"
| bc
)
NUM_PROMPTS
=
${
NUM_PROMPTS
%.*
}
# Remove fractional part
echo
"Running benchmark with
$NUM_PROMPTS
prompts"
# Run the benchmark
python
"
$SCRIPT_DIR
/benchmark_serving_structured_output.py"
$COMMON_PARAMS
\
--request-rate
$qps
\
--result-filename
"
$FILENAME
"
\
--tokenizer-mode
${
TOKENIZER_MODE
:-
"auto"
}
\
--port
${
PORT
:-
8000
}
--num-prompts
$NUM_PROMPTS
echo
"Completed benchmark with QPS:
$qps
"
echo
"----------------------------------------"
...
...
cmake/cpu_extension.cmake
View file @
7a985548
...
...
@@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
FetchContent_MakeAvailable
(
oneDNN
)
list
(
APPEND LIBS dnnl
)
elseif
(
POWER10_FOUND
)
FetchContent_Declare
(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.7.2
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
set
(
ONEDNN_LIBRARY_TYPE
"STATIC"
)
set
(
ONEDNN_BUILD_DOC
"OFF"
)
set
(
ONEDNN_BUILD_EXAMPLES
"OFF"
)
set
(
ONEDNN_BUILD_TESTS
"OFF"
)
set
(
ONEDNN_ENABLE_WORKLOAD
"INFERENCE"
)
set
(
ONEDNN_ENABLE_PRIMITIVE
"MATMUL;REORDER"
)
set
(
ONEDNN_BUILD_GRAPH
"OFF"
)
set
(
ONEDNN_ENABLE_JIT_PROFILING
"OFF"
)
set
(
ONEDNN_ENABLE_ITT_TASKS
"OFF"
)
set
(
ONEDNN_ENABLE_MAX_CPU_ISA
"OFF"
)
set
(
ONEDNN_ENABLE_CPU_ISA_HINTS
"OFF"
)
set
(
CMAKE_POLICY_DEFAULT_CMP0077 NEW
)
set
(
DNNL_CPU_RUNTIME
"OMP"
)
FetchContent_MakeAvailable
(
oneDNN
)
list
(
APPEND LIBS dnnl
)
endif
()
...
...
@@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
"csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp"
${
VLLM_EXT_SRC
}
)
elseif
(
POWER10_FOUND
)
set
(
VLLM_EXT_SRC
"csrc/cpu/quant.cpp"
${
VLLM_EXT_SRC
}
)
endif
()
#
...
...
@@ -214,4 +245,4 @@ define_gpu_extension_target(
WITH_SOABI
)
message
(
STATUS
"Enabling C extension."
)
\ No newline at end of file
message
(
STATUS
"Enabling C extension."
)
cmake/utils.cmake
View file @
7a985548
...
...
@@ -229,11 +229,26 @@ macro(set_gencode_flags_for_srcs)
"
${
multiValueArgs
}
"
${
ARGN
}
)
foreach
(
_ARCH
${
arg_CUDA_ARCHS
}
)
string
(
REPLACE
"."
""
_ARCH
"
${
_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_ARCH
}
"
CODE
"sm_
${
_ARCH
}
"
)
# handle +PTX suffix: generate both sm and ptx codes if requested
string
(
FIND
"
${
_ARCH
}
"
"+PTX"
_HAS_PTX
)
if
(
NOT _HAS_PTX EQUAL -1
)
string
(
REPLACE
"+PTX"
""
_BASE_ARCH
"
${
_ARCH
}
"
)
string
(
REPLACE
"."
""
_STRIPPED_ARCH
"
${
_BASE_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_STRIPPED_ARCH
}
"
CODE
"sm_
${
_STRIPPED_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_STRIPPED_ARCH
}
"
CODE
"compute_
${
_STRIPPED_ARCH
}
"
)
else
()
string
(
REPLACE
"."
""
_STRIPPED_ARCH
"
${
_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_STRIPPED_ARCH
}
"
CODE
"sm_
${
_STRIPPED_ARCH
}
"
)
endif
()
endforeach
()
if
(
${
arg_BUILD_PTX_FOR_ARCH
}
)
...
...
@@ -252,7 +267,10 @@ endmacro()
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes.
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
# architecture in `SRC_CUDA_ARCHS`.
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
...
...
@@ -269,44 +287,63 @@ endmacro()
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
# Example With PTX:
# SRC_CUDA_ARCHS="8.0+PTX"
# TGT_CUDA_ARCHS="9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0+PTX"
#
function
(
cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS
)
list
(
REMOVE_DUPLICATES SRC_CUDA_ARCHS
)
set
(
TGT_CUDA_ARCHS_
${
TGT_CUDA_ARCHS
}
)
set
(
_SRC_CUDA_ARCHS
"
${
SRC_CUDA_ARCHS
}
"
)
set
(
_TGT_CUDA_ARCHS
${
TGT_CUDA_ARCHS
}
)
# handle +PTX suffix: separate base arch for matching, record PTX requests
set
(
_PTX_ARCHS
)
foreach
(
_arch
${
_SRC_CUDA_ARCHS
}
)
if
(
_arch MATCHES
"
\\
+PTX$"
)
string
(
REPLACE
"+PTX"
""
_base
"
${
_arch
}
"
)
list
(
APPEND _PTX_ARCHS
"
${
_base
}
"
)
list
(
REMOVE_ITEM _SRC_CUDA_ARCHS
"
${
_arch
}
"
)
list
(
APPEND _SRC_CUDA_ARCHS
"
${
_base
}
"
)
endif
()
endforeach
()
list
(
REMOVE_DUPLICATES _PTX_ARCHS
)
list
(
REMOVE_DUPLICATES _SRC_CUDA_ARCHS
)
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set
(
_CUDA_ARCHS
)
if
(
"9.0a"
IN_LIST SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM SRC_CUDA_ARCHS
"9.0a"
)
if
(
"9.0"
IN_LIST TGT_CUDA_ARCHS
_
)
list
(
REMOVE_ITEM TGT_CUDA_ARCHS
_
"9.0"
)
if
(
"9.0a"
IN_LIST
_
SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM
_
SRC_CUDA_ARCHS
"9.0a"
)
if
(
"9.0"
IN_LIST TGT_CUDA_ARCHS
)
list
(
REMOVE_ITEM
_
TGT_CUDA_ARCHS
"9.0"
)
set
(
_CUDA_ARCHS
"9.0a"
)
endif
()
endif
()
if
(
"10.0a"
IN_LIST SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM SRC_CUDA_ARCHS
"10.0a"
)
if
(
"10.0a"
IN_LIST
_
SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM
_
SRC_CUDA_ARCHS
"10.0a"
)
if
(
"10.0"
IN_LIST TGT_CUDA_ARCHS
)
list
(
REMOVE_ITEM TGT_CUDA_ARCHS
_
"10.0"
)
list
(
REMOVE_ITEM
_
TGT_CUDA_ARCHS
"10.0"
)
set
(
_CUDA_ARCHS
"10.0a"
)
endif
()
endif
()
list
(
SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING
)
list
(
SORT
_
SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING
)
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# is less or equal to ARCH (but has the same major version since SASS binary
# compatibility is only forward compatible within the same major version).
foreach
(
_ARCH
${
TGT_CUDA_ARCHS
_
}
)
foreach
(
_ARCH
${
_
TGT_CUDA_ARCHS
}
)
set
(
_TMP_ARCH
)
# Extract the major version of the target arch
string
(
REGEX REPLACE
"^([0-9]+)
\\
..*$"
"
\\
1"
TGT_ARCH_MAJOR
"
${
_ARCH
}
"
)
foreach
(
_SRC_ARCH
${
SRC_CUDA_ARCHS
}
)
foreach
(
_SRC_ARCH
${
_
SRC_CUDA_ARCHS
}
)
# Extract the major version of the source arch
string
(
REGEX REPLACE
"^([0-9]+)
\\
..*$"
"
\\
1"
SRC_ARCH_MAJOR
"
${
_SRC_ARCH
}
"
)
# Check
major-version match AND version-less-or-equal
# Check
version-less-or-equal, and allow PTX arches to match across majors
if
(
_SRC_ARCH VERSION_LESS_EQUAL _ARCH
)
if
(
SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR
)
if
(
_SRC_ARCH IN_LIST _PTX_ARCHS OR
SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR
)
set
(
_TMP_ARCH
"
${
_SRC_ARCH
}
"
)
endif
()
else
()
...
...
@@ -322,6 +359,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach
()
list
(
REMOVE_DUPLICATES _CUDA_ARCHS
)
# reapply +PTX suffix to architectures that requested PTX
set
(
_FINAL_ARCHS
)
foreach
(
_arch
${
_CUDA_ARCHS
}
)
if
(
_arch IN_LIST _PTX_ARCHS
)
list
(
APPEND _FINAL_ARCHS
"
${
_arch
}
+PTX"
)
else
()
list
(
APPEND _FINAL_ARCHS
"
${
_arch
}
"
)
endif
()
endforeach
()
set
(
_CUDA_ARCHS
${
_FINAL_ARCHS
}
)
set
(
${
OUT_CUDA_ARCHS
}
${
_CUDA_ARCHS
}
PARENT_SCOPE
)
endfunction
()
...
...
csrc/activation_kernels.cu
View file @
7a985548
...
...
@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
...
...
csrc/attention/attention_kernels.cuh
View file @
7a985548
...
...
@@ -17,660 +17,660 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace
vllm
{
// Utility function for attention softmax.
template
<
int
NUM_WARPS
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
// Warp leaders store the data to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The warps compute the final sums.
if
(
lane
<
NUM_WARPS
)
{
sum
=
red_smem
[
lane
];
}
// Parallel reduction inside the warp.
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
// Broadcast to other threads.
return
VLLM_SHFL_SYNC
(
sum
,
0
);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
seq_len
=
seq_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
// No work to do. Terminate the thread block.
return
;
}
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_seq_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_THREAD_GROUPS
=
NUM_THREADS
/
THREAD_GROUP_SIZE
;
// Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert
(
NUM_THREADS
%
THREAD_GROUP_SIZE
==
0
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
);
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr
int
VEC_SIZE
=
MAX
(
16
/
(
THREAD_GROUP_SIZE
*
sizeof
(
scalar_t
)),
1
);
using
K_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Q_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Quant_vec
=
typename
Vec
<
cache_t
,
VEC_SIZE
>::
Type
;
constexpr
int
NUM_ELEMS_PER_THREAD
=
HEAD_SIZE
/
THREAD_GROUP_SIZE
;
constexpr
int
NUM_VECS_PER_THREAD
=
NUM_ELEMS_PER_THREAD
/
VEC_SIZE
;
const
int
thread_group_idx
=
thread_idx
/
THREAD_GROUP_SIZE
;
const
int
thread_group_offset
=
thread_idx
%
THREAD_GROUP_SIZE
;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
i
+=
NUM_THREAD_GROUPS
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
q_vecs
[
thread_group_offset
][
i
]
=
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
}
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning.
extern
__shared__
char
shared_mem
[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
float
qk_max
=
-
FLT_MAX
;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
// blocksparse specific vars
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if
constexpr
(
IS_BLOCK_SPARSE
)
{
const
int
k_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
const
bool
is_remote
=
((
k_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
);
const
bool
is_local
=
(
k_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
);
if
(
!
is_remote
&&
!
is_local
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
(
thread_group_offset
==
0
)
{
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits
[
token_idx
-
start_token_idx
]
=
-
FLT_MAX
;
}
}
continue
;
}
}
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
{
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
*
k_scale
);
}
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
qk_max
;
}
__syncthreads
();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
// Broadcast the max qk value to all threads.
qk_max
=
VLLM_SHFL_SYNC
(
qk_max
,
0
);
// Get the sum of the exp values.
float
exp_sum
=
0.
f
;
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
logits
[
i
]
=
val
;
exp_sum
+=
val
;
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
}
__syncthreads
();
// If partitioning is enabled, store the max logit and exp_sum.
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
;
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
;
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
L_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
V_quant_vec
=
typename
Vec
<
cache_t
,
V_VEC_SIZE
>::
Type
;
using
Float_L_vec
=
typename
FloatVec
<
L_vec
>::
Type
;
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
i
]
=
0.
f
;
}
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if
constexpr
(
IS_BLOCK_SPARSE
)
{
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
continue
;
}
}
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
-
start_token_idx
));
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
V_vec
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
*
v_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
}
}
}
// Perform reduction within each warp.
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
i
];
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
}
accs
[
i
]
=
acc
;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads
();
// Perform reduction across warps.
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
dst
[
row_idx
]
=
accs
[
i
];
}
}
}
__syncthreads
();
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
accs
[
i
]
+=
src
[
row_idx
];
}
}
}
__syncthreads
();
}
// Write the final output.
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
i
]);
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
__global__
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <algorithm>
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh"
typedef
__hip_bfloat16
__nv_bfloat16
;
#else
#include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace
vllm
{
// Utility function for attention softmax.
template
<
int
NUM_WARPS
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
// Warp leaders store the data to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The warps compute the final sums.
if
(
lane
<
NUM_WARPS
)
{
sum
=
red_smem
[
lane
];
}
// Parallel reduction inside the warp.
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
VLLM_SHFL_XOR_SYNC
(
sum
,
mask
);
}
// Broadcast to other threads.
return
VLLM_SHFL_SYNC
(
sum
,
0
);
}
// TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
=
0
>
// Zero means no partitioning.
__device__
void
paged_attention_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, max_num_partitions,
// head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
const
int
max_num_partitions
=
gridDim
.
z
;
constexpr
bool
USE_PARTITIONING
=
PARTITION_SIZE
>
0
;
const
int
seq_len
=
seq_lens
[
seq_idx
];
if
(
USE_PARTITIONING
&&
partition_idx
*
PARTITION_SIZE
>=
seq_len
)
{
// No work to do. Terminate the thread block.
return
;
}
const
int
num_seq_blocks
=
DIVIDE_ROUND_UP
(
seq_len
,
BLOCK_SIZE
);
const
int
num_blocks_per_partition
=
USE_PARTITIONING
?
PARTITION_SIZE
/
BLOCK_SIZE
:
num_seq_blocks
;
// [start_block_idx, end_block_idx) is the range of blocks to process.
const
int
start_block_idx
=
USE_PARTITIONING
?
partition_idx
*
num_blocks_per_partition
:
0
;
const
int
end_block_idx
=
MIN
(
start_block_idx
+
num_blocks_per_partition
,
num_seq_blocks
);
const
int
num_blocks
=
end_block_idx
-
start_block_idx
;
// [start_token_idx, end_token_idx) is the range of tokens to process.
const
int
start_token_idx
=
start_block_idx
*
BLOCK_SIZE
;
const
int
end_token_idx
=
MIN
(
start_token_idx
+
num_blocks
*
BLOCK_SIZE
,
seq_len
);
const
int
num_tokens
=
end_token_idx
-
start_token_idx
;
constexpr
int
THREAD_GROUP_SIZE
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
constexpr
int
NUM_THREAD_GROUPS
=
NUM_THREADS
/
THREAD_GROUP_SIZE
;
// Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS
assert
(
NUM_THREADS
%
THREAD_GROUP_SIZE
==
0
);
constexpr
int
NUM_TOKENS_PER_THREAD_GROUP
=
DIVIDE_ROUND_UP
(
BLOCK_SIZE
,
WARP_SIZE
);
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
thread_idx
=
threadIdx
.
x
;
const
int
warp_idx
=
thread_idx
/
WARP_SIZE
;
const
int
lane
=
thread_idx
%
WARP_SIZE
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
num_heads
=
gridDim
.
x
;
const
int
num_queries_per_kv
=
num_heads
/
num_kv_heads
;
const
int
kv_head_idx
=
head_idx
/
num_queries_per_kv
;
const
float
alibi_slope
=
alibi_slopes
==
nullptr
?
0.
f
:
alibi_slopes
[
head_idx
];
// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2.
constexpr
int
VEC_SIZE
=
MAX
(
16
/
(
THREAD_GROUP_SIZE
*
sizeof
(
scalar_t
)),
1
);
using
K_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Q_vec
=
typename
Vec
<
scalar_t
,
VEC_SIZE
>::
Type
;
using
Quant_vec
=
typename
Vec
<
cache_t
,
VEC_SIZE
>::
Type
;
constexpr
int
NUM_ELEMS_PER_THREAD
=
HEAD_SIZE
/
THREAD_GROUP_SIZE
;
constexpr
int
NUM_VECS_PER_THREAD
=
NUM_ELEMS_PER_THREAD
/
VEC_SIZE
;
const
int
thread_group_idx
=
thread_idx
/
THREAD_GROUP_SIZE
;
const
int
thread_group_offset
=
thread_idx
%
THREAD_GROUP_SIZE
;
// Load the query to registers.
// Each thread in a thread group has a different part of the query.
// For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous.
const
scalar_t
*
q_ptr
=
q
+
seq_idx
*
q_stride
+
head_idx
*
HEAD_SIZE
;
__shared__
Q_vec
q_vecs
[
THREAD_GROUP_SIZE
][
NUM_VECS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
thread_group_idx
;
i
<
NUM_VECS_PER_THREAD
;
i
+=
NUM_THREAD_GROUPS
)
{
const
int
vec_idx
=
thread_group_offset
+
i
*
THREAD_GROUP_SIZE
;
q_vecs
[
thread_group_offset
][
i
]
=
*
reinterpret_cast
<
const
Q_vec
*>
(
q_ptr
+
vec_idx
*
VEC_SIZE
);
}
__syncthreads
();
// TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs
// Memory planning.
extern
__shared__
char
shared_mem
[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float
*
logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time.
constexpr
int
x
=
16
/
sizeof
(
cache_t
);
float
qk_max
=
-
FLT_MAX
;
// Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes
// dot product with the query.
const
int
*
block_table
=
block_tables
+
seq_idx
*
max_num_blocks_per_seq
;
// blocksparse specific vars
int
bs_block_offset
;
int
q_bs_block_id
;
if
constexpr
(
IS_BLOCK_SPARSE
)
{
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size);
q_bs_block_id
=
(
seq_len
-
1
)
/
blocksparse_block_size
;
if
(
blocksparse_head_sliding_step
>=
0
)
// sliding on q heads
bs_block_offset
=
(
tp_rank
*
num_heads
+
head_idx
)
*
blocksparse_head_sliding_step
+
1
;
else
// sliding on kv heads
bs_block_offset
=
(
tp_rank
*
num_kv_heads
+
kv_head_idx
)
*
(
-
blocksparse_head_sliding_step
)
+
1
;
}
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if
constexpr
(
IS_BLOCK_SPARSE
)
{
const
int
k_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
const
bool
is_remote
=
((
k_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
);
const
bool
is_local
=
(
k_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
);
if
(
!
is_remote
&&
!
is_local
)
{
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
if
(
thread_group_offset
==
0
)
{
// NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be
// skipped.
logits
[
token_idx
-
start_token_idx
]
=
-
FLT_MAX
;
}
}
continue
;
}
}
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
// Load a key to registers.
// Each thread in a thread group has a different part of the key.
// For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on.
for
(
int
i
=
0
;
i
<
NUM_TOKENS_PER_THREAD_GROUP
;
i
++
)
{
const
int
physical_block_offset
=
(
thread_group_idx
+
i
*
WARP_SIZE
)
%
BLOCK_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
K_vec
k_vecs
[
NUM_VECS_PER_THREAD
];
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM_VECS_PER_THREAD
;
j
++
)
{
const
cache_t
*
k_ptr
=
k_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
+
physical_block_offset
*
x
;
const
int
vec_idx
=
thread_group_offset
+
j
*
THREAD_GROUP_SIZE
;
const
int
offset1
=
(
vec_idx
*
VEC_SIZE
)
/
x
;
const
int
offset2
=
(
vec_idx
*
VEC_SIZE
)
%
x
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
k_vecs
[
j
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
}
else
{
// Vector conversion from Quant_vec to K_vec.
Quant_vec
k_vec_quant
=
*
reinterpret_cast
<
const
Quant_vec
*>
(
k_ptr
+
offset1
*
BLOCK_SIZE
*
x
+
offset2
);
k_vecs
[
j
]
=
fp8
::
scaled_convert
<
K_vec
,
Quant_vec
,
KV_DTYPE
>
(
k_vec_quant
,
*
k_scale
);
}
}
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float
qk
=
scale
*
Qk_dot
<
scalar_t
,
THREAD_GROUP_SIZE
>::
dot
(
q_vecs
[
thread_group_offset
],
k_vecs
);
// Add the ALiBi bias if slopes are given.
qk
+=
(
alibi_slope
!=
0
)
?
alibi_slope
*
(
token_idx
-
seq_len
+
1
)
:
0
;
if
(
thread_group_offset
==
0
)
{
// Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits.
const
bool
mask
=
token_idx
>=
seq_len
;
logits
[
token_idx
-
start_token_idx
]
=
mask
?
0.
f
:
qk
;
// Update the max value.
qk_max
=
mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
}
}
}
// Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREAD_GROUP_SIZE
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
qk_max
;
}
__syncthreads
();
// TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence.
qk_max
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
VLLM_SHFL_XOR_SYNC
(
qk_max
,
mask
));
}
// Broadcast the max qk value to all threads.
qk_max
=
VLLM_SHFL_SYNC
(
qk_max
,
0
);
// Get the sum of the exp values.
float
exp_sum
=
0.
f
;
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
float
val
=
__expf
(
logits
[
i
]
-
qk_max
);
logits
[
i
]
=
val
;
exp_sum
+=
val
;
}
exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
exp_sum
);
// Compute softmax.
const
float
inv_sum
=
__fdividef
(
1.
f
,
exp_sum
+
1e-6
f
);
for
(
int
i
=
thread_idx
;
i
<
num_tokens
;
i
+=
NUM_THREADS
)
{
logits
[
i
]
*=
inv_sum
;
}
__syncthreads
();
// If partitioning is enabled, store the max logit and exp_sum.
if
(
USE_PARTITIONING
&&
thread_idx
==
0
)
{
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
max_logits_ptr
=
qk_max
;
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
+
partition_idx
;
*
exp_sums_ptr
=
exp_sum
;
}
// Each thread will fetch 16 bytes from the value cache at a time.
constexpr
int
V_VEC_SIZE
=
MIN
(
16
/
sizeof
(
scalar_t
),
BLOCK_SIZE
);
using
V_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
L_vec
=
typename
Vec
<
scalar_t
,
V_VEC_SIZE
>::
Type
;
using
V_quant_vec
=
typename
Vec
<
cache_t
,
V_VEC_SIZE
>::
Type
;
using
Float_L_vec
=
typename
FloatVec
<
L_vec
>::
Type
;
constexpr
int
NUM_V_VECS_PER_ROW
=
BLOCK_SIZE
/
V_VEC_SIZE
;
constexpr
int
NUM_ROWS_PER_ITER
=
WARP_SIZE
/
NUM_V_VECS_PER_ROW
;
constexpr
int
NUM_ROWS_PER_THREAD
=
DIVIDE_ROUND_UP
(
HEAD_SIZE
,
NUM_ROWS_PER_ITER
);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float
accs
[
NUM_ROWS_PER_THREAD
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
accs
[
i
]
=
0.
f
;
}
scalar_t
zero_value
;
zero
(
zero_value
);
for
(
int
block_idx
=
start_block_idx
+
warp_idx
;
block_idx
<
end_block_idx
;
block_idx
+=
NUM_WARPS
)
{
// NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not
// attended
if
constexpr
(
IS_BLOCK_SPARSE
)
{
int
v_bs_block_id
=
block_idx
*
BLOCK_SIZE
/
blocksparse_block_size
;
if
(
!
((
v_bs_block_id
+
bs_block_offset
)
%
blocksparse_vert_stride
==
0
)
&&
!
((
v_bs_block_id
>
q_bs_block_id
-
blocksparse_local_blocks
)))
{
continue
;
}
}
const
int64_t
physical_block_number
=
static_cast
<
int64_t
>
(
block_table
[
block_idx
]);
const
int
physical_block_offset
=
(
lane
%
NUM_V_VECS_PER_ROW
)
*
V_VEC_SIZE
;
const
int
token_idx
=
block_idx
*
BLOCK_SIZE
+
physical_block_offset
;
L_vec
logits_vec
;
from_float
(
logits_vec
,
*
reinterpret_cast
<
Float_L_vec
*>
(
logits
+
token_idx
-
start_token_idx
));
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
)
{
const
int
offset
=
row_idx
*
BLOCK_SIZE
+
physical_block_offset
;
V_vec
v_vec
;
if
constexpr
(
KV_DTYPE
==
Fp8KVCacheDataType
::
kAuto
)
{
v_vec
=
*
reinterpret_cast
<
const
V_vec
*>
(
v_ptr
+
offset
);
}
else
{
V_quant_vec
v_quant_vec
=
*
reinterpret_cast
<
const
V_quant_vec
*>
(
v_ptr
+
offset
);
// Vector conversion from V_quant_vec to V_vec.
v_vec
=
fp8
::
scaled_convert
<
V_vec
,
V_quant_vec
,
KV_DTYPE
>
(
v_quant_vec
,
*
v_scale
);
}
if
(
block_idx
==
num_seq_blocks
-
1
)
{
// NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may
// contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
V_VEC_SIZE
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
zero_value
;
}
}
accs
[
i
]
+=
dot
(
logits_vec
,
v_vec
);
}
}
}
// Perform reduction within each warp.
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
float
acc
=
accs
[
i
];
#pragma unroll
for
(
int
mask
=
NUM_V_VECS_PER_ROW
/
2
;
mask
>=
1
;
mask
/=
2
)
{
acc
+=
VLLM_SHFL_XOR_SYNC
(
acc
,
mask
);
}
accs
[
i
]
=
acc
;
}
// NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output.
__syncthreads
();
// Perform reduction across warps.
float
*
out_smem
=
reinterpret_cast
<
float
*>
(
shared_mem
);
#pragma unroll
for
(
int
i
=
NUM_WARPS
;
i
>
1
;
i
/=
2
)
{
int
mid
=
i
/
2
;
// Upper warps write to shared memory.
if
(
warp_idx
>=
mid
&&
warp_idx
<
i
)
{
float
*
dst
=
&
out_smem
[(
warp_idx
-
mid
)
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
dst
[
row_idx
]
=
accs
[
i
];
}
}
}
__syncthreads
();
// Lower warps update the output.
if
(
warp_idx
<
mid
)
{
const
float
*
src
=
&
out_smem
[
warp_idx
*
HEAD_SIZE
];
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
accs
[
i
]
+=
src
[
row_idx
];
}
}
}
__syncthreads
();
}
// Write the final output.
if
(
warp_idx
==
0
)
{
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
+
partition_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
const
int
row_idx
=
lane
/
NUM_V_VECS_PER_ROW
+
i
*
NUM_ROWS_PER_ITER
;
if
(
row_idx
<
HEAD_SIZE
&&
lane
%
NUM_V_VECS_PER_ROW
==
0
)
{
from_float
(
*
(
out_ptr
+
row_idx
),
accs
[
i
]);
}
}
}
}
// Grid: (num_heads, num_seqs, 1).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
>
__global__
void
paged_attention_v1_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
>
(
/* exp_sums */
nullptr
,
/* max_logits */
nullptr
,
out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
// Grid: (num_heads, num_seqs, max_num_partitions).
template
<
typename
scalar_t
,
typename
cache_t
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
int
NUM_THREADS
,
vllm
::
Fp8KVCacheDataType
KV_DTYPE
,
bool
IS_BLOCK_SPARSE
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_kernel
(
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads, max_num_partitions]
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
scalar_t
*
__restrict__
q
,
// [num_seqs, num_heads, head_size]
const
cache_t
*
__restrict__
k_cache
,
// [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const
cache_t
*
__restrict__
v_cache
,
// [num_blocks, num_kv_heads,
// head_size, block_size]
const
int
num_kv_heads
,
// [num_heads]
const
float
scale
,
const
int
*
__restrict__
block_tables
,
// [num_seqs, max_num_blocks_per_seq]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_blocks_per_seq
,
const
float
*
__restrict__
alibi_slopes
,
// [num_heads]
const
int
q_stride
,
const
int
kv_block_stride
,
const
int
kv_head_stride
,
const
float
*
k_scale
,
const
float
*
v_scale
,
const
int
tp_rank
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
)
{
paged_attention_kernel
<
scalar_t
,
cache_t
,
HEAD_SIZE
,
BLOCK_SIZE
,
NUM_THREADS
,
KV_DTYPE
,
IS_BLOCK_SPARSE
,
PARTITION_SIZE
>
(
exp_sums
,
max_logits
,
tmp_out
,
q
,
k_cache
,
v_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
alibi_slopes
,
q_stride
,
kv_block_stride
,
kv_head_stride
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
);
}
// Grid: (num_heads, num_seqs).
template
<
typename
scalar_t
,
int
HEAD_SIZE
,
int
NUM_THREADS
,
int
PARTITION_SIZE
>
__global__
void
paged_attention_v2_reduce_kernel
(
scalar_t
*
__restrict__
out
,
// [num_seqs, num_heads, head_size]
const
float
*
__restrict__
exp_sums
,
// [num_seqs, num_heads,
// max_num_partitions]
const
float
*
__restrict__
max_logits
,
// [num_seqs, num_heads,
// max_num_partitions]
const
scalar_t
*
__restrict__
tmp_out
,
// [num_seqs, num_heads,
// max_num_partitions, head_size]
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_len
=
seq_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
// No need to reduce. Only copy tmp_out to out.
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
tmp_out_ptr
[
i
];
}
// Terminate the thread block.
return
;
}
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warp_idx
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Size: 2 * num_partitions.
extern
__shared__
char
shared_mem
[];
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// Load max logits to shared memory.
float
*
shared_max_logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
const
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
max_logit
=
-
FLT_MAX
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
const
float
l
=
max_logits_ptr
[
i
];
shared_max_logits
[
i
]
=
l
;
max_logit
=
fmaxf
(
max_logit
,
l
);
}
__syncthreads
();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
max_logit
;
}
__syncthreads
();
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
// Broadcast the max value to all threads.
max_logit
=
VLLM_SHFL_SYNC
(
max_logit
,
0
);
// Load rescaled exp sums to shared memory.
float
*
shared_exp_sums
=
reinterpret_cast
<
float
*>
(
shared_mem
+
sizeof
(
float
)
*
num_partitions
);
const
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
global_exp_sum
=
0.0
f
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
float
l
=
shared_max_logits
[
i
];
float
rescaled_exp_sum
=
exp_sums_ptr
[
i
]
*
expf
(
l
-
max_logit
);
global_exp_sum
+=
rescaled_exp_sum
;
shared_exp_sums
[
i
]
=
rescaled_exp_sum
;
}
__syncthreads
();
global_exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
global_exp_sum
);
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
// Aggregate tmp_out to out.
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
acc
+=
to_float
(
tmp_out_ptr
[
j
*
HEAD_SIZE
+
i
])
*
shared_exp_sums
[
j
]
*
inv_global_exp_sum
;
}
from_float
(
out_ptr
[
i
],
acc
);
}
}
}
// namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
const
int
*
__restrict__
seq_lens
,
// [num_seqs]
const
int
max_num_partitions
)
{
const
int
num_heads
=
gridDim
.
x
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
seq_idx
=
blockIdx
.
y
;
const
int
seq_len
=
seq_lens
[
seq_idx
];
const
int
num_partitions
=
DIVIDE_ROUND_UP
(
seq_len
,
PARTITION_SIZE
);
if
(
num_partitions
==
1
)
{
// No need to reduce. Only copy tmp_out to out.
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
blockDim
.
x
)
{
out_ptr
[
i
]
=
tmp_out_ptr
[
i
];
}
// Terminate the thread block.
return
;
}
constexpr
int
NUM_WARPS
=
NUM_THREADS
/
WARP_SIZE
;
const
int
warp_idx
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Size: 2 * num_partitions.
extern
__shared__
char
shared_mem
[];
// Workspace for reduction.
__shared__
float
red_smem
[
2
*
NUM_WARPS
];
// Load max logits to shared memory.
float
*
shared_max_logits
=
reinterpret_cast
<
float
*>
(
shared_mem
);
const
float
*
max_logits_ptr
=
max_logits
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
max_logit
=
-
FLT_MAX
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
const
float
l
=
max_logits_ptr
[
i
];
shared_max_logits
[
i
]
=
l
;
max_logit
=
fmaxf
(
max_logit
,
l
);
}
__syncthreads
();
// Get the global max logit.
// Reduce within the warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
if
(
lane
==
0
)
{
red_smem
[
warp_idx
]
=
max_logit
;
}
__syncthreads
();
// Reduce across warps.
max_logit
=
lane
<
NUM_WARPS
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
NUM_WARPS
/
2
;
mask
>=
1
;
mask
/=
2
)
{
max_logit
=
fmaxf
(
max_logit
,
VLLM_SHFL_XOR_SYNC
(
max_logit
,
mask
));
}
// Broadcast the max value to all threads.
max_logit
=
VLLM_SHFL_SYNC
(
max_logit
,
0
);
// Load rescaled exp sums to shared memory.
float
*
shared_exp_sums
=
reinterpret_cast
<
float
*>
(
shared_mem
+
sizeof
(
float
)
*
num_partitions
);
const
float
*
exp_sums_ptr
=
exp_sums
+
seq_idx
*
num_heads
*
max_num_partitions
+
head_idx
*
max_num_partitions
;
float
global_exp_sum
=
0.0
f
;
for
(
int
i
=
threadIdx
.
x
;
i
<
num_partitions
;
i
+=
blockDim
.
x
)
{
float
l
=
shared_max_logits
[
i
];
float
rescaled_exp_sum
=
exp_sums_ptr
[
i
]
*
expf
(
l
-
max_logit
);
global_exp_sum
+=
rescaled_exp_sum
;
shared_exp_sums
[
i
]
=
rescaled_exp_sum
;
}
__syncthreads
();
global_exp_sum
=
block_sum
<
NUM_WARPS
>
(
&
red_smem
[
NUM_WARPS
],
global_exp_sum
);
const
float
inv_global_exp_sum
=
__fdividef
(
1.0
f
,
global_exp_sum
+
1e-6
f
);
// Aggregate tmp_out to out.
const
scalar_t
*
tmp_out_ptr
=
tmp_out
+
seq_idx
*
num_heads
*
max_num_partitions
*
HEAD_SIZE
+
head_idx
*
max_num_partitions
*
HEAD_SIZE
;
scalar_t
*
out_ptr
=
out
+
seq_idx
*
num_heads
*
HEAD_SIZE
+
head_idx
*
HEAD_SIZE
;
#pragma unroll
for
(
int
i
=
threadIdx
.
x
;
i
<
HEAD_SIZE
;
i
+=
NUM_THREADS
)
{
float
acc
=
0.0
f
;
for
(
int
j
=
0
;
j
<
num_partitions
;
++
j
)
{
acc
+=
to_float
(
tmp_out_ptr
[
j
*
HEAD_SIZE
+
i
])
*
shared_exp_sums
[
j
]
*
inv_global_exp_sum
;
}
from_float
(
out_ptr
[
i
],
acc
);
}
}
}
// namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
csrc/attention/vertical_slash_index.cu
0 → 100644
View file @
7a985548
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <cuda.h>
#include <torch/all.h>
__device__
int64_t
save_blocks
(
int
*
block_offset
,
int64_t
range_start
,
int64_t
range_end
,
int64_t
block_size
,
int64_t
input_block_count
,
int64_t
kv_seqlen
)
{
if
(
range_start
>=
kv_seqlen
)
{
return
input_block_count
;
}
if
(
range_end
>
kv_seqlen
)
{
range_end
=
kv_seqlen
;
}
int64_t
current_block_count
=
input_block_count
;
for
(
int
idx
=
range_start
;
idx
<
range_end
;
idx
+=
block_size
)
{
block_offset
[
current_block_count
++
]
=
idx
;
}
return
current_block_count
;
}
__global__
void
convert_vertical_slash_indexes_kernel
(
const
int
*
q_seqlens
,
// [BATCH, ]
const
int
*
kv_seqlens
,
// [BATCH, ]
const
int
*
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
const
int
*
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
int
*
block_count
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int
*
block_offset
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int
*
column_count
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int
*
column_index
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t
N_HEADS
,
int64_t
N_ROWS
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
NNZ_V
,
int64_t
NNZ_S
,
bool
causal
// True for intra, False for succ
)
{
const
int
batch_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
group_idx
=
blockIdx
.
z
;
int64_t
q_seqlen
=
q_seqlens
[
batch_idx
];
int64_t
kv_seqlen
=
kv_seqlens
[
batch_idx
];
int64_t
block_idx_m
=
group_idx
*
blockDim
.
x
+
threadIdx
.
x
;
int64_t
start_m
=
block_idx_m
*
BLOCK_SIZE_M
;
if
(
start_m
>=
q_seqlen
)
{
return
;
}
int64_t
end_m
=
start_m
+
BLOCK_SIZE_M
;
vertical_indexes
+=
(
batch_idx
*
N_HEADS
+
head_idx
)
*
NNZ_V
;
slash_indexes
+=
(
batch_idx
*
N_HEADS
+
head_idx
)
*
NNZ_S
;
int64_t
row_offset
=
(
batch_idx
*
N_HEADS
+
head_idx
)
*
N_ROWS
+
block_idx_m
;
block_count
+=
row_offset
;
block_offset
+=
row_offset
*
NNZ_S
;
column_count
+=
row_offset
;
column_index
+=
row_offset
*
NNZ_V
;
bool
has_slash
=
true
;
int64_t
tmp_col_cnt
=
0
,
tmp_blk_cnt
=
0
;
int64_t
s
=
0
,
v
=
0
;
int64_t
v_idx
=
vertical_indexes
[
v
++
];
int64_t
s_idx
=
slash_indexes
[
s
++
];
if
(
causal
)
{
while
(
s_idx
>=
end_m
+
(
kv_seqlen
-
q_seqlen
)
&&
s
<
NNZ_S
)
{
s_idx
=
slash_indexes
[
s
++
];
}
if
(
s_idx
>
end_m
+
(
kv_seqlen
-
q_seqlen
))
has_slash
=
false
;
s_idx
=
max
((
kv_seqlen
-
q_seqlen
)
+
end_m
-
s_idx
,
BLOCK_SIZE_M
);
}
else
{
while
(
s_idx
>=
end_m
+
kv_seqlen
&&
s
<
NNZ_S
)
{
s_idx
=
slash_indexes
[
s
++
];
}
if
(
s_idx
>
end_m
+
kv_seqlen
)
has_slash
=
false
;
s_idx
=
max
(
kv_seqlen
+
end_m
-
s_idx
,
BLOCK_SIZE_M
);
}
int64_t
range_start
=
s_idx
-
BLOCK_SIZE_M
,
range_end
=
s_idx
;
if
(
!
has_slash
)
{
if
(
causal
)
{
range_start
=
(
kv_seqlen
-
q_seqlen
)
+
end_m
;
range_end
=
(
kv_seqlen
-
q_seqlen
)
+
end_m
+
BLOCK_SIZE_N
;
}
else
{
range_start
=
kv_seqlen
;
range_end
=
kv_seqlen
+
BLOCK_SIZE_N
;
}
}
bool
slash_finished
=
false
;
while
(
1
)
{
if
(
v_idx
<
range_end
)
{
if
(
v_idx
<
range_start
)
{
column_index
[
tmp_col_cnt
++
]
=
v_idx
;
}
if
(
v
<
NNZ_V
)
{
v_idx
=
vertical_indexes
[
v
++
];
}
else
{
if
(
causal
)
v_idx
=
end_m
+
BLOCK_SIZE_N
+
(
kv_seqlen
-
q_seqlen
);
else
v_idx
=
end_m
+
BLOCK_SIZE_N
+
kv_seqlen
;
}
}
else
{
if
((
s
<
NNZ_S
&&
causal
)
||
(
s
<
NNZ_S
&&
!
causal
&&
slash_indexes
[
s
]
>=
start_m
))
{
if
(
causal
)
s_idx
=
max
((
kv_seqlen
-
q_seqlen
)
+
end_m
-
slash_indexes
[
s
++
],
BLOCK_SIZE_M
);
else
s_idx
=
max
(
kv_seqlen
+
end_m
-
slash_indexes
[
s
++
],
BLOCK_SIZE_M
);
}
else
{
if
(
v
==
NNZ_V
||
(
v_idx
>
range_start
&&
causal
))
{
// add the last vertical if no more slash
if
(
v
==
NNZ_V
&&
!
causal
&&
v_idx
<
kv_seqlen
)
{
column_index
[
tmp_col_cnt
++
]
=
v_idx
;
}
tmp_blk_cnt
=
save_blocks
(
block_offset
,
range_start
,
range_end
,
BLOCK_SIZE_N
,
tmp_blk_cnt
,
kv_seqlen
);
break
;
}
else
{
if
(
causal
)
{
range_start
=
(
kv_seqlen
-
q_seqlen
)
+
end_m
;
range_end
=
(
kv_seqlen
-
q_seqlen
)
+
end_m
+
BLOCK_SIZE_N
;
}
else
{
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt
=
save_blocks
(
block_offset
,
range_start
,
range_end
,
BLOCK_SIZE_N
,
tmp_blk_cnt
,
kv_seqlen
);
range_start
=
kv_seqlen
;
range_end
=
kv_seqlen
+
BLOCK_SIZE_N
;
}
slash_finished
=
true
;
}
}
if
(
!
slash_finished
)
{
if
(
s_idx
>
range_end
+
BLOCK_SIZE_M
)
{
tmp_blk_cnt
=
save_blocks
(
block_offset
,
range_start
,
range_end
,
BLOCK_SIZE_N
,
tmp_blk_cnt
,
kv_seqlen
);
range_start
=
s_idx
-
BLOCK_SIZE_M
;
range_end
=
s_idx
;
}
else
if
(
s_idx
>
range_end
)
{
range_end
+=
BLOCK_SIZE_M
;
}
}
}
}
block_count
[
0
]
=
tmp_blk_cnt
;
column_count
[
0
]
=
tmp_col_cnt
;
}
void
convert_vertical_slash_indexes_64x64
(
const
int
*
q_seqlens
,
// [BATCH, ]
const
int
*
kv_seqlens
,
// [BATCH, ]
const
int
*
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
const
int
*
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
int
*
block_count
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int
*
block_offset
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int
*
column_count
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int
*
column_index
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t
BATCH_SIZE
,
int64_t
N_HEADS
,
int64_t
N_ROWS
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
NNZ_V
,
int64_t
NNZ_S
,
bool
causal
)
{
const
int
N_THREADS
=
64
;
const
dim3
dimBlock
(
N_THREADS
);
const
dim3
dimGrid
(
N_HEADS
,
BATCH_SIZE
,
(
N_ROWS
+
N_THREADS
-
1
)
/
N_THREADS
);
convert_vertical_slash_indexes_kernel
<<<
dimGrid
,
dimBlock
>>>
(
q_seqlens
,
kv_seqlens
,
vertical_indexes
,
slash_indexes
,
block_count
,
block_offset
,
column_count
,
column_index
,
N_HEADS
,
N_ROWS
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
NNZ_V
,
NNZ_S
,
causal
);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* This function builds the index of each row of blocks from vertical indices
* and slash indices. The vertical indices are treated as points, while the
* slash indices are converted as ranges. The output consists of the merged
* ranges and separate column indices, where the ranges are represented by
* block indices.
*
* The implementation is referenced from the original MInference repo:
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
*/
void
convert_vertical_slash_indexes
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
column_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
column_index
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch
::
Tensor
q_seqlens
,
// [BATCH, ]
torch
::
Tensor
kv_seqlens
,
// [BATCH, ]
torch
::
Tensor
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
torch
::
Tensor
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
)
{
cudaSetDevice
(
q_seqlens
.
get_device
());
int
batch_size
=
slash_indexes
.
size
(
0
);
int
num_heads
=
slash_indexes
.
size
(
1
);
int
nnz_slash
=
slash_indexes
.
size
(
2
);
int
nnz_vertical
=
vertical_indexes
.
size
(
2
);
int
num_rows
=
(
context_size
+
block_size_M
-
1
)
/
block_size_M
;
convert_vertical_slash_indexes_64x64
(
q_seqlens
.
data_ptr
<
int
>
(),
kv_seqlens
.
data_ptr
<
int
>
(),
vertical_indexes
.
data_ptr
<
int
>
(),
slash_indexes
.
data_ptr
<
int
>
(),
block_count
.
data_ptr
<
int
>
(),
block_offset
.
data_ptr
<
int
>
(),
column_count
.
data_ptr
<
int
>
(),
column_index
.
data_ptr
<
int
>
(),
batch_size
,
num_heads
,
num_rows
,
block_size_M
,
block_size_N
,
nnz_vertical
,
nnz_slash
,
causal
);
}
__global__
void
convert_vertical_slash_indexes_kernel_mergehead
(
const
int
*
q_seqlens
,
// [BATCH, ]
const
int
*
kv_seqlens
,
// [BATCH, ]
const
int
*
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
const
int
*
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
const
int
*
per_head_vertical_topkv
,
const
int
*
per_head_slash_topkv
,
int
*
block_count
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int
*
block_offset
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int
*
column_count
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int
*
column_index
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t
N_HEADS
,
int64_t
N_ROWS
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
NNZ_V
,
int64_t
NNZ_S
,
bool
causal
// True for intra, False for succ
)
{
const
int
batch_idx
=
blockIdx
.
y
;
const
int
head_idx
=
blockIdx
.
x
;
const
int
group_idx
=
blockIdx
.
z
;
int64_t
q_seqlen
=
q_seqlens
[
batch_idx
];
int64_t
kv_seqlen
=
kv_seqlens
[
batch_idx
];
int64_t
block_idx_m
=
group_idx
*
blockDim
.
x
+
threadIdx
.
x
;
int64_t
start_m
=
block_idx_m
*
BLOCK_SIZE_M
;
if
(
start_m
>=
q_seqlen
)
{
return
;
}
int64_t
end_m
=
start_m
+
BLOCK_SIZE_M
;
vertical_indexes
+=
(
batch_idx
*
N_HEADS
+
head_idx
)
*
NNZ_V
;
slash_indexes
+=
(
batch_idx
*
N_HEADS
+
head_idx
)
*
NNZ_S
;
int64_t
row_offset
=
(
batch_idx
*
N_HEADS
+
head_idx
)
*
N_ROWS
+
block_idx_m
;
block_count
+=
row_offset
;
block_offset
+=
row_offset
*
NNZ_S
;
column_count
+=
row_offset
;
column_index
+=
row_offset
*
NNZ_V
;
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
// above is buffer size, use to compute offset)
NNZ_S
=
per_head_slash_topkv
[
head_idx
];
NNZ_V
=
per_head_vertical_topkv
[
head_idx
];
bool
has_slash
=
true
;
int64_t
tmp_col_cnt
=
0
,
tmp_blk_cnt
=
0
;
int64_t
s
=
0
,
v
=
0
;
int64_t
v_idx
=
vertical_indexes
[
v
++
];
int64_t
s_idx
=
slash_indexes
[
s
++
];
if
(
causal
)
{
while
(
s_idx
>=
end_m
+
(
kv_seqlen
-
q_seqlen
)
&&
s
<
NNZ_S
)
{
s_idx
=
slash_indexes
[
s
++
];
}
if
(
s_idx
>
end_m
+
(
kv_seqlen
-
q_seqlen
))
has_slash
=
false
;
s_idx
=
max
((
kv_seqlen
-
q_seqlen
)
+
end_m
-
s_idx
,
BLOCK_SIZE_M
);
}
else
{
while
(
s_idx
>=
end_m
+
kv_seqlen
&&
s
<
NNZ_S
)
{
s_idx
=
slash_indexes
[
s
++
];
}
if
(
s_idx
>
end_m
+
kv_seqlen
)
has_slash
=
false
;
s_idx
=
max
(
kv_seqlen
+
end_m
-
s_idx
,
BLOCK_SIZE_M
);
}
int64_t
range_start
=
s_idx
-
BLOCK_SIZE_M
,
range_end
=
s_idx
;
if
(
!
has_slash
)
{
if
(
causal
)
{
range_start
=
(
kv_seqlen
-
q_seqlen
)
+
end_m
;
range_end
=
(
kv_seqlen
-
q_seqlen
)
+
end_m
+
BLOCK_SIZE_N
;
}
else
{
range_start
=
kv_seqlen
;
range_end
=
kv_seqlen
+
BLOCK_SIZE_N
;
}
}
bool
slash_finished
=
false
;
while
(
1
)
{
if
(
v_idx
<
range_end
)
{
if
(
v_idx
<
range_start
)
{
column_index
[
tmp_col_cnt
++
]
=
v_idx
;
}
if
(
v
<
NNZ_V
)
{
v_idx
=
vertical_indexes
[
v
++
];
}
else
{
if
(
causal
)
v_idx
=
end_m
+
BLOCK_SIZE_N
+
(
kv_seqlen
-
q_seqlen
);
else
v_idx
=
end_m
+
BLOCK_SIZE_N
+
kv_seqlen
;
}
}
else
{
if
((
s
<
NNZ_S
&&
causal
)
||
(
s
<
NNZ_S
&&
!
causal
&&
slash_indexes
[
s
]
>=
start_m
))
{
if
(
causal
)
s_idx
=
max
((
kv_seqlen
-
q_seqlen
)
+
end_m
-
slash_indexes
[
s
++
],
BLOCK_SIZE_M
);
else
s_idx
=
max
(
kv_seqlen
+
end_m
-
slash_indexes
[
s
++
],
BLOCK_SIZE_M
);
}
else
{
if
(
v
==
NNZ_V
||
(
v_idx
>
range_start
&&
causal
))
{
// add the last vertical if no more slash
if
(
v
==
NNZ_V
&&
!
causal
&&
v_idx
<
kv_seqlen
)
{
column_index
[
tmp_col_cnt
++
]
=
v_idx
;
}
tmp_blk_cnt
=
save_blocks
(
block_offset
,
range_start
,
range_end
,
BLOCK_SIZE_N
,
tmp_blk_cnt
,
kv_seqlen
);
break
;
}
else
{
if
(
causal
)
{
range_start
=
(
kv_seqlen
-
q_seqlen
)
+
end_m
;
range_end
=
(
kv_seqlen
-
q_seqlen
)
+
end_m
+
BLOCK_SIZE_N
;
}
else
{
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt
=
save_blocks
(
block_offset
,
range_start
,
range_end
,
BLOCK_SIZE_N
,
tmp_blk_cnt
,
kv_seqlen
);
range_start
=
kv_seqlen
;
range_end
=
kv_seqlen
+
BLOCK_SIZE_N
;
}
slash_finished
=
true
;
}
}
if
(
!
slash_finished
)
{
if
(
s_idx
>
range_end
+
BLOCK_SIZE_M
)
{
tmp_blk_cnt
=
save_blocks
(
block_offset
,
range_start
,
range_end
,
BLOCK_SIZE_N
,
tmp_blk_cnt
,
kv_seqlen
);
range_start
=
s_idx
-
BLOCK_SIZE_M
;
range_end
=
s_idx
;
}
else
if
(
s_idx
>
range_end
)
{
range_end
+=
BLOCK_SIZE_M
;
}
}
}
}
block_count
[
0
]
=
tmp_blk_cnt
;
column_count
[
0
]
=
tmp_col_cnt
;
}
void
convert_vertical_slash_indexes_64x64_mergehead
(
const
int
*
q_seqlens
,
// [BATCH, ]
const
int
*
kv_seqlens
,
// [BATCH, ]
const
int
*
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
const
int
*
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
int
*
per_head_vertical_topkv
,
int
*
per_head_slash_topkv
,
int
*
block_count
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int
*
block_offset
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int
*
column_count
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int
*
column_index
,
// [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t
BATCH_SIZE
,
int64_t
N_HEADS
,
int64_t
N_ROWS
,
int64_t
BLOCK_SIZE_M
,
int64_t
BLOCK_SIZE_N
,
int64_t
NNZ_V
,
int64_t
NNZ_S
,
bool
causal
)
{
const
int
N_THREADS
=
64
;
const
dim3
dimBlock
(
N_THREADS
);
const
dim3
dimGrid
(
N_HEADS
,
BATCH_SIZE
,
(
N_ROWS
+
N_THREADS
-
1
)
/
N_THREADS
);
convert_vertical_slash_indexes_kernel_mergehead
<<<
dimGrid
,
dimBlock
>>>
(
q_seqlens
,
kv_seqlens
,
vertical_indexes
,
slash_indexes
,
per_head_vertical_topkv
,
per_head_slash_topkv
,
block_count
,
block_offset
,
column_count
,
column_index
,
N_HEADS
,
N_ROWS
,
BLOCK_SIZE_M
,
BLOCK_SIZE_N
,
NNZ_V
,
NNZ_S
,
causal
);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* Like the above convert_vertical_slash_indexes, but with
* pre-computed vertical and slash counts.
*/
void
convert_vertical_slash_indexes_mergehead
(
torch
::
Tensor
&
block_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
block_offset
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch
::
Tensor
&
column_count
,
// [BATCH, N_HEADS, NUM_ROWS]
torch
::
Tensor
&
column_index
,
// [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch
::
Tensor
q_seqlens
,
// [BATCH, ]
torch
::
Tensor
kv_seqlens
,
// [BATCH, ]
torch
::
Tensor
vertical_indexes
,
// [BATCH, N_HEADS, NNZ_V]
torch
::
Tensor
slash_indexes
,
// [BATCH, N_HEADS, NNZ_S]
torch
::
Tensor
vertical_indices_count
,
// [N_HEADS, ]
torch
::
Tensor
slash_indices_count
,
// [N_HEADS, ]
int64_t
context_size
,
int64_t
block_size_M
,
int64_t
block_size_N
,
bool
causal
)
{
cudaSetDevice
(
q_seqlens
.
get_device
());
int
batch_size
=
slash_indexes
.
size
(
0
);
int
num_heads
=
slash_indexes
.
size
(
1
);
int
nnz_slash
=
slash_indexes
.
size
(
2
);
int
nnz_vertical
=
vertical_indexes
.
size
(
2
);
int
num_rows
=
(
context_size
+
block_size_M
-
1
)
/
block_size_M
;
convert_vertical_slash_indexes_64x64_mergehead
(
q_seqlens
.
data_ptr
<
int
>
(),
kv_seqlens
.
data_ptr
<
int
>
(),
vertical_indexes
.
data_ptr
<
int
>
(),
slash_indexes
.
data_ptr
<
int
>
(),
vertical_indices_count
.
data_ptr
<
int
>
(),
slash_indices_count
.
data_ptr
<
int
>
(),
block_count
.
data_ptr
<
int
>
(),
block_offset
.
data_ptr
<
int
>
(),
column_count
.
data_ptr
<
int
>
(),
column_index
.
data_ptr
<
int
>
(),
batch_size
,
num_heads
,
num_rows
,
block_size_M
,
block_size_N
,
nnz_vertical
,
nnz_slash
,
causal
);
}
csrc/core/math.hpp
View file @
7a985548
...
...
@@ -7,3 +7,22 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if
(
num
<=
1
)
return
num
;
return
1
<<
(
CHAR_BIT
*
sizeof
(
num
)
-
__builtin_clz
(
num
-
1
));
}
template
<
typename
A
,
typename
B
>
static
inline
constexpr
auto
div_ceil
(
A
a
,
B
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template
<
typename
T
>
inline
constexpr
T
round_to_previous_multiple_of
(
T
a
,
T
b
)
{
return
a
%
b
==
0
?
a
:
(
a
/
b
)
*
b
;
}
// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template
<
typename
T
>
inline
constexpr
T
round_to_next_multiple_of
(
T
a
,
T
b
)
{
return
a
%
b
==
0
?
a
:
((
a
/
b
)
+
1
)
*
b
;
}
csrc/core/scalar_type.hpp
View file @
7a985548
...
...
@@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static
inline
constexpr
auto
kU8
=
ScalarType
::
uint
(
8
);
static
inline
constexpr
auto
kU8B128
=
ScalarType
::
uint
(
8
,
128
);
static
inline
constexpr
auto
kFE2M1f
=
ScalarType
::
float_
(
2
,
1
,
true
,
ScalarType
::
NAN_NONE
);
static
inline
constexpr
auto
kFE3M2f
=
ScalarType
::
float_
(
3
,
2
,
true
,
ScalarType
::
NAN_NONE
);
static
inline
constexpr
auto
kFE4M3fn
=
...
...
@@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
static
inline
constexpr
auto
kUint8
=
kU8
;
static
inline
constexpr
auto
kUint8b128
=
kU8B128
;
static
inline
constexpr
auto
kFloat4_e2m1f
=
kFE2M1f
;
static
inline
constexpr
auto
kFloat6_e3m2f
=
kFE3M2f
;
static
inline
constexpr
auto
kFloat8_e4m3fn
=
kFE4M3fn
;
static
inline
constexpr
auto
kFloat8_e5m2
=
kFE5M2
;
...
...
csrc/cpu/cpu_types_vsx.hpp
View file @
7a985548
...
...
@@ -4,6 +4,7 @@
#include <altivec.h>
#include <cmath>
#include <algorithm>
#include <torch/all.h>
namespace
vec_op
{
...
...
@@ -62,6 +63,10 @@ typedef struct f32x4x4_t {
__vector
float
val
[
4
];
}
f32x4x4_t
;
typedef
struct
i32x4x4_t
{
__vector
int32_t
val
[
4
];
}
i32x4x4_t
;
struct
FP32Vec8
;
struct
FP32Vec16
;
...
...
@@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
vec_xst
(
reg
.
val
[
0
],
0
,
(
signed
short
*
)
ptr
);
vec_xst
(
reg
.
val
[
1
],
16
,
(
signed
short
*
)
ptr
);
}
void
save
(
void
*
ptr
,
const
int
elem_num
)
const
{
const
int
clamped_elem
=
std
::
max
(
0
,
std
::
min
(
elem_num
,
16
));
// Calculate elements to store in each 128-bit part (8 elements each)
const
int
elements_val0
=
std
::
min
(
clamped_elem
,
8
);
const
int
elements_val1
=
std
::
max
(
clamped_elem
-
8
,
0
);
// Convert elements to bytes (2 bytes per element)
const
size_t
bytes_val0
=
elements_val0
*
sizeof
(
signed
short
);
const
size_t
bytes_val1
=
elements_val1
*
sizeof
(
signed
short
);
signed
short
*
dest
=
static_cast
<
signed
short
*>
(
ptr
);
// Store the first part using vec_xst_len
if
(
bytes_val0
>
0
)
{
vec_xst_len
(
reg
.
val
[
0
],
dest
,
bytes_val0
);
}
// Store the second part if needed
if
(
bytes_val1
>
0
)
{
vec_xst_len
(
reg
.
val
[
1
],
dest
+
elements_val0
,
bytes_val1
);
}
}
};
const
static
__vector
signed
short
zero
=
vec_splats
((
signed
short
)
0
);
...
...
@@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
};
struct
INT32Vec16
:
public
Vec
<
INT32Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
i32x4x4_t
reg
;
int32_t
values
[
VEC_ELEM_NUM
];
};
i32x4x4_t
reg
;
explicit
INT32Vec16
(
const
void
*
data_ptr
)
{
reg
.
val
[
0
]
=
vec_xl
(
0
,
reinterpret_cast
<
const
__vector
int32_t
*>
(
data_ptr
));
reg
.
val
[
1
]
=
vec_xl
(
16
,
reinterpret_cast
<
const
__vector
int32_t
*>
(
data_ptr
));
reg
.
val
[
2
]
=
vec_xl
(
32
,
reinterpret_cast
<
const
__vector
int32_t
*>
(
data_ptr
));
reg
.
val
[
3
]
=
vec_xl
(
48
,
reinterpret_cast
<
const
__vector
int32_t
*>
(
data_ptr
));
}
void
save
(
int32_t
*
ptr
)
const
{
vec_xst
(
reg
.
val
[
0
],
0
,
reinterpret_cast
<
__vector
int32_t
*>
(
ptr
));
vec_xst
(
reg
.
val
[
1
],
16
,
reinterpret_cast
<
__vector
int32_t
*>
(
ptr
));
vec_xst
(
reg
.
val
[
2
],
32
,
reinterpret_cast
<
__vector
int32_t
*>
(
ptr
));
vec_xst
(
reg
.
val
[
3
],
48
,
reinterpret_cast
<
__vector
int32_t
*>
(
ptr
));
}
void
save
(
int32_t
*
ptr
,
const
int
elem_num
)
const
{
const
int
elements_in_chunk1
=
(
elem_num
>=
0
)
?
((
elem_num
>=
4
)
?
4
:
elem_num
)
:
0
;
const
int
elements_in_chunk2
=
(
elem_num
>
4
)
?
((
elem_num
>=
8
)
?
4
:
elem_num
-
4
)
:
0
;
const
int
elements_in_chunk3
=
(
elem_num
>
8
)
?
((
elem_num
>=
12
)
?
4
:
elem_num
-
8
)
:
0
;
const
int
elements_in_chunk4
=
(
elem_num
>
12
)
?
((
elem_num
>=
16
)
?
4
:
elem_num
-
12
)
:
0
;
const
size_t
bytes_chunk1
=
static_cast
<
size_t
>
(
elements_in_chunk1
*
sizeof
(
int32_t
));
const
size_t
bytes_chunk2
=
static_cast
<
size_t
>
(
elements_in_chunk2
*
sizeof
(
int32_t
));
const
size_t
bytes_chunk3
=
static_cast
<
size_t
>
(
elements_in_chunk3
*
sizeof
(
int32_t
));
const
size_t
bytes_chunk4
=
static_cast
<
size_t
>
(
elements_in_chunk4
*
sizeof
(
int32_t
));
vec_xst_len
(
reg
.
val
[
0
],
reinterpret_cast
<
int32_t
*>
(
ptr
),
bytes_chunk1
);
vec_xst_len
(
reg
.
val
[
1
],
reinterpret_cast
<
int32_t
*>
(
reinterpret_cast
<
char
*>
(
ptr
)
+
16
),
bytes_chunk2
);
vec_xst_len
(
reg
.
val
[
2
],
reinterpret_cast
<
int32_t
*>
(
reinterpret_cast
<
char
*>
(
ptr
)
+
32
),
bytes_chunk3
);
vec_xst_len
(
reg
.
val
[
3
],
reinterpret_cast
<
int32_t
*>
(
reinterpret_cast
<
char
*>
(
ptr
)
+
48
),
bytes_chunk4
);
}
};
struct
FP32Vec16
:
public
Vec
<
FP32Vec16
>
{
constexpr
static
int
VEC_ELEM_NUM
=
16
;
union
AliasReg
{
...
...
@@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit
FP32Vec16
(
const
BF16Vec8
&
v
)
:
FP32Vec16
(
FP32Vec8
(
v
))
{}
explicit
FP32Vec16
(
const
INT32Vec16
&
v
)
{
reg
.
val
[
0
]
=
vec_ctf
(
v
.
reg
.
val
[
0
],
0
);
reg
.
val
[
1
]
=
vec_ctf
(
v
.
reg
.
val
[
1
],
0
);
reg
.
val
[
2
]
=
vec_ctf
(
v
.
reg
.
val
[
2
],
0
);
reg
.
val
[
3
]
=
vec_ctf
(
v
.
reg
.
val
[
3
],
0
);
}
FP32Vec16
operator
*
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_mul
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_mul
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
...
...
@@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vec_div
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
clamp
(
const
FP32Vec16
&
min
,
const
FP32Vec16
&
max
)
const
{
return
FP32Vec16
(
f32x4x4_t
(
{
vec_min
(
max
.
reg
.
val
[
0
],
vec_max
(
min
.
reg
.
val
[
0
],
reg
.
val
[
0
])),
vec_min
(
max
.
reg
.
val
[
1
],
vec_max
(
min
.
reg
.
val
[
1
],
reg
.
val
[
1
])),
vec_min
(
max
.
reg
.
val
[
2
],
vec_max
(
min
.
reg
.
val
[
2
],
reg
.
val
[
2
])),
vec_min
(
max
.
reg
.
val
[
3
],
vec_max
(
min
.
reg
.
val
[
3
],
reg
.
val
[
3
]))}));
}
FP32Vec16
max
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_max
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_max
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_max
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_max
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
max
(
const
FP32Vec16
&
b
,
int
elem_num
)
const
{
FP32Vec16
result
;
// Create a vector of element indices for each chunk
__vector
unsigned
int
indices
=
{
0
,
1
,
2
,
3
};
__vector
unsigned
int
elem_num_vec
=
vec_splats
(
static_cast
<
unsigned
int
>
(
elem_num
));
// Compute masks for each chunk
__vector
unsigned
int
chunk_offset0
=
{
0
,
0
,
0
,
0
};
// Chunk 0: Elements 0-3
__vector
unsigned
int
chunk_offset1
=
{
4
,
4
,
4
,
4
};
// Chunk 1: Elements 4-7
__vector
unsigned
int
chunk_offset2
=
{
8
,
8
,
8
,
8
};
// Chunk 2: Elements 8-11
__vector
unsigned
int
chunk_offset3
=
{
12
,
12
,
12
,
12
};
// Chunk 3: Elements 12-15
// Compute masks for each chunk
__vector
bool
int
mask0
=
vec_cmplt
(
indices
+
chunk_offset0
,
elem_num_vec
);
__vector
bool
int
mask1
=
vec_cmplt
(
indices
+
chunk_offset1
,
elem_num_vec
);
__vector
bool
int
mask2
=
vec_cmplt
(
indices
+
chunk_offset2
,
elem_num_vec
);
__vector
bool
int
mask3
=
vec_cmplt
(
indices
+
chunk_offset3
,
elem_num_vec
);
// Apply masks to compute the result for each chunk
result
.
reg
.
val
[
0
]
=
vec_sel
(
this
->
reg
.
val
[
0
],
vec_max
(
this
->
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
mask0
);
result
.
reg
.
val
[
1
]
=
vec_sel
(
this
->
reg
.
val
[
1
],
vec_max
(
this
->
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
mask1
);
result
.
reg
.
val
[
2
]
=
vec_sel
(
this
->
reg
.
val
[
2
],
vec_max
(
this
->
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
mask2
);
result
.
reg
.
val
[
3
]
=
vec_sel
(
this
->
reg
.
val
[
3
],
vec_max
(
this
->
reg
.
val
[
3
],
b
.
reg
.
val
[
3
]),
mask3
);
return
FP32Vec16
(
result
.
reg
);
}
FP32Vec16
min
(
const
FP32Vec16
&
b
)
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_min
(
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
vec_min
(
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
vec_min
(
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
vec_min
(
reg
.
val
[
3
],
b
.
reg
.
val
[
3
])}));
}
FP32Vec16
min
(
const
FP32Vec16
&
b
,
int
elem_num
)
const
{
FP32Vec16
result
;
vector
unsigned
int
indices
=
{
0
,
1
,
2
,
3
};
vector
unsigned
int
elem_num_vec
=
vec_splats
(
static_cast
<
unsigned
int
>
(
elem_num
));
vector
unsigned
int
chunk_offset0
=
{
0
,
0
,
0
,
0
};
vector
unsigned
int
chunk_offset1
=
{
4
,
4
,
4
,
4
};
vector
unsigned
int
chunk_offset2
=
{
8
,
8
,
8
,
8
};
vector
unsigned
int
chunk_offset3
=
{
12
,
12
,
12
,
12
};
vector
bool
int
mask0
=
vec_cmplt
(
indices
+
chunk_offset0
,
elem_num_vec
);
vector
bool
int
mask1
=
vec_cmplt
(
indices
+
chunk_offset1
,
elem_num_vec
);
vector
bool
int
mask2
=
vec_cmplt
(
indices
+
chunk_offset2
,
elem_num_vec
);
vector
bool
int
mask3
=
vec_cmplt
(
indices
+
chunk_offset3
,
elem_num_vec
);
result
.
reg
.
val
[
0
]
=
vec_sel
(
this
->
reg
.
val
[
0
],
vec_min
(
this
->
reg
.
val
[
0
],
b
.
reg
.
val
[
0
]),
mask0
);
result
.
reg
.
val
[
1
]
=
vec_sel
(
this
->
reg
.
val
[
1
],
vec_min
(
this
->
reg
.
val
[
1
],
b
.
reg
.
val
[
1
]),
mask1
);
result
.
reg
.
val
[
2
]
=
vec_sel
(
this
->
reg
.
val
[
2
],
vec_min
(
this
->
reg
.
val
[
2
],
b
.
reg
.
val
[
2
]),
mask2
);
result
.
reg
.
val
[
3
]
=
vec_sel
(
this
->
reg
.
val
[
3
],
vec_min
(
this
->
reg
.
val
[
3
],
b
.
reg
.
val
[
3
]),
mask3
);
return
FP32Vec16
(
result
.
reg
);
}
FP32Vec16
abs
()
const
{
return
FP32Vec16
(
f32x4x4_t
({
vec_abs
(
reg
.
val
[
0
]),
vec_abs
(
reg
.
val
[
1
]),
vec_abs
(
reg
.
val
[
2
]),
vec_abs
(
reg
.
val
[
3
])}));
}
float
reduce_max
()
{
__vector
float
max01
=
vec_max
(
reg
.
val
[
0
],
reg
.
val
[
1
]);
__vector
float
max23
=
vec_max
(
reg
.
val
[
2
],
reg
.
val
[
3
]);
__vector
float
max_all
=
vec_max
(
max01
,
max23
);
__vector
float
temp
=
vec_max
(
max_all
,
vec_sld
(
max_all
,
max_all
,
8
));
temp
=
vec_max
(
temp
,
vec_sld
(
temp
,
temp
,
4
));
return
vec_extract
(
temp
,
0
);
}
float
reduce_min
()
{
__vector
float
min01
=
vec_min
(
reg
.
val
[
0
],
reg
.
val
[
1
]);
__vector
float
min23
=
vec_min
(
reg
.
val
[
2
],
reg
.
val
[
3
]);
__vector
float
min_all
=
vec_min
(
min01
,
min23
);
__vector
float
temp
=
vec_min
(
min_all
,
vec_sld
(
min_all
,
min_all
,
8
));
temp
=
vec_min
(
temp
,
vec_sld
(
temp
,
temp
,
4
));
return
vec_extract
(
temp
,
0
);
}
float
reduce_sum
()
const
{
AliasReg
ar
;
ar
.
reg
=
reg
;
...
...
@@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vec_xst
(
reg
.
val
[
2
],
32
,
ptr
);
vec_xst
(
reg
.
val
[
3
],
48
,
ptr
);
}
void
save
(
float
*
ptr
,
const
int
elem_num
)
const
{
const
int
elements_in_chunk1
=
(
elem_num
>=
0
)
?
((
elem_num
>=
4
)
?
4
:
elem_num
)
:
0
;
const
int
elements_in_chunk2
=
(
elem_num
>
4
)
?
((
elem_num
>=
8
)
?
4
:
elem_num
-
4
)
:
0
;
const
int
elements_in_chunk3
=
(
elem_num
>
8
)
?
((
elem_num
>=
12
)
?
4
:
elem_num
-
8
)
:
0
;
const
int
elements_in_chunk4
=
(
elem_num
>
12
)
?
((
elem_num
>=
16
)
?
4
:
elem_num
-
12
)
:
0
;
const
size_t
bytes_chunk1
=
static_cast
<
size_t
>
(
elements_in_chunk1
*
sizeof
(
float
));
const
size_t
bytes_chunk2
=
static_cast
<
size_t
>
(
elements_in_chunk2
*
sizeof
(
float
));
const
size_t
bytes_chunk3
=
static_cast
<
size_t
>
(
elements_in_chunk3
*
sizeof
(
float
));
const
size_t
bytes_chunk4
=
static_cast
<
size_t
>
(
elements_in_chunk4
*
sizeof
(
float
));
vec_xst_len
(
reg
.
val
[
0
],
ptr
,
bytes_chunk1
);
vec_xst_len
(
reg
.
val
[
1
],
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
char
*>
(
ptr
)
+
16
),
bytes_chunk2
);
vec_xst_len
(
reg
.
val
[
2
],
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
char
*>
(
ptr
)
+
32
),
bytes_chunk3
);
vec_xst_len
(
reg
.
val
[
3
],
reinterpret_cast
<
float
*>
(
reinterpret_cast
<
char
*>
(
ptr
)
+
48
),
bytes_chunk4
);
}
};
struct
INT8Vec16
:
public
Vec
<
INT8Vec16
>
{
constexpr
static
int
VEC_NUM_ELEM
=
16
;
// 128 bits / 8 bits = 16
union
AliasReg
{
__vector
signed
char
reg
;
int8_t
values
[
VEC_NUM_ELEM
];
};
__vector
signed
char
reg
;
explicit
INT8Vec16
(
const
FP32Vec16
&
vec
)
{
__vector
signed
int
ret
[
4
];
ret
[
0
]
=
vec_cts
(
vec
.
reg
.
val
[
0
],
0
);
ret
[
1
]
=
vec_cts
(
vec
.
reg
.
val
[
1
],
0
);
ret
[
2
]
=
vec_cts
(
vec
.
reg
.
val
[
2
],
0
);
ret
[
3
]
=
vec_cts
(
vec
.
reg
.
val
[
3
],
0
);
__vector
signed
short
packed1
=
vec_packs
(
ret
[
0
],
ret
[
1
]);
__vector
signed
short
packed2
=
vec_packs
(
ret
[
2
],
ret
[
3
]);
reg
=
vec_packs
(
packed1
,
packed2
);
}
void
save
(
void
*
ptr
)
const
{
*
reinterpret_cast
<
__vector
signed
char
*>
(
ptr
)
=
reg
;
}
void
save
(
signed
char
*
ptr
,
const
int
elem_num
)
{
vec_xst_len
(
reg
,
ptr
,
static_cast
<
size_t
>
(
elem_num
));
}
};
template
<
typename
T
>
...
...
csrc/cpu/pos_encoding.cpp
View file @
7a985548
...
...
@@ -9,7 +9,8 @@ void rotary_embedding_impl(
scalar_t
*
__restrict__
query
,
/// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
...
...
@@ -85,10 +86,13 @@ void rotary_embedding_impl(
compute_loop
(
token_head
,
cache_ptr
,
query
);
}
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
const
int
head_idx
=
i
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
compute_loop
(
token_head
,
cache_ptr
,
key
);
if
(
key
!=
nullptr
)
{
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
const
int
head_idx
=
i
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
compute_loop
(
token_head
,
cache_ptr
,
key
);
}
}
}
}
...
...
@@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl(
scalar_t
*
__restrict__
query
,
/// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads,
/// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
scalar_t
*
__restrict__
key
,
// nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
...
...
@@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl(
}
}
if
(
key
==
nullptr
)
{
return
;
}
#pragma omp parallel for collapse(2)
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
...
...
@@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl(
};
// namespace
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int64_t
head_size
,
std
::
optional
<
torch
::
Tensor
>
key
,
int64_t
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
int
num_tokens
=
positions
.
numel
();
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
-
1
)
/
head_size
;
int64_t
key_stride
=
key
.
stride
(
-
2
);
int
num_kv_heads
=
key
.
has_value
()
?
key
->
size
(
-
1
)
/
head_size
:
num_heads
;
int64_t
key_stride
=
key
.
has_value
()
?
key
->
stride
(
-
2
)
:
0
;
int64_t
query_stride
=
query
.
stride
(
-
2
);
VLLM_DISPATCH_FLOATING_TYPES
(
...
...
@@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
if
(
is_neox
)
{
rotary_embedding_impl
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
,
num_tokens
);
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
,
num_tokens
);
}
else
{
rotary_embedding_gptj_impl
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
,
num_tokens
);
key
.
has_value
()
?
key
->
data_ptr
<
scalar_t
>
()
:
nullptr
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
,
num_tokens
);
}
CPU_KERNEL_GUARD_OUT
(
rotary_embedding_impl
)
...
...
Prev
1
2
3
4
5
6
7
8
…
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment