Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0455c46e
Unverified
Commit
0455c46e
authored
Sep 21, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 21, 2024
Browse files
[Core] Factor out common code in `SequenceData` and `Sequence` (#8675)
parent
d4bf085a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
64 additions
and
97 deletions
+64
-97
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+7
-20
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+3
-9
tests/test_logits_processor.py
tests/test_logits_processor.py
+2
-6
tests/test_sequence.py
tests/test_sequence.py
+2
-5
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+7
-15
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+5
-11
vllm/inputs/registry.py
vllm/inputs/registry.py
+1
-7
vllm/sequence.py
vllm/sequence.py
+37
-24
No files found.
tests/samplers/test_sampler.py
View file @
0455c46e
import
itertools
import
random
from
array
import
array
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
Mock
,
patch
...
...
@@ -12,8 +11,7 @@ import vllm.envs as envs
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
Counter
,
is_pin_memory_available
...
...
@@ -59,9 +57,7 @@ def _do_sample(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
...
...
@@ -205,9 +201,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
return
sampling_params
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
)))
seq_data
=
SequenceData
.
from_seqs
(
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
))
if
num_generated
>
0
:
seq_data
.
output_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_generated
)
...
...
@@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
))
...
...
@@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
1
,
top_k
=
top_k
,
...
...
@@ -699,11 +690,7 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
sampling_params
[
i
],
block_tables
=
{
0
:
[
1
]},
))
...
...
tests/spec_decode/utils.py
View file @
0455c46e
from
array
import
array
from
itertools
import
count
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
...
...
@@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
Logprob
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
...
...
@@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
request_id
=
str
(
i
),
is_prompt
=
len
(
cont_token_ids
)
==
0
,
seq_data
=
{
i
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
prompt_token_ids
[:]),
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
cont_token_ids
[:]),
),
i
:
SequenceData
.
from_seqs
(
prompt_token_ids
[:],
cont_token_ids
[:]),
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
block_tables
=
{
i
:
block_allocations
[
i
][:]},
...
...
tests/test_logits_processor.py
View file @
0455c46e
import
random
from
array
import
array
from
typing
import
Tuple
from
unittest.mock
import
patch
...
...
@@ -9,8 +8,7 @@ import torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
is_pin_memory_available
...
...
@@ -71,9 +69,7 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
...
...
tests/test_sequence.py
View file @
0455c46e
from
array
import
array
import
pytest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
CompletionSequenceGroupOutput
,
SequenceData
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SequenceData
,
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
...
...
@@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
def
test_sequence_data_prefill
():
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
,
4
])
)
seq_data
=
SequenceData
.
from_seqs
(
[
1
,
2
,
3
,
4
])
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_computed_tokens
()
==
0
# advance by 2
...
...
tests/worker/test_encoder_decoder_model_runner.py
View file @
0455c46e
import
itertools
from
array
import
array
from
typing
import
List
import
pytest
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
is_cpu
,
make_tensor_with_pad
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
_get_graph_batch_size
...
...
@@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
encoder_seq_len
)))
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
...
...
@@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
...
...
@@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
...
...
tests/worker/test_model_runner.py
View file @
0455c46e
from
array
import
array
from
typing
import
List
import
pytest
...
...
@@ -8,8 +7,7 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
)
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
...
@@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
...
...
@@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_lens
.
append
(
context_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
)))
seq_data
=
SequenceData
.
from_seqs
(
range
(
context_len
))
seq_data
.
update_num_computed_tokens
(
context_len
)
# Append one token ID since prefill is finished.
seq_data
.
append_token_id
(
1
,
0
)
...
...
@@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
seq_len
)))
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
...
...
@@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
=
SequenceData
.
from_seqs
(
range
(
context_len
))
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_group_metadata
=
SequenceGroupMetadata
(
...
...
vllm/inputs/registry.py
View file @
0455c46e
import
functools
from
array
import
array
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
...
...
@@ -22,10 +21,6 @@ logger = init_logger(__name__)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
@
dataclass
(
frozen
=
True
)
class
InputContext
:
...
...
@@ -130,8 +125,7 @@ class InputRegistry:
# Avoid circular import
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
seq_len
)
dummy_seq_data
=
SequenceData
.
from_counts
({
0
:
seq_len
})
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
...
...
vllm/sequence.py
View file @
0455c46e
...
...
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from
array
import
array
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
functools
import
cached_property
,
reduce
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
,
cast
...
...
@@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
_mrope_position_delta
:
Optional
[
int
]
=
None
@
staticmethod
def
from_counts
(
counts_by_token
:
Mapping
[
int
,
int
])
->
"SequenceData"
:
if
len
(
counts_by_token
)
==
0
:
return
SequenceData
.
from_seqs
([])
arrs
=
[
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
token_id
])
*
count
for
token_id
,
count
in
counts_by_token
.
items
()
]
return
SequenceData
(
reduce
(
array
.
__add__
,
arrs
))
@
staticmethod
def
from_seqs
(
prompt_token_ids
:
GenericSequence
[
int
],
output_token_ids
:
Optional
[
GenericSequence
[
int
]]
=
None
,
)
->
"SequenceData"
:
prompt_token_ids_arr
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
prompt_token_ids
)
if
output_token_ids
is
None
:
return
SequenceData
(
prompt_token_ids_arr
)
output_token_ids_arr
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
output_token_ids
)
return
SequenceData
(
prompt_token_ids_arr
,
_output_token_ids
=
output_token_ids_arr
)
def
__post_init__
(
self
)
->
None
:
assert
self
.
_prompt_token_ids
.
typecode
==
"l"
assert
self
.
_output_token_ids
.
typecode
==
"l"
...
...
@@ -370,8 +400,6 @@ class Sequence:
self
.
lora_request
=
lora_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
from_decoder_prompt
=
from_decoder_prompt
self
.
_prompt
:
Optional
[
str
]
=
None
self
.
_prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
# For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.)
...
...
@@ -400,8 +428,7 @@ class Sequence:
f
"invalid input
{
inputs
}
; did you forget the "
"encoder input prompt fields?"
)
self
.
data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
self
.
prompt_token_ids
))
self
.
data
=
SequenceData
.
from_seqs
(
self
.
prompt_token_ids
)
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_text
=
""
...
...
@@ -422,37 +449,23 @@ class Sequence:
def
n_blocks
(
self
)
->
int
:
return
(
self
.
get_len
()
+
self
.
block_size
-
1
)
//
self
.
block_size
@
property
@
cached_
property
def
prompt
(
self
)
->
Optional
[
str
]:
if
self
.
_prompt
is
not
None
:
# Reuse precomputed prompt string
return
self
.
_prompt
# Select decoder or encoder input prompt str,
# as appropriate
# Select decoder or encoder input prompt str, as appropriate
prompt_key
:
str
=
(
"prompt"
if
self
.
from_decoder_prompt
else
"encoder_prompt"
)
# Cache prompt
self
.
_prompt
=
cast
(
Optional
[
str
],
self
.
inputs
.
get
(
prompt_key
))
return
self
.
_prompt
return
cast
(
Optional
[
str
],
self
.
inputs
.
get
(
prompt_key
))
@
property
@
cached_
property
def
prompt_token_ids
(
self
)
->
List
[
int
]:
if
self
.
_prompt_token_ids
is
not
None
:
# Reuse precomputed prompt token ids
return
self
.
_prompt_token_ids
# Select decoder or encoder input prompt
# token ids, as appropriate
# Select decoder or encoder input prompt token ids, as appropriate
prompt_token_ids_key
:
str
=
(
"prompt_token_ids"
if
self
.
from_decoder_prompt
else
"encoder_prompt_token_ids"
)
# Cache computed prompt token ids
self
.
_prompt_token_ids
=
cast
(
List
[
int
],
self
.
inputs
.
get
(
prompt_token_ids_key
))
return
self
.
_prompt_token_ids
return
cast
(
List
[
int
],
self
.
inputs
.
get
(
prompt_token_ids_key
))
@
property
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment