Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
661175bc
Unverified
Commit
661175bc
authored
Nov 29, 2024
by
wangxiyuan
Committed by
GitHub
Nov 29, 2024
Browse files
[platform] Add verify_quantization in platform. (#10757)
Signed-off-by:
wangxiyuan
<
wangxiyuan1007@gmail.com
>
parent
3132aac0
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
38 additions
and
27 deletions
+38
-27
vllm/config.py
vllm/config.py
+1
-27
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+1
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+1
-0
vllm/platforms/hpu.py
vllm/platforms/hpu.py
+1
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+13
-0
vllm/platforms/neuron.py
vllm/platforms/neuron.py
+2
-0
vllm/platforms/openvino.py
vllm/platforms/openvino.py
+1
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+15
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+2
-0
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+1
-0
No files found.
vllm/config.py
View file @
661175bc
...
...
@@ -393,17 +393,11 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
QUANTIZATION_METHODS
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"modelopt"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"experts_int8"
]
tpu_supported_quantization
=
[
"tpu_int8"
]
neuron_supported_quantization
=
[
"neuron_quant"
]
if
self
.
quantization
is
not
None
:
self
.
quantization
=
self
.
quantization
.
lower
()
...
...
@@ -438,32 +432,12 @@ class ModelConfig:
raise
ValueError
(
f
"Unknown quantization method:
{
self
.
quantization
}
. Must "
f
"be one of
{
supported_quantization
}
."
)
if
current_platform
.
is_rocm
(
)
and
self
.
quantization
not
in
rocm_supported_quantization
:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in ROCm."
)
if
current_platform
.
is_tpu
(
)
and
self
.
quantization
not
in
tpu_supported_quantization
:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in TPU Backend."
)
current_platform
.
verify_quantization
(
self
.
quantization
)
if
self
.
quantization
not
in
optimized_quantization_methods
:
logger
.
warning
(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models."
,
self
.
quantization
)
if
(
self
.
quantization
==
"awq"
and
current_platform
.
is_rocm
()
and
not
envs
.
VLLM_USE_TRITON_AWQ
):
logger
.
warning
(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ."
)
envs
.
VLLM_USE_TRITON_AWQ
=
True
if
current_platform
.
is_neuron
(
)
and
self
.
quantization
not
in
neuron_supported_quantization
:
raise
ValueError
(
f
"
{
self
.
quantization
}
quantization is currently not "
f
"supported in Neuron Backend."
)
def
_verify_cuda_graph
(
self
)
->
None
:
if
self
.
max_seq_len_to_capture
is
None
:
...
...
vllm/platforms/cpu.py
View file @
661175bc
...
...
@@ -19,6 +19,7 @@ logger = init_logger(__name__)
class
CpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
CPU
device_name
:
str
=
"cpu"
device_type
:
str
=
"cpu"
dispatch_key
:
str
=
"CPU"
...
...
vllm/platforms/cuda.py
View file @
661175bc
...
...
@@ -72,6 +72,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
class
CudaPlatformBase
(
Platform
):
_enum
=
PlatformEnum
.
CUDA
device_name
:
str
=
"cuda"
device_type
:
str
=
"cuda"
dispatch_key
:
str
=
"CUDA"
...
...
vllm/platforms/hpu.py
View file @
661175bc
...
...
@@ -12,6 +12,7 @@ else:
class
HpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
HPU
device_name
:
str
=
"hpu"
device_type
:
str
=
"hpu"
dispatch_key
:
str
=
"HPU"
...
...
vllm/platforms/interface.py
View file @
661175bc
...
...
@@ -56,11 +56,13 @@ class DeviceCapability(NamedTuple):
class
Platform
:
_enum
:
PlatformEnum
device_name
:
str
device_type
:
str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key
:
str
=
"CPU"
supported_quantization
:
list
[
str
]
=
[]
def
is_cuda
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
CUDA
...
...
@@ -171,6 +173,17 @@ class Platform:
"""
pass
@
classmethod
def
verify_quantization
(
cls
,
quant
:
str
)
->
None
:
"""
Verify whether the quantization is supported by the current platform.
"""
if
cls
.
supported_quantization
and
\
quant
not
in
cls
.
supported_quantization
:
raise
ValueError
(
f
"
{
quant
}
quantization is currently not supported in "
f
"
{
cls
.
device_name
}
."
)
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/neuron.py
View file @
661175bc
...
...
@@ -10,7 +10,9 @@ else:
class
NeuronPlatform
(
Platform
):
_enum
=
PlatformEnum
.
NEURON
device_name
:
str
=
"neuron"
device_type
:
str
=
"neuron"
supported_quantization
:
list
[
str
]
=
[
"neuron_quant"
]
@
classmethod
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
...
...
vllm/platforms/openvino.py
View file @
661175bc
...
...
@@ -23,6 +23,7 @@ except ImportError as e:
class
OpenVinoPlatform
(
Platform
):
_enum
=
PlatformEnum
.
OPENVINO
device_name
:
str
=
"openvino"
device_type
:
str
=
"openvino"
dispatch_key
:
str
=
"CPU"
...
...
vllm/platforms/rocm.py
View file @
661175bc
...
...
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
,
_Backend
...
...
@@ -35,8 +36,13 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
device_name
:
str
=
"rocm"
device_type
:
str
=
"cuda"
dispatch_key
:
str
=
"CUDA"
supported_quantization
:
list
[
str
]
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
,
"gguf"
]
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
...
...
@@ -79,3 +85,12 @@ class RocmPlatform(Platform):
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
else
:
parallel_config
.
worker_cls
=
"vllm.worker.worker.Worker"
@
classmethod
def
verify_quantization
(
cls
,
quant
:
str
)
->
None
:
super
().
verify_quantization
(
quant
)
if
quant
==
"awq"
and
not
envs
.
VLLM_USE_TRITON_AWQ
:
logger
.
warning
(
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
" is not set, enabling VLLM_USE_TRITON_AWQ."
)
envs
.
VLLM_USE_TRITON_AWQ
=
True
vllm/platforms/tpu.py
View file @
661175bc
...
...
@@ -16,8 +16,10 @@ logger = init_logger(__name__)
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
device_name
:
str
=
"tpu"
device_type
:
str
=
"tpu"
dispatch_key
:
str
=
"XLA"
supported_quantization
:
list
[
str
]
=
[
"tpu_int8"
]
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
...
...
vllm/platforms/xpu.py
View file @
661175bc
...
...
@@ -16,6 +16,7 @@ logger = init_logger(__name__)
class
XPUPlatform
(
Platform
):
_enum
=
PlatformEnum
.
XPU
device_name
:
str
=
"xpu"
device_type
:
str
=
"xpu"
dispatch_key
:
str
=
"XPU"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment