Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
002800f0
Unverified
Commit
002800f0
authored
Sep 04, 2023
by
Zhuohan Li
Committed by
GitHub
Sep 04, 2023
Browse files
Align vLLM's beam search implementation with HF generate (#857)
parent
e15932bb
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
480 additions
and
232 deletions
+480
-232
docs/source/models/adding_model.rst
docs/source/models/adding_model.rst
+1
-1
tests/conftest.py
tests/conftest.py
+56
-10
tests/samplers/test_beam_search.py
tests/samplers/test_beam_search.py
+46
-0
vllm/core/block_manager.py
vllm/core/block_manager.py
+2
-6
vllm/core/scheduler.py
vllm/core/scheduler.py
+45
-69
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+260
-65
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+31
-42
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+3
-3
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+3
-3
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+3
-3
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+3
-3
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+3
-3
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+3
-3
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+3
-3
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+3
-3
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+3
-3
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+3
-3
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+3
-3
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+3
-3
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+3
-3
No files found.
docs/source/models/adding_model.rst
View file @
002800f0
...
@@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
...
@@ -59,7 +59,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
+ kv_caches: List[KVCache],
+ kv_caches: List[KVCache],
+ input_metadata: InputMetadata,
+ input_metadata: InputMetadata,
+ cache_events: Optional[List[torch.cuda.Event]],
+ cache_events: Optional[List[torch.cuda.Event]],
+) ->
Dict[int, Sequence
Output
s]
:
+) ->
Sampler
Output:
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
...
...
tests/conftest.py
View file @
002800f0
...
@@ -67,8 +67,8 @@ class HfRunner:
...
@@ -67,8 +67,8 @@ class HfRunner:
output_ids
,
output_ids
,
skip_special_tokens
=
True
,
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
False
,
clean_up_tokenization_spaces
=
False
,
)
[
0
]
)
output_ids
=
output_ids
[
0
]
.
cpu
().
tolist
()
output_ids
=
output_ids
.
cpu
().
tolist
()
outputs
.
append
((
output_ids
,
output_str
))
outputs
.
append
((
output_ids
,
output_str
))
return
outputs
return
outputs
...
@@ -77,8 +77,34 @@ class HfRunner:
...
@@ -77,8 +77,34 @@ class HfRunner:
prompts
:
List
[
str
],
prompts
:
List
[
str
],
max_tokens
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
return
self
.
generate
(
prompts
,
do_sample
=
False
,
outputs
=
self
.
generate
(
prompts
,
max_new_tokens
=
max_tokens
)
do_sample
=
False
,
max_new_tokens
=
max_tokens
)
for
i
in
range
(
len
(
outputs
)):
output_ids
,
output_str
=
outputs
[
i
]
outputs
[
i
]
=
(
output_ids
[
0
],
output_str
[
0
])
return
outputs
def
generate_beam_search
(
self
,
prompts
:
List
[
str
],
beam_width
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
outputs
=
self
.
generate
(
prompts
,
do_sample
=
False
,
max_new_tokens
=
max_tokens
,
num_beams
=
beam_width
,
num_return_sequences
=
beam_width
)
for
i
in
range
(
len
(
outputs
)):
output_ids
,
output_str
=
outputs
[
i
]
for
j
in
range
(
len
(
output_ids
)):
output_ids
[
j
]
=
[
x
for
x
in
output_ids
[
j
]
if
x
!=
self
.
tokenizer
.
pad_token_id
]
outputs
[
i
]
=
(
output_ids
,
output_str
)
return
outputs
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -107,15 +133,20 @@ class VllmRunner:
...
@@ -107,15 +133,20 @@ class VllmRunner:
prompts
:
List
[
str
],
prompts
:
List
[
str
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
req_outputs
=
self
.
model
.
generate
(
req_outputs
=
self
.
model
.
generate
(
prompts
,
prompts
,
sampling_params
=
sampling_params
)
sampling_params
=
sampling_params
)
outputs
=
[]
outputs
=
[]
for
req_output
in
req_outputs
:
for
req_output
in
req_outputs
:
prompt_str
=
req_output
.
prompt
prompt_str
=
req_output
.
prompt
prompt_ids
=
req_output
.
prompt_token_ids
prompt_ids
=
req_output
.
prompt_token_ids
output_str
=
req_output
.
outputs
[
0
].
text
req_sample_output_ids
=
[]
output_ids
=
req_output
.
outputs
[
0
].
token_ids
req_sample_output_strs
=
[]
outputs
.
append
((
prompt_ids
+
output_ids
,
prompt_str
+
output_str
))
for
sample
in
req_output
.
outputs
:
output_str
=
sample
.
text
output_ids
=
sample
.
token_ids
req_sample_output_ids
.
append
(
prompt_ids
+
output_ids
)
req_sample_output_strs
.
append
(
prompt_str
+
output_str
)
outputs
.
append
((
req_sample_output_ids
,
req_sample_output_strs
))
return
outputs
return
outputs
def
generate_greedy
(
def
generate_greedy
(
...
@@ -124,7 +155,22 @@ class VllmRunner:
...
@@ -124,7 +155,22 @@ class VllmRunner:
max_tokens
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
greedy_params
=
SamplingParams
(
temperature
=
0.0
,
max_tokens
=
max_tokens
)
return
self
.
generate
(
prompts
,
greedy_params
)
outputs
=
self
.
generate
(
prompts
,
greedy_params
)
return
[(
output_ids
[
0
],
output_str
[
0
])
for
output_ids
,
output_str
in
outputs
]
def
generate_beam_search
(
self
,
prompts
:
List
[
str
],
beam_width
:
int
,
max_tokens
:
int
,
)
->
List
[
Tuple
[
List
[
int
],
str
]]:
beam_search_params
=
SamplingParams
(
n
=
beam_width
,
use_beam_search
=
True
,
temperature
=
0.0
,
max_tokens
=
max_tokens
)
outputs
=
self
.
generate
(
prompts
,
beam_search_params
)
return
outputs
@
pytest
.
fixture
@
pytest
.
fixture
...
...
tests/samplers/test_beam_search.py
0 → 100644
View file @
002800f0
"""Compare the outputs of HF and vLLM when using beam search.
Run `pytest tests/samplers/test_beam_search.py --forked`.
"""
import
pytest
# FIXME(zhuohan): The test can not pass if we:
# 1. Increase max_tokens to 256.
# 2. Increase beam_width to 8.
# 3. Use the model "huggyllama/llama-7b".
MAX_TOKENS
=
[
128
]
BEAM_WIDTHS
=
[
4
]
MODELS
=
[
"facebook/opt-125m"
]
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
MAX_TOKENS
)
@
pytest
.
mark
.
parametrize
(
"beam_width"
,
BEAM_WIDTHS
)
def
test_beam_search_single_input
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
dtype
:
str
,
max_tokens
:
int
,
beam_width
:
int
,
)
->
None
:
hf_model
=
hf_runner
(
model
,
dtype
=
dtype
)
hf_outputs
=
hf_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
max_tokens
)
del
hf_model
vllm_model
=
vllm_runner
(
model
,
dtype
=
dtype
)
vllm_outputs
=
vllm_model
.
generate_beam_search
(
example_prompts
,
beam_width
,
max_tokens
)
del
vllm_model
for
i
in
range
(
len
(
example_prompts
)):
hf_output_ids
,
_
=
hf_outputs
[
i
]
vllm_output_ids
,
_
=
vllm_outputs
[
i
]
assert
len
(
hf_output_ids
)
==
len
(
vllm_output_ids
)
for
j
in
range
(
len
(
hf_output_ids
)):
assert
hf_output_ids
[
j
]
==
vllm_output_ids
[
j
],
(
f
"Test
{
i
}
output
{
j
}
:
\n
HF:
{
hf_output_ids
}
\n
"
f
"vLLM:
{
vllm_output_ids
}
"
)
vllm/core/block_manager.py
View file @
002800f0
...
@@ -172,9 +172,7 @@ class BlockSpaceManager:
...
@@ -172,9 +172,7 @@ class BlockSpaceManager:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_in
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
# CPU block -> GPU block.
# CPU block -> GPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
SWAPPED
):
if
seq
.
is_finished
():
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
@@ -203,9 +201,7 @@ class BlockSpaceManager:
...
@@ -203,9 +201,7 @@ class BlockSpaceManager:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
def
swap_out
(
self
,
seq_group
:
SequenceGroup
)
->
Dict
[
int
,
int
]:
# GPU block -> CPU block.
# GPU block -> CPU block.
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
mapping
:
Dict
[
PhysicalTokenBlock
,
PhysicalTokenBlock
]
=
{}
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
if
seq
.
is_finished
():
continue
new_block_table
:
BlockTable
=
[]
new_block_table
:
BlockTable
=
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
...
...
vllm/core/scheduler.py
View file @
002800f0
...
@@ -7,8 +7,7 @@ from vllm.core.block_manager import BlockSpaceManager
...
@@ -7,8 +7,7 @@ from vllm.core.block_manager import BlockSpaceManager
from
vllm.core.policy
import
PolicyFactory
from
vllm.core.policy
import
PolicyFactory
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
from
vllm.sequence
import
(
Sequence
,
SequenceData
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceOutputs
,
SequenceGroupMetadata
,
SequenceStatus
)
SequenceStatus
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -76,6 +75,7 @@ class Scheduler:
...
@@ -76,6 +75,7 @@ class Scheduler:
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
num_cpu_blocks
=
self
.
cache_config
.
num_cpu_blocks
,
)
)
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state.
# Sequence groups in the WAITING state.
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
self
.
waiting
:
List
[
SequenceGroup
]
=
[]
# Sequence groups in the RUNNING state.
# Sequence groups in the RUNNING state.
...
@@ -96,10 +96,11 @@ class Scheduler:
...
@@ -96,10 +96,11 @@ class Scheduler:
if
seq_group
.
request_id
in
request_ids
:
if
seq_group
.
request_id
in
request_ids
:
# Remove the sequence group from the state queue.
# Remove the sequence group from the state queue.
state_queue
.
remove
(
seq_group
)
state_queue
.
remove
(
seq_group
)
for
seq
in
seq_group
.
seqs
:
for
seq
in
seq_group
.
get_
seqs
()
:
if
seq
.
is_finished
():
if
seq
.
is_finished
():
continue
continue
self
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_ABORTED
)
seq
.
status
=
SequenceStatus
.
FINISHED_ABORTED
self
.
free_seq
(
seq
)
request_ids
.
remove
(
seq_group
.
request_id
)
request_ids
.
remove
(
seq_group
.
request_id
)
if
not
request_ids
:
if
not
request_ids
:
return
return
...
@@ -123,6 +124,10 @@ class Scheduler:
...
@@ -123,6 +124,10 @@ class Scheduler:
if
not
self
.
swapped
:
if
not
self
.
swapped
:
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
ignored_seq_groups
:
List
[
SequenceGroup
]
=
[]
scheduled
:
List
[
SequenceGroup
]
=
[]
scheduled
:
List
[
SequenceGroup
]
=
[]
# The total number of sequences on the fly, including the
# requests in the generation phase.
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
for
seq_group
in
self
.
running
)
num_batched_tokens
=
0
num_batched_tokens
=
0
# Optimization: We do not sort the waiting queue since the preempted
# Optimization: We do not sort the waiting queue since the preempted
# sequence groups are added to the front and the new sequence groups
# sequence groups are added to the front and the new sequence groups
...
@@ -130,6 +135,9 @@ class Scheduler:
...
@@ -130,6 +135,9 @@ class Scheduler:
while
self
.
waiting
:
while
self
.
waiting
:
seq_group
=
self
.
waiting
[
0
]
seq_group
=
self
.
waiting
[
0
]
assert
seq_group
.
num_seqs
()
==
1
,
(
"Waiting sequence group should have only one prompt "
"sequence."
)
num_prompt_tokens
=
seq_group
.
get_seqs
()[
0
].
get_len
()
num_prompt_tokens
=
seq_group
.
get_seqs
()[
0
].
get_len
()
if
num_prompt_tokens
>
self
.
prompt_limit
:
if
num_prompt_tokens
>
self
.
prompt_limit
:
logger
.
warning
(
logger
.
warning
(
...
@@ -152,11 +160,7 @@ class Scheduler:
...
@@ -152,11 +160,7 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
num_seqs
(
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
status
=
SequenceStatus
.
WAITING
)
num_curr_seqs
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
)
if
(
num_curr_seqs
+
num_new_seqs
>
if
(
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
):
self
.
scheduler_config
.
max_num_seqs
):
break
break
...
@@ -165,6 +169,7 @@ class Scheduler:
...
@@ -165,6 +169,7 @@ class Scheduler:
self
.
_allocate
(
seq_group
)
self
.
_allocate
(
seq_group
)
self
.
running
.
append
(
seq_group
)
self
.
running
.
append
(
seq_group
)
num_batched_tokens
+=
num_prompt_tokens
num_batched_tokens
+=
num_prompt_tokens
num_curr_seqs
+=
num_new_seqs
scheduled
.
append
(
seq_group
)
scheduled
.
append
(
seq_group
)
if
scheduled
:
if
scheduled
:
...
@@ -210,30 +215,32 @@ class Scheduler:
...
@@ -210,30 +215,32 @@ class Scheduler:
# Swap in the sequence groups in the SWAPPED state if possible.
# Swap in the sequence groups in the SWAPPED state if possible.
self
.
swapped
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
swapped
)
self
.
swapped
=
self
.
policy
.
sort_by_priority
(
now
,
self
.
swapped
)
while
self
.
swapped
and
not
blocks_to_swap_out
:
if
not
preempted
:
seq_group
=
self
.
swapped
[
0
]
num_curr_seqs
=
sum
(
seq_group
.
get_max_num_running_seqs
()
# If the sequence group has been preempted in this step, stop.
for
seq_group
in
self
.
running
)
if
seq_group
in
preempted
:
break
while
self
.
swapped
:
# If the sequence group cannot be swapped in, stop.
seq_group
=
self
.
swapped
[
0
]
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
# If the sequence group cannot be swapped in, stop.
break
if
not
self
.
block_manager
.
can_swap_in
(
seq_group
):
break
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
SWAPPED
)
num_curr_seqs
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
)
if
(
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
):
break
seq_group
=
self
.
swapped
.
pop
(
0
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
self
.
running
.
append
(
seq_group
)
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
if
(
num_curr_seqs
+
num_new_seqs
>
self
.
scheduler_config
.
max_num_seqs
):
break
seq_group
=
self
.
swapped
.
pop
(
0
)
self
.
_swap_in
(
seq_group
,
blocks_to_swap_in
)
self
.
_append_slot
(
seq_group
,
blocks_to_copy
)
num_curr_seqs
+=
num_new_seqs
self
.
running
.
append
(
seq_group
)
# Each sequence in the generation phase only takes one token slot.
# Therefore, the number of batched tokens is equal to the number of
# sequences in the RUNNING state.
num_batched_tokens
=
sum
(
num_batched_tokens
=
sum
(
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
seq_group
.
num_seqs
(
status
=
SequenceStatus
.
RUNNING
)
for
seq_group
in
self
.
running
)
for
seq_group
in
self
.
running
)
...
@@ -275,40 +282,10 @@ class Scheduler:
...
@@ -275,40 +282,10 @@ class Scheduler:
seq_group_metadata_list
.
append
(
seq_group_metadata
)
seq_group_metadata_list
.
append
(
seq_group_metadata
)
return
seq_group_metadata_list
,
scheduler_outputs
return
seq_group_metadata_list
,
scheduler_outputs
def
update
(
def
fork_seq
(
self
,
parent_seq
:
Sequence
,
child_seq
:
Sequence
)
->
None
:
self
,
self
.
block_manager
.
fork
(
parent_seq
,
child_seq
)
seq_outputs
:
Dict
[
int
,
SequenceOutputs
],
)
->
List
[
SequenceGroup
]:
scheduled
:
List
[
SequenceGroup
]
=
[]
for
seq_group
in
self
.
running
:
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
if
seq
.
seq_id
in
seq_outputs
:
scheduled
.
append
(
seq_group
)
break
# Update the scheduled sequences and free blocks.
for
seq_group
in
scheduled
:
# Process beam search results before processing the new tokens.
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
output
=
seq_outputs
[
seq
.
seq_id
]
if
seq
.
seq_id
!=
output
.
parent_seq_id
:
# The sequence is a fork of the parent sequence (beam
# search). Free the current sequence.
self
.
block_manager
.
free
(
seq
)
# Fork the parent sequence.
parent_seq
=
seq_group
.
find
(
output
.
parent_seq_id
)
parent_seq
.
fork
(
seq
)
self
.
block_manager
.
fork
(
parent_seq
,
seq
)
# Process the new tokens.
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
# Append a new token to the sequence.
output
=
seq_outputs
[
seq
.
seq_id
]
seq
.
append_token_id
(
output
.
output_token
,
output
.
logprobs
)
return
scheduled
def
free_seq
(
self
,
seq
:
Sequence
,
finish_status
:
SequenceStatus
)
->
None
:
def
free_seq
(
self
,
seq
:
Sequence
)
->
None
:
seq
.
status
=
finish_status
self
.
block_manager
.
free
(
seq
)
self
.
block_manager
.
free
(
seq
)
def
free_finished_seq_groups
(
self
)
->
None
:
def
free_finished_seq_groups
(
self
)
->
None
:
...
@@ -345,8 +322,8 @@ class Scheduler:
...
@@ -345,8 +322,8 @@ class Scheduler:
# If preemption mode is not specified, we determine the mode as follows:
# If preemption mode is not specified, we determine the mode as follows:
# We use recomputation by default since it incurs lower overhead than
# We use recomputation by default since it incurs lower overhead than
# swapping. However, when the sequence group has multiple sequences
# swapping. However, when the sequence group has multiple sequences
# (e.g., beam search), recomputation is not supported. In
such a case,
# (e.g., beam search), recomputation is not
currently
supported. In
# we use swapping instead.
#
such a case,
we use swapping instead.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
# As swapped sequences are prioritized over waiting sequences,
# As swapped sequences are prioritized over waiting sequences,
# sequence groups with multiple sequences are implicitly prioritized
# sequence groups with multiple sequences are implicitly prioritized
...
@@ -354,8 +331,7 @@ class Scheduler:
...
@@ -354,8 +331,7 @@ class Scheduler:
# TODO(woosuk): Support recomputation for sequence groups with multiple
# TODO(woosuk): Support recomputation for sequence groups with multiple
# sequences. This may require a more sophisticated CUDA kernel.
# sequences. This may require a more sophisticated CUDA kernel.
if
preemption_mode
is
None
:
if
preemption_mode
is
None
:
seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
if
seq_group
.
get_max_num_running_seqs
()
==
1
:
if
len
(
seqs
)
==
1
:
preemption_mode
=
PreemptionMode
.
RECOMPUTE
preemption_mode
=
PreemptionMode
.
RECOMPUTE
else
:
else
:
preemption_mode
=
PreemptionMode
.
SWAP
preemption_mode
=
PreemptionMode
.
SWAP
...
...
vllm/engine/llm_engine.py
View file @
002800f0
...
@@ -11,7 +11,8 @@ from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
...
@@ -11,7 +11,8 @@ from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceOutputs
,
SequenceStatus
)
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_tokenizer
)
get_tokenizer
)
...
@@ -258,14 +259,11 @@ class LLMEngine:
...
@@ -258,14 +259,11 @@ class LLMEngine:
# Create the sequences.
# Create the sequences.
block_size
=
self
.
cache_config
.
block_size
block_size
=
self
.
cache_config
.
block_size
seqs
:
List
[
Sequence
]
=
[]
seq_id
=
next
(
self
.
seq_counter
)
for
_
in
range
(
sampling_params
.
best_of
):
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
)
seq_id
=
next
(
self
.
seq_counter
)
seq
=
Sequence
(
seq_id
,
prompt
,
prompt_token_ids
,
block_size
)
seqs
.
append
(
seq
)
# Create the sequence group.
# Create the sequence group.
seq_group
=
SequenceGroup
(
request_id
,
seq
s
,
sampling_params
,
seq_group
=
SequenceGroup
(
request_id
,
[
seq
]
,
sampling_params
,
arrival_time
)
arrival_time
)
# Add the sequence group to the scheduler.
# Add the sequence group to the scheduler.
...
@@ -303,22 +301,230 @@ class LLMEngine:
...
@@ -303,22 +301,230 @@ class LLMEngine:
]
]
return
seq_group_metadata_list
,
scheduler_outputs
,
None
return
seq_group_metadata_list
,
scheduler_outputs
,
None
def
_process_worker_outputs
(
def
_check_beam_search_early_stopping
(
self
,
output
,
self
,
early_stopping
:
Union
[
bool
,
str
],
sampling_params
:
SamplingParams
,
best_running_seq
:
Sequence
,
current_worst_seq
:
Sequence
,
)
->
bool
:
assert
sampling_params
.
use_beam_search
length_penalty
=
sampling_params
.
length_penalty
if
early_stopping
is
True
:
return
True
current_worst_score
=
(
current_worst_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
))
if
early_stopping
is
False
:
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
))
else
:
assert
early_stopping
==
"never"
if
length_penalty
>
0.0
:
# If length_penalty > 0.0, beam search will prefer longer
# sequences. The highest attainable score calculation is
# based on the longest possible sequence length in this case.
max_possible_length
=
max
(
best_running_seq
.
get_prompt_len
()
+
sampling_params
.
max_tokens
,
self
.
scheduler_config
.
max_model_len
)
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
,
seq_len
=
max_possible_length
))
else
:
# Otherwise, beam search will prefer shorter sequences. The
# highest attainable score calculation is based on the current
# sequence length.
highest_attainable_score
=
(
best_running_seq
.
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
))
return
current_worst_score
>=
highest_attainable_score
def
_process_sequence_group_samples
(
self
,
seq_group
:
SequenceGroup
,
samples
:
List
[
SequenceOutputs
])
->
None
:
parent_seqs
=
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
)
existing_finished_seqs
=
seq_group
.
get_finished_seqs
()
parent_child_dict
=
{
parent_seq
.
seq_id
:
[]
for
parent_seq
in
parent_seqs
}
for
sample
in
samples
:
parent_child_dict
[
sample
.
parent_seq_id
].
append
(
sample
)
# List of (child, parent)
child_seqs
:
List
[
Tuple
[
Sequence
,
Sequence
]]
=
[]
# Process the child samples for each parent sequence
for
parent
in
parent_seqs
:
child_samples
:
List
[
SequenceOutputs
]
=
parent_child_dict
[
parent
.
seq_id
]
if
len
(
child_samples
)
==
0
:
# This parent sequence has no children samples. Remove
# the parent sequence from the sequence group since it will
# not be used in the future iterations.
parent
.
status
=
SequenceStatus
.
FINISHED_ABORTED
seq_group
.
remove
(
parent
.
seq_id
)
self
.
scheduler
.
free_seq
(
parent
)
continue
# Fork the parent sequence if there are multiple child samples.
for
child_sample
in
child_samples
[:
-
1
]:
new_child_seq_id
=
next
(
self
.
seq_counter
)
child
=
parent
.
fork
(
new_child_seq_id
)
child
.
append_token_id
(
child_sample
.
output_token
,
child_sample
.
logprobs
)
child_seqs
.
append
((
child
,
parent
))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample
=
child_samples
[
-
1
]
parent
.
append_token_id
(
last_child_sample
.
output_token
,
last_child_sample
.
logprobs
)
child_seqs
.
append
((
parent
,
parent
))
for
seq
,
_
in
child_seqs
:
self
.
_decode_sequence
(
seq
)
self
.
_check_stop
(
seq
,
seq_group
.
sampling_params
)
# Non-beam search case
if
not
seq_group
.
sampling_params
.
use_beam_search
:
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
child_seqs
:
if
seq
is
not
parent
:
seq_group
.
add
(
seq
)
if
not
seq
.
is_finished
():
self
.
scheduler
.
fork_seq
(
parent
,
seq
)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
# NOTE: we need to fork the new sequences before freeing the
# old sequences.
for
seq
,
parent
in
child_seqs
:
if
seq
is
parent
and
seq
.
is_finished
():
self
.
scheduler
.
free_seq
(
seq
)
return
# Beam search case
# Select the child sequences to keep in the sequence group.
selected_child_seqs
=
[]
unselected_child_seqs
=
[]
beam_width
=
seq_group
.
sampling_params
.
best_of
length_penalty
=
seq_group
.
sampling_params
.
length_penalty
# Select the newly finished sequences with the highest scores
# to replace existing finished sequences.
# Tuple of (seq, parent, is_new)
existing_finished_seqs
=
[(
seq
,
None
,
False
)
for
seq
in
existing_finished_seqs
]
new_finished_seqs
=
[(
seq
,
parent
,
True
)
for
seq
,
parent
in
child_seqs
if
seq
.
is_finished
()]
all_finished_seqs
=
existing_finished_seqs
+
new_finished_seqs
# Sort the finished sequences by their scores.
all_finished_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
),
reverse
=
True
)
for
seq
,
parent
,
is_new
in
all_finished_seqs
[:
beam_width
]:
if
is_new
:
# A newly generated child sequence finishes and has a high
# score, so we will add it into the sequence group.
selected_child_seqs
.
append
((
seq
,
parent
))
for
seq
,
parent
,
is_new
in
all_finished_seqs
[
beam_width
:]:
if
is_new
:
# A newly generated child sequence finishes but has a low
# score, so we will not add it into the sequence group.
# Additionally, if this sequence is a continuation of a
# parent sequence, we will need remove the parent sequence
# from the sequence group.
unselected_child_seqs
.
append
((
seq
,
parent
))
else
:
# An existing finished sequence has a low score, so we will
# remove it from the sequence group.
seq_group
.
remove
(
seq
.
seq_id
)
# select the top beam_width sequences from the running
# sequences for the next iteration to continue the beam
# search.
running_child_seqs
=
[(
seq
,
parent
)
for
seq
,
parent
in
child_seqs
if
not
seq
.
is_finished
()]
# Sort the running sequences by their scores.
running_child_seqs
.
sort
(
key
=
lambda
x
:
x
[
0
].
get_beam_search_score
(
length_penalty
=
length_penalty
,
eos_token_id
=
self
.
tokenizer
.
eos_token_id
),
reverse
=
True
)
# Check if we can stop the beam search.
if
len
(
running_child_seqs
)
==
0
:
# No running sequences, stop the beam search.
stop_beam_search
=
True
elif
len
(
all_finished_seqs
)
<
beam_width
:
# Not enough finished sequences, continue the beam search.
stop_beam_search
=
False
else
:
# Check the early stopping criteria
best_running_seq
=
running_child_seqs
[
0
][
0
]
current_worst_seq
=
all_finished_seqs
[
beam_width
-
1
][
0
]
stop_beam_search
=
self
.
_check_beam_search_early_stopping
(
seq_group
.
sampling_params
.
early_stopping
,
seq_group
.
sampling_params
,
best_running_seq
,
current_worst_seq
)
if
stop_beam_search
:
# Stop the beam search and remove all the running sequences from
# the sequence group.
unselected_child_seqs
.
extend
(
running_child_seqs
)
else
:
# Continue the beam search and select the top beam_width sequences
# to continue the beam search.
selected_child_seqs
.
extend
(
running_child_seqs
[:
beam_width
])
# The remaining running sequences will not be used in the next
# iteration. Again, if these sequences are continuations of
# parent sequences, we will need to remove the parent sequences
# from the sequence group.
unselected_child_seqs
.
extend
(
running_child_seqs
[
beam_width
:])
# For newly created child sequences, add them to the sequence group
# and fork them in block manager if they are not finished.
for
seq
,
parent
in
selected_child_seqs
:
if
seq
is
not
parent
:
seq_group
.
add
(
seq
)
if
not
seq
.
is_finished
():
self
.
scheduler
.
fork_seq
(
parent
,
seq
)
# Free the finished and selected parent sequences' memory in block
# manager. Keep them in the sequence group as candidate output.
for
seq
,
parent
in
selected_child_seqs
:
if
seq
is
parent
and
seq
.
is_finished
():
self
.
scheduler
.
free_seq
(
seq
)
# Remove the unselected parent sequences from the sequence group and
# free their memory in block manager.
for
seq
,
parent
in
unselected_child_seqs
:
if
seq
is
parent
:
# Remove the parent sequence if it is not selected for next
# iteration
seq_group
.
remove
(
seq
.
seq_id
)
self
.
scheduler
.
free_seq
(
seq
)
def
_process_model_outputs
(
self
,
output
:
SamplerOutput
,
scheduler_outputs
:
SchedulerOutputs
)
->
List
[
RequestOutput
]:
scheduler_outputs
:
SchedulerOutputs
)
->
List
[
RequestOutput
]:
# Update the scheduler with the model outputs.
# Update the scheduled sequence groups with the model outputs.
seq_groups
=
self
.
scheduler
.
update
(
output
)
scheduled_seq_groups
=
scheduler_outputs
.
scheduled_seq_groups
for
seq_group
,
samples
in
zip
(
scheduled_seq_groups
,
output
):
self
.
_process_sequence_group_samples
(
seq_group
,
samples
)
# Decode the sequences.
self
.
_decode_sequences
(
seq_groups
)
# Stop the sequences that meet the stopping criteria.
self
.
_stop_sequences
(
seq_groups
)
# Free the finished sequence groups.
# Free the finished sequence groups.
self
.
scheduler
.
free_finished_seq_groups
()
self
.
scheduler
.
free_finished_seq_groups
()
# Create the outputs.
# Create the outputs.
request_outputs
:
List
[
RequestOutput
]
=
[]
request_outputs
:
List
[
RequestOutput
]
=
[]
for
seq_group
in
seq_groups
+
scheduler_outputs
.
ignored_seq_groups
:
for
seq_group
in
(
scheduled_seq_groups
+
scheduler_outputs
.
ignored_seq_groups
):
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_output
=
RequestOutput
.
from_seq_group
(
seq_group
)
request_outputs
.
append
(
request_output
)
request_outputs
.
append
(
request_output
)
...
@@ -351,7 +557,7 @@ class LLMEngine:
...
@@ -351,7 +557,7 @@ class LLMEngine:
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
blocks_to_copy
=
scheduler_outputs
.
blocks_to_copy
,
)
)
return
self
.
_process_
worker
_outputs
(
output
,
scheduler_outputs
)
return
self
.
_process_
model
_outputs
(
output
,
scheduler_outputs
)
def
_log_system_stats
(
def
_log_system_stats
(
self
,
self
,
...
@@ -416,55 +622,44 @@ class LLMEngine:
...
@@ -416,55 +622,44 @@ class LLMEngine:
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
f
"CPU KV cache usage:
{
cpu_cache_usage
*
100
:.
1
f
}
%"
)
self
.
last_logging_time
=
now
self
.
last_logging_time
=
now
def
_decode_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
def
_decode_sequence
(
self
,
seq
:
Sequence
)
->
None
:
"""Decodes the sequence outputs."""
"""Decodes the new token for a sequence."""
for
seq_group
in
seq_groups
:
new_token
,
new_output_text
=
detokenize_incrementally
(
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
self
.
tokenizer
,
new_token
,
new_output_text
=
detokenize_incrementally
(
seq
.
output_tokens
,
self
.
tokenizer
,
seq
.
get_last_token_id
(),
seq
.
output_tokens
,
skip_special_tokens
=
True
,
seq
.
get_last_token_id
(),
)
skip_special_tokens
=
True
,
if
new_token
is
not
None
:
)
seq
.
output_tokens
.
append
(
new_token
)
if
new_token
is
not
None
:
seq
.
output_text
=
new_output_text
seq
.
output_tokens
.
append
(
new_token
)
seq
.
output_text
=
new_output_text
def
_check_stop
(
self
,
seq
:
Sequence
,
sampling_params
:
SamplingParams
)
->
None
:
def
_stop_sequences
(
self
,
seq_groups
:
List
[
SequenceGroup
])
->
None
:
"""Stop the finished sequences."""
"""Stop the finished sequences."""
for
seq_group
in
seq_groups
:
for
stop_str
in
sampling_params
.
stop
:
sampling_params
=
seq_group
.
sampling_params
if
seq
.
output_text
.
endswith
(
stop_str
):
for
seq
in
seq_group
.
get_seqs
(
status
=
SequenceStatus
.
RUNNING
):
# Truncate the output text so that the stop string is
# Check if the sequence has generated a stop string.
# not included in the output.
stopped
=
False
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
for
stop_str
in
sampling_params
.
stop
:
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
if
seq
.
output_text
.
endswith
(
stop_str
):
return
# Truncate the output text so that the stop string is
# not included in the output.
# Check if the sequence has reached max_model_len.
seq
.
output_text
=
seq
.
output_text
[:
-
len
(
stop_str
)]
if
seq
.
get_len
()
>
self
.
scheduler_config
.
max_model_len
:
self
.
scheduler
.
free_seq
(
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
return
stopped
=
True
break
# Check if the sequence has reached max_tokens.
if
stopped
:
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
continue
seq
.
status
=
SequenceStatus
.
FINISHED_LENGTH_CAPPED
return
# Check if the sequence has reached max_model_len.
if
seq
.
get_len
()
>
self
.
scheduler_config
.
max_model_len
:
# Check if the sequence has generated the EOS token.
self
.
scheduler
.
free_seq
(
if
((
not
sampling_params
.
ignore_eos
)
seq
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
)
and
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
):
continue
seq
.
status
=
SequenceStatus
.
FINISHED_STOPPED
# Check if the sequence has reached max_tokens.
return
if
seq
.
get_output_len
()
==
sampling_params
.
max_tokens
:
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_LENGTH_CAPPED
)
continue
# Check if the sequence has generated the EOS token.
if
not
sampling_params
.
ignore_eos
:
if
seq
.
get_last_token_id
()
==
self
.
tokenizer
.
eos_token_id
:
self
.
scheduler
.
free_seq
(
seq
,
SequenceStatus
.
FINISHED_STOPPED
)
continue
def
_run_workers
(
def
_run_workers
(
self
,
self
,
...
...
vllm/model_executor/layers/sampler.py
View file @
002800f0
...
@@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata
...
@@ -9,7 +9,7 @@ from vllm.model_executor.input_metadata import InputMetadata
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
gather_from_tensor_model_parallel_region
)
gather_from_tensor_model_parallel_region
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceOutputs
from
vllm.sequence
import
SamplerOutput
,
SequenceOutputs
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -39,7 +39,7 @@ class Sampler(nn.Module):
...
@@ -39,7 +39,7 @@ class Sampler(nn.Module):
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
# Get the hidden states that we use for sampling.
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input_metadata
)
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input_metadata
)
...
@@ -292,7 +292,13 @@ def _sample_from_prompt(
...
@@ -292,7 +292,13 @@ def _sample_from_prompt(
if
sampling_params
.
use_beam_search
:
if
sampling_params
.
use_beam_search
:
# Beam search.
# Beam search.
beam_width
=
sampling_params
.
best_of
beam_width
=
sampling_params
.
best_of
_
,
next_token_ids
=
torch
.
topk
(
prob
,
beam_width
)
# Sample 2 * beam_width candidates to make sure that with high
# probability we can get `beam_width` candidates in addition to
# the finished sequences for the next iteration. See
# https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563
# for details. See also HF reference:
# https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065
_
,
next_token_ids
=
torch
.
topk
(
prob
,
2
*
beam_width
)
next_token_ids
=
next_token_ids
.
tolist
()
next_token_ids
=
next_token_ids
.
tolist
()
elif
sampling_params
.
temperature
<
_SAMPLING_EPS
:
elif
sampling_params
.
temperature
<
_SAMPLING_EPS
:
# Greedy sampling.
# Greedy sampling.
...
@@ -330,29 +336,11 @@ def _sample_from_generation_tokens(
...
@@ -330,29 +336,11 @@ def _sample_from_generation_tokens(
vocab_size
=
logprobs
.
size
(
-
1
)
vocab_size
=
logprobs
.
size
(
-
1
)
beam_width
=
len
(
seq_ids
)
beam_width
=
len
(
seq_ids
)
_
,
topk_ids
=
torch
.
topk
(
logprobs
.
flatten
(),
beam_width
)
_
,
topk_ids
=
torch
.
topk
(
logprobs
.
flatten
(),
2
*
beam_width
)
topk_ids
=
topk_ids
.
tolist
()
topk_ids
=
topk_ids
.
tolist
()
seq_idx
=
[
i
//
vocab_size
for
i
in
topk_ids
]
seq_idx
=
[
i
//
vocab_size
for
i
in
topk_ids
]
beam_seq_ids
=
[
seq_ids
[
i
]
for
i
in
seq_idx
]
parent_seq_ids
=
[
seq_ids
[
i
]
for
i
in
seq_idx
]
token_ids
=
[
i
%
vocab_size
for
i
in
topk_ids
]
next_token_ids
=
[
i
%
vocab_size
for
i
in
topk_ids
]
beam_outputs
:
Dict
[
int
,
Tuple
[
int
,
int
]]
=
{}
outstanding_beams
:
List
[
Tuple
[
int
,
int
]]
=
[]
# If a beam survives, continue with it.
for
seq_id
,
token_id
in
zip
(
beam_seq_ids
,
token_ids
):
if
seq_id
not
in
beam_outputs
:
beam_outputs
[
seq_id
]
=
(
seq_id
,
token_id
)
else
:
outstanding_beams
.
append
((
seq_id
,
token_id
))
# If a beam is discarded, fork another beam.
for
seq_id
in
seq_ids
:
if
seq_id
not
in
beam_outputs
:
beam_outputs
[
seq_id
]
=
outstanding_beams
.
pop
()
assert
not
outstanding_beams
parent_seq_ids
=
[
beam_outputs
[
seq_id
][
0
]
for
seq_id
in
seq_ids
]
next_token_ids
=
[
beam_outputs
[
seq_id
][
1
]
for
seq_id
in
seq_ids
]
elif
sampling_params
.
temperature
<
_SAMPLING_EPS
:
elif
sampling_params
.
temperature
<
_SAMPLING_EPS
:
# Greedy sampling.
# Greedy sampling.
assert
len
(
seq_ids
)
==
1
assert
len
(
seq_ids
)
==
1
...
@@ -374,16 +362,18 @@ def _sample(
...
@@ -374,16 +362,18 @@ def _sample(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
seq_outputs
:
Dict
[
int
,
Sequence
Output
s
]
=
{}
seq_outputs
:
Sampler
Output
=
[]
# TODO(woosuk): Optimize.
# TODO(woosuk): Optimize.
idx
=
0
idx
=
0
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
input_metadata
.
seq_groups
):
seq_group_outputs
:
List
[
SequenceOutputs
]
=
[]
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
if
i
<
input_metadata
.
num_prompts
:
if
i
<
input_metadata
.
num_prompts
:
# Generate the next tokens for a prompt input.
# Generate the next tokens for a prompt input.
assert
len
(
seq_ids
)
==
sampling_params
.
best_of
assert
len
(
seq_ids
)
==
1
,
"Prompt input should have only one seq."
parent_seq_id
=
seq_ids
[
0
]
prob
=
probs
[
idx
]
prob
=
probs
[
idx
]
logprob
=
logprobs
[
idx
]
logprob
=
logprobs
[
idx
]
idx
+=
1
idx
+=
1
...
@@ -395,17 +385,18 @@ def _sample(
...
@@ -395,17 +385,18 @@ def _sample(
sampling_params
.
logprobs
)
sampling_params
.
logprobs
)
# Build the output.
# Build the output.
for
seq_id
,
next_token_id
in
zip
(
seq_ids
,
next_token_ids
)
:
for
next_token_id
in
next_token_ids
:
output_logprobs
=
next_logprobs
.
copy
()
output_logprobs
=
next_logprobs
.
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
output_logprobs
[
next_token_id
]
=
logprob
[
next_token_id
].
item
()
seq_
outputs
[
seq_id
]
=
SequenceOutputs
(
seq_id
,
seq_id
,
seq_
group_outputs
.
append
(
next_token_id
,
SequenceOutputs
(
parent_seq_id
,
next_token_id
,
output_logprobs
)
output_logprobs
)
)
else
:
else
:
# Generate the next tokens for generation tokens.
# Generate the next tokens for generation tokens.
prob
=
probs
[
idx
:
idx
+
len
(
seq_ids
)]
num_parent_seqs
=
len
(
seq_ids
)
logprob
=
logprobs
[
idx
:
idx
+
len
(
seq_ids
)]
prob
=
probs
[
idx
:
idx
+
num_parent_seqs
]
idx
+=
len
(
seq_ids
)
logprob
=
logprobs
[
idx
:
idx
+
num_parent_seqs
]
idx
+=
num_parent_seqs
# Sample the next tokens.
# Sample the next tokens.
seq_logprobs
=
[
seq_logprobs
=
[
...
@@ -422,17 +413,15 @@ def _sample(
...
@@ -422,17 +413,15 @@ def _sample(
logprob
[
j
],
sampling_params
.
logprobs
)
logprob
[
j
],
sampling_params
.
logprobs
)
# Build the output.
# Build the output.
for
seq_id
,
parent_seq_id
,
next_token_id
in
zip
(
for
parent_seq_id
,
next_token_id
in
zip
(
parent_seq_ids
,
seq_ids
,
parent_seq_ids
,
next_token_ids
):
next_token_ids
):
j
=
seq_ids
.
index
(
parent_seq_id
)
j
=
seq_ids
.
index
(
parent_seq_id
)
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
output_logprobs
=
next_logprobs
[
parent_seq_id
].
copy
()
output_logprobs
[
next_token_id
]
=
logprob
[
j
,
output_logprobs
[
next_token_id
]
=
logprob
[
j
,
next_token_id
].
item
()
next_token_id
].
item
()
seq_outputs
[
seq_id
]
=
SequenceOutputs
(
seq_group_outputs
.
append
(
seq_id
,
SequenceOutputs
(
parent_seq_id
,
next_token_id
,
parent_seq_id
,
output_logprobs
))
next_token_id
,
seq_outputs
.
append
(
seq_group_outputs
)
output_logprobs
,
)
return
seq_outputs
return
seq_outputs
vllm/model_executor/models/aquila.py
View file @
002800f0
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -41,7 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -41,7 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
from
vllm.transformers_utils.configs.aquila
import
AquilaConfig
from
vllm.transformers_utils.configs.aquila
import
AquilaConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -273,7 +273,7 @@ class AquilaForCausalLM(nn.Module):
...
@@ -273,7 +273,7 @@ class AquilaForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
...
...
vllm/model_executor/models/baichuan.py
View file @
002800f0
...
@@ -23,12 +23,11 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
...
@@ -23,12 +23,11 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
import
math
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
vllm.sequence
import
SequenceOutputs
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
...
@@ -42,6 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -42,6 +41,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -290,7 +290,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -290,7 +290,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
...
...
vllm/model_executor/models/bloom.py
View file @
002800f0
...
@@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
...
@@ -21,7 +21,7 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
import
math
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module):
...
@@ -264,7 +264,7 @@ class BloomForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
...
...
vllm/model_executor/models/falcon.py
View file @
002800f0
...
@@ -19,7 +19,7 @@
...
@@ -19,7 +19,7 @@
"""PyTorch Falcon model."""
"""PyTorch Falcon model."""
import
math
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
,
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
,
reduce_from_tensor_model_parallel_region
)
reduce_from_tensor_model_parallel_region
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
from
vllm.transformers_utils.configs
import
RWConfig
from
vllm.transformers_utils.configs
import
RWConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -397,7 +397,7 @@ class FalconForCausalLM(nn.Module):
...
@@ -397,7 +397,7 @@ class FalconForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
transformer
(
input_ids
,
input_ids
,
positions
,
positions
,
...
...
vllm/model_executor/models/gpt2.py
View file @
002800f0
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -38,7 +38,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -218,7 +218,7 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -218,7 +218,7 @@ class GPT2LMHeadModel(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
002800f0
...
@@ -22,7 +22,7 @@
...
@@ -22,7 +22,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -39,7 +39,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -39,7 +39,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -246,7 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -246,7 +246,7 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
...
...
vllm/model_executor/models/gpt_j.py
View file @
002800f0
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -203,7 +203,7 @@ class GPTJForCausalLM(nn.Module):
...
@@ -203,7 +203,7 @@ class GPTJForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
...
...
vllm/model_executor/models/gpt_neox.py
View file @
002800f0
...
@@ -20,7 +20,7 @@
...
@@ -20,7 +20,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -36,7 +36,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -215,7 +215,7 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
...
...
vllm/model_executor/models/internlm.py
View file @
002800f0
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -17,7 +17,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
...
@@ -17,7 +17,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
from
vllm.model_executor.weight_utils
import
(
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
hf_model_weights_iterator
,
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
load_tensor_parallel_weights
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -218,7 +218,7 @@ class InternLMForCausalLM(nn.Module):
...
@@ -218,7 +218,7 @@ class InternLMForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
...
...
vllm/model_executor/models/llama.py
View file @
002800f0
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -43,7 +43,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -43,7 +43,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -256,7 +256,7 @@ class LlamaForCausalLM(nn.Module):
...
@@ -256,7 +256,7 @@ class LlamaForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
...
...
vllm/model_executor/models/mpt.py
View file @
002800f0
# coding=utf-8
# coding=utf-8
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
import
math
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -16,7 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -16,7 +16,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -230,7 +230,7 @@ class MPTForCausalLM(nn.Module):
...
@@ -230,7 +230,7 @@ class MPTForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
...
...
vllm/model_executor/models/opt.py
View file @
002800f0
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
...
@@ -37,7 +37,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module):
...
@@ -282,7 +282,7 @@ class OPTForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
...
...
vllm/model_executor/models/qwen.py
View file @
002800f0
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -32,7 +32,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
...
@@ -32,7 +32,7 @@ from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear
,
ColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.sequence
import
S
equence
Output
s
from
vllm.sequence
import
S
ampler
Output
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
from
vllm.transformers_utils.configs.qwen
import
QWenConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
@@ -235,7 +235,7 @@ class QWenLMHeadModel(nn.Module):
...
@@ -235,7 +235,7 @@ class QWenLMHeadModel(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
Sequence
Output
s
]
:
)
->
Sampler
Output
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment