Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cb391d85
Unverified
Commit
cb391d85
authored
Apr 09, 2025
by
Joe Runde
Committed by
GitHub
Apr 09, 2025
Browse files
[Hardware] add platform-specific request validation api (#16291)
Signed-off-by:
Joe Runde
<
Joseph.Runde@ibm.com
>
parent
fee5b8d3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
38 additions
and
41 deletions
+38
-41
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+0
-4
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+0
-4
vllm/platforms/hpu.py
vllm/platforms/hpu.py
+0
-4
vllm/platforms/interface.py
vllm/platforms/interface.py
+15
-8
vllm/platforms/neuron.py
vllm/platforms/neuron.py
+0
-4
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+0
-4
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+18
-4
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+0
-4
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+5
-5
No files found.
vllm/platforms/cpu.py
View file @
cb391d85
...
...
@@ -180,7 +180,3 @@ class CpuPlatform(Platform):
Get device specific communicator class for distributed communication.
"""
return
"vllm.distributed.device_communicators.cpu_communicator.CpuCommunicator"
# noqa
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
return
True
vllm/platforms/cuda.py
View file @
cb391d85
...
...
@@ -308,10 +308,6 @@ class CudaPlatformBase(Platform):
def
supports_v1
(
cls
,
model_config
:
ModelConfig
)
->
bool
:
return
True
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
return
True
@
classmethod
def
use_custom_allreduce
(
cls
)
->
bool
:
return
True
...
...
vllm/platforms/hpu.py
View file @
cb391d85
...
...
@@ -92,7 +92,3 @@ class HpuPlatform(Platform):
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.hpu_communicator.HpuCommunicator"
# noqa
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
return
True
vllm/platforms/interface.py
View file @
cb391d85
# SPDX-License-Identifier: Apache-2.0
import
enum
import
platform
import
random
...
...
@@ -9,14 +8,21 @@ from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
import
numpy
as
np
import
torch
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
FlexibleArgumentParser
else
:
ModelConfig
=
None
VllmConfig
=
None
LoRARequest
=
None
PoolingParams
=
None
SamplingParams
=
None
FlexibleArgumentParser
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -379,13 +385,6 @@ class Platform:
"""
return
False
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
"""
Returns whether the current platform can support structured output.
"""
return
False
@
classmethod
def
use_custom_allreduce
(
cls
)
->
bool
:
"""
...
...
@@ -393,6 +392,14 @@ class Platform:
"""
return
False
@
classmethod
def
validate_request
(
cls
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
None
:
"""Raises if this request is unsupported on this platform"""
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/neuron.py
View file @
cb391d85
...
...
@@ -67,7 +67,3 @@ class NeuronPlatform(Platform):
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
return
True
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
return
True
vllm/platforms/rocm.py
View file @
cb391d85
...
...
@@ -303,10 +303,6 @@ class RocmPlatform(Platform):
# V1 support on AMD gpus is experimental
return
True
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
return
True
@
classmethod
def
use_custom_allreduce
(
cls
)
->
bool
:
# We only enable custom allreduce for MI300 series
...
...
vllm/platforms/tpu.py
View file @
cb391d85
# SPDX-License-Identifier: Apache-2.0
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
,
Union
import
torch
import
vllm.envs
as
envs
from
vllm.inputs
import
PromptType
from
vllm.logger
import
init_logger
from
.interface
import
Platform
,
PlatformEnum
,
_Backend
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
,
VllmConfig
from
vllm.lora.request
import
LoRARequest
from
vllm.pooling_params
import
PoolingParams
from
vllm.sampling_params
import
SamplingParams
else
:
ModelConfig
=
None
VllmConfig
=
None
LoRARequest
=
None
PoolingParams
=
None
SamplingParams
=
None
logger
=
init_logger
(
__name__
)
...
...
@@ -135,6 +142,13 @@ class TpuPlatform(Platform):
return
True
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
# Structured output is not supported on TPU.
return
False
def
validate_request
(
cls
,
prompt
:
PromptType
,
params
:
Union
[
SamplingParams
,
PoolingParams
],
)
->
None
:
"""Raises if this request is unsupported on this platform"""
if
isinstance
(
params
,
SamplingParams
)
and
params
.
guided_decoding
is
not
None
:
raise
ValueError
(
"Structured output is not supported on "
f
"
{
cls
.
device_name
}
."
)
vllm/platforms/xpu.py
View file @
cb391d85
...
...
@@ -140,7 +140,3 @@ class XPUPlatform(Platform):
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.xpu_communicator.XpuCommunicator"
# noqa
@
classmethod
def
supports_structured_output
(
cls
)
->
bool
:
return
True
vllm/v1/engine/processor.py
View file @
cb391d85
...
...
@@ -141,11 +141,6 @@ class Processor:
else
:
params
.
guided_decoding
.
backend
=
engine_level_backend
from
vllm.platforms
import
current_platform
if
not
current_platform
.
supports_structured_output
():
raise
ValueError
(
"Structured output is not supported on "
f
"
{
current_platform
.
device_name
}
."
)
# Request content validation
if
engine_level_backend
.
startswith
(
"xgrammar"
):
# xgrammar with no fallback
...
...
@@ -187,6 +182,11 @@ class Processor:
# TODO(woosuk): Support pooling models.
# TODO(woosuk): Support encoder-decoder models.
from
vllm.platforms
import
current_platform
current_platform
.
validate_request
(
prompt
=
prompt
,
params
=
params
,
)
self
.
_validate_lora
(
lora_request
)
self
.
_validate_params
(
params
)
if
priority
!=
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment