Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0455c46e
Unverified
Commit
0455c46e
authored
Sep 21, 2024
by
Cyrus Leung
Committed by
GitHub
Sep 21, 2024
Browse files
[Core] Factor out common code in `SequenceData` and `Sequence` (#8675)
parent
d4bf085a
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
64 additions
and
97 deletions
+64
-97
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+7
-20
tests/spec_decode/utils.py
tests/spec_decode/utils.py
+3
-9
tests/test_logits_processor.py
tests/test_logits_processor.py
+2
-6
tests/test_sequence.py
tests/test_sequence.py
+2
-5
tests/worker/test_encoder_decoder_model_runner.py
tests/worker/test_encoder_decoder_model_runner.py
+7
-15
tests/worker/test_model_runner.py
tests/worker/test_model_runner.py
+5
-11
vllm/inputs/registry.py
vllm/inputs/registry.py
+1
-7
vllm/sequence.py
vllm/sequence.py
+37
-24
No files found.
tests/samplers/test_sampler.py
View file @
0455c46e
import
itertools
import
itertools
import
random
import
random
from
array
import
array
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
unittest.mock
import
Mock
,
patch
from
unittest.mock
import
Mock
,
patch
...
@@ -12,8 +11,7 @@ import vllm.envs as envs
...
@@ -12,8 +11,7 @@ import vllm.envs as envs
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
Counter
,
is_pin_memory_available
from
vllm.utils
import
Counter
,
is_pin_memory_available
...
@@ -59,9 +57,7 @@ def _do_sample(
...
@@ -59,9 +57,7 @@ def _do_sample(
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
...
@@ -205,9 +201,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
...
@@ -205,9 +201,8 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
return
sampling_params
return
sampling_params
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
def
create_sequence_data
(
num_input
=
3
,
num_generated
=
0
):
seq_data
=
SequenceData
(
seq_data
=
SequenceData
.
from_seqs
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
))
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_input
)))
if
num_generated
>
0
:
if
num_generated
>
0
:
seq_data
.
output_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
seq_data
.
output_token_ids
=
random
.
choices
(
range
(
0
,
VOCAB_SIZE
),
k
=
num_generated
)
k
=
num_generated
)
...
@@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -511,9 +506,7 @@ def test_sampler_mixed(seed: int, device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
,
sampling_params
=
sampling_params
,
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
...
@@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -613,9 +606,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
1
,
temperature
=
1
,
top_k
=
top_k
,
top_k
=
top_k
,
...
@@ -699,11 +690,7 @@ def test_sampler_repetition_penalty_mixed(device: str):
...
@@ -699,11 +690,7 @@ def test_sampler_repetition_penalty_mixed(device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
sampling_params
[
i
],
sampling_params
=
sampling_params
[
i
],
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
))
))
...
...
tests/spec_decode/utils.py
View file @
0455c46e
from
array
import
array
from
itertools
import
count
from
itertools
import
count
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Callable
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
...
@@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
...
@@ -11,8 +10,7 @@ from vllm.engine.arg_utils import EngineArgs
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
Logprob
,
CompletionSequenceGroupOutput
,
Logprob
,
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
SequenceData
,
SequenceGroupMetadata
,
SequenceOutput
)
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
...
@@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
...
@@ -138,12 +136,8 @@ def create_seq_group_metadata_from_prompts(
request_id
=
str
(
i
),
request_id
=
str
(
i
),
is_prompt
=
len
(
cont_token_ids
)
==
0
,
is_prompt
=
len
(
cont_token_ids
)
==
0
,
seq_data
=
{
seq_data
=
{
i
:
i
:
SequenceData
.
from_seqs
(
prompt_token_ids
[:],
SequenceData
(
cont_token_ids
[:]),
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
prompt_token_ids
[:]),
_output_token_ids
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
cont_token_ids
[:]),
),
},
},
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
sampling_params
=
SamplingParams
(
temperature
=
0.0
,
),
block_tables
=
{
i
:
block_allocations
[
i
][:]},
block_tables
=
{
i
:
block_allocations
[
i
][:]},
...
...
tests/test_logits_processor.py
View file @
0455c46e
import
random
import
random
from
array
import
array
from
typing
import
Tuple
from
typing
import
Tuple
from
unittest.mock
import
patch
from
unittest.mock
import
patch
...
@@ -9,8 +8,7 @@ import torch
...
@@ -9,8 +8,7 @@ import torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
@@ -71,9 +69,7 @@ def test_logits_processors(seed: int, device: str):
...
@@ -71,9 +69,7 @@ def test_logits_processors(seed: int, device: str):
SequenceGroupMetadata
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
seq_data
=
{
seq_data
=
{
0
:
SequenceData
.
from_seqs
([
1
,
2
,
3
])},
0
:
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
]))
},
sampling_params
=
SamplingParams
(
temperature
=
0
,
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
block_tables
=
{
0
:
[
1
]},
...
...
tests/test_sequence.py
View file @
0455c46e
from
array
import
array
import
pytest
import
pytest
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
from
vllm.sequence
import
(
CompletionSequenceGroupOutput
,
SequenceData
,
CompletionSequenceGroupOutput
,
SequenceData
,
SequenceOutput
)
SequenceOutput
)
from
.core.utils
import
create_dummy_prompt
from
.core.utils
import
create_dummy_prompt
...
@@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
...
@@ -58,7 +55,7 @@ def test_sampler_output_eq(sample_outputs):
def
test_sequence_data_prefill
():
def
test_sequence_data_prefill
():
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
1
,
2
,
3
,
4
])
)
seq_data
=
SequenceData
.
from_seqs
(
[
1
,
2
,
3
,
4
])
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_uncomputed_tokens
()
==
4
assert
seq_data
.
get_num_computed_tokens
()
==
0
assert
seq_data
.
get_num_computed_tokens
()
==
0
# advance by 2
# advance by 2
...
...
tests/worker/test_encoder_decoder_model_runner.py
View file @
0455c46e
import
itertools
import
itertools
from
array
import
array
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
import
torch
import
torch
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
is_cpu
,
make_tensor_with_pad
from
vllm.utils
import
is_cpu
,
make_tensor_with_pad
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.enc_dec_model_runner
import
EncoderDecoderModelRunner
from
vllm.worker.model_runner
import
_get_graph_batch_size
from
vllm.worker.model_runner
import
_get_graph_batch_size
...
@@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
...
@@ -119,12 +117,10 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
range
(
seq_len
)))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_lens
.
append
(
encoder_seq_len
)
encoder_seq_data
=
SequenceData
(
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
encoder_seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
...
@@ -317,11 +313,9 @@ def test_prepare_decode(batch_size, multiple_seqs_per_seq_group):
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_data
=
SequenceData
(
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_data
=
SequenceData
(
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
...
@@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
...
@@ -523,11 +517,9 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_data
=
SequenceData
(
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
seq_len
))))
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_len
=
(
i
+
1
)
%
(
model_runner
.
block_size
-
1
)
+
1
encoder_seq_data
=
SequenceData
(
encoder_seq_data
=
SequenceData
.
from_seqs
(
range
(
encoder_seq_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
(
range
(
encoder_seq_len
))))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
False
,
is_prompt
=
False
,
...
...
tests/worker/test_model_runner.py
View file @
0455c46e
from
array
import
array
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
...
@@ -8,8 +7,7 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
...
@@ -8,8 +7,7 @@ from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment
)
init_distributed_environment
)
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
SamplingParams
,
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
SequenceData
,
SequenceGroupMetadata
)
from
vllm.utils
import
get_open_port
from
vllm.utils
import
get_open_port
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
from
vllm.worker.model_runner
import
ModelRunner
,
_get_graph_batch_size
...
@@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
...
@@ -48,8 +46,7 @@ def test_prepare_prompt(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
...
@@ -166,8 +163,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_lens
.
append
(
context_len
)
context_lens
.
append
(
context_len
)
seq_data
=
SequenceData
(
seq_data
=
SequenceData
.
from_seqs
(
range
(
context_len
))
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
)))
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
# Append one token ID since prefill is finished.
# Append one token ID since prefill is finished.
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
...
@@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -326,8 +322,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
seq_lens
.
append
(
seq_len
)
seq_lens
.
append
(
seq_len
)
seq_data
=
SequenceData
(
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
seq_data
=
SequenceData
.
from_seqs
(
range
(
seq_len
))
range
(
seq_len
)))
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
is_prompt
=
True
,
...
@@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
...
@@ -343,8 +338,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
for
i
in
range
(
prefill_batch_size
,
batch_size
):
# make sure all tokens fit into one block
# make sure all tokens fit into one block
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
context_len
=
i
%
(
model_runner
.
block_size
-
1
)
+
1
prompt_toks
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
range
(
context_len
))
seq_data
=
SequenceData
.
from_seqs
(
range
(
context_len
))
seq_data
=
SequenceData
(
prompt_toks
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
append_token_id
(
1
,
0
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_data
.
update_num_computed_tokens
(
context_len
)
seq_group_metadata
=
SequenceGroupMetadata
(
seq_group_metadata
=
SequenceGroupMetadata
(
...
...
vllm/inputs/registry.py
View file @
0455c46e
import
functools
import
functools
from
array
import
array
from
collections
import
UserDict
from
collections
import
UserDict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
Mapping
,
Optional
,
...
@@ -22,10 +21,6 @@ logger = init_logger(__name__)
...
@@ -22,10 +21,6 @@ logger = init_logger(__name__)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
C
=
TypeVar
(
"C"
,
bound
=
PretrainedConfig
,
default
=
PretrainedConfig
)
# NOTE: This has to match with sequence.py's VLLM_TOKEN_ID_ARRAY_TYPE.
# We cannot import it here because of circular dependencies.
VLLM_TOKEN_ID_ARRAY_TYPE
=
"l"
@
dataclass
(
frozen
=
True
)
@
dataclass
(
frozen
=
True
)
class
InputContext
:
class
InputContext
:
...
@@ -130,8 +125,7 @@ class InputRegistry:
...
@@ -130,8 +125,7 @@ class InputRegistry:
# Avoid circular import
# Avoid circular import
from
vllm.sequence
import
SequenceData
from
vllm.sequence
import
SequenceData
dummy_seq_data
=
SequenceData
(
dummy_seq_data
=
SequenceData
.
from_counts
({
0
:
seq_len
})
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
0
])
*
seq_len
)
dummy_multi_modal_data
=
None
dummy_multi_modal_data
=
None
return
dummy_seq_data
,
dummy_multi_modal_data
return
dummy_seq_data
,
dummy_multi_modal_data
...
...
vllm/sequence.py
View file @
0455c46e
...
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
...
@@ -5,6 +5,7 @@ from abc import ABC, abstractmethod
from
array
import
array
from
array
import
array
from
collections
import
defaultdict
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cached_property
,
reduce
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Dict
,
List
,
Mapping
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Set
,
Tuple
,
Union
,
cast
from
typing
import
Set
,
Tuple
,
Union
,
cast
...
@@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
...
@@ -169,6 +170,35 @@ class SequenceData(msgspec.Struct,
# It is used to compute mrope_position_ids.
# It is used to compute mrope_position_ids.
_mrope_position_delta
:
Optional
[
int
]
=
None
_mrope_position_delta
:
Optional
[
int
]
=
None
@
staticmethod
def
from_counts
(
counts_by_token
:
Mapping
[
int
,
int
])
->
"SequenceData"
:
if
len
(
counts_by_token
)
==
0
:
return
SequenceData
.
from_seqs
([])
arrs
=
[
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
[
token_id
])
*
count
for
token_id
,
count
in
counts_by_token
.
items
()
]
return
SequenceData
(
reduce
(
array
.
__add__
,
arrs
))
@
staticmethod
def
from_seqs
(
prompt_token_ids
:
GenericSequence
[
int
],
output_token_ids
:
Optional
[
GenericSequence
[
int
]]
=
None
,
)
->
"SequenceData"
:
prompt_token_ids_arr
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
prompt_token_ids
)
if
output_token_ids
is
None
:
return
SequenceData
(
prompt_token_ids_arr
)
output_token_ids_arr
=
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
output_token_ids
)
return
SequenceData
(
prompt_token_ids_arr
,
_output_token_ids
=
output_token_ids_arr
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
assert
self
.
_prompt_token_ids
.
typecode
==
"l"
assert
self
.
_prompt_token_ids
.
typecode
==
"l"
assert
self
.
_output_token_ids
.
typecode
==
"l"
assert
self
.
_output_token_ids
.
typecode
==
"l"
...
@@ -370,8 +400,6 @@ class Sequence:
...
@@ -370,8 +400,6 @@ class Sequence:
self
.
lora_request
=
lora_request
self
.
lora_request
=
lora_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
prompt_adapter_request
=
prompt_adapter_request
self
.
from_decoder_prompt
=
from_decoder_prompt
self
.
from_decoder_prompt
=
from_decoder_prompt
self
.
_prompt
:
Optional
[
str
]
=
None
self
.
_prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
# For decoder-only models, a Sequence is constructed
# For decoder-only models, a Sequence is constructed
# from an LLMInputs instance (the `inputs` arg.)
# from an LLMInputs instance (the `inputs` arg.)
...
@@ -400,8 +428,7 @@ class Sequence:
...
@@ -400,8 +428,7 @@ class Sequence:
f
"invalid input
{
inputs
}
; did you forget the "
f
"invalid input
{
inputs
}
; did you forget the "
"encoder input prompt fields?"
)
"encoder input prompt fields?"
)
self
.
data
=
SequenceData
(
self
.
data
=
SequenceData
.
from_seqs
(
self
.
prompt_token_ids
)
array
(
VLLM_TOKEN_ID_ARRAY_TYPE
,
self
.
prompt_token_ids
))
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_logprobs
:
SampleLogprobs
=
[]
self
.
output_text
=
""
self
.
output_text
=
""
...
@@ -422,37 +449,23 @@ class Sequence:
...
@@ -422,37 +449,23 @@ class Sequence:
def
n_blocks
(
self
)
->
int
:
def
n_blocks
(
self
)
->
int
:
return
(
self
.
get_len
()
+
self
.
block_size
-
1
)
//
self
.
block_size
return
(
self
.
get_len
()
+
self
.
block_size
-
1
)
//
self
.
block_size
@
property
@
cached_
property
def
prompt
(
self
)
->
Optional
[
str
]:
def
prompt
(
self
)
->
Optional
[
str
]:
if
self
.
_prompt
is
not
None
:
# Select decoder or encoder input prompt str, as appropriate
# Reuse precomputed prompt string
return
self
.
_prompt
# Select decoder or encoder input prompt str,
# as appropriate
prompt_key
:
str
=
(
"prompt"
prompt_key
:
str
=
(
"prompt"
if
self
.
from_decoder_prompt
else
"encoder_prompt"
)
if
self
.
from_decoder_prompt
else
"encoder_prompt"
)
# Cache prompt
return
cast
(
Optional
[
str
],
self
.
inputs
.
get
(
prompt_key
))
self
.
_prompt
=
cast
(
Optional
[
str
],
self
.
inputs
.
get
(
prompt_key
))
return
self
.
_prompt
@
property
@
cached_
property
def
prompt_token_ids
(
self
)
->
List
[
int
]:
def
prompt_token_ids
(
self
)
->
List
[
int
]:
if
self
.
_prompt_token_ids
is
not
None
:
# Select decoder or encoder input prompt token ids, as appropriate
# Reuse precomputed prompt token ids
return
self
.
_prompt_token_ids
# Select decoder or encoder input prompt
# token ids, as appropriate
prompt_token_ids_key
:
str
=
(
"prompt_token_ids"
prompt_token_ids_key
:
str
=
(
"prompt_token_ids"
if
self
.
from_decoder_prompt
else
if
self
.
from_decoder_prompt
else
"encoder_prompt_token_ids"
)
"encoder_prompt_token_ids"
)
# Cache computed prompt token ids
# Cache computed prompt token ids
self
.
_prompt_token_ids
=
cast
(
List
[
int
],
return
cast
(
List
[
int
],
self
.
inputs
.
get
(
prompt_token_ids_key
))
self
.
inputs
.
get
(
prompt_token_ids_key
))
return
self
.
_prompt_token_ids
@
property
@
property
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
def
multi_modal_data
(
self
)
->
"MultiModalDataDict"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment