Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
fcffb7c8
Commit
fcffb7c8
authored
Jan 16, 2024
by
zhuwenwen
Browse files
Merge branch 'vllm-v0.2.7-dtk23.10'
parents
eb181638
4095d0db
Changes
56
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
277 additions
and
71 deletions
+277
-71
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+4
-1
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+1
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+1
-1
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+1
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+2
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+1
-1
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+1
-1
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+1
-1
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+1
-1
vllm/model_executor/models/yi.py
vllm/model_executor/models/yi.py
+1
-1
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+59
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+14
-8
vllm/sampling_params.py
vllm/sampling_params.py
+1
-1
vllm/utils.py
vllm/utils.py
+11
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+133
-35
vllm/worker/worker.py
vllm/worker/worker.py
+45
-15
No files found.
vllm/model_executor/models/gpt_neox.py
View file @
fcffb7c8
...
@@ -54,6 +54,7 @@ class GPTNeoXAttention(nn.Module):
...
@@ -54,6 +54,7 @@ class GPTNeoXAttention(nn.Module):
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
bias
=
getattr
(
config
,
"attention_bias"
,
True
)
tensor_model_parallel_world_size
=
(
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
get_tensor_model_parallel_world_size
())
...
@@ -65,11 +66,13 @@ class GPTNeoXAttention(nn.Module):
...
@@ -65,11 +66,13 @@ class GPTNeoXAttention(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
self
.
head_size
,
self
.
head_size
,
self
.
total_num_heads
,
self
.
total_num_heads
,
bias
=
self
.
bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
self
.
dense
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
self
.
bias
,
linear_method
=
linear_method
,
linear_method
=
linear_method
,
)
)
scaling
=
self
.
head_size
**-
0.5
scaling
=
self
.
head_size
**-
0.5
...
@@ -252,7 +255,7 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -252,7 +255,7 @@ class GPTNeoXForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/internlm.py
View file @
fcffb7c8
...
@@ -255,7 +255,7 @@ class InternLMForCausalLM(nn.Module):
...
@@ -255,7 +255,7 @@ class InternLMForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/llama.py
View file @
fcffb7c8
...
@@ -291,7 +291,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -291,7 +291,7 @@ class LlamaForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/mistral.py
View file @
fcffb7c8
...
@@ -287,7 +287,7 @@ class MistralForCausalLM(nn.Module):
...
@@ -287,7 +287,7 @@ class MistralForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/mixtral.py
View file @
fcffb7c8
...
@@ -320,7 +320,7 @@ class MixtralModel(nn.Module):
...
@@ -320,7 +320,7 @@ class MixtralModel(nn.Module):
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
for
i
in
range
(
len
(
self
.
layers
)):
...
@@ -361,7 +361,7 @@ class MixtralForCausalLM(nn.Module):
...
@@ -361,7 +361,7 @@ class MixtralForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
Optional
[
torch
.
Tensor
],
hidden_states
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/mpt.py
View file @
fcffb7c8
...
@@ -276,7 +276,7 @@ class MPTForCausalLM(nn.Module):
...
@@ -276,7 +276,7 @@ class MPTForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/opt.py
View file @
fcffb7c8
...
@@ -309,7 +309,7 @@ class OPTForCausalLM(nn.Module):
...
@@ -309,7 +309,7 @@ class OPTForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/phi_1_5.py
View file @
fcffb7c8
...
@@ -280,7 +280,7 @@ class PhiForCausalLM(nn.Module):
...
@@ -280,7 +280,7 @@ class PhiForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
head
=
self
.
lm_head
.
linear
head
=
self
.
lm_head
.
linear
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
sampling_metadata
,
head
.
bias
)
...
...
vllm/model_executor/models/qwen.py
View file @
fcffb7c8
...
@@ -247,7 +247,7 @@ class QWenLMHeadModel(nn.Module):
...
@@ -247,7 +247,7 @@ class QWenLMHeadModel(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/yi.py
View file @
fcffb7c8
...
@@ -286,7 +286,7 @@ class YiForCausalLM(nn.Module):
...
@@ -286,7 +286,7 @@ class YiForCausalLM(nn.Module):
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/parallel_utils/communication_op.py
View file @
fcffb7c8
import
torch
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
)
)
...
@@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
...
@@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
(
world_size
*
input_size
[
dim
],
)
+
(
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
input_size
[
dim
+
1
:])
return
output_tensor
return
output_tensor
def
tensor_model_parallel_gather
(
input_
,
dst
=
0
,
dim
=-
1
):
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# Allocate output tensor.
if
get_tensor_model_parallel_rank
()
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
else
:
gather_list
=
None
# Gather.
torch
.
distributed
.
gather
(
input_
,
gather_list
,
dst
=
dst
,
group
=
get_tensor_model_parallel_group
())
if
get_tensor_model_parallel_rank
()
==
dst
:
output_tensor
=
torch
.
cat
(
gather_list
,
dim
=
dim
)
else
:
output_tensor
=
None
return
output_tensor
def
broadcast
(
input_
,
src
=
0
):
"""Broadcast the input tensor."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Broadcast.
torch
.
distributed
.
broadcast
(
input_
,
src
=
src
)
return
input_
def
broadcast_object_list
(
obj_list
,
src
=
0
):
"""Broadcast the input object list."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
obj_list
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
return
obj_list
vllm/model_executor/sampling_metadata.py
View file @
fcffb7c8
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
...
@@ -18,24 +18,29 @@ class SamplingMetadata:
...
@@ -18,24 +18,29 @@ class SamplingMetadata:
seq_data: Seq_id -> SequenceData.
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indicies to sample.
categorized_sample_indices: SamplingType -> token indices to sample.
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
]
,
seq_data
:
Dict
[
int
,
SequenceData
],
seq_data
:
Optional
[
Dict
[
int
,
SequenceData
]
]
,
prompt_lens
:
List
[
int
],
prompt_lens
:
Optional
[
List
[
int
]
]
,
selected_token_indices
:
torch
.
Tensor
,
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
perform_sampling
:
bool
=
True
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
perform_sampling
=
perform_sampling
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
return
(
...
@@ -44,7 +49,8 @@ class SamplingMetadata:
...
@@ -44,7 +49,8 @@ class SamplingMetadata:
f
"seq_data=
{
self
.
seq_data
}
, "
f
"seq_data=
{
self
.
seq_data
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
), "
f
"perform_sampling=
{
self
.
perform_sampling
}
)"
)
@
dataclass
@
dataclass
...
...
vllm/sampling_params.py
View file @
fcffb7c8
...
@@ -100,7 +100,7 @@ class SamplingParams:
...
@@ -100,7 +100,7 @@ class SamplingParams:
temperature
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
min_p
:
in
t
=
0.0
,
min_p
:
floa
t
=
0.0
,
use_beam_search
:
bool
=
False
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
...
...
vllm/utils.py
View file @
fcffb7c8
import
enum
import
enum
import
os
import
socket
import
socket
import
uuid
import
uuid
from
platform
import
uname
from
platform
import
uname
from
typing
import
List
import
psutil
import
psutil
import
torch
import
torch
...
@@ -55,7 +57,15 @@ def in_wsl() -> bool:
...
@@ -55,7 +57,15 @@ def in_wsl() -> bool:
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
def
get_open_port
():
def
get_ip
()
->
str
:
return
socket
.
gethostbyname
(
socket
.
gethostname
())
def
get_open_port
()
->
int
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
return
s
.
getsockname
()[
1
]
def
set_cuda_visible_devices
(
device_ids
:
List
[
int
])
->
None
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
map
(
str
,
device_ids
))
vllm/worker/model_runner.py
View file @
fcffb7c8
import
time
import
time
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -8,6 +8,8 @@ import torch.nn as nn
...
@@ -8,6 +8,8 @@ import torch.nn as nn
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast
,
broadcast_object_list
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
in_wsl
from
vllm.utils
import
in_wsl
...
@@ -28,10 +30,12 @@ class ModelRunner:
...
@@ -28,10 +30,12 @@ class ModelRunner:
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
is_driver_worker
:
bool
=
False
,
):
):
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
is_driver_worker
=
is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
...
@@ -70,7 +74,7 @@ class ModelRunner:
...
@@ -70,7 +74,7 @@ class ModelRunner:
def
_prepare_prompt
(
def
_prepare_prompt
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
]
]:
assert
len
(
seq_group_metadata_list
)
>
0
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
...
@@ -135,14 +139,14 @@ class ModelRunner:
...
@@ -135,14 +139,14 @@ class ModelRunner:
dtype
=
torch
.
long
)
dtype
=
torch
.
long
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
prompt
_lens
=
prompt_lens
,
is_
prompt
=
True
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
max_context_len
=
None
,
max_context_len
=
None
,
context_lens
=
None
,
context_lens
=
None
,
block_tables
=
None
,
block_tables
=
None
,
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
)
)
return
input_tokens
,
input_positions
,
input_metadata
return
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
def
_prepare_decode
(
def
_prepare_decode
(
self
,
self
,
...
@@ -203,32 +207,24 @@ class ModelRunner:
...
@@ -203,32 +207,24 @@ class ModelRunner:
block_tables
.
append
([])
block_tables
.
append
([])
batch_size
=
graph_batch_size
batch_size
=
graph_batch_size
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device
=
"cpu"
if
use_captured_graph
else
"cuda"
pin_memory
=
use_captured_graph
and
not
self
.
in_wsl
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_len
=
1
,
max_len
=
1
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
,
device
=
"cuda"
)
pin_memory
=
pin_memory
)
input_positions
=
_make_tensor_with_pad
(
input_positions
,
input_positions
=
_make_tensor_with_pad
(
input_positions
,
max_len
=
1
,
max_len
=
1
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
,
device
=
"cuda"
)
pin_memory
=
pin_memory
)
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
max_len
=
1
,
max_len
=
1
,
pad
=
_PAD_SLOT_ID
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
device
,
device
=
"cuda"
)
pin_memory
=
pin_memory
)
context_lens
=
torch
.
tensor
(
context_lens
,
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
device
,
device
=
"cuda"
)
pin_memory
=
pin_memory
)
if
use_captured_graph
:
if
use_captured_graph
:
# The shape of graph_block_tables is
# The shape of graph_block_tables is
...
@@ -237,17 +233,18 @@ class ModelRunner:
...
@@ -237,17 +233,18 @@ class ModelRunner:
for
i
,
block_table
in
enumerate
(
block_tables
):
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
"cuda"
)
else
:
else
:
block_tables
=
_make_tensor_with_pad
(
block_tables
=
_make_tensor_with_pad
(
block_tables
,
block_tables
,
max_len
=
max_context_len
,
max_len
=
max_context_len
,
pad
=
0
,
pad
=
0
,
dtype
=
torch
.
int
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
)
)
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
prompt
_lens
=
[]
,
is_
prompt
=
False
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
max_context_len
=
max_context_len
,
max_context_len
=
max_context_len
,
context_lens
=
context_lens
,
context_lens
=
context_lens
,
...
@@ -326,23 +323,127 @@ class ModelRunner:
...
@@ -326,23 +323,127 @@ class ModelRunner:
)
)
return
sampling_metadata
return
sampling_metadata
@
torch
.
inference_mode
()
def
prepare_input_tensors
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
SamplingMetadata
]:
)
->
SamplerOutput
:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
# Prepare input tensors.
if
is_prompt
:
if
is_prompt
:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
(
input_tokens
,
input_positions
,
input_metadata
,
input_tokens
,
input_positions
,
input_metadata
=
inputs
prompt_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
def
get_size_or_none
(
x
:
Optional
[
torch
.
Tensor
]):
return
x
.
size
()
if
x
is
not
None
else
None
# Broadcast the input data. For input tensors, we first broadcast
# its shape and then broadcast the tensor to avoid high
# serialization cost.
py_data
=
{
"input_tokens_size"
:
input_tokens
.
size
(),
"input_positions_size"
:
input_positions
.
size
(),
"is_prompt"
:
input_metadata
.
is_prompt
,
"slot_mapping_size"
:
get_size_or_none
(
input_metadata
.
slot_mapping
),
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens_size"
:
get_size_or_none
(
input_metadata
.
context_lens
),
"block_tables_size"
:
get_size_or_none
(
input_metadata
.
block_tables
),
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"selected_token_indices_size"
:
sampling_metadata
.
selected_token_indices
.
size
(),
}
broadcast_object_list
([
py_data
],
src
=
0
)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast
(
input_tokens
,
src
=
0
)
broadcast
(
input_positions
,
src
=
0
)
if
input_metadata
.
slot_mapping
is
not
None
:
broadcast
(
input_metadata
.
slot_mapping
,
src
=
0
)
if
input_metadata
.
context_lens
is
not
None
:
broadcast
(
input_metadata
.
context_lens
,
src
=
0
)
if
input_metadata
.
block_tables
is
not
None
:
broadcast
(
input_metadata
.
block_tables
,
src
=
0
)
broadcast
(
sampling_metadata
.
selected_token_indices
,
src
=
0
)
else
:
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
receving_list
=
[
None
]
input_tokens
,
input_positions
,
input_metadata
=
inputs
broadcast_object_list
(
receving_list
,
src
=
0
)
py_data
=
receving_list
[
0
]
input_tokens
=
torch
.
empty
(
*
py_data
[
"input_tokens_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_tokens
,
src
=
0
)
input_positions
=
torch
.
empty
(
*
py_data
[
"input_positions_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_positions
,
src
=
0
)
if
py_data
[
"slot_mapping_size"
]
is
not
None
:
slot_mapping
=
torch
.
empty
(
*
py_data
[
"slot_mapping_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
slot_mapping
,
src
=
0
)
else
:
slot_mapping
=
None
if
py_data
[
"context_lens_size"
]
is
not
None
:
context_lens
=
torch
.
empty
(
*
py_data
[
"context_lens_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
context_lens
,
src
=
0
)
else
:
context_lens
=
None
if
py_data
[
"block_tables_size"
]
is
not
None
:
block_tables
=
torch
.
empty
(
*
py_data
[
"block_tables_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
block_tables
,
src
=
0
)
else
:
block_tables
=
None
selected_token_indices
=
torch
.
empty
(
*
py_data
[
"selected_token_indices_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
selected_token_indices
,
src
=
0
)
input_metadata
=
InputMetadata
(
is_prompt
=
py_data
[
"is_prompt"
],
slot_mapping
=
slot_mapping
,
max_context_len
=
py_data
[
"max_context_len"
],
context_lens
=
context_lens
,
block_tables
=
block_tables
,
use_cuda_graph
=
py_data
[
"use_cuda_graph"
],
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
perform_sampling
=
False
,
)
return
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
Optional
[
SamplerOutput
]:
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
=
(
self
.
prepare_input_tensors
(
seq_group_metadata_list
))
# Execute the model.
# Execute the model.
if
input_metadata
.
use_cuda_graph
:
if
input_metadata
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
graph_batch_size
=
input_tokens
.
shape
[
0
]
...
@@ -356,9 +457,6 @@ class ModelRunner:
...
@@ -356,9 +457,6 @@ class ModelRunner:
input_metadata
=
input_metadata
,
input_metadata
=
input_metadata
,
)
)
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
input_metadata
.
prompt_lens
)
# Sample the next token.
# Sample the next token.
output
=
self
.
model
.
sample
(
output
=
self
.
model
.
sample
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -424,7 +522,7 @@ class ModelRunner:
...
@@ -424,7 +522,7 @@ class ModelRunner:
for
batch_size
in
reversed
(
_BATCH_SIZES_TO_CAPTURE
):
for
batch_size
in
reversed
(
_BATCH_SIZES_TO_CAPTURE
):
# Create dummy input_metadata.
# Create dummy input_metadata.
input_metadata
=
InputMetadata
(
input_metadata
=
InputMetadata
(
prompt
_lens
=
[]
,
is_
prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
slot_mapping
=
slot_mapping
[:
batch_size
],
max_context_len
=
self
.
max_context_len_to_capture
,
max_context_len
=
self
.
max_context_len_to_capture
,
context_lens
=
context_lens
[:
batch_size
],
context_lens
=
context_lens
[:
batch_size
],
...
...
vllm/worker/worker.py
View file @
fcffb7c8
...
@@ -8,6 +8,8 @@ import torch.distributed
...
@@ -8,6 +8,8 @@ import torch.distributed
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_object_list
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
initialize_model_parallel
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
...
@@ -28,17 +30,23 @@ class Worker:
...
@@ -28,17 +30,23 @@ class Worker:
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
scheduler_config
:
SchedulerConfig
,
rank
:
Optional
[
int
]
=
None
,
local_rank
:
int
,
distributed_init_method
:
Optional
[
str
]
=
None
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
)
scheduler_config
,
is_driver_worker
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
cache_config
=
None
...
@@ -57,13 +65,7 @@ class Worker:
...
@@ -57,13 +65,7 @@ class Worker:
# This env var set by Ray causes exceptions with graph building.
# This env var set by Ray causes exceptions with graph building.
os
.
environ
.
pop
(
"NCCL_ASYNC_ERROR_HANDLING"
,
None
)
os
.
environ
.
pop
(
"NCCL_ASYNC_ERROR_HANDLING"
,
None
)
# Env vars will be set by Ray.
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
self
.
rank
=
self
.
rank
if
self
.
rank
is
not
None
else
int
(
os
.
getenv
(
"RANK"
,
"-1"
))
local_rank
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
if
self
.
rank
<
0
:
raise
ValueError
(
"Invalid or unspecified rank."
)
torch
.
cuda
.
set_device
(
self
.
device
)
torch
.
cuda
.
set_device
(
self
.
device
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
...
@@ -125,14 +127,12 @@ class Worker:
...
@@ -125,14 +127,12 @@ class Worker:
# the model initialization and profiling.
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
@
torch
.
inference_mode
()
def
cache_swap
(
def
execute_model
(
self
,
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
)
->
None
:
# Issue cache operations.
# Issue cache operations.
issued_cache_op
=
False
issued_cache_op
=
False
if
blocks_to_swap_in
:
if
blocks_to_swap_in
:
...
@@ -152,8 +152,38 @@ class Worker:
...
@@ -152,8 +152,38 @@ class Worker:
if
cache_events
is
not
None
:
if
cache_events
is
not
None
:
for
event
in
cache_events
:
for
event
in
cache_events
:
event
.
wait
()
event
.
wait
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
block_swapping_info
=
[
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]
broadcast_object_list
([
num_seq_groups
]
+
block_swapping_info
,
src
=
0
)
else
:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data
=
[
None
]
*
4
broadcast_object_list
(
recv_data
,
src
=
0
)
num_seq_groups
=
recv_data
[
0
]
block_swapping_info
=
recv_data
[
1
:]
self
.
cache_swap
(
*
block_swapping_info
)
# If there is no input, we don't need to execute the model.
# If there is no input, we don't need to execute the model.
if
n
ot
seq_group
_metadata_list
:
if
n
um_
seq_group
s
==
0
:
return
{}
return
{}
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment