Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
fcffb7c8
Commit
fcffb7c8
authored
Jan 16, 2024
by
zhuwenwen
Browse files
Merge branch 'vllm-v0.2.7-dtk23.10'
parents
eb181638
4095d0db
Changes
56
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
253 additions
and
226 deletions
+253
-226
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+11
-6
tests/kernels/test_layernorm.py
tests/kernels/test_layernorm.py
+6
-3
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+7
-4
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+3
-2
vllm/__init__.py
vllm/__init__.py
+1
-1
vllm/config.py
vllm/config.py
+0
-6
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+36
-33
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+141
-116
vllm/engine/ray_utils.py
vllm/engine/ray_utils.py
+22
-15
vllm/entrypoints/api_server.py
vllm/entrypoints/api_server.py
+0
-1
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+4
-5
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+14
-26
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+1
-1
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+1
-1
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+1
-1
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+1
-1
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+1
-1
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+1
-1
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+1
-1
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+1
-1
No files found.
tests/kernels/test_cache.py
View file @
fcffb7c8
...
...
@@ -14,6 +14,7 @@ BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS
=
[
1024
,
36000
]
# Arbitrary values for testing
NUM_MAPPINGS
=
[
256
]
# Arbitrary values for testing
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
@
pytest
.
mark
.
parametrize
(
"num_mappings"
,
NUM_MAPPINGS
)
...
...
@@ -24,6 +25,7 @@ SEEDS = [0]
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_copy_blocks
(
kv_cache_factory
,
...
...
@@ -35,11 +37,12 @@ def test_copy_blocks(
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
gpu_id
=
f
"cuda:
{
device
}
"
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert
2
*
num_mappings
<=
num_blocks
...
...
@@ -56,7 +59,7 @@ def test_copy_blocks(
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
num_layers
,
num_heads
,
head_size
,
dtype
,
seed
)
head_size
,
dtype
,
seed
,
gpu_id
)
# Clone the KV caches.
cloned_key_caches
=
[
key_cache
.
clone
()
for
key_cache
in
key_caches
]
...
...
@@ -88,6 +91,7 @@ def test_copy_blocks(
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_reshape_and_cache
(
kv_cache_factory
,
...
...
@@ -98,28 +102,29 @@ def test_reshape_and_cache(
num_blocks
:
int
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
random
.
seed
(
seed
)
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
gpu_id
=
f
"cuda:
{
device
}
"
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
long
,
device
=
gpu_id
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
device
=
gpu_id
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
# Create the KV caches.
key_caches
,
value_caches
=
kv_cache_factory
(
num_blocks
,
block_size
,
1
,
num_heads
,
head_size
,
dtype
,
seed
)
seed
,
gpu_id
)
key_cache
,
value_cache
=
key_caches
[
0
],
value_caches
[
0
]
# Clone the KV caches.
...
...
tests/kernels/test_layernorm.py
View file @
fcffb7c8
...
...
@@ -8,6 +8,7 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES
=
[
768
,
5120
,
8192
]
# Arbitrary values for testing
ADD_RESIDUAL
=
[
False
,
True
]
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
NUM_TOKENS
)
...
...
@@ -15,6 +16,7 @@ SEEDS = [0]
@
pytest
.
mark
.
parametrize
(
"add_residual"
,
ADD_RESIDUAL
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_rms_norm
(
num_tokens
:
int
,
...
...
@@ -22,14 +24,15 @@ def test_rms_norm(
add_residual
:
bool
,
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
).
cuda
(
)
gpu_id
=
f
"cuda:
{
device
}
"
layer
=
RMSNorm
(
hidden_size
).
to
(
dtype
=
dtype
,
device
=
gpu_id
)
layer
.
weight
.
data
.
normal_
(
mean
=
1.0
,
std
=
0.1
)
scale
=
1
/
(
2
*
hidden_size
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
x
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
gpu_id
)
x
*=
scale
residual
=
torch
.
randn_like
(
x
)
*
scale
if
add_residual
else
None
...
...
tests/kernels/test_pos_encoding.py
View file @
fcffb7c8
...
...
@@ -13,6 +13,7 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
SEQ_LENS
=
[
11
,
8192
]
# Arbitrary values for testing
SEEDS
=
[
0
]
DEVICES
=
[
i
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)]
@
pytest
.
mark
.
parametrize
(
"is_neox_style"
,
IS_NEOX_STYLE
)
...
...
@@ -23,6 +24,7 @@ SEEDS = [0]
@
pytest
.
mark
.
parametrize
(
"rotary_dim"
,
ROTARY_DIMS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
DEVICES
)
@
torch
.
inference_mode
()
def
test_rotary_embedding
(
is_neox_style
:
bool
,
...
...
@@ -33,6 +35,7 @@ def test_rotary_embedding(
rotary_dim
:
Optional
[
int
],
dtype
:
torch
.
dtype
,
seed
:
int
,
device
:
int
,
max_position
:
int
=
8192
,
base
:
int
=
10000
,
)
->
None
:
...
...
@@ -40,20 +43,20 @@ def test_rotary_embedding(
rotary_dim
=
head_size
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
gpu_id
=
f
"cuda:
{
device
}
"
if
rotary_dim
is
None
:
rotary_dim
=
head_size
rope
=
get_rope
(
head_size
,
rotary_dim
,
max_position
,
base
,
is_neox_style
)
rope
=
rope
.
to
(
dtype
).
cuda
(
)
rope
=
rope
.
to
(
dtype
=
dtype
,
device
=
gpu_id
)
positions
=
torch
.
randint
(
0
,
max_position
,
(
batch_size
,
seq_len
),
device
=
"cuda"
)
device
=
gpu_id
)
query
=
torch
.
randn
(
batch_size
,
seq_len
,
num_heads
*
head_size
,
dtype
=
dtype
,
device
=
"cuda"
)
device
=
gpu_id
)
key
=
torch
.
randn_like
(
query
)
# NOTE(woosuk): The reference implementation should be executed first
...
...
tests/worker/test_model_runner.py
View file @
fcffb7c8
...
...
@@ -33,8 +33,9 @@ def test_prepare_prompt():
expected_selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
max_seq_len
input_tokens
,
input_positions
,
_
=
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
_
,
return_prompt_lens
=
(
model_runner
.
_prepare_prompt
(
seq_group_metadata_list
))
assert
return_prompt_lens
==
prompt_lens
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
assert
input_tokens
.
shape
==
(
batch_size
,
max_seq_len
)
...
...
vllm/__init__.py
View file @
fcffb7c8
...
...
@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.version
import
__dcu_version__
__version__
=
"0.2.
6
"
__version__
=
"0.2.
7
"
__all__
=
[
"LLM"
,
...
...
vllm/config.py
View file @
fcffb7c8
...
...
@@ -181,12 +181,6 @@ class ModelConfig:
self
.
max_context_len_to_capture
=
self
.
max_model_len
self
.
max_context_len_to_capture
=
min
(
self
.
max_context_len_to_capture
,
self
.
max_model_len
)
if
(
self
.
quantization
in
[
"gptq"
,
"squeezellm"
]
and
not
self
.
enforce_eager
):
# Related issue: https://github.com/vllm-project/vllm/issues/2147
logger
.
warning
(
f
"
{
self
.
quantization
}
does not support CUDA graph "
"yet. Disabling CUDA graph."
)
self
.
enforce_eager
=
True
def
verify_with_parallel_config
(
self
,
...
...
vllm/engine/async_llm_engine.py
View file @
fcffb7c8
...
...
@@ -183,49 +183,53 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list
,
scheduler_outputs
,
ignored
=
self
.
_schedule
()
if
scheduler_outputs
.
is_empty
():
return
ignored
# Execute the model.
output
=
await
self
.
_run_workers_async
(
"execute_model"
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
all_outputs
=
await
self
.
_run_workers_async
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
scheduler_outputs
.
blocks_to_swap_in
,
"blocks_to_swap_out"
:
scheduler_outputs
.
blocks_to_swap_out
,
"blocks_to_copy"
:
scheduler_outputs
.
blocks_to_copy
,
})
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
else
:
output
=
[]
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
+
ignored
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
async
def
_run_workers_async
(
self
,
method
:
str
,
*
args
,
get_all_outputs
:
bool
=
False
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
coros
=
[]
for
worker
in
self
.
workers
:
if
self
.
parallel_config
.
worker_use_ray
:
coros
.
append
(
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
))
else
:
executor
=
getattr
(
worker
,
method
)
coros
.
append
(
asyncio
.
get_event_loop
().
run_in_executor
(
None
,
partial
(
executor
,
*
args
,
**
kwargs
)))
all_outputs
=
await
asyncio
.
gather
(
*
coros
)
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
# Run the driver worker asynchronously.
driver_executor
=
getattr
(
self
.
driver_worker
,
method
)
coros
.
append
(
asyncio
.
get_event_loop
().
run_in_executor
(
None
,
partial
(
driver_executor
,
*
driver_args
,
**
driver_kwargs
)))
if
get_all_outputs
:
return
all_outputs
# Run the ray workers asynchronously.
for
worker
in
self
.
workers
:
coros
.
append
(
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
))
# Make sure all workers have the same results.
output
=
all_outputs
[
0
]
for
other_output
in
all_outputs
[
1
:]:
assert
output
==
other_output
return
output
all_outputs
=
await
asyncio
.
gather
(
*
coros
)
return
all_outputs
class
AsyncLLMEngine
:
...
...
@@ -490,13 +494,12 @@ class AsyncLLMEngine:
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
placement_group
=
initialize_cluster
(
parallel_config
,
engine_args
.
engine_use_ray
)
placement_group
=
initialize_cluster
(
parallel_config
,
engine_args
.
engine_use_ray
)
# Create the async LLM engine.
engine
=
cls
(
parallel_config
.
worker_use_ray
,
engine_args
.
engine_use_ray
,
*
engine_configs
,
distributed_init_method
,
placement_group
,
log_requests
=
not
engine_args
.
disable_log_requests
,
log_stats
=
not
engine_args
.
disable_log_stats
,
...
...
vllm/engine/llm_engine.py
View file @
fcffb7c8
import
copy
from
collections
import
defaultdict
import
os
import
time
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Any
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
(
TYPE_CHECKING
,
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
)
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
...
...
@@ -14,14 +15,12 @@ from vllm.logger import init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
SequenceGroupOutput
,
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_tokenizer
)
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
,
set_cuda_visible_devices
,
get_ip
,
get_open_port
if
ray
:
from
ray.air.util.torch_dist
import
init_torch_dist_process_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
if
TYPE_CHECKING
:
...
...
@@ -54,8 +53,6 @@ class LLMEngine:
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
placement_group: Ray placement group for distributed execution.
Required for distributed execution.
log_stats: Whether to log statistics.
...
...
@@ -67,7 +64,6 @@ class LLMEngine:
cache_config
:
CacheConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
distributed_init_method
:
str
,
placement_group
:
Optional
[
"PlacementGroup"
],
log_stats
:
bool
,
)
->
None
:
...
...
@@ -112,7 +108,7 @@ class LLMEngine:
os
.
environ
[
"RAY_USAGE_STATS_ENABLED"
]
=
"0"
self
.
_init_workers_ray
(
placement_group
)
else
:
self
.
_init_workers
(
distributed_init_method
)
self
.
_init_workers
()
# Profile the memory usage and initialize the cache.
self
.
_init_cache
()
...
...
@@ -127,7 +123,7 @@ class LLMEngine:
# List of (timestamp, num_tokens)
self
.
num_generation_tokens
:
List
[
Tuple
[
float
,
int
]]
=
[]
def
_init_workers
(
self
,
distributed_init_method
:
str
):
def
_init_workers
(
self
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
...
...
@@ -136,70 +132,122 @@ class LLMEngine:
"Ray is required if parallel_config.world_size > 1."
)
self
.
workers
:
List
[
Worker
]
=
[]
worker
=
Worker
(
distributed_init_method
=
f
"tcp://
{
get_ip
()
}
:
{
get_open_port
()
}
"
self
.
driver_worker
=
Worker
(
self
.
model_config
,
self
.
parallel_config
,
self
.
scheduler_config
,
0
,
distributed_init_method
,
)
self
.
workers
.
append
(
worker
)
self
.
_run_workers
(
"init_model"
,
get_all_outputs
=
True
,
)
self
.
_run_workers
(
"load_model"
,
get_all_outputs
=
True
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
local_rank
=
0
,
rank
=
0
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
True
,
)
self
.
_run_workers
(
"init_model"
)
self
.
_run_workers
(
"load_model"
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
num_gpus
=
self
.
cache_config
.
gpu_memory_utilization
else
:
num_gpus
=
1
self
.
workers
:
List
[
Worker
]
=
[]
for
bundle
in
placement_group
.
bundle_specs
:
self
.
driver_dummy_worker
:
RayWorkerVllm
=
None
self
.
workers
:
List
[
RayWorkerVllm
]
=
[]
driver_ip
=
get_ip
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
num_gpus
=
self
.
cache_config
.
gpu_memory_utilization
else
:
num_gpus
=
1
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
bundle_id
,
)
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
PlacementGroupSchedulingStrategy
(
placement_group
=
placement_group
,
placement_group_capture_child_tasks
=
True
),
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerVllm
).
remote
(
self
.
model_config
.
trust_remote_code
)
self
.
workers
.
append
(
worker
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
else
:
self
.
workers
.
append
(
worker
)
if
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node."
)
driver_node_id
,
driver_gpu_ids
=
ray
.
get
(
self
.
driver_dummy_worker
.
get_node_and_gpu_ids
.
remote
())
worker_node_and_gpu_ids
=
ray
.
get
(
[
worker
.
get_node_and_gpu_ids
.
remote
()
for
worker
in
self
.
workers
])
node_workers
=
defaultdict
(
list
)
node_gpus
=
defaultdict
(
list
)
node_workers
[
driver_node_id
].
append
(
0
)
node_gpus
[
driver_node_id
].
extend
(
driver_gpu_ids
)
for
i
,
(
node_id
,
gpu_ids
)
in
enumerate
(
worker_node_and_gpu_ids
,
start
=
1
):
node_workers
[
node_id
].
append
(
i
)
node_gpus
[
node_id
].
extend
(
gpu_ids
)
for
node_id
,
gpu_ids
in
node_gpus
.
items
():
node_gpus
[
node_id
]
=
sorted
(
gpu_ids
)
# Set CUDA_VISIBLE_DEVICES for the driver.
set_cuda_visible_devices
(
node_gpus
[
driver_node_id
])
for
worker
,
(
node_id
,
_
)
in
zip
(
self
.
workers
,
worker_node_and_gpu_ids
):
worker
.
set_cuda_visible_devices
.
remote
(
node_gpus
[
node_id
])
distributed_init_method
=
f
"tcp://
{
driver_ip
}
:
{
get_open_port
()
}
"
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from
vllm.worker.worker
import
Worker
# Initialize torch distributed process group for the workers.
init_torch_dist_process_group
(
self
.
workers
,
backend
=
"nccl"
)
model_config
=
copy
.
deepcopy
(
self
.
model_config
)
parallel_config
=
copy
.
deepcopy
(
self
.
parallel_config
)
scheduler_config
=
copy
.
deepcopy
(
self
.
scheduler_config
)
self
.
_run_workers
(
"init_worker"
,
get_all_outputs
=
True
,
worker_init_fn
=
lambda
:
Worker
(
model_config
,
parallel_config
,
scheduler_config
,
None
,
None
,
))
self
.
_run_workers
(
"init_model"
,
get_all_outputs
=
True
,
for
rank
,
(
worker
,
(
node_id
,
_
))
in
enumerate
(
zip
(
self
.
workers
,
worker_node_and_gpu_ids
),
start
=
1
):
local_rank
=
node_workers
[
node_id
].
index
(
rank
)
worker
.
init_worker
.
remote
(
lambda
rank
=
rank
,
local_rank
=
local_rank
:
Worker
(
model_config
,
parallel_config
,
scheduler_config
,
local_rank
,
rank
,
distributed_init_method
,
))
driver_rank
=
0
driver_local_rank
=
node_workers
[
driver_node_id
].
index
(
driver_rank
)
self
.
driver_worker
=
Worker
(
model_config
,
parallel_config
,
scheduler_config
,
driver_local_rank
,
driver_rank
,
distributed_init_method
,
is_driver_worker
=
True
,
)
self
.
_run_workers
(
"init_model"
)
self
.
_run_workers
(
"load_model"
,
get_all_outputs
=
True
,
max_concurrent_workers
=
self
.
parallel_config
.
max_parallel_loading_workers
,
)
...
...
@@ -213,7 +261,6 @@ class LLMEngine:
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks
=
self
.
_run_workers
(
"profile_num_available_blocks"
,
get_all_outputs
=
True
,
block_size
=
self
.
cache_config
.
block_size
,
gpu_memory_utilization
=
self
.
cache_config
.
gpu_memory_utilization
,
cpu_swap_space
=
self
.
cache_config
.
swap_space_bytes
,
...
...
@@ -257,11 +304,9 @@ class LLMEngine:
engine_configs
=
engine_args
.
create_engine_configs
()
parallel_config
=
engine_configs
[
2
]
# Initialize the cluster.
distributed_init_method
,
placement_group
=
initialize_cluster
(
parallel_config
)
placement_group
=
initialize_cluster
(
parallel_config
)
# Create the LLM engine.
engine
=
cls
(
*
engine_configs
,
distributed_init_method
,
placement_group
,
log_stats
=
not
engine_args
.
disable_log_stats
)
return
engine
...
...
@@ -328,16 +373,6 @@ class LLMEngine:
"""Returns True if there are unfinished requests."""
return
self
.
scheduler
.
has_unfinished_seqs
()
def
_schedule
(
self
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
List
[
RequestOutput
]]:
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
return
seq_group_metadata_list
,
scheduler_outputs
,
[
RequestOutput
.
from_seq_group
(
seq_group
)
for
seq_group
in
scheduler_outputs
.
ignored_seq_groups
]
def
_check_beam_search_early_stopping
(
self
,
early_stopping
:
Union
[
bool
,
str
],
...
...
@@ -586,18 +621,23 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list
,
scheduler_outputs
,
ignored
=
self
.
_schedule
()
if
scheduler_outputs
.
is_empty
():
return
ignored
# Execute the model.
output
=
self
.
_run_workers
(
"execute_model"
,
seq_group_metadata_list
=
seq_group_metadata_list
,
blocks_to_swap_in
=
scheduler_outputs
.
blocks_to_swap_in
,
blocks_to_swap_out
=
scheduler_outputs
.
blocks_to_swap_out
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
seq_group_metadata_list
,
scheduler_outputs
=
self
.
scheduler
.
schedule
()
if
not
scheduler_outputs
.
is_empty
():
# Execute the model.
all_outputs
=
self
.
_run_workers
(
"execute_model"
,
driver_kwargs
=
{
"seq_group_metadata_list"
:
seq_group_metadata_list
,
"blocks_to_swap_in"
:
scheduler_outputs
.
blocks_to_swap_in
,
"blocks_to_swap_out"
:
scheduler_outputs
.
blocks_to_swap_out
,
"blocks_to_copy"
:
scheduler_outputs
.
blocks_to_copy
,
})
# Only the driver worker returns the sampling results.
output
=
all_outputs
[
0
]
else
:
output
=
[]
return
self
.
_process_model_outputs
(
output
,
scheduler_outputs
)
...
...
@@ -725,53 +765,38 @@ class LLMEngine:
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
return
def
_run_workers_in_batch
(
self
,
workers
,
method
:
str
,
*
args
,
**
kwargs
,
):
all_outputs
=
[]
for
worker
in
workers
:
if
self
.
parallel_config
.
worker_use_ray
:
executor
=
partial
(
worker
.
execute_method
.
remote
,
method
)
else
:
executor
=
getattr
(
worker
,
method
)
output
=
executor
(
*
args
,
**
kwargs
)
all_outputs
.
append
(
output
)
if
self
.
parallel_config
.
worker_use_ray
:
all_outputs
=
ray
.
get
(
all_outputs
)
return
all_outputs
def
_run_workers
(
self
,
method
:
str
,
*
args
,
get_all_outputs
:
bool
=
False
,
driver_args
:
Optional
[
List
[
Any
]]
=
None
,
driver_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers."""
all_outputs
=
[]
if
max_concurrent_workers
:
work_groups
=
[
self
.
workers
[
i
:
i
+
max_concurrent_workers
]
for
i
in
range
(
0
,
len
(
self
.
workers
),
max_concurrent_workers
)
]
else
:
work_groups
=
[
self
.
workers
]
raise
NotImplementedError
(
"max_concurrent_workers is not supported yet."
)
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
args
,
**
kwargs
)
for
worker
in
self
.
workers
]
if
driver_args
is
None
:
driver_args
=
args
if
driver_kwargs
is
None
:
driver_kwargs
=
kwargs
for
workers
in
work_groups
:
all_outputs
.
extend
(
self
.
_run_workers_in_batch
(
workers
,
method
,
*
args
,
**
kwargs
)
)
# Start the driver worker after all the ray workers.
driver_worker_output
=
getattr
(
self
.
driver_worker
,
method
)(
*
driver_
args
,
**
driver_
kwargs
)
if
get_all_outputs
:
return
all_outputs
# Get the results of the ray workers.
if
self
.
workers
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
# Make sure all workers have the same results.
output
=
all_outputs
[
0
]
for
other_output
in
all_outputs
[
1
:]:
assert
output
==
other_output
return
output
return
[
driver_worker_output
]
+
ray_worker_outputs
vllm/engine/ray_utils.py
View file @
fcffb7c8
from
typing
import
Optional
,
Tuple
,
TYPE_CHECKING
from
typing
import
Optional
,
List
,
Tuple
,
TYPE_CHECKING
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_open_port
,
is_h
ip
from
vllm.utils
import
is_hip
,
set_cuda_visible_devices
,
get_
ip
logger
=
init_logger
(
__name__
)
try
:
import
ray
from
ray.air.util.torch_dist
import
TorchDistributedWorker
class
RayWorkerVllm
(
TorchDistributedWorker
)
:
class
RayWorkerVllm
:
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
...
...
@@ -30,12 +29,22 @@ try:
executor
=
getattr
(
self
,
method
)
return
executor
(
*
args
,
**
kwargs
)
def
get_node_ip
(
self
)
->
str
:
return
get_ip
()
def
get_node_and_gpu_ids
(
self
)
->
Tuple
[
str
,
List
[
int
]]:
node_id
=
ray
.
get_runtime_context
().
get_node_id
()
gpu_ids
=
ray
.
get_gpu_ids
()
return
node_id
,
gpu_ids
def
set_cuda_visible_devices
(
self
,
device_ids
)
->
None
:
set_cuda_visible_devices
(
device_ids
)
except
ImportError
as
e
:
logger
.
warning
(
f
"Failed to import Ray with
{
e
!
r
}
. "
"For distributed inference, please install Ray with "
"`pip install ray pandas pyarrow`."
)
ray
=
None
TorchDistributedWorker
=
None
RayWorkerVllm
=
None
if
TYPE_CHECKING
:
...
...
@@ -75,13 +84,11 @@ def initialize_cluster(
ray
.
init
(
address
=
ray_address
,
ignore_reinit_error
=
True
)
if
not
parallel_config
.
worker_use_ray
:
# Initialize cluster locally.
port
=
get_open_port
()
# We need to setup the distributed init method to make sure
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method
=
f
"tcp://localhost:
{
port
}
"
return
distributed_init_method
,
None
assert
parallel_config
.
world_size
==
1
,
(
"Ray is required if parallel_config.world_size > 1."
)
return
None
# Create placement group for worker processes
current_placement_group
=
ray
.
util
.
get_current_placement_group
()
if
current_placement_group
:
# We are in a placement group
...
...
@@ -106,12 +113,12 @@ def initialize_cluster(
"The number of required GPUs exceeds the total number of "
"available GPUs in the cluster."
)
# Create a new placement group
current_
placement_group
=
ray
.
util
.
placement_group
([{
"GPU"
:
1
}]
*
parallel_config
.
world_size
)
placement_group
_specs
=
([{
"GPU"
:
1
}]
*
parallel_config
.
world_size
)
current_placement_group
=
ray
.
util
.
placement_group
(
placement_group_specs
)
# Wait until PG is ready - this will block until all
# requested resources are available, and will timeout
# if they cannot be provisioned.
ray
.
get
(
current_placement_group
.
ready
(),
timeout
=
1800
)
return
None
,
current_placement_group
return
current_placement_group
vllm/entrypoints/api_server.py
View file @
fcffb7c8
...
...
@@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingParams
from
vllm.utils
import
random_uuid
TIMEOUT_KEEP_ALIVE
=
5
# seconds.
TIMEOUT_TO_PREVENT_DEADLOCK
=
1
# seconds.
app
=
FastAPI
()
engine
=
None
...
...
vllm/model_executor/input_metadata.py
View file @
fcffb7c8
from
typing
import
List
,
Optional
from
typing
import
Optional
import
torch
...
...
@@ -16,28 +16,27 @@ class InputMetadata:
def
__init__
(
self
,
prompt
_lens
:
List
[
int
]
,
is_
prompt
:
bool
,
slot_mapping
:
torch
.
Tensor
,
max_context_len
:
Optional
[
int
],
context_lens
:
Optional
[
torch
.
Tensor
],
block_tables
:
Optional
[
torch
.
Tensor
],
use_cuda_graph
:
bool
,
)
->
None
:
self
.
prompt
_lens
=
prompt
_lens
self
.
is_
prompt
=
is_
prompt
self
.
max_context_len
=
max_context_len
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
block_tables
=
block_tables
self
.
use_cuda_graph
=
use_cuda_graph
self
.
is_prompt
=
len
(
prompt_lens
)
>
0
# Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack.
self
.
attn_bias
=
None
def
__repr__
(
self
)
->
str
:
return
(
"InputMetadata("
f
"prompt
_lens
=
{
self
.
prompt
_lens
}
, "
f
"
is_
prompt=
{
self
.
is_
prompt
}
, "
f
"max_context_len=
{
self
.
max_context_len
}
, "
f
"slot_mapping=
{
self
.
slot_mapping
}
, "
f
"context_lens=
{
self
.
context_lens
}
, "
...
...
vllm/model_executor/layers/sampler.py
View file @
fcffb7c8
...
...
@@ -5,7 +5,7 @@ import torch
import
torch.nn
as
nn
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_
all_
gather
)
tensor_model_parallel_gather
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
,
SamplingTensors
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
...
...
@@ -37,7 +37,7 @@ class Sampler(nn.Module):
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
...
...
@@ -45,6 +45,14 @@ class Sampler(nn.Module):
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
self
.
vocab_size
)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if
not
sampling_metadata
.
perform_sampling
:
return
None
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
# Apply logits processors (if any).
...
...
@@ -92,14 +100,15 @@ class Sampler(nn.Module):
def
_get_logits
(
hidden_states
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
embedding_bias
:
Optional
[
torch
.
Tensor
],
vocab_size
:
int
)
->
torch
.
Tensor
:
vocab_size
:
int
)
->
Optional
[
torch
.
Tensor
]
:
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
tensor_model_parallel_
all_
gather
(
logits
)
logits
=
tensor_model_parallel_gather
(
logits
)
# Remove paddings in vocab (if any).
logits
=
logits
[:,
:
vocab_size
]
if
logits
is
not
None
:
logits
=
logits
[:,
:
vocab_size
]
return
logits
...
...
@@ -112,27 +121,6 @@ def _prune_hidden_states(
sampling_metadata
.
selected_token_indices
)
def
_get_prompt_and_output_tokens
(
sampling_metadata
:
SamplingMetadata
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
prompt_len
=
sampling_metadata
.
prompt_lens
[
i
]
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
seq_data
=
sampling_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
return
prompt_tokens
,
output_tokens
def
_get_bin_counts_and_mask
(
tokens
:
torch
.
Tensor
,
vocab_size
:
int
,
...
...
vllm/model_executor/models/aquila.py
View file @
fcffb7c8
...
...
@@ -298,7 +298,7 @@ class AquilaForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/baichuan.py
View file @
fcffb7c8
...
...
@@ -313,7 +313,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/bloom.py
View file @
fcffb7c8
...
...
@@ -290,7 +290,7 @@ class BloomForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/chatglm.py
View file @
fcffb7c8
...
...
@@ -349,7 +349,7 @@ class ChatGLMForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/falcon.py
View file @
fcffb7c8
...
...
@@ -394,7 +394,7 @@ class FalconForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/gpt2.py
View file @
fcffb7c8
...
...
@@ -235,7 +235,7 @@ class GPT2LMHeadModel(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
fcffb7c8
...
...
@@ -254,7 +254,7 @@ class GPTBigCodeForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/gpt_j.py
View file @
fcffb7c8
...
...
@@ -240,7 +240,7 @@ class GPTJForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
return
next_tokens
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment