Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5753ff5
Commit
a5753ff5
authored
Jun 19, 2024
by
zhuwenwen
Browse files
v0.5.0.post1
parents
21c06ecb
0f0d8bc0
Changes
108
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1082 additions
and
613 deletions
+1082
-613
vllm/_custom_ops.py
vllm/_custom_ops.py
+50
-7
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+6
-7
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+232
-0
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+16
-7
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+120
-0
vllm/attention/selector.py
vllm/attention/selector.py
+14
-2
vllm/config.py
vllm/config.py
+8
-4
vllm/core/scheduler.py
vllm/core/scheduler.py
+10
-10
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+13
-298
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+6
-10
vllm/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+5
-5
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+3
-8
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+567
-242
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+3
-0
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+5
-2
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+12
-5
vllm/entrypoints/openai/api_server.py
vllm/entrypoints/openai/api_server.py
+3
-3
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+2
-2
vllm/envs.py
vllm/envs.py
+6
-0
No files found.
vllm/_custom_ops.py
View file @
a5753ff5
import
contextlib
import
functools
from
typing
import
List
,
Optional
,
Tuple
,
Type
import
torch
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
try
:
import
vllm._C
except
ImportError
as
e
:
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
with
contextlib
.
suppress
(
ImportError
):
...
...
@@ -23,6 +26,25 @@ def is_custom_op_supported(op_name: str) -> bool:
return
op
is
not
None
def
hint_on_error
(
fn
):
@
functools
.
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
fn
(
*
args
,
**
kwargs
)
except
AttributeError
as
e
:
msg
=
(
"Error in calling custom op %s: %s
\n
"
"Possibly you have built or installed an obsolete version of vllm.
\n
"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger
.
error
(
msg
,
fn
.
__name__
,
e
)
raise
e
return
wrapper
# activation ops
def
silu_and_mul
(
out
:
torch
.
Tensor
,
x
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C
.
silu_and_mul
(
out
,
x
)
...
...
@@ -190,8 +212,8 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass
def
cutlass_scaled_mm
_dq
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
def
cutlass_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
])
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
...
...
@@ -200,8 +222,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
torch
.
ops
.
_C
.
cutlass_scaled_mm_dq
(
out
,
a
,
b
,
scale_a
,
scale_b
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
out
,
a
,
b
,
scale_a
,
scale_b
)
return
out
...
...
@@ -459,3 +480,25 @@ def dispatch_bgmv_low_level(
h_out
,
y_offset
,
)
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values
=
globals
()
names_and_values_to_update
=
{}
# prepare variables to avoid dict size change during iteration
k
,
v
,
arg
=
None
,
None
,
None
fn_type
=
type
(
lambda
x
:
x
)
for
k
,
v
in
names_and_values
.
items
():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if
isinstance
(
v
,
fn_type
)
\
and
v
.
__code__
.
co_filename
==
__file__
\
and
any
(
arg
is
torch
.
Tensor
or
arg
==
"torch.Tensor"
for
arg
in
v
.
__annotations__
.
values
()):
names_and_values_to_update
[
k
]
=
hint_on_error
(
v
)
names_and_values
.
update
(
names_and_values_to_update
)
del
names_and_values_to_update
,
names_and_values
,
v
,
k
,
fn_type
vllm/attention/backends/flash_attn.py
View file @
a5753ff5
...
...
@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
flash_attn_varlen_func
(
out
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
...
...
@@ -329,13 +329,14 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
out
=
output
[:
num_prefill_tokens
],
)
assert
output
[:
num_prefill_tokens
].
shape
==
out
.
shape
output
[:
num_prefill_tokens
]
=
out
else
:
# prefix-enabled attention
assert
prefill_meta
.
seq_lens
is
not
None
max_seq_len
=
max
(
prefill_meta
.
seq_lens
)
flash_attn_varlen_func
(
output
[:
num_prefill_tokens
]
=
flash_attn_varlen_func
(
q
=
query
,
k
=
key_cache
,
v
=
value_cache
,
...
...
@@ -347,12 +348,11 @@ class FlashAttentionImpl(AttentionImpl):
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
block_table
=
prefill_meta
.
block_tables
,
out
=
output
[:
num_prefill_tokens
],
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
# Decoding run.
flash_attn_with_kvcache
(
output
[
num_prefill_tokens
:]
=
flash_attn_with_kvcache
(
decode_query
.
unsqueeze
(
1
),
key_cache
,
value_cache
,
...
...
@@ -361,8 +361,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
out
=
output
[
num_prefill_tokens
:].
unsqueeze
(
1
),
)
).
squeeze
(
1
)
# Reshape the output tensor.
return
output
.
view
(
num_tokens
,
hidden_size
)
vllm/attention/backends/pallas.py
0 → 100644
View file @
a5753ff5
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
torch_xla.experimental.custom_kernel
# Required to register custom ops.
import
torch_xla.experimental.dynamo_set_buffer_donor
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
class
PallasAttentionBackend
(
AttentionBackend
):
@
staticmethod
def
get_impl_cls
()
->
Type
[
"PallasAttentionBackendImpl"
]:
return
PallasAttentionBackendImpl
@
staticmethod
def
make_metadata
(
*
args
,
**
kwargs
)
->
"PallasMetadata"
:
return
PallasMetadata
(
*
args
,
**
kwargs
)
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_kv_heads
,
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
)
->
None
:
raise
NotImplementedError
(
"swap_blocks is not implemented."
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
)
->
None
:
# TODO(woosuk): Implement this.
raise
NotImplementedError
(
"copy_blocks is not implemented."
)
@
dataclass
class
PallasMetadata
(
AttentionMetadata
):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables
:
Optional
[
torch
.
Tensor
]
context_lens
:
Optional
[
torch
.
Tensor
]
@
property
def
prefill_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_prefills
==
0
:
return
None
assert
self
.
num_decode_tokens
==
0
assert
self
.
block_tables
is
None
assert
self
.
context_lens
is
None
return
self
@
property
def
decode_metadata
(
self
)
->
Optional
[
"PallasMetadata"
]:
if
self
.
num_decode_tokens
==
0
:
return
None
assert
self
.
num_prefills
==
0
assert
self
.
num_prefill_tokens
==
0
assert
self
.
block_tables
is
not
None
assert
self
.
context_lens
is
not
None
return
self
class
PallasAttentionBackendImpl
(
AttentionImpl
):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
None
:
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
num_kv_heads
=
num_heads
if
num_kv_heads
is
None
else
num_kv_heads
assert
self
.
num_heads
%
self
.
num_kv_heads
==
0
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
if
head_size
%
128
!=
0
:
raise
NotImplementedError
(
"Head size must be a multiple of 128."
)
if
alibi_slopes
is
not
None
:
raise
NotImplementedError
(
"Alibi slopes is not supported."
)
if
sliding_window
is
not
None
:
raise
NotImplementedError
(
"Sliding window is not supported."
)
if
kv_cache_dtype
!=
"auto"
:
raise
NotImplementedError
(
"FP8 KV cache dtype is not supported."
)
if
blocksparse_params
is
not
None
:
raise
NotImplementedError
(
"Blocksparse is not supported."
)
if
torch_xla
.
tpu
.
version
()
<
4
:
raise
NotImplementedError
(
"TPU version must be 4 or higher."
)
self
.
megacore_mode
=
None
tpu_type
=
torch_xla
.
tpu
.
get_tp_groupu_env
()[
"TYPE"
].
lower
()
if
not
tpu_type
.
endswith
(
"lite"
):
if
self
.
num_kv_heads
%
2
==
0
:
self
.
megacore_mode
=
"kv_head"
else
:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self
.
megacore_mode
=
"batch"
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]],
attn_metadata
:
PallasMetadata
,
kv_scale
:
float
=
1.0
,
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert
kv_scale
==
1.0
batch_size
,
seq_len
,
hidden_size
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_size
)
if
kv_cache
[
0
]
is
not
None
:
slot_mapping
=
attn_metadata
.
slot_mapping
key_cache
,
value_cache
=
kv_cache
write_to_kv_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
)
query
=
query
*
self
.
scale
if
attn_metadata
.
num_prefills
>
0
:
assert
seq_len
%
16
==
0
,
(
"Pallas FlashAttention kernel requires seq_len to be a "
f
"multiple of 16 but got
{
seq_len
}
"
)
# Handle GQA/MQA.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
key
=
key
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
value
=
value
.
repeat_interleave
(
self
.
num_queries_per_kv
,
dim
=-
2
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_size
)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output
=
torch
.
ops
.
xla
.
flash_attention
(
query
.
permute
(
0
,
2
,
1
,
3
),
key
.
permute
(
0
,
2
,
1
,
3
),
value
.
permute
(
0
,
2
,
1
,
3
),
True
,
)
output
=
output
.
permute
(
0
,
2
,
1
,
3
)
else
:
# Decoding run.
assert
kv_cache
is
not
None
pages_per_compute_block
=
16
# TODO(woosuk): Tune this value.
if
self
.
megacore_mode
==
"batch"
and
batch_size
%
2
!=
0
:
megacore_mode
=
None
else
:
megacore_mode
=
self
.
megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if
megacore_mode
is
not
None
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
.
squeeze
(
dim
=
1
),
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
megacore_mode
=
megacore_mode
,
)
else
:
output
=
torch
.
ops
.
xla
.
paged_attention
(
query
.
squeeze
(
dim
=
1
),
key_cache
,
value_cache
,
attn_metadata
.
context_lens
,
attn_metadata
.
block_tables
,
pages_per_compute_block
,
)
# Reshape the output tensor.
return
output
.
reshape
(
batch_size
,
seq_len
,
hidden_size
)
def
write_to_kv_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
key_cache
,
True
)
torch
.
ops
.
xla
.
dynamo_set_buffer_donor_
(
value_cache
,
True
)
key
=
key
.
flatten
(
0
,
2
)
value
=
value
.
flatten
(
0
,
2
)
key_cache
=
key_cache
.
flatten
(
0
,
2
)
value_cache
=
value_cache
.
flatten
(
0
,
2
)
key_cache
.
index_copy_
(
0
,
slot_mapping
,
key
)
value_cache
.
index_copy_
(
0
,
slot_mapping
,
value
)
vllm/attention/backends/torch_sdpa.py
View file @
a5753ff5
...
...
@@ -8,8 +8,16 @@ from torch.nn.functional import scaled_dot_product_attention
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
)
from
vllm.attention.ops.paged_attn
import
(
PagedAttention
,
PagedAttentionMetadata
)
from
vllm.attention.ops.paged_attn
import
PagedAttentionMetadata
from
vllm.utils
import
is_cpu
if
is_cpu
():
try
:
from
vllm.attention.ops.ipex_attn
import
PagedAttention
except
ImportError
:
from
vllm.attention.ops.paged_attn
import
PagedAttention
else
:
from
vllm.attention.ops.paged_attn
import
PagedAttention
class
TorchSDPABackend
(
AttentionBackend
):
...
...
@@ -197,13 +205,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata
.
attn_bias
):
end
=
start
+
seq_len
sub_out
=
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
value
[:,
start
:
end
,
:],
query
[
None
,
:,
start
:
end
,
:],
key
[
None
,
:,
start
:
end
,
:],
value
[
None
,
:,
start
:
end
,
:],
attn_mask
=
mask
,
dropout_p
=
0.0
,
is_causal
=
not
self
.
need_mask
,
scale
=
self
.
scale
).
movedim
(
query
.
dim
()
-
2
,
0
)
scale
=
self
.
scale
).
squeeze
(
0
).
movedim
(
query
.
dim
()
-
2
,
0
)
output
[
start
:
end
,
:,
:]
=
sub_out
start
=
end
else
:
...
...
@@ -248,7 +257,7 @@ def _make_alibi_bias(
num_heads
=
alibi_slopes
.
shape
[
0
]
bias
=
bias
[
None
,
:].
repeat
((
num_heads
,
1
,
1
))
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
bias
.
mul_
(
alibi_slopes
[:,
None
,
None
])
.
unsqueeze_
(
0
)
inf_mask
=
torch
.
empty
(
(
1
,
seq_len
,
seq_len
),
dtype
=
bias
.
dtype
).
fill_
(
-
torch
.
inf
).
triu_
(
diagonal
=
1
)
...
...
vllm/attention/ops/ipex_attn.py
0 → 100644
View file @
a5753ff5
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
intel_extension_for_pytorch.llm.modules
as
ipex_modules
import
torch
from
vllm
import
_custom_ops
as
ops
class
PagedAttention
:
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
int
,
...]:
return
(
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
)
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
num_kv_heads
:
int
,
head_size
:
int
,
*
args
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
num_blocks
=
kv_cache
.
shape
[
1
]
key_cache
=
kv_cache
[
0
]
key_cache
=
key_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
value_cache
=
kv_cache
[
1
]
value_cache
=
value_cache
.
view
(
num_blocks
,
num_kv_heads
,
-
1
,
head_size
)
return
key_cache
,
value_cache
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
*
args
,
)
->
None
:
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
.
flatten
().
int
())
@
staticmethod
def
forward_decode
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
:
float
,
*
args
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
2
]
head_mapping
=
torch
.
arange
(
0
,
num_kv_heads
,
device
=
"cpu"
,
dtype
=
torch
.
int32
,
).
view
(
num_kv_heads
,
1
).
repeat_interleave
(
query
.
size
(
1
)
//
num_kv_heads
).
flatten
()
ipex_modules
.
PagedAttention
.
single_query_cached_kv_attention
(
output
,
query
.
contiguous
(),
key_cache
,
value_cache
,
head_mapping
,
scale
,
block_tables
,
context_lens
,
block_size
,
max_context_len
,
alibi_slopes
)
return
output
@
staticmethod
def
forward_prefix
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
subquery_start_loc
:
torch
.
Tensor
,
prompt_lens_tensor
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_subquery_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
*
args
,
)
->
torch
.
Tensor
:
raise
NotImplementedError
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
Dict
[
int
,
int
],
*
args
,
)
->
None
:
raise
NotImplementedError
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
Dict
[
int
,
List
[
int
]],
*
args
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
vllm/attention/selector.py
View file @
a5753ff5
...
...
@@ -7,7 +7,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.attention.backends.abstract
import
AttentionBackend
from
vllm.logger
import
init_logger
from
vllm.utils
import
is_cpu
,
is_hip
from
vllm.utils
import
is_cpu
,
is_hip
,
is_tpu
logger
=
init_logger
(
__name__
)
...
...
@@ -18,6 +18,7 @@ class _Backend(enum.Enum):
ROCM_FLASH
=
enum
.
auto
()
TORCH_SDPA
=
enum
.
auto
()
FLASHINFER
=
enum
.
auto
()
PALLAS
=
enum
.
auto
()
@
lru_cache
(
maxsize
=
None
)
...
...
@@ -57,6 +58,9 @@ def get_attn_backend(
ROCmFlashAttentionBackend
)
return
ROCmFlashAttentionBackend
elif
backend
==
_Backend
.
TORCH_SDPA
:
# TODO: make XPU backend available here.
assert
is_cpu
(),
RuntimeError
(
"Torch SDPA backend is only used for the CPU device."
)
logger
.
info
(
"Using Torch SDPA backend."
)
from
vllm.attention.backends.torch_sdpa
import
TorchSDPABackend
return
TorchSDPABackend
...
...
@@ -66,6 +70,10 @@ def get_attn_backend(
"Please make sure --enforce-eager is set."
)
from
vllm.attention.backends.flashinfer
import
FlashInferBackend
return
FlashInferBackend
elif
backend
==
_Backend
.
PALLAS
:
logger
.
info
(
"Using Pallas backend."
)
from
vllm.attention.backends.pallas
import
PallasAttentionBackend
return
PallasAttentionBackend
else
:
raise
ValueError
(
"Invalid attention backend."
)
...
...
@@ -80,7 +88,6 @@ def which_attn_to_use(
block_size
:
int
,
)
->
_Backend
:
"""Returns which flash attention backend to use."""
# Default case.
selected_backend
=
_Backend
.
FLASH_ATTN
...
...
@@ -100,6 +107,11 @@ def which_attn_to_use(
logger
.
info
(
"Cannot use %s backend on CPU."
,
selected_backend
)
return
_Backend
.
TORCH_SDPA
if
is_tpu
():
if
selected_backend
!=
_Backend
.
PALLAS
:
logger
.
info
(
"Cannot use %s backend on TPU."
,
selected_backend
)
return
_Backend
.
PALLAS
if
is_hip
():
# AMD GPUs.
selected_backend
=
(
_Backend
.
ROCM_FLASH
if
selected_backend
...
...
vllm/config.py
View file @
a5753ff5
...
...
@@ -11,7 +11,8 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.transformers_utils.config
import
get_config
,
get_hf_text_config
from
vllm.utils
import
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
from
vllm.utils
import
(
cuda_device_count_stateless
,
get_cpu_memory
,
is_cpu
,
is_hip
,
is_neuron
,
is_tpu
)
if
TYPE_CHECKING
:
from
ray.util.placement_group
import
PlacementGroup
...
...
@@ -212,7 +213,7 @@ class ModelConfig:
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
if
(
self
.
quantization
not
in
[
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
]
):
not
in
(
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
)
):
logger
.
warning
(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
...
...
@@ -605,12 +606,11 @@ class ParallelConfig:
if
self
.
distributed_executor_backend
is
None
and
self
.
world_size
>
1
:
# We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group.
from
torch.cuda
import
device_count
from
vllm.executor
import
ray_utils
backend
=
"mp"
ray_found
=
ray_utils
.
ray
is
not
None
if
device_count
()
<
self
.
world_size
:
if
cuda_
device_count
_stateless
()
<
self
.
world_size
:
if
not
ray_found
:
raise
ValueError
(
"Unable to load Ray which is "
"required for multi-node inference"
)
...
...
@@ -748,6 +748,8 @@ class DeviceConfig:
# Automated device type detection
if
is_neuron
():
self
.
device_type
=
"neuron"
elif
is_tpu
():
self
.
device_type
=
"tpu"
elif
is_cpu
():
self
.
device_type
=
"cpu"
else
:
...
...
@@ -761,6 +763,8 @@ class DeviceConfig:
# Some device types require processing inputs on CPU
if
self
.
device_type
in
[
"neuron"
]:
self
.
device
=
torch
.
device
(
"cpu"
)
elif
self
.
device_type
in
[
"tpu"
]:
self
.
device
=
None
else
:
# Set device with device type
self
.
device
=
torch
.
device
(
self
.
device_type
)
...
...
vllm/core/scheduler.py
View file @
a5753ff5
...
...
@@ -50,8 +50,8 @@ class SchedulingBudget:
"""
token_budget
:
int
max_num_seqs
:
int
_reques
e
t_ids_num_batched_tokens
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_reques
e
t_ids_num_curr_seqs
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_request_ids_num_batched_tokens
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_request_ids_num_curr_seqs
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_num_batched_tokens
:
int
=
0
_num_curr_seqs
:
int
=
0
...
...
@@ -65,28 +65,28 @@ class SchedulingBudget:
return
self
.
token_budget
-
self
.
num_batched_tokens
def
add_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
if
req_id
in
self
.
_reques
e
t_ids_num_batched_tokens
:
if
req_id
in
self
.
_request_ids_num_batched_tokens
:
return
self
.
_reques
e
t_ids_num_batched_tokens
.
add
(
req_id
)
self
.
_request_ids_num_batched_tokens
.
add
(
req_id
)
self
.
_num_batched_tokens
+=
num_batched_tokens
def
subtract_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
if
req_id
in
self
.
_reques
e
t_ids_num_batched_tokens
:
self
.
_reques
e
t_ids_num_batched_tokens
.
remove
(
req_id
)
if
req_id
in
self
.
_request_ids_num_batched_tokens
:
self
.
_request_ids_num_batched_tokens
.
remove
(
req_id
)
self
.
_num_batched_tokens
-=
num_batched_tokens
def
add_num_seqs
(
self
,
req_id
:
str
,
num_curr_seqs
:
int
):
if
req_id
in
self
.
_reques
e
t_ids_num_curr_seqs
:
if
req_id
in
self
.
_request_ids_num_curr_seqs
:
return
self
.
_reques
e
t_ids_num_curr_seqs
.
add
(
req_id
)
self
.
_request_ids_num_curr_seqs
.
add
(
req_id
)
self
.
_num_curr_seqs
+=
num_curr_seqs
def
subtract_num_seqs
(
self
,
req_id
:
str
,
num_curr_seqs
:
int
):
if
req_id
in
self
.
_reques
e
t_ids_num_curr_seqs
:
self
.
_reques
e
t_ids_num_curr_seqs
.
remove
(
req_id
)
if
req_id
in
self
.
_request_ids_num_curr_seqs
:
self
.
_request_ids_num_curr_seqs
.
remove
(
req_id
)
self
.
_num_curr_seqs
-=
num_curr_seqs
@
property
...
...
vllm/distributed/communication_op.py
View file @
a5753ff5
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Union
import
torch
from
torch.distributed
import
ProcessGroup
import
torch.distributed
from
.parallel_state
import
(
get_cpu_world_group
,
get_pp_pynccl_communicator
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tp_ca_communicator
,
get_tp_pynccl_communicator
)
@
dataclass
class
GraphCaptureContext
:
stream
:
torch
.
cuda
.
Stream
@
contextmanager
def
graph_capture
():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream
=
torch
.
cuda
.
Stream
()
graph_capture_context
=
GraphCaptureContext
(
stream
)
ca_comm
=
get_tp_ca_communicator
()
maybe_ca_context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
with
torch
.
cuda
.
stream
(
stream
),
maybe_ca_context
:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
tp_pynccl_comm
=
get_tp_pynccl_communicator
()
pp_pynccl_comm
=
get_pp_pynccl_communicator
()
if
not
tp_pynccl_comm
:
maybe_tp_pynccl_context
=
nullcontext
()
else
:
maybe_tp_pynccl_context
=
tp_pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
if
not
pp_pynccl_comm
:
maybe_pp_pynccl_context
=
nullcontext
()
else
:
maybe_pp_pynccl_context
=
pp_pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
with
maybe_tp_pynccl_context
,
maybe_pp_pynccl_context
:
yield
graph_capture_context
from
.parallel_state
import
get_tp_group
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group.
NOTE: This operation will be applied in-place on the input tensor if
disable_custom_all_reduce is set to True. Otherwise, this operation may or
may not be applied in place depending on whether custom all reduce is
invoked for a particular tensor, which further depends on the tensor size
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm
=
get_tp_ca_communicator
()
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
if
ca_comm
is
not
None
:
out
=
ca_comm
.
custom_all_reduce
(
input_
)
if
out
is
not
None
:
return
out
pynccl_comm
=
get_tp_pynccl_communicator
()
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
pynccl_comm
.
all_reduce
(
input_
)
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_tensor_model_parallel_group
())
return
input_
"""All-reduce the input tensor across model parallel group."""
return
get_tp_group
().
all_reduce
(
input_
)
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
input_size
=
input_
.
size
()
# Allocate output tensor.
output_tensor
=
torch
.
empty
((
world_size
,
)
+
input_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
# All-gather.
torch
.
distributed
.
all_gather_into_tensor
(
output_tensor
,
input_
,
group
=
get_tensor_model_parallel_group
())
# Reshape
output_tensor
=
output_tensor
.
movedim
(
0
,
dim
)
output_tensor
=
output_tensor
.
reshape
(
input_size
[:
dim
]
+
(
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
return
output_tensor
return
get_tp_group
().
all_gather
(
input_
,
dim
)
def
tensor_model_parallel_gather
(
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# Allocate output tensor.
if
get_tensor_model_parallel_rank
()
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
else
:
gather_list
=
None
# Gather.
torch
.
distributed
.
gather
(
input_
,
gather_list
,
dst
=
dst
,
group
=
get_tensor_model_parallel_group
())
if
get_tensor_model_parallel_rank
()
==
dst
:
output_tensor
=
torch
.
cat
(
gather_list
,
dim
=
dim
)
else
:
output_tensor
=
None
return
output_tensor
def
broadcast
(
input_
:
torch
.
Tensor
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
):
"""Broadcast the input tensor."""
group
=
group
or
torch
.
distributed
.
group
.
WORLD
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
input_
# Broadcast.
torch
.
distributed
.
broadcast
(
input_
,
src
=
src
,
group
=
group
)
return
input_
def
broadcast_object_list
(
obj_list
:
List
[
Any
],
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
):
"""Broadcast the input object list."""
group
=
group
or
torch
.
distributed
.
group
.
WORLD
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
obj_list
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
,
group
=
group
)
return
obj_list
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"device"
,
"dtype"
,
"size"
])
def
_split_tensor_dict
(
tensor_dict
:
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]
)
->
Tuple
[
List
[
Tuple
[
str
,
Any
]],
List
[
torch
.
Tensor
]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list
=
[]
tensor_list
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device
=
"cpu"
if
value
.
is_cpu
else
"cuda"
metadata_list
.
append
(
(
key
,
TensorMetadata
(
device
,
value
.
dtype
,
value
.
size
())))
tensor_list
.
append
(
value
)
else
:
metadata_list
.
append
((
key
,
value
))
return
metadata_list
,
tensor_list
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
,
metadata_group
:
Optional
[
ProcessGroup
]
=
None
)
->
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]:
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
# Bypass the function if we are using only 1 GPU.
if
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
(
group
=
group
)
==
1
):
return
tensor_dict
group
=
group
or
torch
.
distributed
.
group
.
WORLD
metadata_group
=
metadata_group
or
get_cpu_world_group
()
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
"""Gather the input tensor across model parallel group."""
return
get_tp_group
().
gather
(
input_
,
dst
,
dim
)
rank
=
torch
.
distributed
.
get_rank
()
if
rank
==
src
:
metadata_list
:
List
[
Tuple
[
Any
,
Any
]]
=
[]
assert
isinstance
(
tensor_dict
,
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
metadata_list
,
tensor_list
=
_split_tensor_dict
(
tensor_dict
)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` involves serialization and deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
,
group
=
metadata_group
)
async_handles
=
[]
for
tensor
in
tensor_list
:
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
metadata_group
,
async_op
=
True
)
else
:
# use group for GPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
for
async_handle
in
async_handles
:
async_handle
.
wait
()
else
:
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
,
group
=
metadata_group
)
assert
recv_metadata_list
[
0
]
is
not
None
tensor_dict
=
{}
async_handles
=
[]
for
key
,
value
in
recv_metadata_list
[
0
]:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
value
.
device
)
if
tensor
.
numel
()
==
0
:
# Skip broadcasting empty tensors.
tensor_dict
[
key
]
=
tensor
continue
if
tensor
.
is_cpu
:
# use metadata_group for CPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
metadata_group
,
async_op
=
True
)
else
:
# use group for GPU tensors
handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
group
=
group
,
async_op
=
True
)
async_handles
.
append
(
handle
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
for
async_handle
in
async_handles
:
async_handle
.
wait
()
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
):
if
not
torch
.
distributed
.
is_initialized
():
return
tensor_dict
return
get_tp_group
().
broadcast_tensor_dict
(
tensor_dict
,
src
)
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
a5753ff5
...
...
@@ -9,9 +9,9 @@ import vllm.envs as envs
from
vllm
import
_custom_ops
as
ops
from
vllm.distributed.device_communicators.custom_all_reduce_utils
import
(
gpu_p2p_access_check
)
from
vllm.distributed.parallel_state
import
(
get_local_rank
,
get_tensor_model_parallel_cpu_group
,
is_in_the_same_node
)
from
vllm.distributed.parallel_state
import
is_in_the_same_node
from
vllm.logger
import
init_logger
from
vllm.utils
import
cuda_device_count_stateless
try
:
import
pynvml
...
...
@@ -86,8 +86,8 @@ class CustomAllreduce:
# max_size: max supported allreduce size
def
__init__
(
self
,
group
:
Optional
[
ProcessGroup
]
=
None
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]
]
=
None
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
max_size
=
8192
*
1024
)
->
None
:
"""
Args:
...
...
@@ -107,7 +107,6 @@ class CustomAllreduce:
# e.g. in a non-cuda environment
return
group
=
group
or
get_tensor_model_parallel_cpu_group
()
self
.
group
=
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
...
...
@@ -134,10 +133,7 @@ class CustomAllreduce:
world_size
,
str
(
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
))
return
if
device
is
None
:
local_rank
=
get_local_rank
()
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
elif
isinstance
(
device
,
int
):
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
...
...
@@ -149,7 +145,7 @@ class CustomAllreduce:
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
torch
.
cuda
.
device_count
()))
device_ids
=
list
(
range
(
cuda
_
device_count
_stateless
()))
physical_device_id
=
device_ids
[
device
.
index
]
tensor
=
torch
.
tensor
([
physical_device_id
],
...
...
vllm/distributed/device_communicators/custom_all_reduce_utils.py
View file @
a5753ff5
...
...
@@ -11,8 +11,8 @@ import torch.distributed as dist
import
torch.multiprocessing
as
mp
import
vllm.envs
as
envs
from
vllm.distributed.parallel_state
import
get_cpu_world_group
,
get_local_rank
from
vllm.logger
import
init_logger
from
vllm.utils
import
cuda_device_count_stateless
logger
=
init_logger
(
__name__
)
...
...
@@ -153,7 +153,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
is_distributed
=
dist
.
is_initialized
()
num_dev
=
torch
.
cuda
.
device_count
()
num_dev
=
cuda
_
device_count
_stateless
()
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
is
None
:
cuda_visible_devices
=
","
.
join
(
str
(
i
)
for
i
in
range
(
num_dev
))
...
...
@@ -162,7 +162,8 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
f
"
{
VLLM_CONFIG_ROOT
}
/vllm/gpu_p2p_access_cache_for_
{
cuda_visible_devices
}
.json"
)
os
.
makedirs
(
os
.
path
.
dirname
(
path
),
exist_ok
=
True
)
if
((
not
is_distributed
or
get_local_rank
()
==
0
)
from
vllm.distributed.parallel_state
import
get_world_group
if
((
not
is_distributed
or
get_world_group
().
local_rank
==
0
)
and
(
not
os
.
path
.
exists
(
path
))):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
...
...
@@ -174,8 +175,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
with
open
(
path
,
"w"
)
as
f
:
json
.
dump
(
cache
,
f
,
indent
=
4
)
if
is_distributed
:
cpu_world_group
=
get_cpu_world_group
()
dist
.
barrier
(
cpu_world_group
)
get_world_group
().
barrier
()
logger
.
info
(
"reading GPU P2P access cache from %s"
,
path
)
with
open
(
path
,
"r"
)
as
f
:
cache
=
json
.
load
(
f
)
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
a5753ff5
...
...
@@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp
from
vllm.distributed.device_communicators.pynccl_wrapper
import
(
NCCLLibrary
,
buffer_type
,
cudaStream_t
,
ncclComm_t
,
ncclDataTypeEnum
,
ncclRedOpTypeEnum
,
ncclUniqueId
)
from
vllm.distributed.parallel_state
import
get_cpu_world_group
,
get_local_rank
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
...
...
@@ -19,8 +18,8 @@ class PyNcclCommunicator:
def
__init__
(
self
,
group
:
Optional
[
ProcessGroup
]
=
None
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]
]
=
None
,
group
:
ProcessGroup
,
device
:
Union
[
int
,
str
,
torch
.
device
],
library_path
:
Optional
[
str
]
=
None
,
):
"""
...
...
@@ -35,7 +34,6 @@ class PyNcclCommunicator:
is bind to a unique device.
"""
assert
dist
.
is_initialized
()
group
=
get_cpu_world_group
()
if
group
is
None
else
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"PyNcclCommunicator should be attached to a non-NCCL group."
)
self
.
group
=
group
...
...
@@ -77,10 +75,7 @@ class PyNcclCommunicator:
byte_list
=
tensor
.
tolist
()
for
i
,
byte
in
enumerate
(
byte_list
):
self
.
unique_id
.
internal
[
i
]
=
byte
if
device
is
None
:
local_rank
=
get_local_rank
()
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
elif
isinstance
(
device
,
int
):
if
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
...
...
vllm/distributed/parallel_state.py
View file @
a5753ff5
This diff is collapsed.
Click to expand it.
vllm/engine/arg_utils.py
View file @
a5753ff5
...
...
@@ -504,7 +504,7 @@ class EngineArgs:
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
EngineArgs
.
device
,
choices
=
[
"auto"
,
"cuda"
,
"neuron"
,
"cpu"
],
choices
=
[
"auto"
,
"cuda"
,
"neuron"
,
"cpu"
,
"tpu"
],
help
=
'Device type for vLLM execution.'
)
# Related to Vision-language models such as llava
...
...
vllm/engine/async_llm_engine.py
View file @
a5753ff5
...
...
@@ -375,6 +375,9 @@ class AsyncLLMEngine:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutorAsync
executor_class
=
NeuronExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
from
vllm.executor.tpu_executor
import
TPUExecutorAsync
executor_class
=
TPUExecutorAsync
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
assert
distributed_executor_backend
is
None
,
(
"Distributed execution is not supported with the CPU backend."
)
...
...
vllm/engine/llm_engine.py
View file @
a5753ff5
...
...
@@ -6,7 +6,6 @@ from typing import Type, TypeVar, Union
from
transformers
import
GenerationConfig
,
PreTrainedTokenizer
import
vllm
from
vllm.config
import
(
CacheConfig
,
DecodingConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
SpeculativeConfig
,
...
...
@@ -38,6 +37,7 @@ from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
from
vllm.usage.usage_lib
import
(
UsageContext
,
is_usage_stats_enabled
,
usage_message
)
from
vllm.utils
import
Counter
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
_LOCAL_LOGGING_INTERVAL_SEC
=
5
...
...
@@ -169,7 +169,7 @@ class LLMEngine:
"enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)"
,
vllm
.
__version__
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
model_config
.
tokenizer
,
...
...
@@ -341,6 +341,9 @@ class LLMEngine:
if
engine_config
.
device_config
.
device_type
==
"neuron"
:
from
vllm.executor.neuron_executor
import
NeuronExecutor
executor_class
=
NeuronExecutor
elif
engine_config
.
device_config
.
device_type
==
"tpu"
:
from
vllm.executor.tpu_executor
import
TPUExecutor
executor_class
=
TPUExecutor
elif
engine_config
.
device_config
.
device_type
==
"cpu"
:
from
vllm.executor.cpu_executor
import
CPUExecutor
executor_class
=
CPUExecutor
...
...
vllm/entrypoints/llm.py
View file @
a5753ff5
...
...
@@ -545,11 +545,13 @@ class LLM:
total
=
num_requests
,
desc
=
"Processed prompts"
,
dynamic_ncols
=
True
,
postfix
=
f
"Generation Speed:
{
0
:.
2
f
}
toks/s"
,
postfix
=
(
f
"est. speed input:
{
0
:.
2
f
}
toks/s, "
f
"output:
{
0
:.
2
f
}
toks/s"
),
)
# Run the engine.
outputs
:
List
[
Union
[
RequestOutput
,
EmbeddingRequestOutput
]]
=
[]
total_toks
=
0
total_in_toks
=
0
total_out_toks
=
0
while
self
.
llm_engine
.
has_unfinished_requests
():
step_outputs
=
self
.
llm_engine
.
step
()
for
output
in
step_outputs
:
...
...
@@ -558,10 +560,15 @@ class LLM:
if
use_tqdm
:
if
isinstance
(
output
,
RequestOutput
):
# Calculate tokens only for RequestOutput
total_toks
+=
sum
(
total_in_toks
+=
len
(
output
.
prompt_token_ids
)
in_spd
=
total_in_toks
/
pbar
.
format_dict
[
"elapsed"
]
total_out_toks
+=
sum
(
len
(
stp
.
token_ids
)
for
stp
in
output
.
outputs
)
spd
=
total_toks
/
pbar
.
format_dict
[
"elapsed"
]
pbar
.
postfix
=
f
"Generation Speed:
{
spd
:.
2
f
}
toks/s"
out_spd
=
total_out_toks
/
pbar
.
format_dict
[
"elapsed"
]
pbar
.
postfix
=
(
f
"est. speed input:
{
in_spd
:.
2
f
}
toks/s, "
f
"output:
{
out_spd
:.
2
f
}
toks/s"
)
pbar
.
update
(
1
)
if
use_tqdm
:
pbar
.
close
()
...
...
vllm/entrypoints/openai/api_server.py
View file @
a5753ff5
...
...
@@ -15,7 +15,6 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from
prometheus_client
import
make_asgi_app
from
starlette.routing
import
Mount
import
vllm
import
vllm.envs
as
envs
from
vllm.engine.arg_utils
import
AsyncEngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
...
@@ -29,6 +28,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from
vllm.entrypoints.openai.serving_embedding
import
OpenAIServingEmbedding
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.version
import
__version__
as
VLLM_VERSION
TIMEOUT_KEEP_ALIVE
=
5
# seconds
...
...
@@ -93,7 +93,7 @@ async def show_available_models():
@
app
.
get
(
"/version"
)
async
def
show_version
():
ver
=
{
"version"
:
vllm
.
__version__
}
ver
=
{
"version"
:
VLLM_VERSION
}
return
JSONResponse
(
content
=
ver
)
...
...
@@ -174,7 +174,7 @@ if __name__ == "__main__":
raise
ValueError
(
f
"Invalid middleware
{
middleware
}
. "
f
"Must be a function or a class."
)
logger
.
info
(
"vLLM API server version %s"
,
vllm
.
__version__
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
if
args
.
served_model_name
is
not
None
:
...
...
vllm/entrypoints/openai/run_batch.py
View file @
a5753ff5
...
...
@@ -5,7 +5,6 @@ from io import StringIO
import
aiohttp
import
vllm
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
nullable_str
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
from
vllm.entrypoints.openai.protocol
import
(
BatchRequestInput
,
...
...
@@ -15,6 +14,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from
vllm.logger
import
init_logger
from
vllm.usage.usage_lib
import
UsageContext
from
vllm.utils
import
random_uuid
from
vllm.version
import
__version__
as
VLLM_VERSION
logger
=
init_logger
(
__name__
)
...
...
@@ -135,7 +135,7 @@ async def main(args):
if
__name__
==
"__main__"
:
args
=
parse_args
()
logger
.
info
(
"vLLM API server version %s"
,
vllm
.
__version__
)
logger
.
info
(
"vLLM API server version %s"
,
VLLM_VERSION
)
logger
.
info
(
"args: %s"
,
args
)
asyncio
.
run
(
main
(
args
))
vllm/envs.py
View file @
a5753ff5
...
...
@@ -27,6 +27,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION
:
int
=
0
VLLM_ATTENTION_BACKEND
:
Optional
[
str
]
=
None
VLLM_CPU_KVCACHE_SPACE
:
int
=
0
VLLM_XLA_CACHE_PATH
:
str
=
"~/.vllm/xla_cache/"
VLLM_USE_RAY_COMPILED_DAG
:
bool
=
False
VLLM_WORKER_MULTIPROC_METHOD
:
str
=
"spawn"
VLLM_IMAGE_FETCH_TIMEOUT
:
int
=
5
...
...
@@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_IMAGE_FETCH_TIMEOUT"
,
"5"
)),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH"
:
lambda
:
os
.
getenv
(
"VLLM_XLA_CACHE_PATH"
,
"~/.vllm/xla_cache/"
),
}
# end-env-vars-definition
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment