Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
708e6c18
Unverified
Commit
708e6c18
authored
Nov 28, 2023
by
Zhuohan Li
Committed by
GitHub
Nov 28, 2023
Browse files
[FIX] Fix class naming (#1803)
parent
b9438904
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
17 deletions
+17
-17
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+4
-4
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+3
-3
vllm/sequence.py
vllm/sequence.py
+10
-10
No files found.
vllm/engine/llm_engine.py
View file @
708e6c18
...
@@ -12,8 +12,8 @@ from vllm.logger import init_logger
...
@@ -12,8 +12,8 @@ from vllm.logger import init_logger
from
vllm.outputs
import
RequestOutput
from
vllm.outputs
import
RequestOutput
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
from
vllm.sequence
import
(
SamplerOutput
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
,
SequenceGroupOutput
s
,
SequenceGroupMetadata
,
SequenceGroupOutput
,
SequenceOutput
s
,
SequenceStatus
)
SequenceOutput
,
SequenceStatus
)
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
from
vllm.transformers_utils.tokenizer
import
(
detokenize_incrementally
,
get_tokenizer
)
get_tokenizer
)
from
vllm.utils
import
Counter
from
vllm.utils
import
Counter
...
@@ -363,7 +363,7 @@ class LLMEngine:
...
@@ -363,7 +363,7 @@ class LLMEngine:
return
current_worst_score
>=
highest_attainable_score
return
current_worst_score
>=
highest_attainable_score
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
def
_process_sequence_group_outputs
(
self
,
seq_group
:
SequenceGroup
,
outputs
:
SequenceGroupOutput
s
)
->
None
:
outputs
:
SequenceGroupOutput
)
->
None
:
# Process prompt logprobs
# Process prompt logprobs
prompt_logprobs
=
outputs
.
prompt_logprobs
prompt_logprobs
=
outputs
.
prompt_logprobs
if
prompt_logprobs
is
not
None
:
if
prompt_logprobs
is
not
None
:
...
@@ -384,7 +384,7 @@ class LLMEngine:
...
@@ -384,7 +384,7 @@ class LLMEngine:
# Process the child samples for each parent sequence
# Process the child samples for each parent sequence
for
parent
in
parent_seqs
:
for
parent
in
parent_seqs
:
child_samples
:
List
[
SequenceOutput
s
]
=
parent_child_dict
[
child_samples
:
List
[
SequenceOutput
]
=
parent_child_dict
[
parent
.
seq_id
]
parent
.
seq_id
]
if
len
(
child_samples
)
==
0
:
if
len
(
child_samples
)
==
0
:
# This parent sequence has no children samples. Remove
# This parent sequence has no children samples. Remove
...
...
vllm/model_executor/layers/sampler.py
View file @
708e6c18
...
@@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
...
@@ -9,7 +9,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather
)
tensor_model_parallel_all_gather
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
s
,
SequenceOutput
s
)
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
_SAMPLING_EPS
=
1e-5
_SAMPLING_EPS
=
1e-5
...
@@ -641,7 +641,7 @@ def _build_sampler_output(
...
@@ -641,7 +641,7 @@ def _build_sampler_output(
next_token_ids
,
next_token_ids
,
group_sample_logprobs
):
group_sample_logprobs
):
seq_outputs
.
append
(
seq_outputs
.
append
(
SequenceOutput
s
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
SequenceOutput
(
seq_ids
[
parent_id
],
next_token_id
,
logprobs
))
sampler_output
.
append
(
sampler_output
.
append
(
SequenceGroupOutput
s
(
seq_outputs
,
group_prompt_logprobs
))
SequenceGroupOutput
(
seq_outputs
,
group_prompt_logprobs
))
return
sampler_output
return
sampler_output
vllm/sequence.py
View file @
708e6c18
...
@@ -352,7 +352,7 @@ class SequenceGroupMetadata:
...
@@ -352,7 +352,7 @@ class SequenceGroupMetadata:
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
class
SequenceOutput
s
:
class
SequenceOutput
:
"""The model output associated with a sequence.
"""The model output associated with a sequence.
Args:
Args:
...
@@ -374,40 +374,40 @@ class SequenceOutputs:
...
@@ -374,40 +374,40 @@ class SequenceOutputs:
self
.
logprobs
=
logprobs
self
.
logprobs
=
logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceOutput
s
(parent_seq_id=
{
self
.
parent_seq_id
}
, "
return
(
f
"SequenceOutput(parent_seq_id=
{
self
.
parent_seq_id
}
, "
f
"output_token=
{
self
.
output_token
}
, "
f
"output_token=
{
self
.
output_token
}
, "
f
"logprobs=
{
self
.
logprobs
}
)"
)
f
"logprobs=
{
self
.
logprobs
}
)"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
SequenceOutput
s
):
if
not
isinstance
(
other
,
SequenceOutput
):
raise
NotImplementedError
()
raise
NotImplementedError
()
return
(
self
.
parent_seq_id
==
other
.
parent_seq_id
return
(
self
.
parent_seq_id
==
other
.
parent_seq_id
and
self
.
output_token
==
other
.
output_token
and
self
.
output_token
==
other
.
output_token
and
self
.
logprobs
==
other
.
logprobs
)
and
self
.
logprobs
==
other
.
logprobs
)
class
SequenceGroupOutput
s
:
class
SequenceGroupOutput
:
"""The model output
s
associated with a sequence group."""
"""The model output associated with a sequence group."""
def
__init__
(
def
__init__
(
self
,
self
,
samples
:
List
[
SequenceOutput
s
],
samples
:
List
[
SequenceOutput
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
prompt_logprobs
:
Optional
[
PromptLogprobs
],
)
->
None
:
)
->
None
:
self
.
samples
=
samples
self
.
samples
=
samples
self
.
prompt_logprobs
=
prompt_logprobs
self
.
prompt_logprobs
=
prompt_logprobs
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
return
(
f
"SequenceGroupOutput
s
(samples=
{
self
.
samples
}
, "
return
(
f
"SequenceGroupOutput(samples=
{
self
.
samples
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
)"
)
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
)"
)
def
__eq__
(
self
,
other
:
object
)
->
bool
:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
if
not
isinstance
(
other
,
SequenceGroupOutput
s
):
if
not
isinstance
(
other
,
SequenceGroupOutput
):
raise
NotImplementedError
()
raise
NotImplementedError
()
return
(
self
.
samples
==
other
.
samples
return
(
self
.
samples
==
other
.
samples
and
self
.
prompt_logprobs
==
other
.
prompt_logprobs
)
and
self
.
prompt_logprobs
==
other
.
prompt_logprobs
)
# For each sequence group, we generate a list of SequenceOutput
s
object,
# For each sequence group, we generate a list of SequenceOutput object,
# each of which contains one possible candidate for the next token.
# each of which contains one possible candidate for the next token.
SamplerOutput
=
List
[
SequenceGroupOutput
s
]
SamplerOutput
=
List
[
SequenceGroupOutput
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment