Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
fcffb7c8
"src/vscode:/vscode.git/clone" did not exist on "4f8853e48184b5610b08b5fe8545b16a693066e1"
Commit
fcffb7c8
authored
Jan 16, 2024
by
zhuwenwen
Browse files
Merge branch 'vllm-v0.2.7-dtk23.10'
parents
eb181638
4095d0db
Changes
56
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
277 additions
and
71 deletions
+277
-71
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+4
-1
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+1
-1
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+1
-1
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+1
-1
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+2
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+1
-1
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+1
-1
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+1
-1
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+1
-1
vllm/model_executor/models/yi.py
vllm/model_executor/models/yi.py
+1
-1
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+59
-0
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+14
-8
vllm/sampling_params.py
vllm/sampling_params.py
+1
-1
vllm/utils.py
vllm/utils.py
+11
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+133
-35
vllm/worker/worker.py
vllm/worker/worker.py
+45
-15
No files found.
vllm/model_executor/models/gpt_neox.py
View file @
fcffb7c8
...
...
@@ -54,6 +54,7 @@ class GPTNeoXAttention(nn.Module):
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
hidden_size
=
config
.
hidden_size
self
.
head_size
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
bias
=
getattr
(
config
,
"attention_bias"
,
True
)
tensor_model_parallel_world_size
=
(
get_tensor_model_parallel_world_size
())
...
...
@@ -65,11 +66,13 @@ class GPTNeoXAttention(nn.Module):
config
.
hidden_size
,
self
.
head_size
,
self
.
total_num_heads
,
bias
=
self
.
bias
,
linear_method
=
linear_method
,
)
self
.
dense
=
RowParallelLinear
(
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
self
.
bias
,
linear_method
=
linear_method
,
)
scaling
=
self
.
head_size
**-
0.5
...
...
@@ -252,7 +255,7 @@ class GPTNeoXForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/internlm.py
View file @
fcffb7c8
...
...
@@ -255,7 +255,7 @@ class InternLMForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/llama.py
View file @
fcffb7c8
...
...
@@ -291,7 +291,7 @@ class LlamaForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/mistral.py
View file @
fcffb7c8
...
...
@@ -287,7 +287,7 @@ class MistralForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/mixtral.py
View file @
fcffb7c8
...
...
@@ -320,7 +320,7 @@ class MixtralModel(nn.Module):
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
...
...
@@ -361,7 +361,7 @@ class MixtralForCausalLM(nn.Module):
self
,
hidden_states
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/mpt.py
View file @
fcffb7c8
...
...
@@ -276,7 +276,7 @@ class MPTForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/opt.py
View file @
fcffb7c8
...
...
@@ -309,7 +309,7 @@ class OPTForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/phi_1_5.py
View file @
fcffb7c8
...
...
@@ -280,7 +280,7 @@ class PhiForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
head
=
self
.
lm_head
.
linear
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
...
...
vllm/model_executor/models/qwen.py
View file @
fcffb7c8
...
...
@@ -247,7 +247,7 @@ class QWenLMHeadModel(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/models/yi.py
View file @
fcffb7c8
...
...
@@ -286,7 +286,7 @@ class YiForCausalLM(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
)
->
Optional
[
SamplerOutput
]
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
...
...
vllm/model_executor/parallel_utils/communication_op.py
View file @
fcffb7c8
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
)
...
...
@@ -45,3 +46,61 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
(
world_size
*
input_size
[
dim
],
)
+
input_size
[
dim
+
1
:])
return
output_tensor
def
tensor_model_parallel_gather
(
input_
,
dst
=
0
,
dim
=-
1
):
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
assert
-
input_
.
dim
()
<=
dim
<
input_
.
dim
(),
(
f
"Invalid dim (
{
dim
}
) for input tensor with shape
{
input_
.
size
()
}
"
)
if
dim
<
0
:
# Convert negative dim to positive.
dim
+=
input_
.
dim
()
# Allocate output tensor.
if
get_tensor_model_parallel_rank
()
==
dst
:
gather_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
else
:
gather_list
=
None
# Gather.
torch
.
distributed
.
gather
(
input_
,
gather_list
,
dst
=
dst
,
group
=
get_tensor_model_parallel_group
())
if
get_tensor_model_parallel_rank
()
==
dst
:
output_tensor
=
torch
.
cat
(
gather_list
,
dim
=
dim
)
else
:
output_tensor
=
None
return
output_tensor
def
broadcast
(
input_
,
src
=
0
):
"""Broadcast the input tensor."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Broadcast.
torch
.
distributed
.
broadcast
(
input_
,
src
=
src
)
return
input_
def
broadcast_object_list
(
obj_list
,
src
=
0
):
"""Broadcast the input object list."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
obj_list
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
return
obj_list
vllm/model_executor/sampling_metadata.py
View file @
fcffb7c8
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
...
...
@@ -18,24 +18,29 @@ class SamplingMetadata:
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indicies to sample.
categorized_sample_indices: SamplingType -> token indices to sample.
perform_sampling: Whether to perform sampling. This option is used to
make the sampling only happens in the driver worker, and disable
sampling in other worker processes.
"""
def
__init__
(
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_data
:
Dict
[
int
,
SequenceData
],
prompt_lens
:
List
[
int
],
seq_groups
:
Optional
[
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
]
,
seq_data
:
Optional
[
Dict
[
int
,
SequenceData
]
]
,
prompt_lens
:
Optional
[
List
[
int
]
]
,
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
categorized_sample_indices
:
Optional
[
Dict
[
SamplingType
,
torch
.
Tensor
]],
perform_sampling
:
bool
=
True
,
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
perform_sampling
=
perform_sampling
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompts
=
len
(
prompt_lens
)
if
prompt_lens
is
not
None
else
0
def
__repr__
(
self
)
->
str
:
return
(
...
...
@@ -44,7 +49,8 @@ class SamplingMetadata:
f
"seq_data=
{
self
.
seq_data
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
), "
f
"perform_sampling=
{
self
.
perform_sampling
}
)"
)
@
dataclass
...
...
vllm/sampling_params.py
View file @
fcffb7c8
...
...
@@ -100,7 +100,7 @@ class SamplingParams:
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
min_p
:
in
t
=
0.0
,
min_p
:
floa
t
=
0.0
,
use_beam_search
:
bool
=
False
,
length_penalty
:
float
=
1.0
,
early_stopping
:
Union
[
bool
,
str
]
=
False
,
...
...
vllm/utils.py
View file @
fcffb7c8
import
enum
import
os
import
socket
import
uuid
from
platform
import
uname
from
typing
import
List
import
psutil
import
torch
...
...
@@ -55,7 +57,15 @@ def in_wsl() -> bool:
return
"microsoft"
in
" "
.
join
(
uname
()).
lower
()
def
get_open_port
():
def
get_ip
()
->
str
:
return
socket
.
gethostbyname
(
socket
.
gethostname
())
def
get_open_port
()
->
int
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
""
,
0
))
return
s
.
getsockname
()[
1
]
def
set_cuda_visible_devices
(
device_ids
:
List
[
int
])
->
None
:
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
map
(
str
,
device_ids
))
vllm/worker/model_runner.py
View file @
fcffb7c8
import
time
from
typing
import
Dict
,
List
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
...
...
@@ -8,6 +8,8 @@ import torch.nn as nn
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast
,
broadcast_object_list
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
in_wsl
...
...
@@ -28,10 +30,12 @@ class ModelRunner:
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
is_driver_worker
:
bool
=
False
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
is_driver_worker
=
is_driver_worker
# model_config can be None in tests/samplers/test_sampler.py.
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
...
...
@@ -70,7 +74,7 @@ class ModelRunner:
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
List
[
int
]
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
...
...
@@ -135,14 +139,14 @@ class ModelRunner:
dtype
=
torch
.
long
)
input_metadata
=
InputMetadata
(
prompt
_lens
=
prompt_lens
,
is_
prompt
=
True
,
slot_mapping
=
slot_mapping
,
max_context_len
=
None
,
context_lens
=
None
,
block_tables
=
None
,
use_cuda_graph
=
False
,
)
return
input_tokens
,
input_positions
,
input_metadata
return
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
def
_prepare_decode
(
self
,
...
...
@@ -203,32 +207,24 @@ class ModelRunner:
block_tables
.
append
([])
batch_size
=
graph_batch_size
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device
=
"cpu"
if
use_captured_graph
else
"cuda"
pin_memory
=
use_captured_graph
and
not
self
.
in_wsl
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_len
=
1
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
device
,
pin_memory
=
pin_memory
)
device
=
"cuda"
)
input_positions
=
_make_tensor_with_pad
(
input_positions
,
max_len
=
1
,
pad
=
0
,
dtype
=
torch
.
long
,
device
=
device
,
pin_memory
=
pin_memory
)
device
=
"cuda"
)
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
max_len
=
1
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
,
device
=
device
,
pin_memory
=
pin_memory
)
device
=
"cuda"
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
device
,
pin_memory
=
pin_memory
)
device
=
"cuda"
)
if
use_captured_graph
:
# The shape of graph_block_tables is
...
...
@@ -237,17 +233,18 @@ class ModelRunner:
for
i
,
block_table
in
enumerate
(
block_tables
):
if
block_table
:
input_block_tables
[
i
,
:
len
(
block_table
)]
=
block_table
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
device
)
block_tables
=
torch
.
tensor
(
input_block_tables
,
device
=
"cuda"
)
else
:
block_tables
=
_make_tensor_with_pad
(
block_tables
,
max_len
=
max_context_len
,
pad
=
0
,
dtype
=
torch
.
int
,
device
=
"cuda"
,
)
input_metadata
=
InputMetadata
(
prompt
_lens
=
[]
,
is_
prompt
=
False
,
slot_mapping
=
slot_mapping
,
max_context_len
=
max_context_len
,
context_lens
=
context_lens
,
...
...
@@ -326,23 +323,127 @@ class ModelRunner:
)
return
sampling_metadata
def
prepare_input_tensors
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
,
SamplingMetadata
]:
if
self
.
is_driver_worker
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
(
input_tokens
,
input_positions
,
input_metadata
,
prompt_lens
)
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
else
:
(
input_tokens
,
input_positions
,
input_metadata
)
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
prompt_lens
=
[]
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
)
def
get_size_or_none
(
x
:
Optional
[
torch
.
Tensor
]):
return
x
.
size
()
if
x
is
not
None
else
None
# Broadcast the input data. For input tensors, we first broadcast
# its shape and then broadcast the tensor to avoid high
# serialization cost.
py_data
=
{
"input_tokens_size"
:
input_tokens
.
size
(),
"input_positions_size"
:
input_positions
.
size
(),
"is_prompt"
:
input_metadata
.
is_prompt
,
"slot_mapping_size"
:
get_size_or_none
(
input_metadata
.
slot_mapping
),
"max_context_len"
:
input_metadata
.
max_context_len
,
"context_lens_size"
:
get_size_or_none
(
input_metadata
.
context_lens
),
"block_tables_size"
:
get_size_or_none
(
input_metadata
.
block_tables
),
"use_cuda_graph"
:
input_metadata
.
use_cuda_graph
,
"selected_token_indices_size"
:
sampling_metadata
.
selected_token_indices
.
size
(),
}
broadcast_object_list
([
py_data
],
src
=
0
)
# TODO(zhuohan): Combine the broadcasts or set async_op=True.
broadcast
(
input_tokens
,
src
=
0
)
broadcast
(
input_positions
,
src
=
0
)
if
input_metadata
.
slot_mapping
is
not
None
:
broadcast
(
input_metadata
.
slot_mapping
,
src
=
0
)
if
input_metadata
.
context_lens
is
not
None
:
broadcast
(
input_metadata
.
context_lens
,
src
=
0
)
if
input_metadata
.
block_tables
is
not
None
:
broadcast
(
input_metadata
.
block_tables
,
src
=
0
)
broadcast
(
sampling_metadata
.
selected_token_indices
,
src
=
0
)
else
:
receving_list
=
[
None
]
broadcast_object_list
(
receving_list
,
src
=
0
)
py_data
=
receving_list
[
0
]
input_tokens
=
torch
.
empty
(
*
py_data
[
"input_tokens_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_tokens
,
src
=
0
)
input_positions
=
torch
.
empty
(
*
py_data
[
"input_positions_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
input_positions
,
src
=
0
)
if
py_data
[
"slot_mapping_size"
]
is
not
None
:
slot_mapping
=
torch
.
empty
(
*
py_data
[
"slot_mapping_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
slot_mapping
,
src
=
0
)
else
:
slot_mapping
=
None
if
py_data
[
"context_lens_size"
]
is
not
None
:
context_lens
=
torch
.
empty
(
*
py_data
[
"context_lens_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
context_lens
,
src
=
0
)
else
:
context_lens
=
None
if
py_data
[
"block_tables_size"
]
is
not
None
:
block_tables
=
torch
.
empty
(
*
py_data
[
"block_tables_size"
],
dtype
=
torch
.
int
,
device
=
"cuda"
)
broadcast
(
block_tables
,
src
=
0
)
else
:
block_tables
=
None
selected_token_indices
=
torch
.
empty
(
*
py_data
[
"selected_token_indices_size"
],
dtype
=
torch
.
long
,
device
=
"cuda"
)
broadcast
(
selected_token_indices
,
src
=
0
)
input_metadata
=
InputMetadata
(
is_prompt
=
py_data
[
"is_prompt"
],
slot_mapping
=
slot_mapping
,
max_context_len
=
py_data
[
"max_context_len"
],
context_lens
=
context_lens
,
block_tables
=
block_tables
,
use_cuda_graph
=
py_data
[
"use_cuda_graph"
],
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
None
,
seq_data
=
None
,
prompt_lens
=
None
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
None
,
perform_sampling
=
False
,
)
return
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]
]
,
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
)
->
SamplerOutput
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
# Prepare input tensors.
if
is_prompt
:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
input_metadata
=
inputs
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
input_metadata
=
inputs
)
->
Optional
[
SamplerOutput
]:
input_tokens
,
input_positions
,
input_metadata
,
sampling_metadata
=
(
self
.
prepare_input_tensors
(
seq_group_metadata_list
))
# Execute the model.
if
input_metadata
.
use_cuda_graph
:
graph_batch_size
=
input_tokens
.
shape
[
0
]
...
...
@@ -356,9 +457,6 @@ class ModelRunner:
input_metadata
=
input_metadata
,
)
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
input_metadata
.
prompt_lens
)
# Sample the next token.
output
=
self
.
model
.
sample
(
hidden_states
=
hidden_states
,
...
...
@@ -424,7 +522,7 @@ class ModelRunner:
for
batch_size
in
reversed
(
_BATCH_SIZES_TO_CAPTURE
):
# Create dummy input_metadata.
input_metadata
=
InputMetadata
(
prompt
_lens
=
[]
,
is_
prompt
=
False
,
slot_mapping
=
slot_mapping
[:
batch_size
],
max_context_len
=
self
.
max_context_len_to_capture
,
context_lens
=
context_lens
[:
batch_size
],
...
...
vllm/worker/worker.py
View file @
fcffb7c8
...
...
@@ -8,6 +8,8 @@ import torch.distributed
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils.communication_op
import
(
broadcast_object_list
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
...
...
@@ -28,17 +30,23 @@ class Worker:
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
rank
:
Optional
[
int
]
=
None
,
distributed_init_method
:
Optional
[
str
]
=
None
,
local_rank
:
int
,
rank
:
int
,
distributed_init_method
:
str
,
is_driver_worker
:
bool
=
False
,
)
->
None
:
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
local_rank
=
local_rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
self
.
is_driver_worker
:
assert
self
.
rank
==
0
,
"The driver worker must have rank 0."
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
)
scheduler_config
,
is_driver_worker
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self
.
cache_config
=
None
...
...
@@ -57,13 +65,7 @@ class Worker:
# This env var set by Ray causes exceptions with graph building.
os
.
environ
.
pop
(
"NCCL_ASYNC_ERROR_HANDLING"
,
None
)
# Env vars will be set by Ray.
self
.
rank
=
self
.
rank
if
self
.
rank
is
not
None
else
int
(
os
.
getenv
(
"RANK"
,
"-1"
))
local_rank
=
int
(
os
.
getenv
(
"LOCAL_RANK"
,
"0"
))
self
.
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
if
self
.
rank
<
0
:
raise
ValueError
(
"Invalid or unspecified rank."
)
self
.
device
=
torch
.
device
(
f
"cuda:
{
self
.
local_rank
}
"
)
torch
.
cuda
.
set_device
(
self
.
device
)
_check_if_gpu_supports_dtype
(
self
.
model_config
.
dtype
)
...
...
@@ -125,14 +127,12 @@ class Worker:
# the model initialization and profiling.
set_random_seed
(
self
.
model_config
.
seed
)
@
torch
.
inference_mode
()
def
execute_model
(
def
cache_swap
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
blocks_to_swap_in
:
Dict
[
int
,
int
],
blocks_to_swap_out
:
Dict
[
int
,
int
],
blocks_to_copy
:
Dict
[
int
,
List
[
int
]],
)
->
SamplerOutput
:
)
->
None
:
# Issue cache operations.
issued_cache_op
=
False
if
blocks_to_swap_in
:
...
...
@@ -152,8 +152,38 @@ class Worker:
if
cache_events
is
not
None
:
for
event
in
cache_events
:
event
.
wait
()
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
Optional
[
List
[
SequenceGroupMetadata
]]
=
None
,
blocks_to_swap_in
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_swap_out
:
Optional
[
Dict
[
int
,
int
]]
=
None
,
blocks_to_copy
:
Optional
[
Dict
[
int
,
List
[
int
]]]
=
None
,
)
->
Optional
[
SamplerOutput
]:
if
self
.
is_driver_worker
:
assert
seq_group_metadata_list
is
not
None
num_seq_groups
=
len
(
seq_group_metadata_list
)
assert
blocks_to_swap_in
is
not
None
assert
blocks_to_swap_out
is
not
None
assert
blocks_to_copy
is
not
None
block_swapping_info
=
[
blocks_to_swap_in
,
blocks_to_swap_out
,
blocks_to_copy
]
broadcast_object_list
([
num_seq_groups
]
+
block_swapping_info
,
src
=
0
)
else
:
# num_seq_groups, blocks_to_swap_in, blocks_to_swap_out,
# blocks_to_copy (4 elements)
recv_data
=
[
None
]
*
4
broadcast_object_list
(
recv_data
,
src
=
0
)
num_seq_groups
=
recv_data
[
0
]
block_swapping_info
=
recv_data
[
1
:]
self
.
cache_swap
(
*
block_swapping_info
)
# If there is no input, we don't need to execute the model.
if
n
ot
seq_group
_metadata_list
:
if
n
um_
seq_group
s
==
0
:
return
{}
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment