Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
74fa1d12
Unverified
Commit
74fa1d12
authored
Dec 30, 2024
by
Michael Goin
Committed by
GitHub
Dec 31, 2024
Browse files
[Bugfix] Fix OpenAI parallel sampling when using xgrammar (#11637)
Signed-off-by:
mgoin
<
michael@neuralmagic.com
>
parent
a2a40bcd
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
13 deletions
+17
-13
tests/entrypoints/openai/test_completion.py
tests/entrypoints/openai/test_completion.py
+6
-8
vllm/model_executor/guided_decoding/xgrammar_decoding.py
vllm/model_executor/guided_decoding/xgrammar_decoding.py
+5
-0
vllm/sampling_params.py
vllm/sampling_params.py
+5
-4
vllm/sequence.py
vllm/sequence.py
+1
-1
No files found.
tests/entrypoints/openai/test_completion.py
View file @
74fa1d12
...
...
@@ -28,6 +28,8 @@ PA_NAME = "swapnilbp/llama_tweet_ptune"
# need to change to match the prompt adapter
PA_NUM_VIRTUAL_TOKENS
=
8
GUIDED_DECODING_BACKENDS
=
[
"outlines"
,
"lm-format-enforcer"
,
"xgrammar"
]
@
pytest
.
fixture
(
scope
=
"module"
)
def
zephyr_lora_files
():
...
...
@@ -635,8 +637,7 @@ async def test_allowed_token_ids(client: openai.AsyncOpenAI):
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
async
def
test_guided_json_completion
(
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
,
sample_json_schema
):
...
...
@@ -658,8 +659,7 @@ async def test_guided_json_completion(client: openai.AsyncOpenAI,
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
async
def
test_guided_regex_completion
(
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
,
sample_regex
):
...
...
@@ -680,8 +680,7 @@ async def test_guided_regex_completion(client: openai.AsyncOpenAI,
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
async
def
test_guided_choice_completion
(
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
,
sample_guided_choice
):
...
...
@@ -761,8 +760,7 @@ async def test_echo_logprob_completion(client: openai.AsyncOpenAI,
@
pytest
.
mark
.
asyncio
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
[
"outlines"
,
"lm-format-enforcer"
])
@
pytest
.
mark
.
parametrize
(
"guided_decoding_backend"
,
GUIDED_DECODING_BACKENDS
)
async
def
test_guided_decoding_type_error
(
client
:
openai
.
AsyncOpenAI
,
guided_decoding_backend
:
str
,
sample_json_schema
,
sample_regex
):
...
...
vllm/model_executor/guided_decoding/xgrammar_decoding.py
View file @
74fa1d12
# noqa: UP007
from
__future__
import
annotations
import
copy
import
json
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
...
...
@@ -309,3 +310,7 @@ class XGrammarLogitsProcessor:
scores
=
scores
.
to
(
device_type
).
squeeze
()
return
scores
def
clone
(
self
)
->
XGrammarLogitsProcessor
:
"""Deepcopy due to per-sequence state in the matchers"""
return
copy
.
deepcopy
(
self
)
vllm/sampling_params.py
View file @
74fa1d12
...
...
@@ -450,15 +450,16 @@ class SamplingParams(
return
self
.
_all_stop_token_ids
def
clone
(
self
)
->
"SamplingParams"
:
"""Deep copy
excluding
LogitsProcessor objects.
"""Deep copy
, but maybe not the
LogitsProcessor objects.
LogitsProcessor objects are excluded because they may contain an
arbitrary, nontrivial amount of data.
LogitsProcessor objects may contain an arbitrary, nontrivial amount of
data that is expensive to copy. However, if not copied, the processor
needs to support parallel decoding for multiple sequences
See https://github.com/vllm-project/vllm/issues/3087
"""
logit_processor_refs
=
None
if
self
.
logits_processors
is
None
else
{
id
(
lp
):
lp
id
(
lp
):
lp
.
clone
()
if
hasattr
(
lp
,
'clone'
)
else
lp
for
lp
in
self
.
logits_processors
}
return
copy
.
deepcopy
(
self
,
memo
=
logit_processor_refs
)
...
...
vllm/sequence.py
View file @
74fa1d12
...
...
@@ -1372,7 +1372,7 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
@
staticmethod
def
add_request
(
request_id
:
str
,
engine
,
params
,
**
kwargs
):
original_params
=
params
params
=
copy
.
deepcopy
(
original_params
)
params
=
original_params
.
clone
(
)
params
.
n
=
1
group
=
ParallelSampleSequenceGroup
(
request_id
)
seqs
=
[]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment