Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
27feead2
"tests/vscode:/vscode.git/clone" did not exist on "1b9bc16b1a57c6e7957c3fd74d89d1b206fc5bb8"
Unverified
Commit
27feead2
authored
Nov 29, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 29, 2023
Browse files
Refactor Worker & InputMetadata (#1843)
parent
c7821956
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
447 additions
and
280 deletions
+447
-280
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+10
-2
vllm/model_executor/models/phi_1_5.py
vllm/model_executor/models/phi_1_5.py
+27
-26
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+10
-2
vllm/model_executor/models/yi.py
vllm/model_executor/models/yi.py
+10
-2
vllm/model_executor/sampling_metadata.py
vllm/model_executor/sampling_metadata.py
+43
-0
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+334
-0
vllm/worker/worker.py
vllm/worker/worker.py
+13
-248
No files found.
vllm/model_executor/models/opt.py
View file @
27feead2
...
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -308,11 +309,18 @@ class OPTForCausalLM(nn.Module):
...
@@ -308,11 +309,18 @@ class OPTForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/phi_1_5.py
View file @
27feead2
...
@@ -54,6 +54,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -54,6 +54,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -210,28 +211,6 @@ class PhiLayer(nn.Module):
...
@@ -210,28 +211,6 @@ class PhiLayer(nn.Module):
return
hidden_states
return
hidden_states
class
PhiCausalLMHead
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
linear
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
):
hidden_states
=
self
.
ln
(
hidden_states
)
next_tokens
=
self
.
sampler
(
self
.
linear
.
weight
,
hidden_states
,
input_metadata
,
self
.
linear
.
bias
)
return
next_tokens
class
PhiModel
(
nn
.
Module
):
class
PhiModel
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -253,7 +232,7 @@ class PhiModel(nn.Module):
...
@@ -253,7 +232,7 @@ class PhiModel(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embd
(
input_ids
)
hidden_states
=
self
.
embd
(
input_ids
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
cache_event
=
None
if
cache_events
is
None
else
cache_events
[
i
]
cache_event
=
None
if
cache_events
is
None
else
cache_events
[
i
]
...
@@ -268,6 +247,17 @@ class PhiModel(nn.Module):
...
@@ -268,6 +247,17 @@ class PhiModel(nn.Module):
return
hidden_states
return
hidden_states
class
PhiCausalLMHead
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
linear
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
class
PhiForCausalLM
(
nn
.
Module
):
class
PhiForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -279,6 +269,7 @@ class PhiForCausalLM(nn.Module):
...
@@ -279,6 +269,7 @@ class PhiForCausalLM(nn.Module):
self
.
transformer
=
PhiModel
(
config
,
linear_method
)
self
.
transformer
=
PhiModel
(
config
,
linear_method
)
self
.
lm_head
=
PhiCausalLMHead
(
config
)
self
.
lm_head
=
PhiCausalLMHead
(
config
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -287,11 +278,21 @@ class PhiForCausalLM(nn.Module):
...
@@ -287,11 +278,21 @@ class PhiForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
lm_logits
=
self
.
lm_head
(
hidden_states
,
input_metadata
)
hidden_states
=
self
.
lm_head
.
ln
(
hidden_states
)
return
lm_logits
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
head
=
self
.
lm_head
.
linear
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
model_name_or_path
:
str
,
model_name_or_path
:
str
,
...
...
vllm/model_executor/models/qwen.py
View file @
27feead2
...
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -23,6 +23,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -246,11 +247,18 @@ class QWenLMHeadModel(nn.Module):
...
@@ -246,11 +247,18 @@ class QWenLMHeadModel(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/yi.py
View file @
27feead2
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -284,11 +285,18 @@ class YiForCausalLM(nn.Module):
...
@@ -284,11 +285,18 @@ class YiForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/sampling_metadata.py
0 → 100644
View file @
27feead2
from
typing
import
Dict
,
List
,
Tuple
import
torch
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
class
SamplingMetadata
:
"""Metadata for input sequences. Used in sampler.
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
selected_token_indices: Token indices selected for sampling.
categorized_sample_indices: SamplingType -> token indicies to sample.
"""
def
__init__
(
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_data
:
Dict
[
int
,
SequenceData
],
prompt_lens
:
List
[
int
],
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
num_prompts
=
len
(
prompt_lens
)
def
__repr__
(
self
)
->
str
:
return
(
"SamplingMetadata("
f
"seq_groups=
{
self
.
seq_groups
}
, "
f
"seq_data=
{
self
.
seq_data
}
, "
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
"selected_token_indices=
{
self
.
selected_token_indices
}
, "
f
"categorized_sample_indices=
{
self
.
categorized_sample_indices
}
)"
)
vllm/worker/model_runner.py
0 → 100644
View file @
27feead2
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
SchedulerConfig
from
vllm.logger
import
init_logger
from
vllm.model_executor
import
get_model
,
InputMetadata
,
SamplingMetadata
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
logger
=
init_logger
(
__name__
)
_PAD_SLOT_ID
=
-
1
class
ModelRunner
:
def
__init__
(
self
,
model_config
:
ModelConfig
,
parallel_config
:
ParallelConfig
,
scheduler_config
:
SchedulerConfig
,
):
self
.
model_config
=
model_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
sliding_window
=
model_config
.
get_sliding_window
()
self
.
model
=
None
self
.
block_size
=
None
# Set after initial profiling.
def
load_model
(
self
)
->
None
:
self
.
model
=
get_model
(
self
.
model_config
)
def
set_block_size
(
self
,
block_size
:
int
)
->
None
:
self
.
block_size
=
block_size
def
_prepare_prompt
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
prompt_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
assert
len
(
seq_ids
)
==
1
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
input_tokens
.
append
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
append
(
list
(
range
(
prompt_len
)))
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
append
([
_PAD_SLOT_ID
]
*
prompt_len
)
continue
# Compute the slot mapping.
slot_mapping
.
append
([])
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
# where start_idx is max(0, prompt_len - sliding_window).
# For example, if the prompt len is 10, sliding window is 8, and
# block size is 4, the first two tokens are masked and the slot
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
start_idx
=
0
if
self
.
sliding_window
is
not
None
:
start_idx
=
max
(
0
,
prompt_len
-
self
.
sliding_window
)
for
i
in
range
(
prompt_len
):
if
i
<
start_idx
:
slot_mapping
[
-
1
].
append
(
_PAD_SLOT_ID
)
continue
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
max_prompt_len
=
max
(
prompt_lens
)
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
long
)
input_positions
=
_make_tensor_with_pad
(
input_positions
,
max_prompt_len
,
pad
=
0
,
dtype
=
torch
.
long
)
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
max_prompt_len
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
)
input_metadata
=
InputMetadata
(
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping
,
max_context_len
=
None
,
context_lens
=
None
,
block_tables
=
None
,
)
return
input_tokens
,
input_positions
,
input_metadata
def
_prepare_decode
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
assert
len
(
seq_group_metadata_list
)
>
0
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
context_lens
:
List
[
int
]
=
[]
block_tables
:
List
[
List
[
int
]]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
assert
not
seq_group_metadata
.
is_prompt
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
context_len
=
seq_data
.
get_len
()
if
self
.
sliding_window
is
not
None
:
context_len
=
min
(
context_len
,
self
.
sliding_window
)
context_lens
.
append
(
context_len
)
position
=
context_len
-
1
input_positions
.
append
([
position
])
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
block_tables
.
append
(
block_table
)
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_len
=
1
,
pad
=
0
,
dtype
=
torch
.
long
)
input_positions
=
_make_tensor_with_pad
(
input_positions
,
max_len
=
1
,
pad
=
0
,
dtype
=
torch
.
long
)
slot_mapping
=
_make_tensor_with_pad
(
slot_mapping
,
max_len
=
1
,
pad
=
_PAD_SLOT_ID
,
dtype
=
torch
.
long
)
max_context_len
=
max
(
context_lens
)
context_lens
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
max_block_table_len
=
max
([
len
(
t
)
for
t
in
block_tables
])
block_tables
=
_make_tensor_with_pad
(
block_tables
,
max_len
=
max_block_table_len
,
pad
=
0
,
dtype
=
torch
.
int
)
input_metadata
=
InputMetadata
(
prompt_lens
=
[],
slot_mapping
=
slot_mapping
,
max_context_len
=
max_context_len
,
context_lens
=
context_lens
,
block_tables
=
block_tables
,
)
return
input_tokens
,
input_positions
,
input_metadata
def
_prepare_sample
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
prompt_lens
:
List
[
int
],
)
->
SamplingMetadata
:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
1
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
if
seq_group_metadata
.
is_prompt
:
assert
len
(
seq_ids
)
==
1
prompt_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
categorized_sample_indices_start_idx
)
categorized_sample_indices_start_idx
+=
1
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
prompt_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
max_prompt_len
else
:
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
))
categorized_sample_indices_start_idx
+=
num_seqs
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
categorized_sample_indices
=
{
t
:
torch
.
tensor
(
seq_ids
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
sampling_metadata
=
SamplingMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
)
return
sampling_metadata
@
torch
.
inference_mode
()
def
execute_model
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
kv_caches
:
List
[
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]]
=
None
,
)
->
SamplerOutput
:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
# Prepare input tensors.
is_prompt
=
seq_group_metadata_list
[
0
].
is_prompt
if
is_prompt
:
inputs
=
self
.
_prepare_prompt
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
input_metadata
=
inputs
else
:
inputs
=
self
.
_prepare_decode
(
seq_group_metadata_list
)
input_tokens
,
input_positions
,
input_metadata
=
inputs
sampling_metadata
=
self
.
_prepare_sample
(
seq_group_metadata_list
,
input_metadata
.
prompt_lens
)
# Execute the model.
hidden_states
=
self
.
model
(
input_ids
=
input_tokens
,
positions
=
input_positions
,
kv_caches
=
kv_caches
,
input_metadata
=
input_metadata
,
cache_events
=
cache_events
,
)
# Sample the next token.
output
=
self
.
model
.
sample
(
hidden_states
=
hidden_states
,
sampling_metadata
=
sampling_metadata
,
)
return
output
@
torch
.
inference_mode
()
def
profile_run
(
self
)
->
None
:
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size
=
self
.
model_config
.
get_vocab_size
()
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
vocab_size
-
1
)
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
seqs
:
List
[
SequenceGroupMetadata
]
=
[]
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
seq_data
=
SequenceData
([
0
]
*
seq_len
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
)
seqs
.
append
(
seq
)
# Run the model with the dummy inputs.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
kv_caches
=
[(
None
,
None
)]
*
num_layers
self
.
execute_model
(
seqs
,
kv_caches
)
return
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
assert
len
(
x
)
<=
max_len
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
def
_make_tensor_with_pad
(
x
:
List
[
List
[
int
]],
max_len
:
int
,
pad
:
int
,
dtype
:
torch
.
dtype
,
)
->
torch
.
Tensor
:
padded_x
=
[
_pad_to_max
(
x_i
,
max_len
,
pad
)
for
x_i
in
x
]
return
torch
.
tensor
(
padded_x
,
dtype
=
dtype
,
device
=
"cuda"
)
vllm/worker/worker.py
View file @
27feead2
"""A GPU worker class."""
"""A GPU worker class."""
import
os
import
os
from
typing
import
Dict
,
List
,
Tuple
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
CacheConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.model_executor
import
get_model
,
InputMetadata
,
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
)
initialize_model_parallel
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceGroupMetadata
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.model_runner
import
ModelRunner
from
vllm.utils
import
get_gpu_memory
from
vllm.utils
import
get_gpu_memory
...
@@ -38,11 +38,11 @@ class Worker:
...
@@ -38,11 +38,11 @@ class Worker:
self
.
rank
=
rank
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
distributed_init_method
=
distributed_init_method
self
.
model_runner
=
ModelRunner
(
model_config
,
parallel_config
,
scheduler_config
)
# Uninitialized cache engine. Will be initialized by
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
# self.init_cache_engine().
self
.
cache_config
=
None
self
.
cache_config
=
None
self
.
block_size
=
None
self
.
sliding_window
=
None
self
.
cache_engine
=
None
self
.
cache_engine
=
None
self
.
cache_events
=
None
self
.
cache_events
=
None
self
.
gpu_cache
=
None
self
.
gpu_cache
=
None
...
@@ -69,7 +69,7 @@ class Worker:
...
@@ -69,7 +69,7 @@ class Worker:
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
def
load_model
(
self
):
def
load_model
(
self
):
self
.
model
=
get_model
(
self
.
model_config
)
self
.
model
_runner
.
load_model
(
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_num_available_blocks
(
def
profile_num_available_blocks
(
...
@@ -83,40 +83,9 @@ class Worker:
...
@@ -83,40 +83,9 @@ class Worker:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
reset_peak_memory_stats
()
torch
.
cuda
.
reset_peak_memory_stats
()
# Profile memory usage with max_num_sequences sequences and the total
# Execute a forward pass with dummy inputs to profile the memory usage
# number of tokens equal to max_num_batched_tokens.
# of the model.
self
.
model_runner
.
profile_run
()
# Enable top-k sampling to reflect the accurate memory usage.
vocab_size
=
self
.
model
.
config
.
vocab_size
sampling_params
=
SamplingParams
(
top_p
=
0.99
,
top_k
=
vocab_size
-
1
)
max_num_batched_tokens
=
self
.
scheduler_config
.
max_num_batched_tokens
max_num_seqs
=
self
.
scheduler_config
.
max_num_seqs
seqs
=
[]
for
group_id
in
range
(
max_num_seqs
):
seq_len
=
(
max_num_batched_tokens
//
max_num_seqs
+
(
group_id
<
max_num_batched_tokens
%
max_num_seqs
))
seq_data
=
SequenceData
([
0
]
*
seq_len
)
seq
=
SequenceGroupMetadata
(
request_id
=
str
(
group_id
),
is_prompt
=
True
,
seq_data
=
{
group_id
:
seq_data
},
sampling_params
=
sampling_params
,
block_tables
=
None
,
)
seqs
.
append
(
seq
)
input_tokens
,
input_positions
,
input_metadata
=
self
.
_prepare_inputs
(
seqs
)
# Execute the model.
num_layers
=
self
.
model_config
.
get_num_layers
(
self
.
parallel_config
)
self
.
model
(
input_ids
=
input_tokens
,
positions
=
input_positions
,
kv_caches
=
[(
None
,
None
)]
*
num_layers
,
input_metadata
=
input_metadata
,
cache_events
=
None
,
)
# Calculate the number of blocks that can be allocated with the
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
# profiled peak memory.
...
@@ -140,197 +109,11 @@ class Worker:
...
@@ -140,197 +109,11 @@ class Worker:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
def
init_cache_engine
(
self
,
cache_config
:
CacheConfig
)
->
None
:
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
self
.
block_size
=
cache_config
.
block_size
self
.
sliding_window
=
cache_config
.
sliding_window
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
cache_engine
=
CacheEngine
(
self
.
cache_config
,
self
.
model_config
,
self
.
parallel_config
)
self
.
parallel_config
)
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
cache_events
=
self
.
cache_engine
.
events
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
gpu_cache
=
self
.
cache_engine
.
gpu_cache
self
.
model_runner
.
set_block_size
(
self
.
cache_engine
.
block_size
)
def
_prepare_inputs
(
self
,
seq_group_metadata_list
:
List
[
SequenceGroupMetadata
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
InputMetadata
]:
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]]
=
[]
input_tokens
:
List
[
List
[
int
]]
=
[]
input_positions
:
List
[
List
[
int
]]
=
[]
slot_mapping
:
List
[
List
[
int
]]
=
[]
selected_token_indices
:
List
[
int
]
=
[]
selected_token_start_idx
=
0
categorized_sample_indices
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices_start_idx
=
0
# Add prompt tokens.
prompt_lens
:
List
[
int
]
=
[]
for
seq_group_metadata
in
seq_group_metadata_list
:
if
not
seq_group_metadata
.
is_prompt
:
continue
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
# Use any sequence in the group.
seq_id
=
seq_ids
[
0
]
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
prompt_tokens
=
seq_data
.
get_token_ids
()
prompt_len
=
len
(
prompt_tokens
)
prompt_lens
.
append
(
prompt_len
)
if
sampling_params
.
prompt_logprobs
is
not
None
:
# NOTE: prompt token positions do not need sample, skip
categorized_sample_indices_start_idx
+=
prompt_len
-
1
categorized_sample_indices
[
sampling_params
.
sampling_type
].
append
(
categorized_sample_indices_start_idx
)
categorized_sample_indices_start_idx
+=
1
input_tokens
.
append
(
prompt_tokens
)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions
.
append
(
list
(
range
(
prompt_len
)))
if
seq_group_metadata
.
block_tables
is
None
:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping
.
append
([
0
]
*
prompt_len
)
continue
# Compute the slot mapping.
slot_mapping
.
append
([])
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
for
i
in
range
(
prompt_len
):
block_number
=
block_table
[
i
//
self
.
block_size
]
block_offset
=
i
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
[
-
1
].
append
(
slot
)
# Add generation tokens.
max_context_len
=
0
max_num_blocks_per_seq
=
0
context_lens
:
List
[
int
]
=
[]
generation_block_tables
:
List
[
List
[
int
]]
=
[]
max_seq_len
=
max
(
prompt_lens
)
if
prompt_lens
else
1
for
i
,
seq_group_metadata
in
enumerate
(
seq_group_metadata_list
):
if
seq_group_metadata
.
is_prompt
:
# We need to do this in this loop as we need to know max_seq_len
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
sampling_params
=
seq_group_metadata
.
sampling_params
assert
len
(
prompt_lens
)
==
len
(
seq_group_metadata_list
)
prompt_len
=
prompt_lens
[
i
]
if
sampling_params
.
prompt_logprobs
is
not
None
:
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
prompt_len
-
1
))
selected_token_indices
.
append
(
selected_token_start_idx
+
prompt_len
-
1
)
selected_token_start_idx
+=
max_seq_len
continue
seq_ids
=
list
(
seq_group_metadata
.
seq_data
.
keys
())
sampling_params
=
seq_group_metadata
.
sampling_params
seq_groups
.
append
((
seq_ids
,
sampling_params
))
num_seqs
=
len
(
seq_ids
)
selected_token_indices
.
extend
(
range
(
selected_token_start_idx
,
selected_token_start_idx
+
num_seqs
))
selected_token_start_idx
+=
num_seqs
categorized_sample_indices
[
sampling_params
.
sampling_type
].
extend
(
range
(
categorized_sample_indices_start_idx
,
categorized_sample_indices_start_idx
+
num_seqs
))
categorized_sample_indices_start_idx
+=
num_seqs
for
seq_id
in
seq_ids
:
seq_data
=
seq_group_metadata
.
seq_data
[
seq_id
]
generation_token
=
seq_data
.
get_last_token_id
()
input_tokens
.
append
([
generation_token
])
context_len
=
seq_data
.
get_len
()
position
=
context_len
-
1
if
self
.
sliding_window
is
not
None
:
context_len
=
min
(
context_len
,
self
.
sliding_window
)
input_positions
.
append
([
position
])
block_table
=
seq_group_metadata
.
block_tables
[
seq_id
]
max_context_len
=
max
(
max_context_len
,
context_len
)
max_num_blocks_per_seq
=
max
(
max_num_blocks_per_seq
,
len
(
block_table
))
context_lens
.
append
(
context_len
)
block_number
=
block_table
[
position
//
self
.
block_size
]
block_offset
=
position
%
self
.
block_size
slot
=
block_number
*
self
.
block_size
+
block_offset
slot_mapping
.
append
([
slot
])
if
self
.
sliding_window
is
not
None
:
sliding_window_blocks
=
(
self
.
sliding_window
//
self
.
block_size
)
block_table
=
block_table
[
-
sliding_window_blocks
:]
generation_block_tables
.
append
(
block_table
)
padded_input_tokens
=
[
_pad_to_max
(
tokens
,
max_seq_len
,
pad
=
0
)
for
tokens
in
input_tokens
]
padded_input_positions
=
[
_pad_to_max
(
positions
,
max_seq_len
,
pad
=
0
)
for
positions
in
input_positions
]
padded_slot_mapping
=
[
_pad_to_max
(
mapping
,
max_seq_len
,
pad
=-
1
)
for
mapping
in
slot_mapping
]
padded_block_tables
=
[
_pad_to_max
(
block_table
,
max_num_blocks_per_seq
,
pad
=
0
)
for
block_table
in
generation_block_tables
]
# Convert to tensors.
tokens_tensor
=
torch
.
tensor
(
padded_input_tokens
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
positions_tensor
=
torch
.
tensor
(
padded_input_positions
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
slot_mapping_tensor
=
torch
.
tensor
(
padded_slot_mapping
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
context_lens_tensor
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
selected_token_indices
=
torch
.
tensor
(
selected_token_indices
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
categorized_sample_indices
=
{
t
:
torch
.
tensor
(
seq_ids
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
for
t
,
seq_ids
in
categorized_sample_indices
.
items
()
}
block_tables_tensor
=
torch
.
tensor
(
padded_block_tables
,
dtype
=
torch
.
int
,
device
=
"cuda"
)
seq_data
:
Dict
[
int
,
SequenceData
]
=
{}
for
seq_group_metadata
in
seq_group_metadata_list
:
seq_data
.
update
(
seq_group_metadata
.
seq_data
)
input_metadata
=
InputMetadata
(
seq_groups
=
seq_groups
,
seq_data
=
seq_data
,
prompt_lens
=
prompt_lens
,
slot_mapping
=
slot_mapping_tensor
,
context_lens
=
context_lens_tensor
,
max_context_len
=
max_context_len
,
block_tables
=
block_tables_tensor
,
selected_token_indices
=
selected_token_indices
,
categorized_sample_indices
=
categorized_sample_indices
,
sliding_window
=
self
.
sliding_window
,
)
return
tokens_tensor
,
positions_tensor
,
input_metadata
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
execute_model
(
def
execute_model
(
...
@@ -361,18 +144,8 @@ class Worker:
...
@@ -361,18 +144,8 @@ class Worker:
event
.
wait
()
event
.
wait
()
return
{}
return
{}
# Prepare input tensors.
output
=
self
.
model_runner
.
execute_model
(
seq_group_metadata_list
,
input_tokens
,
input_positions
,
input_metadata
=
self
.
_prepare_inputs
(
self
.
gpu_cache
,
cache_events
)
seq_group_metadata_list
)
# Execute the model.
output
=
self
.
model
(
input_ids
=
input_tokens
,
positions
=
input_positions
,
kv_caches
=
self
.
gpu_cache
,
input_metadata
=
input_metadata
,
cache_events
=
cache_events
,
)
return
output
return
output
...
@@ -407,14 +180,6 @@ def _init_distributed_environment(
...
@@ -407,14 +180,6 @@ def _init_distributed_environment(
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
def
_pad_to_alignment
(
x
:
List
[
int
],
multiple_of
:
int
,
pad
:
int
)
->
List
[
int
]:
return
x
+
[
pad
]
*
((
-
len
(
x
))
%
multiple_of
)
def
_pad_to_max
(
x
:
List
[
int
],
max_len
:
int
,
pad
:
int
)
->
List
[
int
]:
return
x
+
[
pad
]
*
(
max_len
-
len
(
x
))
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
# Check if the GPU supports the dtype.
if
torch_dtype
==
torch
.
bfloat16
:
if
torch_dtype
==
torch
.
bfloat16
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment