Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
43c413ec
Unverified
Commit
43c413ec
authored
May 03, 2024
by
Lily Liu
Committed by
GitHub
May 03, 2024
Browse files
[Kernel] Use flashinfer for decoding (#4353)
Co-authored-by:
LiuXiaoxuanPKU
<
llilyliupku@gmail.com
>
parent
f8e7adda
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
600 additions
and
53 deletions
+600
-53
csrc/cache.h
csrc/cache.h
+8
-0
csrc/cache_kernels.cu
csrc/cache_kernels.cu
+80
-0
csrc/pybind.cpp
csrc/pybind.cpp
+4
-0
tests/basic_correctness/test_basic_correctness.py
tests/basic_correctness/test_basic_correctness.py
+11
-1
tests/distributed/test_basic_distributed_correctness.py
tests/distributed/test_basic_distributed_correctness.py
+9
-5
tests/kernels/conftest.py
tests/kernels/conftest.py
+7
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+77
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+12
-0
vllm/attention/backends/abstract.py
vllm/attention/backends/abstract.py
+9
-4
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+220
-0
vllm/attention/selector.py
vllm/attention/selector.py
+6
-0
vllm/config.py
vllm/config.py
+5
-0
vllm/sequence.py
vllm/sequence.py
+3
-1
vllm/utils.py
vllm/utils.py
+52
-15
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+97
-26
No files found.
csrc/cache.h
View file @
43c413ec
...
...
@@ -24,6 +24,14 @@ void reshape_and_cache(
const
std
::
string
&
kv_cache_dtype
,
const
float
kv_scale
);
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
torch
::
Tensor
&
value
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
torch
::
Tensor
&
slot_mapping
,
const
std
::
string
&
kv_cache_dtype
);
// Just for unittest
void
convert_fp8
(
torch
::
Tensor
&
src_cache
,
...
...
csrc/cache_kernels.cu
View file @
43c413ec
...
...
@@ -215,6 +215,41 @@ __global__ void reshape_and_cache_kernel(
}
}
template
<
typename
scalar_t
>
__global__
void
reshape_and_cache_flash_kernel
(
const
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
value
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
k_cache
,
// [num_blocks, block_size, num_heads, head_size]
scalar_t
*
__restrict__
v_cache
,
// [num_blocks, block_size, num_heads, head_size]
const
int64_t
*
__restrict__
slot_mapping
,
// [num_tokens]
const
int
block_stride
,
const
int
key_stride
,
const
int
value_stride
,
const
int
num_heads
,
const
int
head_size
,
const
int
block_size
)
{
const
int64_t
token_idx
=
blockIdx
.
x
;
const
int64_t
slot_idx
=
slot_mapping
[
token_idx
];
// NOTE: slot_idx can be -1 if the token is padded
if
(
slot_idx
<
0
)
{
return
;
}
const
int64_t
block_idx
=
slot_idx
/
block_size
;
const
int64_t
block_offset
=
slot_idx
%
block_size
;
const
int
n
=
num_heads
*
head_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int64_t
src_key_idx
=
token_idx
*
key_stride
+
i
;
const
int64_t
src_value_idx
=
token_idx
*
value_stride
+
i
;
const
int
head_idx
=
i
/
head_size
;
const
int
head_offset
=
i
%
head_size
;
const
int64_t
tgt_value_idx
=
block_idx
*
block_stride
+
block_offset
*
num_heads
*
head_size
+
head_idx
*
head_size
+
head_offset
;
k_cache
[
tgt_value_idx
]
=
key
[
src_key_idx
];
v_cache
[
tgt_value_idx
]
=
value
[
src_value_idx
];
}
}
}
// namespace vllm
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
...
...
@@ -275,6 +310,51 @@ void reshape_and_cache(
}
}
void
reshape_and_cache_flash
(
torch
::
Tensor
&
key
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
value
,
// [num_tokens, num_heads, head_size]
torch
::
Tensor
&
k_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
v_cache
,
// [num_blocks, block_size, num_heads, head_size]
torch
::
Tensor
&
slot_mapping
,
// [num_tokens]
const
std
::
string
&
kv_cache_dtype
)
{
// FIXME: only support auto datatype, does not support fp8
if
(
kv_cache_dtype
!=
"auto"
)
{
TORCH_CHECK
(
false
,
"Unsupported data type of kv cache: "
,
kv_cache_dtype
);
}
int
num_tokens
=
key
.
size
(
0
);
int
num_heads
=
key
.
size
(
1
);
int
head_size
=
key
.
size
(
2
);
int
block_size
=
k_cache
.
size
(
1
);
int
key_stride
=
key
.
stride
(
0
);
int
value_stride
=
value
.
stride
(
0
);
int
block_stride
=
k_cache
.
stride
(
0
);
TORCH_CHECK
(
k_cache
.
stride
(
0
)
==
v_cache
.
stride
(
0
));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
head_size
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
key
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
key
.
scalar_type
(),
"reshape_and_cache_flash"
,
[
&
]
{
vllm
::
reshape_and_cache_flash_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
value
.
data_ptr
<
scalar_t
>
(),
k_cache
.
data_ptr
<
scalar_t
>
(),
v_cache
.
data_ptr
<
scalar_t
>
(),
slot_mapping
.
data_ptr
<
int64_t
>
(),
block_stride
,
key_stride
,
value_stride
,
num_heads
,
head_size
,
block_size
);
});
}
namespace
vllm
{
template
<
typename
Tout
,
typename
Tin
>
...
...
csrc/pybind.cpp
View file @
43c413ec
...
...
@@ -96,6 +96,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache"
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"reshape_and_cache_flash"
,
&
reshape_and_cache_flash
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"convert_fp8"
,
&
convert_fp8
,
...
...
tests/basic_correctness/test_basic_correctness.py
View file @
43c413ec
...
...
@@ -2,12 +2,15 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import
os
import
pytest
MODELS
=
[
"facebook/opt-125m"
,
"meta-llama/Llama-2-7b-hf"
,
]
VLLM_ATTENTION_BACKEND
=
"VLLM_ATTENTION_BACKEND"
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
...
...
@@ -23,11 +26,18 @@ def test_models(
max_tokens
:
int
,
enforce_eager
:
bool
,
)
->
None
:
backend_by_env_var
=
os
.
getenv
(
VLLM_ATTENTION_BACKEND
)
if
backend_by_env_var
==
"FLASHINFER"
and
enforce_eager
is
False
:
pytest
.
skip
(
"Skipping non-eager test for FlashInferBackend."
)
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
enforce_eager
=
enforce_eager
,
gpu_memory_utilization
=
0.7
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
...
...
tests/distributed/test_basic_distributed_correctness.py
View file @
43c413ec
...
...
@@ -18,6 +18,7 @@ import torch
MODELS
=
[
os
.
environ
[
"TEST_DIST_MODEL"
],
]
VLLM_ATTENTION_BACKEND
=
"VLLM_ATTENTION_BACKEND"
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
...
...
@@ -33,16 +34,19 @@ def test_models(
dtype
:
str
,
max_tokens
:
int
,
)
->
None
:
enforce_eager
=
False
backend_by_env_var
=
os
.
getenv
(
VLLM_ATTENTION_BACKEND
)
if
backend_by_env_var
==
"FLASHINFER"
:
enforce_eager
=
True
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
2
,
)
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
,
tensor_parallel_size
=
2
,
enforce_eager
=
enforce_eager
)
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
del
vllm_model
...
...
tests/kernels/conftest.py
View file @
43c413ec
import
pytest
from
vllm.utils
import
create_kv_caches_with_random
from
vllm.utils
import
(
create_kv_caches_with_random
,
create_kv_caches_with_random_flash
)
@
pytest
.
fixture
()
def
kv_cache_factory
():
return
create_kv_caches_with_random
@
pytest
.
fixture
()
def
kv_cache_factory_flashinfer
():
return
create_kv_caches_with_random_flash
tests/kernels/test_cache.py
View file @
43c413ec
...
...
@@ -5,6 +5,7 @@ import pytest
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm._C
import
cache_ops
from
vllm.utils
import
is_hip
COPYING_DIRECTION
=
[(
'cuda'
,
'cpu'
),
(
'cuda'
,
'cuda'
),
(
'cpu'
,
'cuda'
)]
...
...
@@ -191,6 +192,82 @@ def test_reshape_and_cache(
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"kv_cache_dtype"
,
KV_CACHE_DTYPE
)
@
torch
.
inference_mode
()
def
test_reshape_and_cache_flash
(
kv_cache_factory_flashinfer
,
num_tokens
:
int
,
num_heads
:
int
,
head_size
:
int
,
block_size
:
int
,
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
str
,
kv_cache_dtype
:
str
,
)
->
None
:
if
kv_cache_dtype
==
"fp8"
:
pytest
.
skip
()
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
'cuda'
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
device
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory_flashinfer
(
num_blocks
,
block_size
,
1
,
num_heads
,
head_size
,
kv_cache_dtype
,
dtype
,
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Clone the KV caches.
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Call the reshape_and_cache kernel.
cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'floor'
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
for
i
in
range
(
num_tokens
):
block_idx
=
block_indicies
[
i
]
block_offset
=
block_offsets
[
i
]
cloned_key_cache
[
block_idx
,
block_offset
,
:,
:]
=
key
[
i
]
cloned_value_cache
[
block_idx
,
block_offset
,
:,
:]
=
value
[
i
]
assert
torch
.
allclose
(
key_cache
,
cloned_key_cache
)
assert
torch
.
allclose
(
value_cache
,
cloned_value_cache
)
@
pytest
.
mark
.
parametrize
(
"direction"
,
COPYING_DIRECTION
)
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
...
...
vllm/_custom_ops.py
View file @
43c413ec
...
...
@@ -222,6 +222,18 @@ def reshape_and_cache(
slot_mapping
,
kv_cache_dtype
,
kv_scale
)
def
reshape_and_cache_flash
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
)
->
None
:
vllm_cache_ops
.
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
)
def
copy_blocks
(
key_caches
:
torch
.
Tensor
,
value_caches
:
torch
.
Tensor
,
block_mapping
:
torch
.
Tensor
)
->
None
:
vllm_cache_ops
.
copy_blocks
(
key_caches
,
value_caches
,
block_mapping
)
...
...
vllm/attention/backends/abstract.py
View file @
43c413ec
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
fields
from
typing
import
Any
,
Dict
,
Generic
,
List
,
Optional
,
Tuple
,
Type
,
TypeVar
from
typing
import
(
Any
,
Dict
,
Generic
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
TypeVar
)
import
torch
...
...
@@ -15,7 +16,7 @@ class AttentionBackend(ABC):
@
staticmethod
@
abstractmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata"
:
def
make_metadata
(
*
args
,
**
kwargs
)
->
"AttentionMetadata
PerStage
"
:
raise
NotImplementedError
@
staticmethod
...
...
@@ -50,13 +51,17 @@ class AttentionBackend(ABC):
class
AttentionMetadataPerStage
:
"""Attention metadata for a specific stage. I.e., prefill or decode."""
def
asdict_zerocopy
(
self
)
->
Dict
[
str
,
Any
]:
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
)
->
Dict
[
str
,
Any
]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if
skip_fields
is
None
:
skip_fields
=
set
()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return
{
field
.
name
:
getattr
(
self
,
field
.
name
)
for
field
in
fields
(
self
)
for
field
in
fields
(
self
)
if
field
.
name
not
in
skip_fields
}
...
...
vllm/attention/backends/flashinfer.py
0 → 100644
View file @
43c413ec
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
try
:
import
flashinfer
from
flash_attn
import
flash_attn_varlen_func
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
except
ImportError
:
flashinfer
=
None
flash_attn_varlen_func
=
None
BatchDecodeWithPagedKVCacheWrapper
=
None
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadataPerStage
)
class
FlashInferBackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"FlashInferImpl"
]:
return
FlashInferImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"FlashInferMetadata"
:
return
FlashInferMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
2
,
block_size
,
num_kv_heads
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
raise
NotImplementedError
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
raise
NotImplementedError
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
128
,
256
]
@
dataclass
class
FlashInferMetadata
(
AttentionMetadataPerStage
):
is_prompt
:
bool
use_cuda_graph
:
bool
=
False
decode_wrapper
:
Optional
[
BatchDecodeWithPagedKVCacheWrapper
]
=
None
# Metadata for the prefill stage since we still
# use flash attention for prefill.
seq_start_loc
:
Optional
[
torch
.
Tensor
]
=
None
max_seq_len
:
Optional
[
int
]
=
None
block_tables
:
Optional
[
torch
.
Tensor
]
=
None
# Metadata for the decode stage
# Workspace buffer required by the kernel, the buffer should not
# be allocated/deacollated by the FalshInfermetadata object.
workspace_buffer
:
Optional
[
torch
.
Tensor
]
=
None
# An example for paged_kv_indices, paged_kv_indptr:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
# The indptr of the paged kv cache, shape: [batch_size + 1]
paged_kv_indptr
:
Optional
[
torch
.
Tensor
]
=
None
# The page indices of the paged kv cache
paged_kv_indices
:
Optional
[
torch
.
Tensor
]
=
None
# The number of entries in the last page of each request in
# the paged kv cache, shape: [batch_size]
paged_kv_last_page_len
:
Optional
[
torch
.
Tensor
]
=
None
# The number of query/output heads
num_qo_heads
:
Optional
[
int
]
=
None
# The number of key/value heads
num_kv_heads
:
Optional
[
int
]
=
None
# The dimension of the attention heads
head_dim
:
Optional
[
int
]
=
None
# Block size of vllm
page_size
:
Optional
[
int
]
=
None
# The data type of the paged kv cache
data_type
:
torch
.
dtype
=
None
def
__post_init__
(
self
):
# Refer to
# https://github.com/flashinfer-ai/flashinfer/blob/3d55c71a62052c590c130897d3a3db49b14fcc34/include/flashinfer/utils.cuh#L157
supported_head_sizes
=
FlashInferBackend
.
get_supported_head_sizes
()
if
self
.
head_dim
is
not
None
and
self
.
head_dim
\
not
in
supported_head_sizes
:
raise
ValueError
(
f
"Only
{
supported_head_sizes
}
are supported for head_dim,"
,
f
"received
{
self
.
head_dim
}
."
)
# When using flashinfer, we are also creating the FlashInferMetadata,
# which will also call post_init by default, here we want to skip the
# post_init if it's the prefill phase.
if
not
self
.
is_prompt
:
self
.
decode_wrapper
=
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
self
.
workspace_buffer
,
"NHD"
)
self
.
decode_wrapper
.
begin_forward
(
self
.
paged_kv_indptr
,
self
.
paged_kv_indices
,
self
.
paged_kv_last_page_len
,
self
.
num_qo_heads
,
self
.
num_kv_heads
,
self
.
head_dim
,
self
.
page_size
,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode
=
"NONE"
,
data_type
=
self
.
data_type
)
def
asdict_zerocopy
(
self
,
skip_fields
:
Optional
[
Set
[
str
]]
=
None
)
->
Dict
[
str
,
Any
]:
if
skip_fields
is
None
:
skip_fields
=
set
()
# We need to skip the decode_wrapper field since it cannot be
# broadcasted with nccl when TP is enabled.
skip_fields
.
add
(
'decode_wrapper'
)
return
super
().
asdict_zerocopy
(
skip_fields
)
class
FlashInferImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
Optional
[
int
]
=
None
,
alibi_slopes
:
Optional
[
List
[
float
]]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
if
sliding_window
is
not
None
:
raise
ValueError
(
"Sliding window is not supported in FlashInfer."
)
self
.
sliding_window
=
(
-
1
,
-
1
)
self
.
alibi_slopes
=
alibi_slopes
self
.
scale
=
scale
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
[
FlashInferMetadata
],
kv_scale
:
float
):
num_tokens
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
-
1
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
-
1
,
self
.
num_kv_heads
,
self
.
head_size
)
if
attn_metadata
.
num_prefill_tokens
>
0
:
assert
attn_metadata
.
num_decode_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
attn_metadata
.
num_decode_tokens
>
0
:
assert
attn_metadata
.
num_prefill_tokens
==
0
,
(
"Chunked prefill is not supported with flashinfer yet."
)
if
kv_cache
is
not
None
:
# Use the same reshape and cache kernel as flash attention.
ops
.
reshape_and_cache_flash
(
key
,
value
,
kv_cache
[:,
0
],
kv_cache
[:,
1
],
attn_metadata
.
slot_mapping
.
flatten
(),
attn_metadata
.
kv_cache_dtype
,
)
if
prefill_meta
:
=
attn_metadata
.
prefill_metadata
:
assert
prefill_meta
.
block_tables
is
not
None
if
kv_cache
is
None
or
prefill_meta
.
block_tables
.
numel
()
==
0
:
output
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
)
else
:
raise
NotImplementedError
(
"Prefix caching is not supported with flashinfer yet."
)
else
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
query
=
query
.
contiguous
(
)
# Flashinfer requires query to be contiguous
output
=
attn_metadata
.
decode_metadata
.
decode_wrapper
.
forward
(
query
,
kv_cache
,
sm_scale
=
self
.
scale
,
)
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/selector.py
View file @
43c413ec
...
...
@@ -17,6 +17,7 @@ class _Backend(enum.Enum):
XFORMERS
=
enum
.
auto
()
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
@
lru_cache
(
maxsize
=
None
)
...
...
@@ -41,6 +42,11 @@ def get_attn_backend(dtype: torch.dtype) -> Type[AttentionBackend]:
logger
.
info
(
"Using Torch SDPA backend."
)
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
return
TorchSDPABackend
elif
backend
==
_Backend
.
FLASHINFER
:
logger
.
info
(
"Using Flashinfer backend."
)
logger
.
warning
(
"Eager mode is enforced for the Flashinfer backend. "
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
else
:
raise
ValueError
(
"Invalid attention backend."
)
...
...
vllm/config.py
View file @
43c413ec
...
...
@@ -298,6 +298,11 @@ class ModelConfig:
return
max
(
1
,
total_num_kv_heads
//
parallel_config
.
tensor_parallel_size
)
def
get_num_attention_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
return
self
.
hf_text_config
.
num_attention_heads
//
\
parallel_config
.
tensor_parallel_size
def
get_num_layers
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
total_num_hidden_layers
=
self
.
hf_text_config
.
num_hidden_layers
return
total_num_hidden_layers
//
parallel_config
.
pipeline_parallel_size
...
...
vllm/sequence.py
View file @
43c413ec
...
...
@@ -579,8 +579,10 @@ class SequenceGroupMetadata:
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
state: Internal state tied to this sequence group.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
state: Internal state tied to this sequence group.
multi_modal_data: Multi modal data.
"""
...
...
vllm/utils.py
View file @
43c413ec
...
...
@@ -355,21 +355,9 @@ def _generate_random_fp8(
del
tensor_tmp
def
create_kv_caches_with_random
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
def
get_kv_cache_torch_dtype
(
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
)
->
torch
.
dtype
:
if
isinstance
(
cache_dtype
,
str
):
if
cache_dtype
==
"auto"
:
if
isinstance
(
model_dtype
,
str
):
...
...
@@ -388,6 +376,55 @@ def create_kv_caches_with_random(
torch_dtype
=
cache_dtype
else
:
raise
ValueError
(
f
"Invalid kv cache dtype:
{
cache_dtype
}
"
)
return
torch_dtype
def
create_kv_caches_with_random_flash
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
assert
cache_dtype
!=
"fp8"
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
key_value_cache_shape
=
(
num_blocks
,
2
,
block_size
,
num_heads
,
head_size
)
scale
=
head_size
**-
0.5
key_caches
,
value_caches
=
[],
[]
for
_
in
range
(
num_layers
):
key_value_cache
=
torch
.
empty
(
size
=
key_value_cache_shape
,
dtype
=
torch_dtype
,
device
=
device
)
key_value_cache
.
uniform_
(
-
scale
,
scale
)
key_caches
.
append
(
key_value_cache
[:,
0
])
value_caches
.
append
(
key_value_cache
[:,
1
])
return
key_caches
,
value_caches
def
create_kv_caches_with_random
(
num_blocks
:
int
,
block_size
:
int
,
num_layers
:
int
,
num_heads
:
int
,
head_size
:
int
,
cache_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]],
model_dtype
:
Optional
[
Union
[
str
,
torch
.
dtype
]]
=
None
,
seed
:
int
=
0
,
device
:
Optional
[
str
]
=
"cuda"
,
)
->
Tuple
[
List
[
torch
.
Tensor
],
List
[
torch
.
Tensor
]]:
torch
.
random
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed
(
seed
)
torch_dtype
=
get_kv_cache_torch_dtype
(
cache_dtype
,
model_dtype
)
scale
=
head_size
**-
0.5
x
=
16
//
torch
.
tensor
([],
dtype
=
torch_dtype
).
element_size
()
...
...
vllm/worker/model_runner.py
View file @
43c413ec
...
...
@@ -9,6 +9,7 @@ import torch.nn as nn
from
vllm.attention
import
(
AttentionMetadata
,
AttentionMetadataPerStage
,
get_attn_backend
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
from
vllm.config
import
(
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
,
with_pynccl_for_all_reduce
...
...
@@ -23,8 +24,8 @@ from vllm.model_executor.model_loader import get_model
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
MultiModalData
,
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
)
from
vllm.utils
import
(
CudaMemoryProfiler
,
get_kv_cache_torch_dtype
,
is_hip
,
is_pin_memory_available
,
make_tensor_with_pad
)
logger
=
init_logger
(
__name__
)
...
...
@@ -155,6 +156,9 @@ class ModelRunner:
# (max batch size to capture, max context len to capture / block size).
self
.
graph_block_tables
:
torch
.
Tensor
# Set after initial profiling.
# Set if the backend is flashinfer.
self
.
flashinfer_workspace_buffer
:
torch
.
Tensor
def
load_model
(
self
)
->
None
:
with
CudaMemoryProfiler
()
as
m
:
self
.
model
=
get_model
(
...
...
@@ -315,6 +319,7 @@ class ModelRunner:
# Compute the slot mapping.
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, seq_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
...
...
@@ -390,18 +395,26 @@ class ModelRunner:
dtype
=
seq_start_loc
.
dtype
,
out
=
seq_start_loc
[
1
:])
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
subquery_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
)
if
self
.
attn_backend
is
FlashInferBackend
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
use_cuda_graph
=
False
,
seq_start_loc
=
seq_start_loc
,
max_seq_len
=
max_seq_len
,
block_tables
=
block_tables
)
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
True
,
seq_lens
=
seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
subquery_start_loc
,
seq_start_loc
=
seq_start_loc
,
context_lens_tensor
=
context_lens_tensor
,
block_tables
=
block_tables
,
use_cuda_graph
=
False
,
)
return
PreparePromptMetadata
(
input_tokens
=
input_tokens
,
...
...
@@ -429,6 +442,24 @@ class ModelRunner:
lora_prompt_mapping
:
List
[
int
]
=
[]
lora_requests
:
Set
[
LoRARequest
]
=
set
()
# The following fields are only for flashinfer
# Please follow https://docs.flashinfer.ai/tutorials/kv_layout.html#page-layout
# for the precise definition of the following fields.
# An example:
# request 1, page indices [0, 5, 8]
# request 2, page indices [1, 6, 7]
# request 3, page indices [3, 4]
# paged_kv_indices is a concatenation of page indices of all requests:
# [0, 5, 8, 1, 6, 7, 3, 4]
# paged_kv_indptr is used to index into paged_kv_indices:
# [0, 3, 6, 8]
paged_kv_indices
:
List
[
int
]
=
[]
# 0 at the beginning of paged_kv_indptr indicates the start of the
# first request’s page indices in the paged_kv_indices list.
paged_kv_indptr
:
List
[
int
]
=
[
0
]
# paged_kv_last_page_len is the length of the last page of each request
paged_kv_last_page_len
:
List
[
int
]
=
[]
if
len
(
seq_group_metadata_list
)
==
0
:
return
PrepareDecodeMetadata
.
empty
()
...
...
@@ -469,6 +500,13 @@ class ModelRunner:
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
paged_kv_indices
.
extend
(
block_table
)
paged_kv_indptr
.
append
(
paged_kv_indptr
[
-
1
]
+
len
(
block_table
))
last_page_len
=
seq_data
.
get_len
()
%
self
.
block_size
if
last_page_len
==
0
:
last_page_len
=
self
.
block_size
paged_kv_last_page_len
.
append
(
last_page_len
)
# vLLM uses cuda graph only for decoding requests.
# See `capture_model` API for more details.
# For decoding requests, batch_size == input_tokens.
...
...
@@ -518,18 +556,51 @@ class ModelRunner:
device
=
self
.
device
,
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
None
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
if
self
.
attn_backend
is
FlashInferBackend
:
if
not
hasattr
(
self
,
"flashinfer_workspace_buffer"
):
# Allocate 16MB workspace buffer
# Follow the example of flashinfer: https://docs.flashinfer.ai/api/python/decode.html
self
.
flashinfer_workspace_buffer
=
torch
.
empty
(
16
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
paged_kv_indptr
=
torch
.
tensor
(
paged_kv_indptr
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
paged_kv_indices
=
torch
.
tensor
(
paged_kv_indices
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
paged_kv_last_page_len
=
torch
.
tensor
(
paged_kv_last_page_len
,
dtype
=
torch
.
int
,
device
=
self
.
device
)
kv_cache_dtype
=
get_kv_cache_torch_dtype
(
self
.
kv_cache_dtype
,
self
.
model_config
.
dtype
)
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
use_cuda_graph
=
False
,
workspace_buffer
=
self
.
flashinfer_workspace_buffer
,
paged_kv_indptr
=
paged_kv_indptr
,
paged_kv_indices
=
paged_kv_indices
,
paged_kv_last_page_len
=
paged_kv_last_page_len
,
num_qo_heads
=
self
.
model_config
.
get_num_attention_heads
(
self
.
parallel_config
),
num_kv_heads
=
self
.
model_config
.
get_num_kv_heads
(
self
.
parallel_config
),
head_dim
=
self
.
model_config
.
get_head_size
(),
page_size
=
self
.
block_size
,
data_type
=
kv_cache_dtype
)
else
:
attn_metadata
=
self
.
attn_backend
.
make_metadata
(
is_prompt
=
False
,
seq_lens
=
None
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
None
,
max_seq_len
=
max_seq_len
,
subquery_start_loc
=
None
,
seq_start_loc
=
None
,
context_lens_tensor
=
None
,
block_tables
=
block_tables
,
use_cuda_graph
=
use_captured_graph
,
)
return
PrepareDecodeMetadata
(
input_tokens
=
input_tokens
,
input_positions
=
input_positions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment