Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
fd4ea8ef
"vscode:/vscode.git/clone" did not exist on "bf6a8dc2156b9761e7bcdd0df605cc1d875f8435"
Unverified
Commit
fd4ea8ef
authored
Jan 04, 2024
by
Zhuohan Li
Committed by
GitHub
Jan 03, 2024
Browse files
Use NCCL instead of ray for control-plane communication to remove serialization overhead (#2221)
parent
1066cbd1
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
272 additions
and
69 deletions
+272
-69
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+1
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+1
-1
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+1
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+2
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+1
-1
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+1
-1
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+1
-1
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+1
-1
vllm/model_executor/models/yi.py
vllm/model_executor/models/yi.py
+1
-1
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+59
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+14
-8
vllm/utils.py
vllm/utils.py
+11
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+133
-35
vllm/worker/worker.py
vllm/worker/worker.py
+45
-15
No files found.
vllm/model_executor/models/internlm.py
View file @
fd4ea8ef
...
@@ -255,7 +255,7 @@ class InternLMForCausalLM(nn.Module):
...
@@ -255,7 +255,7 @@ class InternLMForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/llama.py
View file @
fd4ea8ef
...
@@ -291,7 +291,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -291,7 +291,7 @@ class LlamaForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/mistral.py
View file @
fd4ea8ef
...
@@ -287,7 +287,7 @@ class MistralForCausalLM(nn.Module):
...
@@ -287,7 +287,7 @@ class MistralForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/mixtral.py
View file @
fd4ea8ef
...
@@ -320,7 +320,7 @@ class MixtralModel(nn.Module):
...
@@ -320,7 +320,7 @@ class MixtralModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
...
@@ -361,7 +361,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -361,7 +361,7 @@ class MixtralForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
Optional
[
torch
.
Tensor
],
hidden_states
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/mpt.py
View file @
fd4ea8ef
...
@@ -276,7 +276,7 @@ class MPTForCausalLM(nn.Module):
...
@@ -276,7 +276,7 @@ class MPTForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/opt.py
View file @
fd4ea8ef
...
@@ -309,7 +309,7 @@ class OPTForCausalLM(nn.Module):
...
@@ -309,7 +309,7 @@ class OPTForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/phi_1_5.py
View file @
fd4ea8ef
...
@@ -280,7 +280,7 @@ class PhiForCausalLM(nn.Module):
...
@@ -280,7 +280,7 @@ class PhiForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
head
=
self
.
lm_head
.
linear
head
=
self
.
lm_head
.
linear
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
sampling_metadata
,
head
.
bias
)
...
...
vllm/model_executor/models/qwen.py
View file @
fd4ea8ef
...
@@ -247,7 +247,7 @@ class QWenLMHeadModel(nn.Module):
...
@@ -247,7 +247,7 @@ class QWenLMHeadModel(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/yi.py
View file @
fd4ea8ef
...
@@ -286,7 +286,7 @@ class YiForCausalLM(nn.Module):
...
@@ -286,7 +286,7 @@ class YiForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/parallel_utils/communication_op.py
View file @
fd4ea8ef
import
torch
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
)
)
...
@@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
...
@@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
(
world_size
*
input_size
[
dim
],
)
+
(
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
input_size
[
dim
+
1
:])
return
output_tensor
return
output_tensor
def
tensor_model_parallel_gather
(
input_
,
dst
=
0
,
dim
=-
1
):
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# Allocate output tensor.
if
get_tensor_model_parallel_rank
()
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
else
:
gather_list
=
None
# Gather.
torch
.
distributed
.
gather
(
input_
,
gather_list
,
dst
=
dst
,
group
=
get_tensor_model_parallel_group
())
if
get_tensor_model_parallel_rank
()
==
dst
:
output_tensor
=
torch
.
cat
(
gather_list
,
dim
=
dim
)
else
:
output_tensor
=
None
return
output_tensor
def
broadcast
(
input_
,
src
=
0
):
"""Broadcast the input tensor."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Broadcast.
torch
.
distributed
.
broadcast
(
input_
,
src
=
src
)
return
input_
def
broadcast_object_list
(
obj_list
,
src
=
0
):
"""Broadcast the input object list."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
obj_list
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
return
obj_list
vllm/model_executor/sampling_metadata.py
View file @
fd4ea8ef
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -18,24 +18,29 @@ class SamplingMetadata:
...
@@ -18,24 +18,29 @@ class SamplingMetadata:
seq_data: Seq_id -> SequenceData.
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indicies to sample.
categorized_sample_indices: SamplingType -> token indices to sample.
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
]
,
seq_data
:
Dict
[
int
,
SequenceData
],
seq_data
:
Optional
[
Dict
[
int
,
SequenceData
]
]
,
prompt_lens
:
List
[
int
],
prompt_lens
:
Optional
[
List
[
int
]
]
,
selected_token_indices
:
torch
.
Tensor
,
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
perform_sampling
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
perform_sampling
=
perform_sampling
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
return
(
...
@@ -44,7 +49,8 @@ class SamplingMetadata:
...
@@ -44,7 +49,8 @@ class SamplingMetadata:
f
"seq_data=
{
self
.
seq_data
}
, "
f
"seq_data=
{
self
.
seq_data
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
), "
f
"perform_sampling=
{
self
.
perform_sampling
}
)"
)
@
dataclass
@
dataclass
...
...
vllm/utils.py
View file @
fd4ea8ef
import
enum
import
enum
import
os
import
socket
import
socket
import
uuid
import
uuid
from
platform
import
uname
from
platform
import
uname
from
typing
import
List
import
psutil
import
psutil
import
torch
import
torch
...
@@ -55,7 +57,15 @@ def in_wsl() -> bool:
...
@@ -55,7 +57,15 @@ def in_wsl() -> bool:
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
def
get_open_port
():
def
get_ip
()
->
str
:
return
socket
.
gethostbyname
(
socket
.
gethostname
())
def
get_open_port
()
->
int
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
return
s
.
getsockname
()[
1
]
def
set_cuda_visible_devices
(
device_ids
:
List
[
int
])
->
None
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
map
(
str
,
device_ids
))
vllm/worker/model_runner.py
View file @
fd4ea8ef
import
time
import
time
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -8,6 +8,8 @@ import torch.nn as nn
...
@@ -8,6 +8,8 @@ import torch.nn as nn
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast
,
broadcast_object_list
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
in_wsl
from
vllm.utils
import
in_wsl
...
@@ -28,10 +30,12 @@ class ModelRunner:
...
@@ -28,10 +30,12 @@ class ModelRunner:
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
is_driver_worker
:
bool
=
False
,
):
):
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
is_driver_worker
=
is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
...
@@ -70,7 +74,7 @@ class ModelRunner:
...
@@ -70,7 +74,7 @@ class ModelRunner:
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
]
]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
...
@@ -135,14 +139,14 @@ class ModelRunner:
...
@@ -135,14 +139,14 @@ class ModelRunner:
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
prompt
_lens
=
prompt_lens
,
is_
prompt
=
True
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
max_context_len
=
None
,
max_context_len
=
None
,
context_lens
=
None
,
context_lens
=
None
,
block_tables
=
None
,
block_tables
=
None
,
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
)
)
return
input_tokens
,
input_positions
,
input_metadata
return
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
...
@@ -203,32 +207,24 @@ class ModelRunner:
...
@@ -203,32 +207,24 @@ class ModelRunner:
block_tables
.
append
([])
block_tables
.
append
([])
batch_size
=
graph_batch_size
batch_size
=
graph_batch_size
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device
=
"cpu"
if
use_captured_graph
else
"cuda"
pin_memory
=
use_captured_graph
and
not
self
.
in_wsl
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_len
=
1
,
max_len
=
1
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
,
device
=
"cuda"
)
pin_memory
=
pin_memory
)
input_positions
=
_make_tensor_with_pad
(
input_positions
,
input_positions
=
_make_tensor_with_pad
(
input_positions
,
max_len
=
1
,
max_len
=
1
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
,
device
=
"cuda"
)
pin_memory
=
pin_memory
)
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
max_len
=
1
,
max_len
=
1
,
pad
=
_PAD_SLOT_ID
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
,
device
=
"cuda"
)
pin_memory
=
pin_memory
)
context_lens
=
torch
.
tensor
(
context_lens
,
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
device
,
device
=
"cuda"
)
pin_memory
=
pin_memory
)
if
use_captured_graph
:
if
use_captured_graph
:
# The shape of graph_block_tables is
# The shape of graph_block_tables is
...
@@ -237,17 +233,18 @@ class ModelRunner:
...
@@ -237,17 +233,18 @@ class ModelRunner:
for
i
,
block_table
in
enumerate
(
block_tables
):
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
"cuda"
)
else
:
else
:
block_tables
=
_make_tensor_with_pad
(
block_tables
=
_make_tensor_with_pad
(
block_tables
,
block_tables
,
max_len
=
max_context_len
,
max_len
=
max_context_len
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
)
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
prompt
_lens
=
[]
,
is_
prompt
=
False
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
max_context_len
=
max_context_len
,
max_context_len
=
max_context_len
,
context_lens
=
context_lens
,
context_lens
=
context_lens
,
...
@@ -326,23 +323,127 @@ class ModelRunner:
...
@@ -326,23 +323,127 @@ class ModelRunner:
)
)
return
sampling_metadata
return
sampling_metadata
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
SamplingMetadata
]:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
def
get_size_or_none
(
x
:
Optional
[
torch
.
Tensor
]):
return
x
.
size
()
if
x
is
not
None
else
None
# Broadcast the input data. For input tensors, we first broadcast
# its shape and then broadcast the tensor to avoid high
# serialization cost.
py_data
=
{
"input_tokens_size"
:
input_tokens
.
size
(),
"input_positions_size"
:
input_positions
.
size
(),
"is_prompt"
:
input_metadata
.
is_prompt
,
"slot_mapping_size"
:
get_size_or_none
(
input_metadata
.
slot_mapping
),
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens_size"
:
get_size_or_none
(
input_metadata
.
context_lens
),
"block_tables_size"
:
get_size_or_none
(
input_metadata
.
block_tables
),
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"selected_token_indices_size"
:
sampling_metadata
.
selected_token_indices
.
size
(),
}
broadcast_object_list
([
py_data
],
src
=
0
)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast
(
input_tokens
,
src
=
0
)
broadcast
(
input_positions
,
src
=
0
)
if
input_metadata
.
slot_mapping
is
not
None
:
broadcast
(
input_metadata
.
slot_mapping
,
src
=
0
)
if
input_metadata
.
context_lens
is
not
None
:
broadcast
(
input_metadata
.
context_lens
,
src
=
0
)
if
input_metadata
.
block_tables
is
not
None
:
broadcast
(
input_metadata
.
block_tables
,
src
=
0
)
broadcast
(
sampling_metadata
.
selected_token_indices
,
src
=
0
)
else
:
receving_list
=
[
None
]
broadcast_object_list
(
receving_list
,
src
=
0
)
py_data
=
receving_list
[
0
]
input_tokens
=
torch
.
empty
(
*
py_data
[
"input_tokens_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_tokens
,
src
=
0
)
input_positions
=
torch
.
empty
(
*
py_data
[
"input_positions_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_positions
,
src
=
0
)
if
py_data
[
"slot_mapping_size"
]
is
not
None
:
slot_mapping
=
torch
.
empty
(
*
py_data
[
"slot_mapping_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
slot_mapping
,
src
=
0
)
else
:
slot_mapping
=
None
if
py_data
[
"context_lens_size"
]
is
not
None
:
context_lens
=
torch
.
empty
(
*
py_data
[
"context_lens_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
context_lens
,
src
=
0
)
else
:
context_lens
=
None
if
py_data
[
"block_tables_size"
]
is
not
None
:
block_tables
=
torch
.
empty
(
*
py_data
[
"block_tables_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
block_tables
,
src
=
0
)
else
:
block_tables
=
None
selected_token_indices
=
torch
.
empty
(
*
py_data
[
"selected_token_indices_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
selected_token_indices
,
src
=
0
)
input_metadata
=
InputMetadata
(
is_prompt
=
py_data
[
"is_prompt"
],
slot_mapping
=
slot_mapping
,
max_context_len
=
py_data
[
"max_context_len"
],
context_lens
=
context_lens
,
block_tables
=
block_tables
,
use_cuda_graph
=
py_data
[
"use_cuda_graph"
],
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
perform_sampling
=
False
,
)
return
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]:
# NOTE: We assume that all sequences in the group are all prompts or
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
=
(
# all decodes.
self
.
prepare_input_tensors
(
seq_group_metadata_list
))
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
input_metadata
=
inputs
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
input_metadata
=
inputs
# Execute the model.
# Execute the model.
if
input_metadata
.
use_cuda_graph
:
if
input_metadata
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
graph_batch_size
=
input_tokens
.
shape
[
0
]
...
@@ -356,9 +457,6 @@ class ModelRunner:
...
@@ -356,9 +457,6 @@ class ModelRunner:
input_metadata
=
input_metadata
,
input_metadata
=
input_metadata
,
)
)
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
input_metadata
.
prompt_lens
)
# Sample the next token.
# Sample the next token.
output
=
self
.
model
.
sample
(
output
=
self
.
model
.
sample
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -424,7 +522,7 @@ class ModelRunner:
...
@@ -424,7 +522,7 @@ class ModelRunner:
for
batch_size
in
reversed
(
_BATCH_SIZES_TO_CAPTURE
):
for
batch_size
in
reversed
(
_BATCH_SIZES_TO_CAPTURE
):
# Create dummy input_metadata.
# Create dummy input_metadata.
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
prompt
_lens
=
[]
,
is_
prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
slot_mapping
=
slot_mapping
[:
batch_size
],
max_context_len
=
self
.
max_context_len_to_capture
,
max_context_len
=
self
.
max_context_len_to_capture
,
context_lens
=
context_lens
[:
batch_size
],
context_lens
=
context_lens
[:
batch_size
],
...
...
vllm/worker/worker.py
View file @
fd4ea8ef
...
@@ -8,6 +8,8 @@ import torch.distributed
...
@@ -8,6 +8,8 @@ import torch.distributed
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_object_list
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
initialize_model_parallel
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
...
@@ -28,17 +30,23 @@ class Worker:
...
@@ -28,17 +30,23 @@ class Worker:
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
rank
:
Optional
[
int
]
=
None
,
local_rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
)
scheduler_config
,
is_driver_worker
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
cache_config
=
None
...
@@ -57,13 +65,7 @@ class Worker:
...
@@ -57,13 +65,7 @@ class Worker:
# This env var set by Ray causes exceptions with graph building.
# This env var set by Ray causes exceptions with graph building.
os
.
environ
.
pop
(
"NCCL_ASYNC_ERROR_HANDLING"
,
None
)
os
.
environ
.
pop
(
"NCCL_ASYNC_ERROR_HANDLING"
,
None
)
# Env vars will be set by Ray.
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
rank
=
self
.
rank
if
self
.
rank
is
not
None
else
int
(
os
.
getenv
(
"RANK"
,
"-1"
))
local_rank
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
if
self
.
rank
<
0
:
raise
ValueError
(
"Invalid or unspecified rank."
)
torch
.
cuda
.
set_device
(
self
.
device
)
torch
.
cuda
.
set_device
(
self
.
device
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
...
@@ -125,14 +127,12 @@ class Worker:
...
@@ -125,14 +127,12 @@ class Worker:
# the model initialization and profiling.
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
@
torch
.
inference_mode
()
def
cache_swap
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
)
->
None
:
# Issue cache operations.
# Issue cache operations.
issued_cache_op
=
False
issued_cache_op
=
False
if
blocks_to_swap_in
:
if
blocks_to_swap_in
:
...
@@ -152,8 +152,38 @@ class Worker:
...
@@ -152,8 +152,38 @@ class Worker:
if
cache_events
is
not
None
:
if
cache_events
is
not
None
:
for
event
in
cache_events
:
for
event
in
cache_events
:
event
.
wait
()
event
.
wait
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
block_swapping_info
=
[
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]
broadcast_object_list
([
num_seq_groups
]
+
block_swapping_info
,
src
=
0
)
else
:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data
=
[
None
]
*
4
broadcast_object_list
(
recv_data
,
src
=
0
)
num_seq_groups
=
recv_data
[
0
]
block_swapping_info
=
recv_data
[
1
:]
self
.
cache_swap
(
*
block_swapping_info
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
n
ot
seq_group
_metadata_list
:
if
n
um_
seq_group
s
==
0
:
return
{}
return
{}
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment