Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500b93c8
Commit
500b93c8
authored
Jul 25, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1
parents
99426767
38c4b7e8
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
770 additions
and
231 deletions
+770
-231
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+2
-0
vllm/executor/openvino_executor.py
vllm/executor/openvino_executor.py
+2
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+168
-99
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+8
-6
vllm/executor/ray_xpu_executor.py
vllm/executor/ray_xpu_executor.py
+33
-47
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+2
-0
vllm/executor/xpu_executor.py
vllm/executor/xpu_executor.py
+2
-0
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+3
-4
vllm/inputs/data.py
vllm/inputs/data.py
+1
-23
vllm/lora/request.py
vllm/lora/request.py
+39
-3
vllm/lora/utils.py
vllm/lora/utils.py
+47
-0
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+4
-3
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+55
-14
vllm/model_executor/layers/fused_moe/moe_pallas.py
vllm/model_executor/layers/fused_moe/moe_pallas.py
+62
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+63
-26
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+4
-0
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+2
-2
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+2
-2
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+268
-0
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+3
-2
No files found.
vllm/executor/neuron_executor.py
View file @
500b93c8
...
@@ -11,6 +11,8 @@ logger = init_logger(__name__)
...
@@ -11,6 +11,8 @@ logger = init_logger(__name__)
class
NeuronExecutor
(
ExecutorBase
):
class
NeuronExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
(
self
.
lora_config
is
assert
(
self
.
lora_config
is
None
),
"LoRA is not supported for Neuron backend."
None
),
"LoRA is not supported for Neuron backend."
...
...
vllm/executor/openvino_executor.py
View file @
500b93c8
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
class
OpenVINOExecutor
(
ExecutorBase
):
class
OpenVINOExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
self
.
device_config
.
device_type
==
"openvino"
assert
self
.
device_config
.
device_type
==
"openvino"
assert
self
.
lora_config
is
None
,
"OpenVINO backend doesn't support LoRA"
assert
self
.
lora_config
is
None
,
"OpenVINO backend doesn't support LoRA"
...
...
vllm/executor/ray_gpu_executor.py
View file @
500b93c8
import
asyncio
import
asyncio
import
os
import
os
import
pickle
from
collections
import
defaultdict
from
collections
import
defaultdict
from
itertools
import
islice
,
repeat
from
itertools
import
islice
,
repeat
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
...
@@ -11,7 +10,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
...
@@ -11,7 +10,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
error_on_invalid_device_count_status
,
from
vllm.utils
import
(
_run_task_with_lock
,
error_on_invalid_device_count_status
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
get_vllm_instance_id
,
make_async
)
...
@@ -23,13 +23,33 @@ if TYPE_CHECKING:
...
@@ -23,13 +23,33 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
USE_RAY_COMPILED_DAG
=
envs
.
VLLM_USE_RAY_COMPILED_DAG
class
RayGPUExecutor
(
DistributedGPUExecutor
):
class
RayGPUExecutor
(
DistributedGPUExecutor
):
uses_ray
:
bool
=
True
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
self
.
parallel_config
.
distributed_executor_backend
==
"ray"
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self
.
use_ray_compiled_dag
=
envs
.
VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self
.
use_ray_spmd_worker
=
envs
.
VLLM_USE_RAY_SPMD_WORKER
if
self
.
use_ray_compiled_dag
:
assert
self
.
use_ray_spmd_worker
,
(
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1"
)
if
self
.
use_ray_spmd_worker
:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert
self
.
use_ray_compiled_dag
,
(
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1"
)
assert
self
.
uses_ray
placement_group
=
self
.
parallel_config
.
placement_group
placement_group
=
self
.
parallel_config
.
placement_group
# Disable Ray usage stats collection.
# Disable Ray usage stats collection.
...
@@ -40,11 +60,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -40,11 +60,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
self
.
_init_workers_ray
(
placement_group
)
self
.
forward_dag
=
None
self
.
forward_dag
:
Optional
[
"ray.dag.CompiledDAG"
]
=
None
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
self
.
extra_execute_model_run_workers_kwargs
[
"use_ray_compiled_dag"
]
=
True
def
_configure_ray_workers_use_nsight
(
self
,
def
_configure_ray_workers_use_nsight
(
self
,
ray_remote_kwargs
)
->
Dict
[
str
,
Any
]:
ray_remote_kwargs
)
->
Dict
[
str
,
Any
]:
...
@@ -61,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -61,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
return
ray_remote_kwargs
return
ray_remote_kwargs
def
_get_worker_wrapper_args
(
self
)
->
Dict
[
str
,
Any
]:
if
self
.
speculative_config
is
not
None
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
else
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
return
dict
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
**
ray_remote_kwargs
):
if
(
self
.
parallel_config
.
tensor_parallel_size
==
1
if
(
self
.
parallel_config
.
tensor_parallel_size
==
1
...
@@ -83,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -83,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the workers.
# Create the workers.
driver_ip
=
get_ip
()
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
continue
...
@@ -92,39 +123,28 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -92,39 +123,28 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_bundle_index
=
bundle_id
,
placement_group_bundle_index
=
bundle_id
,
)
)
if
self
.
speculative_config
is
not
None
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
else
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
worker
=
ray
.
remote
(
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_cpus
=
0
,
num_gpus
=
num_gpus
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
)(
RayWorkerWrapper
).
remote
(
**
worker_wrapper_kwargs
)
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
self
.
use_ray_spmd_worker
:
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
else
:
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
else
:
if
self
.
driver_dummy_worker
is
None
:
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
**
worker_wrapper_kwargs
)
else
:
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
if
not
self
.
use_ray_spmd_worker
and
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
raise
ValueError
(
"Ray does not allocate any GPUs on the driver node. Consider "
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"adjusting the Ray placement group or running the driver on a "
...
@@ -224,13 +244,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -224,13 +244,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
# broadcasted to.
# broadcasted to.
self
.
non_driver_workers
:
List
[
RayWorkerWrapper
]
=
[]
self
.
non_driver_workers
:
List
[
RayWorkerWrapper
]
=
[]
for
idx
,
rank
in
enumerate
(
worker_ranks
[
1
:]):
# Enforce rank order for correct rank to return final output.
for
rank
,
worker
in
sorted
(
zip
(
worker_ranks
[
1
:],
self
.
workers
)):
# We need to skip the driver worker, which we
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
# do by skipping worker_ranks[0] which is always 0.
if
rank
%
self
.
parallel_config
.
tensor_parallel_size
==
0
:
if
rank
%
self
.
parallel_config
.
tensor_parallel_size
==
0
:
self
.
tp_driver_workers
.
append
(
self
.
worker
s
[
idx
]
)
self
.
tp_driver_workers
.
append
(
worker
)
else
:
else
:
self
.
non_driver_workers
.
append
(
self
.
worker
s
[
idx
]
)
self
.
non_driver_workers
.
append
(
worker
)
def
_driver_execute_model
(
def
_driver_execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
...
@@ -240,9 +261,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -240,9 +261,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
Passing None will cause the driver to stop the model execution
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
loop running in each of the remote workers.
"""
"""
assert
not
self
.
use_ray_spmd_worker
,
(
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
)
return
self
.
driver_worker
.
execute_method
(
"execute_model"
,
return
self
.
driver_worker
.
execute_method
(
"execute_model"
,
execute_model_req
)
execute_model_req
)
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
if
not
self
.
use_ray_spmd_worker
:
return
super
().
execute_model
(
execute_model_req
)
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
outputs
=
ray
.
get
(
self
.
forward_dag
.
execute
(
execute_model_req
))
return
outputs
[
0
]
def
_run_workers
(
def
_run_workers
(
self
,
self
,
method
:
str
,
method
:
str
,
...
@@ -252,7 +287,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -252,7 +287,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
use_dummy_driver
:
bool
=
False
,
use_dummy_driver
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Any
:
)
->
Any
:
"""Runs the given method on all workers. Can be used in the following
"""Runs the given method on all workers. Can be used in the following
...
@@ -267,6 +301,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -267,6 +301,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
- all_args/all_kwargs: args/kwargs for each worker are specified
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
individually
"""
"""
if
self
.
use_ray_spmd_worker
:
assert
not
async_run_tensor_parallel_workers_only
,
(
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode."
)
if
max_concurrent_workers
:
if
max_concurrent_workers
:
raise
NotImplementedError
(
raise
NotImplementedError
(
...
@@ -275,99 +313,125 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -275,99 +313,125 @@ class RayGPUExecutor(DistributedGPUExecutor):
count
=
len
(
self
.
workers
)
if
not
\
count
=
len
(
self
.
workers
)
if
not
\
async_run_tensor_parallel_workers_only
\
async_run_tensor_parallel_workers_only
\
else
len
(
self
.
non_driver_workers
)
else
len
(
self
.
non_driver_workers
)
# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index
:
int
=
0
if
self
.
use_ray_spmd_worker
else
1
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
else
islice
(
all_args
,
1
,
None
)
else
islice
(
all_args
,
first_worker_args_index
,
None
)
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
else
islice
(
all_kwargs
,
1
,
None
)
else
islice
(
all_kwargs
,
first_worker_args_index
,
None
)
if
use_ray_compiled_dag
:
# Start the ray workers first.
# Right now, compiled DAG can only accept a single
ray_workers
=
self
.
workers
# input. TODO(sang): Fix it.
if
async_run_tensor_parallel_workers_only
:
assert
self
.
forward_dag
is
not
None
ray_workers
=
self
.
non_driver_workers
output_channels
=
self
.
forward_dag
.
execute
(
1
)
ray_worker_outputs
=
[
ray_worker_outputs
=
[]
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
else
:
for
(
worker
,
worker_args
,
worker_kwargs
# Start the ray workers first.
)
in
zip
(
ray_workers
,
all_worker_args
,
all_worker_kwargs
)
ray_workers
=
self
.
workers
]
if
async_run_tensor_parallel_workers_only
:
ray_workers
=
self
.
non_driver_workers
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
ray_workers
,
all_worker_args
,
all_worker_kwargs
)
]
if
async_run_tensor_parallel_workers_only
:
if
async_run_tensor_parallel_workers_only
:
# Just return futures
# Just return futures
return
ray_worker_outputs
return
ray_worker_outputs
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_worker_output
=
[]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if
not
self
.
use_ray_spmd_worker
:
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
driver_worker_output
=
[
self
.
driver_worker
.
execute_method
(
method
,
*
driver_args
,
**
driver_kwargs
)
]
else
:
assert
self
.
driver_dummy_worker
is
not
None
driver_worker_output
=
[
ray
.
get
(
self
.
driver_dummy_worker
.
execute_method
.
remote
(
method
,
*
driver_args
,
**
driver_kwargs
))
]
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
method
,
*
driver_args
,
**
driver_kwargs
)
else
:
assert
self
.
driver_dummy_worker
is
not
None
driver_worker_output
=
ray
.
get
(
self
.
driver_dummy_worker
.
execute_method
.
remote
(
method
,
*
driver_args
,
**
driver_kwargs
))
# Get the results of the ray workers.
# Get the results of the ray workers.
if
self
.
workers
:
if
self
.
workers
:
if
use_ray_compiled_dag
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
try
:
ray_worker_outputs
=
[
pickle
.
loads
(
chan
.
begin_read
())
for
chan
in
output_channels
]
finally
:
# Has to call end_read in order to reuse the DAG.
for
chan
in
output_channels
:
chan
.
end_read
()
else
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
return
driver_worker_output
+
ray_worker_outputs
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
ray
.
get
(
parallel_worker_tasks
)
def
_compiled_ray_dag
(
self
):
def
_compiled_ray_dag
(
self
,
enable_asyncio
:
bool
):
import
pkg_resources
import
pkg_resources
required_version
=
"2.9"
from
packaging
import
version
current_version
=
pkg_resources
.
get_distribution
(
"ray"
).
version
required_version
=
version
.
parse
(
"2.32"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
distributed_executor_backend
==
"
ray
"
assert
self
.
parallel_config
.
use_
ray
# Right now, compiled DAG requires at least 1 arg. We send
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
# a dummy value for now. It will be fixed soon.
with
InputNode
()
as
input_data
:
with
InputNode
()
as
input_data
:
forward_dag
=
MultiOutputNode
([
forward_dag
=
MultiOutputNode
([
worker
.
execute_model_compiled_dag_remote
.
worker
.
execute_model_spmd
.
bind
(
# type: ignore[attr-defined]
bind
(
# type: ignore[attr-defined]
input_data
)
for
worker
in
self
.
workers
input_data
)
for
worker
in
self
.
workers
])
])
return
forward_dag
.
experimental_compile
()
return
forward_dag
.
experimental_compile
(
enable_asyncio
=
enable_asyncio
)
def
__del__
(
self
):
if
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
import
ray
for
worker
in
self
.
workers
:
ray
.
kill
(
worker
)
class
RayGPUExecutorAsync
(
RayGPUExecutor
,
DistributedGPUExecutorAsync
):
class
RayGPUExecutorAsync
(
RayGPUExecutor
,
DistributedGPUExecutorAsync
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec_method
=
make_async
(
self
.
driver_worker
.
execute_method
)
self
.
pp_locks
:
Optional
[
List
[
asyncio
.
Lock
]]
=
None
self
.
use_ray_spmd_worker
=
envs
.
VLLM_USE_RAY_SPMD_WORKER
if
not
self
.
use_ray_compiled_dag
:
self
.
driver_exec_method
=
make_async
(
self
.
driver_worker
.
execute_method
)
async
def
execute_model_async
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
if
not
self
.
use_ray_spmd_worker
:
return
await
super
().
execute_model_async
(
execute_model_req
)
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
True
)
dag_future
=
await
self
.
forward_dag
.
execute_async
(
execute_model_req
)
outputs
=
await
dag_future
return
outputs
[
0
]
async
def
_driver_execute_model_async
(
async
def
_driver_execute_model_async
(
self
,
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
)
->
List
[
SamplerOutput
]:
assert
not
self
.
use_ray_spmd_worker
,
(
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
)
if
not
self
.
tp_driver_workers
:
return
await
self
.
driver_exec_method
(
"execute_model"
,
execute_model_req
)
if
self
.
pp_locks
is
None
:
if
self
.
pp_locks
is
None
:
# This locks each pipeline parallel stage so multiple virtual
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
# engines can't execute on the same stage at the same time
...
@@ -378,15 +442,11 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
...
@@ -378,15 +442,11 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
]
async
def
_run_task_with_lock
(
task
,
lock
,
*
args
,
**
kwargs
):
tasks
=
[
async
with
lock
:
return
await
task
(
*
args
,
**
kwargs
)
tasks
=
[]
tasks
.
append
(
asyncio
.
create_task
(
asyncio
.
create_task
(
_run_task_with_lock
(
self
.
driver_exec_method
,
self
.
pp_locks
[
0
],
_run_task_with_lock
(
self
.
driver_exec_method
,
self
.
pp_locks
[
0
],
"execute_model"
,
execute_model_req
)))
"execute_model"
,
execute_model_req
))
]
for
pp_rank
,
driver_worker
in
enumerate
(
self
.
tp_driver_workers
,
for
pp_rank
,
driver_worker
in
enumerate
(
self
.
tp_driver_workers
,
start
=
1
):
start
=
1
):
tasks
.
append
(
tasks
.
append
(
...
@@ -401,8 +461,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
...
@@ -401,8 +461,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return
results
[
-
1
]
return
results
[
-
1
]
async
def
_start_worker_execution_loop
(
self
):
async
def
_start_worker_execution_loop
(
self
):
assert
not
self
.
use_ray_spmd_worker
,
(
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1"
)
coros
=
[
coros
=
[
worker
.
execute_method
.
remote
(
"start_worker_execution_loop"
)
worker
.
execute_method
.
remote
(
"start_worker_execution_loop"
)
for
worker
in
self
.
non_driver_workers
for
worker
in
self
.
non_driver_workers
]
]
return
await
asyncio
.
gather
(
*
coros
)
return
await
asyncio
.
gather
(
*
coros
)
def
__del__
(
self
):
if
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
import
ray
for
worker
in
self
.
workers
:
ray
.
kill
(
worker
)
vllm/executor/ray_utils.py
View file @
500b93c8
import
pickle
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
get_ip
,
is_hip
,
is_xpu
from
vllm.utils
import
get_ip
,
is_hip
,
is_xpu
from
vllm.worker.worker_base
import
WorkerWrapperBase
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
@@ -31,16 +31,18 @@ try:
...
@@ -31,16 +31,18 @@ try:
gpu_ids
=
ray
.
get_gpu_ids
()
gpu_ids
=
ray
.
get_gpu_ids
()
return
node_id
,
gpu_ids
return
node_id
,
gpu_ids
def
execute_model_compiled_dag_remote
(
self
,
ignored
):
def
execute_model_spmd
(
self
,
execute_model_req
:
ExecuteModelRequest
):
"""Used only when compiled DAG is enabled."""
"""Used only when SPMD worker and compiled DAG are both
enabled."""
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
import
torch
import
torch
if
not
self
.
compiled_dag_cuda_device_set
:
if
not
self
.
compiled_dag_cuda_device_set
:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
self
.
compiled_dag_cuda_device_set
=
True
output
=
self
.
worker
.
execute_model
()
return
self
.
worker
.
_execute_model_spmd
(
execute_model_req
)
output
=
pickle
.
dumps
(
output
)
return
output
ray_import_err
=
None
ray_import_err
=
None
...
...
vllm/executor/ray_xpu_executor.py
View file @
500b93c8
import
asyncio
import
asyncio
import
os
import
os
import
pickle
from
collections
import
defaultdict
from
collections
import
defaultdict
from
itertools
import
islice
,
repeat
from
itertools
import
islice
,
repeat
from
typing
import
(
TYPE_CHECKING
,
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Set
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
Tuple
,
Union
)
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
...
@@ -30,11 +30,13 @@ logger = init_logger(__name__)
...
@@ -30,11 +30,13 @@ logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG
=
bool
(
os
.
get
env
(
"
VLLM_USE_RAY_COMPILED_DAG
"
,
0
))
USE_RAY_COMPILED_DAG
=
env
s
.
VLLM_USE_RAY_COMPILED_DAG
class
RayXPUExecutor
(
DistributedGPUExecutor
):
class
RayXPUExecutor
(
DistributedGPUExecutor
):
uses_ray
:
bool
=
True
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
@@ -72,10 +74,9 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -72,10 +74,9 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
self
.
_init_workers_ray
(
placement_group
)
# Profile the memory usage and initialize the cache.
self
.
forward_dag
=
None
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
# This is non-None when the execute model loop is running
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
...
@@ -108,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -108,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
return
num_gpu_blocks
,
num_cpu_blocks
return
num_gpu_blocks
,
num_cpu_blocks
def
_get_worker_wrapper_args
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
...
@@ -125,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -125,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the workers.
# Create the workers.
driver_ip
=
get_ip
()
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
continue
...
@@ -138,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -138,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
num_gpus
=
num_gpus
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
)(
RayWorkerWrapper
).
remote
(
**
worker_wrapper_kwargs
)
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
self
.
driver_worker
=
RayWorkerWrapper
(
**
worker_wrapper_kwargs
)
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
else
:
else
:
# Else, added to the list of workers.
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
self
.
workers
.
append
(
worker
)
...
@@ -270,7 +271,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -270,7 +271,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
use_dummy_driver
:
bool
=
False
,
use_dummy_driver
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Any
:
)
->
Any
:
"""Runs the given method on all workers. Can be used in the following
"""Runs the given method on all workers. Can be used in the following
...
@@ -293,26 +293,20 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -293,26 +293,20 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
else
islice
(
all_kwargs
,
1
,
None
)
else
islice
(
all_kwargs
,
1
,
None
)
if
use_ray_compiled_dag
:
# Start the ray workers first.
# Right now, compiled DAG can only accept a single
ray_worker_outputs
=
[
# input. TODO(sang): Fix it.
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
assert
self
.
forward_dag
is
not
None
for
(
worker
,
worker_args
,
worker_kwargs
output_channels
=
self
.
forward_dag
.
execute
(
1
)
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
else
:
]
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
]
if
async_run_remote_workers_only
:
if
async_run_remote_workers_only
:
# Just return futures
# Just return futures
return
ray_worker_outputs
return
ray_worker_outputs
driver_worker_output
=
[]
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
# Start the driver worker after all the ray workers.
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
if
not
use_dummy_driver
:
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
...
@@ -324,36 +318,28 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -324,36 +318,28 @@ class RayXPUExecutor(DistributedGPUExecutor):
method
,
*
driver_args
,
**
driver_kwargs
))
method
,
*
driver_args
,
**
driver_kwargs
))
# Get the results of the ray workers.
# Get the results of the ray workers.
if
self
.
workers
:
if
self
.
workers
:
if
use_ray_compiled_dag
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
try
:
ray_worker_outputs
=
[
pickle
.
loads
(
chan
.
begin_read
())
for
chan
in
output_channels
]
finally
:
# Has to call end_read in order to reuse the DAG.
for
chan
in
output_channels
:
chan
.
end_read
()
else
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
return
driver_worker_output
+
ray_worker_outputs
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
ray
.
get
(
parallel_worker_tasks
)
def
_compiled_ray_dag
(
self
):
def
_compiled_ray_dag
(
self
,
enable_asyncio
:
bool
):
import
pkg_resources
import
pkg_resources
required_version
=
"2.9"
from
packaging
import
version
current_version
=
pkg_resources
.
get_distribution
(
"ray"
).
version
required_version
=
version
.
parse
(
"2.32"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
worker_
use_ray
assert
self
.
parallel_config
.
use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
# a dummy value for now. It will be fixed soon.
...
@@ -363,7 +349,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
...
@@ -363,7 +349,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
bind
(
# type: ignore[attr-defined]
bind
(
# type: ignore[attr-defined]
input_data
)
for
worker
in
self
.
workers
input_data
)
for
worker
in
self
.
workers
])
])
return
forward_dag
.
experimental_compile
()
return
forward_dag
.
experimental_compile
(
enable_asyncio
=
enable_asyncio
)
def
check_health
(
self
)
->
None
:
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
"""Raises an error if engine is unhealthy."""
...
...
vllm/executor/tpu_executor.py
View file @
500b93c8
...
@@ -14,6 +14,8 @@ logger = init_logger(__name__)
...
@@ -14,6 +14,8 @@ logger = init_logger(__name__)
class
TPUExecutor
(
ExecutorBase
):
class
TPUExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
,
(
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
,
(
"Chunked prefill is not yet supported for TPU backend"
)
"Chunked prefill is not yet supported for TPU backend"
)
...
...
vllm/executor/xpu_executor.py
View file @
500b93c8
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
class
XPUExecutor
(
GPUExecutor
):
class
XPUExecutor
(
GPUExecutor
):
uses_ray
:
bool
=
False
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
...
...
vllm/inputs/__init__.py
View file @
500b93c8
from
.data
import
(
LLMInputs
,
ParsedText
,
ParsedTokens
,
PromptInputs
,
from
.data
import
(
LLMInputs
,
ParsedText
,
ParsedTokens
,
PromptInputs
,
PromptStrictInputs
,
TextPrompt
,
TextTokensPrompt
,
TextPrompt
,
TokensPrompt
,
parse_and_batch_prompt
)
TokensPrompt
,
parse_and_batch_prompt
)
from
.registry
import
InputContext
,
InputRegistry
from
.registry
import
InputContext
,
InputRegistry
INPUT_REGISTRY
=
InputRegistry
()
INPUT_REGISTRY
=
InputRegistry
()
...
@@ -14,6 +13,6 @@ See also:
...
@@ -14,6 +13,6 @@ See also:
__all__
=
[
__all__
=
[
"ParsedText"
,
"ParsedTokens"
,
"parse_and_batch_prompt"
,
"TextPrompt"
,
"ParsedText"
,
"ParsedTokens"
,
"parse_and_batch_prompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"
TextTokensPrompt"
,
"PromptStrictInputs"
,
"PromptInputs
"
,
"TokensPrompt"
,
"
PromptInputs"
,
"LLMInputs"
,
"INPUT_REGISTRY
"
,
"LLMInputs"
,
"INPUT_REGISTRY"
,
"InputContext"
,
"InputRegistry"
"InputContext"
,
"InputRegistry"
]
]
vllm/inputs/data.py
View file @
500b93c8
...
@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
...
@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
"""
"""
class
TextTokensPrompt
(
TypedDict
):
PromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt
:
str
"""The prompt text."""
prompt_token_ids
:
List
[
int
]
"""The token IDs of the prompt."""
multi_modal_data
:
NotRequired
[
"MultiModalDataDict"
]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
"""
The inputs to the LLM, which can take one of the following forms:
The inputs to the LLM, which can take one of the following forms:
...
@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
...
@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
- A tokenized prompt (:class:`TokensPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
"""
"""
PromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
,
TextTokensPrompt
]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class
LLMInputs
(
TypedDict
):
class
LLMInputs
(
TypedDict
):
"""
"""
...
...
vllm/lora/request.py
View file @
500b93c8
from
dataclasses
import
dataclass
import
warnings
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
typing
import
Optional
from
vllm.adapter_commons.request
import
AdapterRequest
from
vllm.adapter_commons.request
import
AdapterRequest
...
@@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest):
...
@@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest):
lora_name
:
str
lora_name
:
str
lora_int_id
:
int
lora_int_id
:
int
lora_local_path
:
str
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
field
(
default
=
None
,
repr
=
False
)
long_lora_max_len
:
Optional
[
int
]
=
None
long_lora_max_len
:
Optional
[
int
]
=
None
__hash__
=
AdapterRequest
.
__hash__
__hash__
=
AdapterRequest
.
__hash__
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__dict__
:
warnings
.
warn
(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'lora_path' instead."
,
DeprecationWarning
,
stacklevel
=
2
)
if
not
self
.
lora_path
:
self
.
lora_path
=
self
.
lora_local_path
or
""
# Ensure lora_path is not empty
assert
self
.
lora_path
,
"lora_path cannot be empty"
@
property
@
property
def
adapter_id
(
self
):
def
adapter_id
(
self
):
return
self
.
lora_int_id
return
self
.
lora_int_id
...
@@ -32,6 +48,26 @@ class LoRARequest(AdapterRequest):
...
@@ -32,6 +48,26 @@ class LoRARequest(AdapterRequest):
def
name
(
self
):
def
name
(
self
):
return
self
.
lora_name
return
self
.
lora_name
@
property
def
path
(
self
):
return
self
.
lora_path
@
property
@
property
def
local_path
(
self
):
def
local_path
(
self
):
return
self
.
lora_local_path
warnings
.
warn
(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead."
,
DeprecationWarning
,
stacklevel
=
2
)
return
self
.
lora_path
@
local_path
.
setter
def
local_path
(
self
,
value
):
warnings
.
warn
(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead."
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
lora_path
=
value
vllm/lora/utils.py
View file @
500b93c8
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
import
huggingface_hub
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
HFValidationError
,
RepositoryNotFoundError
)
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
...
@@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
return
"."
.
join
(
parts
[
2
:
-
1
]),
parts
[
-
1
]
==
"lora_embedding_A"
return
"."
.
join
(
parts
[
2
:
-
1
]),
parts
[
-
1
]
==
"lora_embedding_A"
raise
ValueError
(
f
"
{
name
}
is unsupported LoRA weight"
)
raise
ValueError
(
f
"
{
name
}
is unsupported LoRA weight"
)
def
get_adapter_absolute_path
(
lora_path
:
str
)
->
str
:
"""
Resolves the given lora_path to an absolute local path.
If the lora_path is identified as a Hugging Face model identifier,
it will download the model and return the local snapshot path.
Otherwise, it treats the lora_path as a local file path and
converts it to an absolute path.
Parameters:
lora_path (str): The path to the lora model, which can be an absolute path,
a relative path, or a Hugging Face model identifier.
Returns:
str: The resolved absolute local path to the lora model.
"""
# Check if the path is an absolute path. Return it no matter exists or not.
if
os
.
path
.
isabs
(
lora_path
):
return
lora_path
# If the path starts with ~, expand the user home directory.
if
lora_path
.
startswith
(
'~'
):
return
os
.
path
.
expanduser
(
lora_path
)
# Check if the expanded relative path exists locally.
if
os
.
path
.
exists
(
lora_path
):
return
os
.
path
.
abspath
(
lora_path
)
# If the path does not exist locally, assume it's a Hugging Face repo.
try
:
local_snapshot_path
=
huggingface_hub
.
snapshot_download
(
repo_id
=
lora_path
)
except
(
HfHubHTTPError
,
RepositoryNotFoundError
,
EntryNotFoundError
,
HFValidationError
):
# Handle errors that may occur during the download
# Return original path instead instead of throwing error here
logger
.
exception
(
"Error downloading the HuggingFace model"
)
return
lora_path
return
local_snapshot_path
vllm/lora/worker_manager.py
View file @
500b93c8
...
@@ -13,6 +13,7 @@ from vllm.logger import init_logger
...
@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from
vllm.lora.models
import
(
LoRAModel
,
LoRAModelManager
,
from
vllm.lora.models
import
(
LoRAModel
,
LoRAModelManager
,
LRUCacheLoRAModelManager
,
create_lora_manager
)
LRUCacheLoRAModelManager
,
create_lora_manager
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.utils
import
get_adapter_absolute_path
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -89,8 +90,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
...
@@ -89,8 +90,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
packed_modules_mapping
[
module
])
packed_modules_mapping
[
module
])
else
:
else
:
expected_lora_modules
.
append
(
module
)
expected_lora_modules
.
append
(
module
)
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora_
request
.
lora_local_
path
,
lora_path
,
expected_lora_modules
,
expected_lora_modules
,
max_position_embeddings
=
self
.
max_position_embeddings
,
max_position_embeddings
=
self
.
max_position_embeddings
,
lora_model_id
=
lora_request
.
lora_int_id
,
lora_model_id
=
lora_request
.
lora_int_id
,
...
@@ -102,8 +104,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
...
@@ -102,8 +104,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
embedding_padding_modules
=
self
.
embedding_padding_modules
,
embedding_padding_modules
=
self
.
embedding_padding_modules
,
)
)
except
Exception
as
e
:
except
Exception
as
e
:
raise
RuntimeError
(
raise
RuntimeError
(
f
"Loading lora
{
lora_path
}
failed"
)
from
e
f
"Loading lora
{
lora_request
.
lora_local_path
}
failed"
)
from
e
if
lora
.
rank
>
self
.
lora_config
.
max_lora_rank
:
if
lora
.
rank
>
self
.
lora_config
.
max_lora_rank
:
raise
ValueError
(
raise
ValueError
(
f
"LoRA rank
{
lora
.
rank
}
is greater than max_lora_rank "
f
"LoRA rank
{
lora
.
rank
}
is greater than max_lora_rank "
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
500b93c8
...
@@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.
layers.fused_moe.fused_moe
import
f
us
ed_moe
from
vllm.model_executor.
custom_op
import
C
us
tomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
...
@@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
raise
NotImplementedError
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
"""MoE method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
...
@@ -61,19 +61,37 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
...
@@ -61,19 +61,37 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
def
apply
(
layer
:
torch
.
nn
.
Module
,
self
,
x
:
torch
.
Tensor
,
layer
:
torch
.
nn
.
Module
,
router_logits
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
top_k
:
int
,
router_logits
:
torch
.
Tensor
,
renormalize
:
bool
=
True
,
top_k
:
int
,
use_grouped_topk
:
bool
=
False
,
renormalize
:
bool
=
True
,
num_expert_group
:
Optional
[
int
]
=
None
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
num_expert_group
,
topk_group
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
return
fused_moe
(
x
,
return
fused_moe
(
x
,
layer
.
w13_weight
,
w1
,
layer
.
w2_weight
,
w2
,
router_logits
,
router_logits
,
top_k
,
top_k
,
renormalize
=
renormalize
,
renormalize
=
renormalize
,
...
@@ -82,6 +100,28 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
...
@@ -82,6 +100,28 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
num_expert_group
=
num_expert_group
,
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
topk_group
=
topk_group
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
)
class
FusedMoE
(
torch
.
nn
.
Module
):
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
"""FusedMoE layer for MoE models.
...
@@ -118,6 +158,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -118,6 +158,7 @@ class FusedMoE(torch.nn.Module):
topk_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -141,7 +182,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -141,7 +182,7 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
UnquantizedFusedMoEMethod
())
UnquantizedFusedMoEMethod
())
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
.
quant_method
.
create_weights
(
...
...
vllm/model_executor/layers/fused_moe/moe_pallas.py
0 → 100644
View file @
500b93c8
import
torch
import
torch.nn.functional
as
F
from
torch_xla.experimental.custom_kernel
import
_histogram
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape
=
hidden_states
.
shape
hidden_size
=
hidden_states
.
shape
[
-
1
]
num_tokens
=
hidden_states
.
shape
[:
-
1
].
numel
()
num_experts
=
w1
.
shape
[
0
]
intermediate_size
=
w2
.
shape
[
-
1
]
device
=
hidden_states
.
device
dtype
=
hidden_states
.
dtype
assert
(
num_tokens
*
topk
)
%
16
==
0
,
(
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
f
"16 but got
{
num_tokens
*
topk
}
"
)
hidden_states
=
hidden_states
.
view
(
num_tokens
,
hidden_size
)
gating_output
=
gating_output
.
view
(
num_tokens
,
num_experts
)
topk_weights
=
gating_output
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float
)
topk_weights
,
topk_indices
=
topk_weights
.
topk
(
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
.
to
(
dtype
)
topk_indices
=
topk_indices
.
flatten
()
topk_argsort_indices
=
topk_indices
.
argsort
()
topk_argsort_revert_indices
=
topk_argsort_indices
.
argsort
()
token_indices
=
torch
.
arange
(
num_tokens
,
device
=
device
).
repeat_interleave
(
topk
)
token_indices
=
token_indices
[
topk_argsort_indices
]
group_sizes
=
_histogram
(
topk_indices
.
to
(
torch
.
int32
),
0
,
num_experts
-
1
)
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
# from HF Transformers.
w1
=
w1
.
transpose
(
1
,
2
)
w2
=
w2
.
transpose
(
1
,
2
)
x
=
hidden_states
[
token_indices
]
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w1
,
group_sizes
)
x
=
F
.
silu
(
x
[...,
:
intermediate_size
])
*
x
[...,
intermediate_size
:]
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w2
,
group_sizes
)
x
=
x
[
topk_argsort_revert_indices
].
reshape
(
-
1
,
topk
,
hidden_size
)
x
=
x
*
topk_weights
.
unsqueeze_
(
dim
=-
1
)
x
=
x
.
sum
(
dim
=-
2
)
x
=
x
.
reshape
(
orig_shape
)
return
x
vllm/model_executor/layers/linear.py
View file @
500b93c8
...
@@ -160,6 +160,7 @@ class LinearBase(torch.nn.Module):
...
@@ -160,6 +160,7 @@ class LinearBase(torch.nn.Module):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -174,7 +175,8 @@ class LinearBase(torch.nn.Module):
...
@@ -174,7 +175,8 @@ class LinearBase(torch.nn.Module):
self
.
quant_method
:
Optional
[
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -190,6 +192,8 @@ class ReplicatedLinear(LinearBase):
...
@@ -190,6 +192,8 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: If true, skip adding bias but instead return it.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -198,15 +202,23 @@ class ReplicatedLinear(LinearBase):
...
@@ -198,15 +202,23 @@ class ReplicatedLinear(LinearBase):
bias
:
bool
=
True
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
prefix
:
str
=
""
):
quant_config
)
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
=
prefix
)
# All the linear layer supports quant method.
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
self
.
quant_method
.
create_weights
(
self
,
[
self
.
output_size
],
self
.
input_size
,
self
.
input_size
,
[
self
.
output_size
],
self
.
output_size
,
self
.
params_dtype
)
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
prefix
=
prefix
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
...
@@ -215,6 +227,15 @@ class ReplicatedLinear(LinearBase):
...
@@ -215,6 +227,15 @@ class ReplicatedLinear(LinearBase):
else
:
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -249,6 +270,8 @@ class ColumnParallelLinear(LinearBase):
...
@@ -249,6 +270,8 @@ class ColumnParallelLinear(LinearBase):
quant_config: Quantization configure.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -259,9 +282,10 @@ class ColumnParallelLinear(LinearBase):
...
@@ -259,9 +282,10 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
):
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
,
prefix
)
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
...
@@ -286,7 +310,8 @@ class ColumnParallelLinear(LinearBase):
...
@@ -286,7 +310,8 @@ class ColumnParallelLinear(LinearBase):
input_size
=
self
.
input_size
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
,
prefix
=
prefix
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
torch
.
empty
(
self
.
output_size_per_partition
,
...
@@ -358,6 +383,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -358,6 +383,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -367,7 +394,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -367,7 +394,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
...
@@ -377,7 +405,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -377,7 +405,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output
=
gather_output
,
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -497,6 +526,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -497,6 +526,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -507,7 +538,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -507,7 +538,8 @@ class QKVParallelLinear(ColumnParallelLinear):
bias
:
bool
=
True
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
self
.
total_num_heads
=
total_num_heads
...
@@ -539,7 +571,8 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -539,7 +571,8 @@ class QKVParallelLinear(ColumnParallelLinear):
gather_output
=
False
,
gather_output
=
False
,
skip_bias_add
=
skip_bias_add
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -698,14 +731,16 @@ class RowParallelLinear(LinearBase):
...
@@ -698,14 +731,16 @@ class RowParallelLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
,
prefix
)
self
.
input_is_parallel
=
input_is_parallel
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -716,7 +751,8 @@ class RowParallelLinear(LinearBase):
...
@@ -716,7 +751,8 @@ class RowParallelLinear(LinearBase):
input_size
=
self
.
input_size
,
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
,
prefix
=
prefix
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
"results can lead to incorrect results"
)
...
@@ -760,18 +796,19 @@ class RowParallelLinear(LinearBase):
...
@@ -760,18 +796,19 @@ class RowParallelLinear(LinearBase):
# Matrix multiply.
# Matrix multiply.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output
_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
output_
=
output_parallel
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
bias
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
500b93c8
...
@@ -2,6 +2,7 @@ from typing import Dict, Type
...
@@ -2,6 +2,7 @@ from typing import Dict, Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.bitsandbytes
import
(
from
vllm.model_executor.layers.quantization.bitsandbytes
import
(
...
@@ -10,6 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
...
@@ -10,6 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig
)
CompressedTensorsConfig
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
)
DeepSpeedFPConfig
)
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
...
@@ -24,11 +26,13 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
...
@@ -24,11 +26,13 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fp8"
:
Fp8Config
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
# The order of gptq methods is important for config.py iteration over
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
500b93c8
...
@@ -207,8 +207,8 @@ class AQLMConfig(QuantizationConfig):
...
@@ -207,8 +207,8 @@ class AQLMConfig(QuantizationConfig):
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
out_group_size
)
out_group_size
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AQLMLinearMethod"
]:
prefix
:
str
)
->
Optional
[
"AQLMLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
AQLMLinearMethod
(
self
)
return
AQLMLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
500b93c8
...
@@ -63,8 +63,8 @@ class AWQConfig(QuantizationConfig):
...
@@ -63,8 +63,8 @@ class AWQConfig(QuantizationConfig):
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
return
cls
(
weight_bits
,
group_size
,
zero_point
)
return
cls
(
weight_bits
,
group_size
,
zero_point
)
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AWQLinearMethod"
]:
prefix
:
str
)
->
Optional
[
"AWQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
return
AWQLinearMethod
(
self
)
return
AWQLinearMethod
(
self
)
return
None
return
None
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
0 → 100644
View file @
500b93c8
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_awq_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_awq_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
logger
=
init_logger
(
__name__
)
class
AWQMarlinConfig
(
QuantizationConfig
):
"""Config class for AWQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
has_zp
=
has_zp
self
.
lm_head_quantized
=
lm_head_quantized
verify_awq_marlin_supported
(
num_bits
=
self
.
weight_bits
,
group_size
=
self
.
group_size
,
has_zp
=
self
.
has_zp
)
def
__repr__
(
self
)
->
str
:
return
(
f
"AWQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"has_zp=
{
self
.
has_zp
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"awq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"AWQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
has_zp
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
has_zp
,
lm_head_quantized
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_awq_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"awq"
:
logger
.
info
(
"Detected that the model can run with awq_marlin"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"AWQMarlinLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
AWQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_awq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
has_zp
=
quant_config
.
get
(
"zero_point"
,
None
)
if
quant_method
!=
"awq"
:
return
False
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
has_zp
is
None
):
return
False
return
check_awq_marlin_supported
(
num_bits
=
num_bits
,
group_size
=
group_size
,
has_zp
=
has_zp
,
min_capability
=
cls
.
get_min_capability
())
class
AWQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ Marlin.
Args:
quant_config: The AWQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
AWQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
num_groups
=
input_size_per_partition
//
group_size
qzeros
=
Parameter
(
torch
.
empty
(
num_groups
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
num_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
num_groups
=
num_groups
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Repack weights from AWQ format to marlin format.
marlin_qweight
=
ops
.
awq_marlin_repack
(
layer
.
qweight
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from AWQ format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
# Permute zero-points from AWQ format to marlin format.
marlin_zp
=
awq_to_marlin_zero_points
(
layer
.
qzeros
,
size_k
=
layer
.
num_groups
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_tensor
(
layer
,
"qzeros"
,
marlin_zp
)
# Not-used
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_awq_marlin_linear
(
input
=
x
,
weight
=
layer
.
qweight
,
weight_scale
=
layer
.
scales
,
weight_zp
=
layer
.
qzeros
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
bias
=
bias
)
vllm/model_executor/layers/quantization/base_config.py
View file @
500b93c8
...
@@ -97,12 +97,13 @@ class QuantizationConfig(ABC):
...
@@ -97,12 +97,13 @@ class QuantizationConfig(ABC):
return
default
return
default
@
abstractmethod
@
abstractmethod
def
get_quant_method
(
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
"""Get the quantize method to use for the quantized layer.
"""Get the quantize method to use for the quantized layer.
Args:
Args:
layer: The layer for the quant method.
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
Returns:
The quantize method. None if the given layer doesn't support quant
The quantize method. None if the given layer doesn't support quant
method.
method.
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment