Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
388ee3de
Unverified
Commit
388ee3de
authored
Nov 20, 2024
by
youkaichao
Committed by
GitHub
Nov 20, 2024
Browse files
[torch.compile] limit inductor threads and lazy import quant (#10482)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
2f77b6cf
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
178 additions
and
64 deletions
+178
-64
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
tests/quantization/utils.py
tests/quantization/utils.py
+2
-2
tests/test_lazy_torch_compile.py
tests/test_lazy_torch_compile.py
+68
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-3
vllm/config.py
vllm/config.py
+5
-3
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+73
-51
vllm/model_executor/models/internvl.py
vllm/model_executor/models/internvl.py
+2
-2
vllm/model_executor/models/qwen2_vl.py
vllm/model_executor/models/qwen2_vl.py
+4
-3
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+2
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+11
-0
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+9
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
388ee3de
...
@@ -50,7 +50,9 @@ steps:
...
@@ -50,7 +50,9 @@ steps:
-
tests/multimodal
-
tests/multimodal
-
tests/test_utils
-
tests/test_utils
-
tests/worker
-
tests/worker
-
tests/test_lazy_torch_compile.py
commands
:
commands
:
-
python3 test_lazy_torch_compile.py
-
pytest -v -s mq_llm_engine
# MQLLMEngine
-
pytest -v -s mq_llm_engine
# MQLLMEngine
-
pytest -v -s async_engine
# AsyncLLMEngine
-
pytest -v -s async_engine
# AsyncLLMEngine
-
NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
-
NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py
...
...
tests/quantization/utils.py
View file @
388ee3de
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
get_quantization_config
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool:
...
@@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool:
capability
=
current_platform
.
get_device_capability
()
capability
=
current_platform
.
get_device_capability
()
assert
capability
is
not
None
assert
capability
is
not
None
min_capability
=
QUANTIZATION_METHODS
[
quant_method
]
.
get_min_capability
()
min_capability
=
get_quantization_config
(
quant_method
)
.
get_min_capability
()
return
capability
.
to_int
()
>=
min_capability
return
capability
.
to_int
()
>=
min_capability
tests/test_lazy_torch_compile.py
0 → 100644
View file @
388ee3de
# Description: Test the lazy import module
# The utility function cannot be placed in `vllm.utils`
# this needs to be a standalone script
import
contextlib
import
dataclasses
import
sys
import
traceback
from
typing
import
Callable
,
Generator
@
dataclasses
.
dataclass
class
BlameResult
:
found
:
bool
=
False
trace_stack
:
str
=
""
@
contextlib
.
contextmanager
def
blame
(
func
:
Callable
)
->
Generator
[
BlameResult
,
None
,
None
]:
"""
Trace the function calls to find the first function that satisfies the
condition. The trace stack will be stored in the result.
Usage:
```python
with blame(lambda: some_condition()) as result:
# do something
if result.found:
print(result.trace_stack)
"""
result
=
BlameResult
()
def
_trace_calls
(
frame
,
event
,
arg
=
None
):
nonlocal
result
if
event
in
[
'call'
,
'return'
]:
# for every function call or return
try
:
# Temporarily disable the trace function
sys
.
settrace
(
None
)
# check condition here
if
not
result
.
found
and
func
():
result
.
found
=
True
result
.
trace_stack
=
""
.
join
(
traceback
.
format_stack
())
# Re-enable the trace function
sys
.
settrace
(
_trace_calls
)
except
NameError
:
# modules are deleted during shutdown
pass
return
_trace_calls
sys
.
settrace
(
_trace_calls
)
yield
result
sys
.
settrace
(
None
)
module_name
=
"torch._inductor.async_compile"
with
blame
(
lambda
:
module_name
in
sys
.
modules
)
as
result
:
import
vllm
# noqa
assert
not
result
.
found
,
(
f
"Module
{
module_name
}
is already imported, the"
f
" first import location is:
\n
{
result
.
trace_stack
}
"
)
print
(
f
"Module
{
module_name
}
is not imported yet"
)
vllm/_custom_ops.py
View file @
388ee3de
...
@@ -19,9 +19,6 @@ if not current_platform.is_tpu() and not current_platform.is_hpu():
...
@@ -19,9 +19,6 @@ if not current_platform.is_tpu() and not current_platform.is_hpu():
except
ImportError
as
e
:
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
if
current_platform
.
is_rocm
():
import
vllm._rocm_C
# noqa: F401
supports_moe_ops
=
False
supports_moe_ops
=
False
with
contextlib
.
suppress
(
ImportError
):
with
contextlib
.
suppress
(
ImportError
):
import
vllm._moe_C
# noqa: F401
import
vllm._moe_C
# noqa: F401
...
...
vllm/config.py
View file @
388ee3de
...
@@ -14,7 +14,8 @@ from transformers import PretrainedConfig
...
@@ -14,7 +14,8 @@ from transformers import PretrainedConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
(
QUANTIZATION_METHODS
,
get_quantization_config
)
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.tracing
import
is_otel_available
,
otel_import_error_traceback
from
vllm.tracing
import
is_otel_available
,
otel_import_error_traceback
...
@@ -370,7 +371,7 @@ class ModelConfig:
...
@@ -370,7 +371,7 @@ class ModelConfig:
return
quant_cfg
return
quant_cfg
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
supported_quantization
=
QUANTIZATION_METHODS
rocm_supported_quantization
=
[
rocm_supported_quantization
=
[
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"awq"
,
"gptq"
,
"fp8"
,
"compressed_tensors"
,
"compressed-tensors"
,
"fbgemm_fp8"
"fbgemm_fp8"
...
@@ -392,7 +393,8 @@ class ModelConfig:
...
@@ -392,7 +393,8 @@ class ModelConfig:
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
quant_method
=
quant_cfg
.
get
(
"quant_method"
,
""
).
lower
()
# Detect which checkpoint is it
# Detect which checkpoint is it
for
_
,
method
in
QUANTIZATION_METHODS
.
items
():
for
name
in
QUANTIZATION_METHODS
:
method
=
get_quantization_config
(
name
)
quantization_override
=
method
.
override_quantization_method
(
quantization_override
=
method
.
override_quantization_method
(
quant_cfg
,
self
.
quantization
)
quant_cfg
,
self
.
quantization
)
if
quantization_override
:
if
quantization_override
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
388ee3de
from
typing
import
Dict
,
Type
from
typing
import
Dict
,
List
,
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.bitsandbytes
import
(
BitsAndBytesConfig
)
QUANTIZATION_METHODS
:
List
[
str
]
=
[
from
vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors
import
(
# noqa: E501
"aqlm"
,
"awq"
,
"deepspeedfp"
,
"tpu_int8"
,
"fp8"
,
"fbgemm_fp8"
,
"modelopt"
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
,
"gguf"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"gptq"
,
"compressed-tensors"
,
"bitsandbytes"
,
"qqq"
,
"hqq"
,
"experts_int8"
,
"neuron_quant"
,
"ipex"
,
]
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
if
quantization
not
in
QUANTIZATION_METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
# lazy import to avoid triggering `torch.compile` too early
from
.aqlm
import
AQLMConfig
from
.awq
import
AWQConfig
from
.awq_marlin
import
AWQMarlinConfig
from
.bitsandbytes
import
BitsAndBytesConfig
from
.compressed_tensors.compressed_tensors
import
(
# noqa: E501
CompressedTensorsConfig
)
CompressedTensorsConfig
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
from
.deepspeedfp
import
DeepSpeedFPConfig
DeepSpeedFPConfig
)
from
.experts_int8
import
ExpertsInt8Config
from
vllm.model_executor.layers.quantization.experts_int8
import
(
from
.fbgemm_fp8
import
FBGEMMFp8Config
ExpertsInt8Config
)
from
.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
.gguf
import
GGUFConfig
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gguf
import
GGUFConfig
from
.gptq_marlin
import
GPTQMarlinConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
.gptq_marlin_24
import
GPTQMarlin24Config
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
.hqq_marlin
import
HQQMarlinConfig
GPTQMarlinConfig
)
from
.ipex_quant
import
IPEXConfig
from
vllm.model_executor.layers.quantization.gptq_marlin_24
import
(
from
.marlin
import
MarlinConfig
GPTQMarlin24Config
)
from
.modelopt
import
ModelOptFp8Config
from
vllm.model_executor.layers.quantization.hqq_marlin
import
HQQMarlinConfig
from
.neuron_quant
import
NeuronQuantConfig
from
vllm.model_executor.layers.quantization.ipex_quant
import
IPEXConfig
from
.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
.tpu_int8
import
Int8TpuConfig
from
vllm.model_executor.layers.quantization.modelopt
import
ModelOptFp8Config
from
vllm.model_executor.layers.quantization.neuron_quant
import
(
NeuronQuantConfig
)
from
vllm.model_executor.layers.quantization.qqq
import
QQQConfig
from
vllm.model_executor.layers.quantization.tpu_int8
import
Int8TpuConfig
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
method_to_config
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
...
@@ -53,13 +79,9 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -53,13 +79,9 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"experts_int8"
:
ExpertsInt8Config
,
"experts_int8"
:
ExpertsInt8Config
,
"neuron_quant"
:
NeuronQuantConfig
,
"neuron_quant"
:
NeuronQuantConfig
,
"ipex"
:
IPEXConfig
,
"ipex"
:
IPEXConfig
,
}
}
def
get_quantization_config
(
quantization
:
str
)
->
Type
[
QuantizationConfig
]:
return
method_to_config
[
quantization
]
if
quantization
not
in
QUANTIZATION_METHODS
:
raise
ValueError
(
f
"Invalid quantization method:
{
quantization
}
"
)
return
QUANTIZATION_METHODS
[
quantization
]
__all__
=
[
__all__
=
[
...
...
vllm/model_executor/models/internvl.py
View file @
388ee3de
...
@@ -19,8 +19,8 @@ from vllm.attention import AttentionMetadata
...
@@ -19,8 +19,8 @@ from vllm.attention import AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
DecoderOnlyInputs
,
DummyData
,
InputContext
,
token_inputs
)
InputContext
,
token_inputs
)
from
vllm.model_executor.layers.quantization
import
(
AWQ
Config
,
from
vllm.model_executor.layers.quantization
import
Quantization
Config
Quantization
Config
)
from
vllm.model_executor.layers.quantization.awq
import
AWQ
Config
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.models.intern_vit
import
(
InternVisionModel
,
from
vllm.model_executor.models.intern_vit
import
(
InternVisionModel
,
InternVisionPatchModel
)
InternVisionPatchModel
)
...
...
vllm/model_executor/models/qwen2_vl.py
View file @
388ee3de
...
@@ -51,9 +51,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -51,9 +51,10 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.pooler
import
Pooler
,
PoolingType
from
vllm.model_executor.layers.quantization
import
(
GPTQConfig
,
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
GPTQMarlinConfig
,
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.sampler
import
SamplerOutput
,
get_sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
...
...
vllm/platforms/cuda.py
View file @
388ee3de
...
@@ -10,6 +10,8 @@ import pynvml
...
@@ -10,6 +10,8 @@ import pynvml
import
torch
import
torch
from
typing_extensions
import
ParamSpec
from
typing_extensions
import
ParamSpec
# import custom ops, trigger op registration
import
vllm._C
# noqa
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
from
.interface
import
DeviceCapability
,
Platform
,
PlatformEnum
...
...
vllm/platforms/rocm.py
View file @
388ee3de
...
@@ -9,6 +9,17 @@ from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
...
@@ -9,6 +9,17 @@ from .interface import DeviceCapability, Platform, PlatformEnum, _Backend
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
try
:
import
vllm._C
# noqa: F401
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._C with %r"
,
e
)
# import custom ops, trigger op registration
try
:
import
vllm._rocm_C
# noqa: F401
except
ImportError
as
e
:
logger
.
warning
(
"Failed to import from vllm._rocm_C with %r"
,
e
)
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
None
)
in
[
"fork"
,
None
]:
if
os
.
environ
.
get
(
"VLLM_WORKER_MULTIPROC_METHOD"
,
None
)
in
[
"fork"
,
None
]:
logger
.
warning
(
"`fork` method is not supported by ROCm. "
logger
.
warning
(
"`fork` method is not supported by ROCm. "
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
"VLLM_WORKER_MULTIPROC_METHOD is overridden to"
...
...
vllm/plugins/__init__.py
View file @
388ee3de
import
logging
import
logging
import
os
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Optional
from
typing
import
TYPE_CHECKING
,
Optional
...
@@ -18,6 +19,14 @@ def load_general_plugins():
...
@@ -18,6 +19,14 @@ def load_general_plugins():
processes. They should be designed in a way that they can be loaded
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
multiple times without causing issues.
"""
"""
# all processes created by vllm will load plugins,
# and here we can inject some common environment variables
# for all processes.
# see https://github.com/vllm-project/vllm/issues/10480
os
.
environ
[
'TORCHINDUCTOR_COMPILE_THREADS'
]
=
'1'
global
plugins_loaded
global
plugins_loaded
if
plugins_loaded
:
if
plugins_loaded
:
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment