Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
500b93c8
Commit
500b93c8
authored
Jul 25, 2024
by
zhuwenwen
Browse files
Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1
parents
99426767
38c4b7e8
Changes
282
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
770 additions
and
231 deletions
+770
-231
vllm/executor/neuron_executor.py
vllm/executor/neuron_executor.py
+2
-0
vllm/executor/openvino_executor.py
vllm/executor/openvino_executor.py
+2
-0
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+168
-99
vllm/executor/ray_utils.py
vllm/executor/ray_utils.py
+8
-6
vllm/executor/ray_xpu_executor.py
vllm/executor/ray_xpu_executor.py
+33
-47
vllm/executor/tpu_executor.py
vllm/executor/tpu_executor.py
+2
-0
vllm/executor/xpu_executor.py
vllm/executor/xpu_executor.py
+2
-0
vllm/inputs/__init__.py
vllm/inputs/__init__.py
+3
-4
vllm/inputs/data.py
vllm/inputs/data.py
+1
-23
vllm/lora/request.py
vllm/lora/request.py
+39
-3
vllm/lora/utils.py
vllm/lora/utils.py
+47
-0
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+4
-3
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+55
-14
vllm/model_executor/layers/fused_moe/moe_pallas.py
vllm/model_executor/layers/fused_moe/moe_pallas.py
+62
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+63
-26
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+4
-0
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+2
-2
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+2
-2
vllm/model_executor/layers/quantization/awq_marlin.py
vllm/model_executor/layers/quantization/awq_marlin.py
+268
-0
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+3
-2
No files found.
vllm/executor/neuron_executor.py
View file @
500b93c8
...
...
@@ -11,6 +11,8 @@ logger = init_logger(__name__)
class
NeuronExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
assert
(
self
.
lora_config
is
None
),
"LoRA is not supported for Neuron backend."
...
...
vllm/executor/openvino_executor.py
View file @
500b93c8
...
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
class
OpenVINOExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
assert
self
.
device_config
.
device_type
==
"openvino"
assert
self
.
lora_config
is
None
,
"OpenVINO backend doesn't support LoRA"
...
...
vllm/executor/ray_gpu_executor.py
View file @
500b93c8
import
asyncio
import
os
import
pickle
from
collections
import
defaultdict
from
itertools
import
islice
,
repeat
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
,
Tuple
...
...
@@ -11,7 +10,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
from
vllm.executor.ray_utils
import
RayWorkerWrapper
,
ray
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
,
SamplerOutput
from
vllm.utils
import
(
error_on_invalid_device_count_status
,
from
vllm.utils
import
(
_run_task_with_lock
,
error_on_invalid_device_count_status
,
get_distributed_init_method
,
get_ip
,
get_open_port
,
get_vllm_instance_id
,
make_async
)
...
...
@@ -23,13 +23,33 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
USE_RAY_COMPILED_DAG
=
envs
.
VLLM_USE_RAY_COMPILED_DAG
class
RayGPUExecutor
(
DistributedGPUExecutor
):
uses_ray
:
bool
=
True
def
_init_executor
(
self
)
->
None
:
assert
self
.
parallel_config
.
distributed_executor_backend
==
"ray"
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self
.
use_ray_compiled_dag
=
envs
.
VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self
.
use_ray_spmd_worker
=
envs
.
VLLM_USE_RAY_SPMD_WORKER
if
self
.
use_ray_compiled_dag
:
assert
self
.
use_ray_spmd_worker
,
(
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1"
)
if
self
.
use_ray_spmd_worker
:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert
self
.
use_ray_compiled_dag
,
(
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1"
)
assert
self
.
uses_ray
placement_group
=
self
.
parallel_config
.
placement_group
# Disable Ray usage stats collection.
...
...
@@ -40,11 +60,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
self
.
extra_execute_model_run_workers_kwargs
[
"use_ray_compiled_dag"
]
=
True
self
.
forward_dag
:
Optional
[
"ray.dag.CompiledDAG"
]
=
None
def
_configure_ray_workers_use_nsight
(
self
,
ray_remote_kwargs
)
->
Dict
[
str
,
Any
]:
...
...
@@ -61,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
return
ray_remote_kwargs
def
_get_worker_wrapper_args
(
self
)
->
Dict
[
str
,
Any
]:
if
self
.
speculative_config
is
not
None
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
else
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
return
dict
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
if
(
self
.
parallel_config
.
tensor_parallel_size
==
1
...
...
@@ -83,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the workers.
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
...
...
@@ -92,39 +123,28 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_bundle_index
=
bundle_id
,
)
if
self
.
speculative_config
is
not
None
:
worker_module_name
=
"vllm.spec_decode.spec_decode_worker"
worker_class_name
=
"create_spec_worker"
else
:
worker_module_name
=
"vllm.worker.worker"
worker_class_name
=
"Worker"
worker
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
)(
RayWorkerWrapper
).
remote
(
**
worker_wrapper_kwargs
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
worker_module_name
=
worker_module_name
,
worker_class_name
=
worker_class_name
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
else
:
# Else, added to the list of workers.
if
self
.
use_ray_spmd_worker
:
self
.
workers
.
append
(
worker
)
if
self
.
driver_dummy_worker
is
None
:
else
:
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
**
worker_wrapper_kwargs
)
else
:
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
if
not
self
.
use_ray_spmd_worker
and
self
.
driver_dummy_worker
is
None
:
raise
ValueError
(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
...
...
@@ -224,13 +244,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
# broadcasted to.
self
.
non_driver_workers
:
List
[
RayWorkerWrapper
]
=
[]
for
idx
,
rank
in
enumerate
(
worker_ranks
[
1
:]):
# Enforce rank order for correct rank to return final output.
for
rank
,
worker
in
sorted
(
zip
(
worker_ranks
[
1
:],
self
.
workers
)):
# We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0.
if
rank
%
self
.
parallel_config
.
tensor_parallel_size
==
0
:
self
.
tp_driver_workers
.
append
(
self
.
worker
s
[
idx
]
)
self
.
tp_driver_workers
.
append
(
worker
)
else
:
self
.
non_driver_workers
.
append
(
self
.
worker
s
[
idx
]
)
self
.
non_driver_workers
.
append
(
worker
)
def
_driver_execute_model
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
...
...
@@ -240,9 +261,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
Passing None will cause the driver to stop the model execution
loop running in each of the remote workers.
"""
assert
not
self
.
use_ray_spmd_worker
,
(
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
)
return
self
.
driver_worker
.
execute_method
(
"execute_model"
,
execute_model_req
)
def
execute_model
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
if
not
self
.
use_ray_spmd_worker
:
return
super
().
execute_model
(
execute_model_req
)
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
outputs
=
ray
.
get
(
self
.
forward_dag
.
execute
(
execute_model_req
))
return
outputs
[
0
]
def
_run_workers
(
self
,
method
:
str
,
...
...
@@ -252,7 +287,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
use_dummy_driver
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers. Can be used in the following
...
...
@@ -267,6 +301,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if
self
.
use_ray_spmd_worker
:
assert
not
async_run_tensor_parallel_workers_only
,
(
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode."
)
if
max_concurrent_workers
:
raise
NotImplementedError
(
...
...
@@ -275,99 +313,125 @@ class RayGPUExecutor(DistributedGPUExecutor):
count
=
len
(
self
.
workers
)
if
not
\
async_run_tensor_parallel_workers_only
\
else
len
(
self
.
non_driver_workers
)
# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index
:
int
=
0
if
self
.
use_ray_spmd_worker
else
1
all_worker_args
=
repeat
(
args
,
count
)
if
all_args
is
None
\
else
islice
(
all_args
,
1
,
None
)
else
islice
(
all_args
,
first_worker_args_index
,
None
)
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
else
islice
(
all_kwargs
,
1
,
None
)
if
use_ray_compiled_dag
:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert
self
.
forward_dag
is
not
None
output_channels
=
self
.
forward_dag
.
execute
(
1
)
ray_worker_outputs
=
[]
else
:
# Start the ray workers first.
ray_workers
=
self
.
workers
if
async_run_tensor_parallel_workers_only
:
ray_workers
=
self
.
non_driver_workers
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
ray_workers
,
all_worker_args
,
all_worker_kwargs
)
]
else
islice
(
all_kwargs
,
first_worker_args_index
,
None
)
# Start the ray workers first.
ray_workers
=
self
.
workers
if
async_run_tensor_parallel_workers_only
:
ray_workers
=
self
.
non_driver_workers
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
ray_workers
,
all_worker_args
,
all_worker_kwargs
)
]
if
async_run_tensor_parallel_workers_only
:
# Just return futures
return
ray_worker_outputs
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
driver_worker_output
=
[]
# In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if
not
self
.
use_ray_spmd_worker
:
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
driver_worker_output
=
[
self
.
driver_worker
.
execute_method
(
method
,
*
driver_args
,
**
driver_kwargs
)
]
else
:
assert
self
.
driver_dummy_worker
is
not
None
driver_worker_output
=
[
ray
.
get
(
self
.
driver_dummy_worker
.
execute_method
.
remote
(
method
,
*
driver_args
,
**
driver_kwargs
))
]
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
method
,
*
driver_args
,
**
driver_kwargs
)
else
:
assert
self
.
driver_dummy_worker
is
not
None
driver_worker_output
=
ray
.
get
(
self
.
driver_dummy_worker
.
execute_method
.
remote
(
method
,
*
driver_args
,
**
driver_kwargs
))
# Get the results of the ray workers.
if
self
.
workers
:
if
use_ray_compiled_dag
:
try
:
ray_worker_outputs
=
[
pickle
.
loads
(
chan
.
begin_read
())
for
chan
in
output_channels
]
finally
:
# Has to call end_read in order to reuse the DAG.
for
chan
in
output_channels
:
chan
.
end_read
()
else
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
return
driver_worker_output
+
ray_worker_outputs
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
def
_compiled_ray_dag
(
self
):
def
_compiled_ray_dag
(
self
,
enable_asyncio
:
bool
):
import
pkg_resources
required_version
=
"2.9"
current_version
=
pkg_resources
.
get_distribution
(
"ray"
).
version
from
packaging
import
version
required_version
=
version
.
parse
(
"2.32"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
distributed_executor_backend
==
"
ray
"
assert
self
.
parallel_config
.
use_
ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
with
InputNode
()
as
input_data
:
forward_dag
=
MultiOutputNode
([
worker
.
execute_model_compiled_dag_remote
.
bind
(
# type: ignore[attr-defined]
worker
.
execute_model_spmd
.
bind
(
# type: ignore[attr-defined]
input_data
)
for
worker
in
self
.
workers
])
return
forward_dag
.
experimental_compile
()
return
forward_dag
.
experimental_compile
(
enable_asyncio
=
enable_asyncio
)
def
__del__
(
self
):
if
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
import
ray
for
worker
in
self
.
workers
:
ray
.
kill
(
worker
)
class
RayGPUExecutorAsync
(
RayGPUExecutor
,
DistributedGPUExecutorAsync
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
driver_exec_method
=
make_async
(
self
.
driver_worker
.
execute_method
)
self
.
pp_locks
:
Optional
[
List
[
asyncio
.
Lock
]]
=
None
self
.
use_ray_spmd_worker
=
envs
.
VLLM_USE_RAY_SPMD_WORKER
if
not
self
.
use_ray_compiled_dag
:
self
.
driver_exec_method
=
make_async
(
self
.
driver_worker
.
execute_method
)
async
def
execute_model_async
(
self
,
execute_model_req
:
ExecuteModelRequest
)
->
List
[
SamplerOutput
]:
if
not
self
.
use_ray_spmd_worker
:
return
await
super
().
execute_model_async
(
execute_model_req
)
if
self
.
forward_dag
is
None
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
True
)
dag_future
=
await
self
.
forward_dag
.
execute_async
(
execute_model_req
)
outputs
=
await
dag_future
return
outputs
[
0
]
async
def
_driver_execute_model_async
(
self
,
execute_model_req
:
Optional
[
ExecuteModelRequest
]
=
None
)
->
List
[
SamplerOutput
]:
assert
not
self
.
use_ray_spmd_worker
,
(
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1"
)
if
not
self
.
tp_driver_workers
:
return
await
self
.
driver_exec_method
(
"execute_model"
,
execute_model_req
)
if
self
.
pp_locks
is
None
:
# This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time
...
...
@@ -378,15 +442,11 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
for
_
in
range
(
self
.
parallel_config
.
pipeline_parallel_size
)
]
async
def
_run_task_with_lock
(
task
,
lock
,
*
args
,
**
kwargs
):
async
with
lock
:
return
await
task
(
*
args
,
**
kwargs
)
tasks
=
[]
tasks
.
append
(
tasks
=
[
asyncio
.
create_task
(
_run_task_with_lock
(
self
.
driver_exec_method
,
self
.
pp_locks
[
0
],
"execute_model"
,
execute_model_req
)))
"execute_model"
,
execute_model_req
))
]
for
pp_rank
,
driver_worker
in
enumerate
(
self
.
tp_driver_workers
,
start
=
1
):
tasks
.
append
(
...
...
@@ -401,8 +461,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return
results
[
-
1
]
async
def
_start_worker_execution_loop
(
self
):
assert
not
self
.
use_ray_spmd_worker
,
(
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1"
)
coros
=
[
worker
.
execute_method
.
remote
(
"start_worker_execution_loop"
)
for
worker
in
self
.
non_driver_workers
]
return
await
asyncio
.
gather
(
*
coros
)
def
__del__
(
self
):
if
self
.
forward_dag
is
not
None
:
self
.
forward_dag
.
teardown
()
import
ray
for
worker
in
self
.
workers
:
ray
.
kill
(
worker
)
vllm/executor/ray_utils.py
View file @
500b93c8
import
pickle
from
typing
import
List
,
Optional
,
Tuple
from
vllm.config
import
ParallelConfig
from
vllm.logger
import
init_logger
from
vllm.sequence
import
ExecuteModelRequest
from
vllm.utils
import
get_ip
,
is_hip
,
is_xpu
from
vllm.worker.worker_base
import
WorkerWrapperBase
...
...
@@ -31,16 +31,18 @@ try:
gpu_ids
=
ray
.
get_gpu_ids
()
return
node_id
,
gpu_ids
def
execute_model_compiled_dag_remote
(
self
,
ignored
):
"""Used only when compiled DAG is enabled."""
def
execute_model_spmd
(
self
,
execute_model_req
:
ExecuteModelRequest
):
"""Used only when SPMD worker and compiled DAG are both
enabled."""
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
import
torch
if
not
self
.
compiled_dag_cuda_device_set
:
torch
.
cuda
.
set_device
(
self
.
worker
.
device
)
self
.
compiled_dag_cuda_device_set
=
True
output
=
self
.
worker
.
execute_model
()
output
=
pickle
.
dumps
(
output
)
return
output
return
self
.
worker
.
_execute_model_spmd
(
execute_model_req
)
ray_import_err
=
None
...
...
vllm/executor/ray_xpu_executor.py
View file @
500b93c8
import
asyncio
import
os
import
pickle
from
collections
import
defaultdict
from
itertools
import
islice
,
repeat
from
typing
import
(
TYPE_CHECKING
,
Any
,
Awaitable
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Union
)
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
DeviceConfig
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
MultiModalConfig
,
ParallelConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
...
...
@@ -30,11 +30,13 @@ logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG
=
bool
(
os
.
get
env
(
"
VLLM_USE_RAY_COMPILED_DAG
"
,
0
))
USE_RAY_COMPILED_DAG
=
env
s
.
VLLM_USE_RAY_COMPILED_DAG
class
RayXPUExecutor
(
DistributedGPUExecutor
):
uses_ray
:
bool
=
True
def
__init__
(
self
,
model_config
:
ModelConfig
,
...
...
@@ -72,10 +74,9 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers.
self
.
_init_workers_ray
(
placement_group
)
# Profile the memory usage and initialize the cache.
self
.
forward_dag
=
None
if
USE_RAY_COMPILED_DAG
:
self
.
forward_dag
=
self
.
_compiled_ray_dag
()
self
.
forward_dag
=
self
.
_compiled_ray_dag
(
enable_asyncio
=
False
)
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
...
...
@@ -108,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
return
num_gpu_blocks
,
num_cpu_blocks
def
_get_worker_wrapper_args
(
self
)
->
Dict
[
str
,
Any
]:
return
dict
(
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
def
_init_workers_ray
(
self
,
placement_group
:
"PlacementGroup"
,
**
ray_remote_kwargs
):
if
self
.
parallel_config
.
tensor_parallel_size
==
1
:
...
...
@@ -125,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the workers.
driver_ip
=
get_ip
()
worker_wrapper_kwargs
=
self
.
_get_worker_wrapper_args
()
for
bundle_id
,
bundle
in
enumerate
(
placement_group
.
bundle_specs
):
if
not
bundle
.
get
(
"GPU"
,
0
):
continue
...
...
@@ -138,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
num_gpus
=
num_gpus
,
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
)(
RayWorkerWrapper
).
remote
(
**
worker_wrapper_kwargs
)
worker_ip
=
ray
.
get
(
worker
.
get_node_ip
.
remote
())
if
worker_ip
==
driver_ip
and
self
.
driver_dummy_worker
is
None
:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
worker_module_name
=
"vllm.worker.xpu_worker"
,
worker_class_name
=
"XPUWorker"
,
trust_remote_code
=
self
.
model_config
.
trust_remote_code
,
)
self
.
driver_worker
=
RayWorkerWrapper
(
**
worker_wrapper_kwargs
)
else
:
# Else, added to the list of workers.
self
.
workers
.
append
(
worker
)
...
...
@@ -270,7 +271,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_kwargs
:
Optional
[
List
[
Dict
[
str
,
Any
]]]
=
None
,
use_dummy_driver
:
bool
=
False
,
max_concurrent_workers
:
Optional
[
int
]
=
None
,
use_ray_compiled_dag
:
bool
=
False
,
**
kwargs
,
)
->
Any
:
"""Runs the given method on all workers. Can be used in the following
...
...
@@ -293,26 +293,20 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_worker_kwargs
=
repeat
(
kwargs
,
count
)
if
all_kwargs
is
None
\
else
islice
(
all_kwargs
,
1
,
None
)
if
use_ray_compiled_dag
:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
assert
self
.
forward_dag
is
not
None
output_channels
=
self
.
forward_dag
.
execute
(
1
)
else
:
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
]
# Start the ray workers first.
ray_worker_outputs
=
[
worker
.
execute_method
.
remote
(
method
,
*
worker_args
,
**
worker_kwargs
)
for
(
worker
,
worker_args
,
worker_kwargs
)
in
zip
(
self
.
workers
,
all_worker_args
,
all_worker_kwargs
)
]
if
async_run_remote_workers_only
:
# Just return futures
return
ray_worker_outputs
driver_worker_output
=
[]
driver_args
=
args
if
all_args
is
None
else
all_args
[
0
]
driver_kwargs
=
kwargs
if
all_kwargs
is
None
else
all_kwargs
[
0
]
# Start the driver worker after all the ray workers.
if
not
use_dummy_driver
:
driver_worker_output
=
self
.
driver_worker
.
execute_method
(
...
...
@@ -324,36 +318,28 @@ class RayXPUExecutor(DistributedGPUExecutor):
method
,
*
driver_args
,
**
driver_kwargs
))
# Get the results of the ray workers.
if
self
.
workers
:
if
use_ray_compiled_dag
:
try
:
ray_worker_outputs
=
[
pickle
.
loads
(
chan
.
begin_read
())
for
chan
in
output_channels
]
finally
:
# Has to call end_read in order to reuse the DAG.
for
chan
in
output_channels
:
chan
.
end_read
()
else
:
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
ray_worker_outputs
=
ray
.
get
(
ray_worker_outputs
)
return
[
driver_worker_output
]
+
ray_worker_outputs
return
driver_worker_output
+
ray_worker_outputs
def
_wait_for_tasks_completion
(
self
,
parallel_worker_tasks
:
Any
)
->
None
:
"""Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete."""
ray
.
get
(
parallel_worker_tasks
)
def
_compiled_ray_dag
(
self
):
def
_compiled_ray_dag
(
self
,
enable_asyncio
:
bool
):
import
pkg_resources
required_version
=
"2.9"
current_version
=
pkg_resources
.
get_distribution
(
"ray"
).
version
from
packaging
import
version
required_version
=
version
.
parse
(
"2.32"
)
current_version
=
version
.
parse
(
pkg_resources
.
get_distribution
(
"ray"
).
version
)
if
current_version
<
required_version
:
raise
ValueError
(
f
"Ray version
{
required_version
}
or greater is "
f
"required, but found
{
current_version
}
"
)
from
ray.dag
import
InputNode
,
MultiOutputNode
assert
self
.
parallel_config
.
worker_
use_ray
assert
self
.
parallel_config
.
use_ray
# Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon.
...
...
@@ -363,7 +349,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
bind
(
# type: ignore[attr-defined]
input_data
)
for
worker
in
self
.
workers
])
return
forward_dag
.
experimental_compile
()
return
forward_dag
.
experimental_compile
(
enable_asyncio
=
enable_asyncio
)
def
check_health
(
self
)
->
None
:
"""Raises an error if engine is unhealthy."""
...
...
vllm/executor/tpu_executor.py
View file @
500b93c8
...
...
@@ -14,6 +14,8 @@ logger = init_logger(__name__)
class
TPUExecutor
(
ExecutorBase
):
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
assert
not
self
.
scheduler_config
.
chunked_prefill_enabled
,
(
"Chunked prefill is not yet supported for TPU backend"
)
...
...
vllm/executor/xpu_executor.py
View file @
500b93c8
...
...
@@ -18,6 +18,8 @@ logger = init_logger(__name__)
class
XPUExecutor
(
GPUExecutor
):
uses_ray
:
bool
=
False
def
__init__
(
self
,
model_config
:
ModelConfig
,
...
...
vllm/inputs/__init__.py
View file @
500b93c8
from
.data
import
(
LLMInputs
,
ParsedText
,
ParsedTokens
,
PromptInputs
,
PromptStrictInputs
,
TextPrompt
,
TextTokensPrompt
,
TokensPrompt
,
parse_and_batch_prompt
)
TextPrompt
,
TokensPrompt
,
parse_and_batch_prompt
)
from
.registry
import
InputContext
,
InputRegistry
INPUT_REGISTRY
=
InputRegistry
()
...
...
@@ -14,6 +13,6 @@ See also:
__all__
=
[
"ParsedText"
,
"ParsedTokens"
,
"parse_and_batch_prompt"
,
"TextPrompt"
,
"TokensPrompt"
,
"
TextTokensPrompt"
,
"PromptStrictInputs"
,
"PromptInputs
"
,
"LLMInputs"
,
"INPUT_REGISTRY"
,
"InputContext"
,
"InputRegistry"
"TokensPrompt"
,
"
PromptInputs"
,
"LLMInputs"
,
"INPUT_REGISTRY
"
,
"InputContext"
,
"InputRegistry"
]
vllm/inputs/data.py
View file @
500b93c8
...
...
@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
"""
class
TextTokensPrompt
(
TypedDict
):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt
:
str
"""The prompt text."""
prompt_token_ids
:
List
[
int
]
"""The token IDs of the prompt."""
multi_modal_data
:
NotRequired
[
"MultiModalDataDict"
]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
PromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
]
"""
The inputs to the LLM, which can take one of the following forms:
...
...
@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs
=
Union
[
str
,
TextPrompt
,
TokensPrompt
,
TextTokensPrompt
]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class
LLMInputs
(
TypedDict
):
"""
...
...
vllm/lora/request.py
View file @
500b93c8
from
dataclasses
import
dataclass
import
warnings
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
vllm.adapter_commons.request
import
AdapterRequest
...
...
@@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest):
lora_name
:
str
lora_int_id
:
int
lora_local_path
:
str
lora_path
:
str
=
""
lora_local_path
:
Optional
[
str
]
=
field
(
default
=
None
,
repr
=
False
)
long_lora_max_len
:
Optional
[
int
]
=
None
__hash__
=
AdapterRequest
.
__hash__
def
__post_init__
(
self
):
if
'lora_local_path'
in
self
.
__dict__
:
warnings
.
warn
(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'lora_path' instead."
,
DeprecationWarning
,
stacklevel
=
2
)
if
not
self
.
lora_path
:
self
.
lora_path
=
self
.
lora_local_path
or
""
# Ensure lora_path is not empty
assert
self
.
lora_path
,
"lora_path cannot be empty"
@
property
def
adapter_id
(
self
):
return
self
.
lora_int_id
...
...
@@ -32,6 +48,26 @@ class LoRARequest(AdapterRequest):
def
name
(
self
):
return
self
.
lora_name
@
property
def
path
(
self
):
return
self
.
lora_path
@
property
def
local_path
(
self
):
return
self
.
lora_local_path
warnings
.
warn
(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead."
,
DeprecationWarning
,
stacklevel
=
2
)
return
self
.
lora_path
@
local_path
.
setter
def
local_path
(
self
,
value
):
warnings
.
warn
(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead."
,
DeprecationWarning
,
stacklevel
=
2
)
self
.
lora_path
=
value
vllm/lora/utils.py
View file @
500b93c8
import
os
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Type
import
huggingface_hub
from
huggingface_hub.utils
import
(
EntryNotFoundError
,
HfHubHTTPError
,
HFValidationError
,
RepositoryNotFoundError
)
from
torch
import
nn
from
transformers
import
PretrainedConfig
...
...
@@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
return
"."
.
join
(
parts
[
2
:
-
1
]),
parts
[
-
1
]
==
"lora_embedding_A"
raise
ValueError
(
f
"
{
name
}
is unsupported LoRA weight"
)
def
get_adapter_absolute_path
(
lora_path
:
str
)
->
str
:
"""
Resolves the given lora_path to an absolute local path.
If the lora_path is identified as a Hugging Face model identifier,
it will download the model and return the local snapshot path.
Otherwise, it treats the lora_path as a local file path and
converts it to an absolute path.
Parameters:
lora_path (str): The path to the lora model, which can be an absolute path,
a relative path, or a Hugging Face model identifier.
Returns:
str: The resolved absolute local path to the lora model.
"""
# Check if the path is an absolute path. Return it no matter exists or not.
if
os
.
path
.
isabs
(
lora_path
):
return
lora_path
# If the path starts with ~, expand the user home directory.
if
lora_path
.
startswith
(
'~'
):
return
os
.
path
.
expanduser
(
lora_path
)
# Check if the expanded relative path exists locally.
if
os
.
path
.
exists
(
lora_path
):
return
os
.
path
.
abspath
(
lora_path
)
# If the path does not exist locally, assume it's a Hugging Face repo.
try
:
local_snapshot_path
=
huggingface_hub
.
snapshot_download
(
repo_id
=
lora_path
)
except
(
HfHubHTTPError
,
RepositoryNotFoundError
,
EntryNotFoundError
,
HFValidationError
):
# Handle errors that may occur during the download
# Return original path instead instead of throwing error here
logger
.
exception
(
"Error downloading the HuggingFace model"
)
return
lora_path
return
local_snapshot_path
vllm/lora/worker_manager.py
View file @
500b93c8
...
...
@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from
vllm.lora.models
import
(
LoRAModel
,
LoRAModelManager
,
LRUCacheLoRAModelManager
,
create_lora_manager
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.utils
import
get_adapter_absolute_path
logger
=
init_logger
(
__name__
)
...
...
@@ -89,8 +90,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
packed_modules_mapping
[
module
])
else
:
expected_lora_modules
.
append
(
module
)
lora_path
=
get_adapter_absolute_path
(
lora_request
.
lora_path
)
lora
=
self
.
_lora_model_cls
.
from_local_checkpoint
(
lora_
request
.
lora_local_
path
,
lora_path
,
expected_lora_modules
,
max_position_embeddings
=
self
.
max_position_embeddings
,
lora_model_id
=
lora_request
.
lora_int_id
,
...
...
@@ -102,8 +104,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
embedding_padding_modules
=
self
.
embedding_padding_modules
,
)
except
Exception
as
e
:
raise
RuntimeError
(
f
"Loading lora
{
lora_request
.
lora_local_path
}
failed"
)
from
e
raise
RuntimeError
(
f
"Loading lora
{
lora_path
}
failed"
)
from
e
if
lora
.
rank
>
self
.
lora_config
.
max_lora_rank
:
raise
ValueError
(
f
"LoRA rank
{
lora
.
rank
}
is greater than max_lora_rank "
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
500b93c8
...
...
@@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.
layers.fused_moe.fused_moe
import
f
us
ed_moe
from
vllm.model_executor.
custom_op
import
C
us
tomOp
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
...
...
@@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise
NotImplementedError
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
...
...
@@ -61,19 +61,37 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
=
True
,
use_grouped_topk
:
bool
=
False
,
num_expert_group
:
Optional
[
int
]
=
None
,
topk_group
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
return
self
.
forward
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
router_logits
,
top_k
,
renormalize
,
use_grouped_topk
,
num_expert_group
,
topk_group
)
def
forward_cuda
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.fused_moe
import
fused_moe
return
fused_moe
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
=
renormalize
,
...
...
@@ -82,6 +100,28 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
num_expert_group
=
num_expert_group
,
topk_group
=
topk_group
)
def
forward_cpu
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"The CPU backend currently does not support MoE."
)
def
forward_tpu
(
self
,
x
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
num_expert_group
:
Optional
[
int
],
topk_group
:
Optional
[
int
],
)
->
torch
.
Tensor
:
from
vllm.model_executor.layers.fused_moe.moe_pallas
import
fused_moe
assert
not
use_grouped_topk
assert
num_expert_group
is
None
assert
topk_group
is
None
return
fused_moe
(
x
,
w1
,
w2
,
router_logits
,
top_k
,
renormalize
)
class
FusedMoE
(
torch
.
nn
.
Module
):
"""FusedMoE layer for MoE models.
...
...
@@ -118,6 +158,7 @@ class FusedMoE(torch.nn.Module):
topk_group
:
Optional
[
int
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
...
...
@@ -141,7 +182,7 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
(
UnquantizedFusedMoEMethod
())
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
...
...
vllm/model_executor/layers/fused_moe/moe_pallas.py
0 → 100644
View file @
500b93c8
import
torch
import
torch.nn.functional
as
F
from
torch_xla.experimental.custom_kernel
import
_histogram
def
fused_moe
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
gating_output
:
torch
.
Tensor
,
topk
:
int
,
renormalize
:
bool
,
)
->
torch
.
Tensor
:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape
=
hidden_states
.
shape
hidden_size
=
hidden_states
.
shape
[
-
1
]
num_tokens
=
hidden_states
.
shape
[:
-
1
].
numel
()
num_experts
=
w1
.
shape
[
0
]
intermediate_size
=
w2
.
shape
[
-
1
]
device
=
hidden_states
.
device
dtype
=
hidden_states
.
dtype
assert
(
num_tokens
*
topk
)
%
16
==
0
,
(
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
f
"16 but got
{
num_tokens
*
topk
}
"
)
hidden_states
=
hidden_states
.
view
(
num_tokens
,
hidden_size
)
gating_output
=
gating_output
.
view
(
num_tokens
,
num_experts
)
topk_weights
=
gating_output
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float
)
topk_weights
,
topk_indices
=
topk_weights
.
topk
(
topk
,
dim
=-
1
)
if
renormalize
:
topk_weights
=
topk_weights
/
topk_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
topk_weights
=
topk_weights
.
to
(
dtype
)
topk_indices
=
topk_indices
.
flatten
()
topk_argsort_indices
=
topk_indices
.
argsort
()
topk_argsort_revert_indices
=
topk_argsort_indices
.
argsort
()
token_indices
=
torch
.
arange
(
num_tokens
,
device
=
device
).
repeat_interleave
(
topk
)
token_indices
=
token_indices
[
topk_argsort_indices
]
group_sizes
=
_histogram
(
topk_indices
.
to
(
torch
.
int32
),
0
,
num_experts
-
1
)
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
# from HF Transformers.
w1
=
w1
.
transpose
(
1
,
2
)
w2
=
w2
.
transpose
(
1
,
2
)
x
=
hidden_states
[
token_indices
]
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w1
,
group_sizes
)
x
=
F
.
silu
(
x
[...,
:
intermediate_size
])
*
x
[...,
intermediate_size
:]
x
=
torch
.
ops
.
xla
.
gmm
(
x
,
w2
,
group_sizes
)
x
=
x
[
topk_argsort_revert_indices
].
reshape
(
-
1
,
topk
,
hidden_size
)
x
=
x
*
topk_weights
.
unsqueeze_
(
dim
=-
1
)
x
=
x
.
sum
(
dim
=-
2
)
x
=
x
.
reshape
(
orig_shape
)
return
x
vllm/model_executor/layers/linear.py
View file @
500b93c8
...
...
@@ -160,6 +160,7 @@ class LinearBase(torch.nn.Module):
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
...
...
@@ -174,7 +175,8 @@ class LinearBase(torch.nn.Module):
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
=
prefix
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
...
...
@@ -190,6 +192,8 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
...
...
@@ -198,15 +202,23 @@ class ReplicatedLinear(LinearBase):
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
,
prefix
=
prefix
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
,
prefix
=
prefix
)
if
bias
:
self
.
bias
=
Parameter
(
...
...
@@ -215,6 +227,15 @@ class ReplicatedLinear(LinearBase):
else
:
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
...
...
@@ -249,6 +270,8 @@ class ColumnParallelLinear(LinearBase):
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
...
...
@@ -259,9 +282,10 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
):
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
,
prefix
)
self
.
gather_output
=
gather_output
...
...
@@ -286,7 +310,8 @@ class ColumnParallelLinear(LinearBase):
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
,
prefix
=
prefix
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
...
...
@@ -358,6 +383,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
...
...
@@ -367,7 +394,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
...
...
@@ -377,7 +405,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output
=
gather_output
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -497,6 +526,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
"""
def
__init__
(
self
,
...
...
@@ -507,7 +538,8 @@ class QKVParallelLinear(ColumnParallelLinear):
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
total_num_heads
=
total_num_heads
...
...
@@ -539,7 +571,8 @@ class QKVParallelLinear(ColumnParallelLinear):
gather_output
=
False
,
skip_bias_add
=
skip_bias_add
,
params_dtype
=
params_dtype
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
prefix
)
def
weight_loader
(
self
,
param
:
Parameter
,
...
...
@@ -698,14 +731,16 @@ class RowParallelLinear(LinearBase):
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
quant_config
,
prefix
)
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
# Divide the weight matrix along the last dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
...
...
@@ -716,7 +751,8 @@ class RowParallelLinear(LinearBase):
input_size
=
self
.
input_size
,
output_size
=
self
.
output_size
,
params_dtype
=
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
weight_loader
=
self
.
weight_loader
,
prefix
=
prefix
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
...
...
@@ -760,18 +796,19 @@ class RowParallelLinear(LinearBase):
# Matrix multiply.
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output
_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
output_
=
output_parallel
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
...
...
vllm/model_executor/layers/quantization/__init__.py
View file @
500b93c8
...
...
@@ -2,6 +2,7 @@ from typing import Dict, Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.bitsandbytes
import
(
...
...
@@ -10,6 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig
)
from
vllm.model_executor.layers.quantization.deepspeedfp
import
(
DeepSpeedFPConfig
)
from
vllm.model_executor.layers.quantization.fbgemm_fp8
import
FBGEMMFp8Config
from
vllm.model_executor.layers.quantization.fp8
import
Fp8Config
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
...
...
@@ -24,11 +26,13 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"awq"
:
AWQConfig
,
"deepspeedfp"
:
DeepSpeedFPConfig
,
"fp8"
:
Fp8Config
,
"fbgemm_fp8"
:
FBGEMMFp8Config
,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin"
:
MarlinConfig
,
"gptq_marlin_24"
:
GPTQMarlin24Config
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"awq_marlin"
:
AWQMarlinConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"compressed-tensors"
:
CompressedTensorsConfig
,
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
500b93c8
...
...
@@ -207,8 +207,8 @@ class AQLMConfig(QuantizationConfig):
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
out_group_size
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AQLMLinearMethod"
]:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"AQLMLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
AQLMLinearMethod
(
self
)
return
None
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
500b93c8
...
...
@@ -63,8 +63,8 @@ class AWQConfig(QuantizationConfig):
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
return
cls
(
weight_bits
,
group_size
,
zero_point
)
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AWQLinearMethod"
]:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"AWQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
AWQLinearMethod
(
self
)
return
None
...
...
vllm/model_executor/layers/quantization/awq_marlin.py
0 → 100644
View file @
500b93c8
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
apply_awq_marlin_linear
,
awq_to_marlin_zero_points
,
check_awq_marlin_supported
,
marlin_make_empty_g_idx
,
marlin_make_workspace
,
marlin_permute_scales
,
replace_tensor
,
verify_awq_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
logger
=
init_logger
(
__name__
)
class
AWQMarlinConfig
(
QuantizationConfig
):
"""Config class for AWQ Marlin"""
def
__init__
(
self
,
weight_bits
:
int
,
group_size
:
int
,
has_zp
:
bool
,
lm_head_quantized
:
bool
)
->
None
:
self
.
weight_bits
=
weight_bits
self
.
pack_factor
=
32
//
self
.
weight_bits
# packed into int32
self
.
group_size
=
group_size
self
.
has_zp
=
has_zp
self
.
lm_head_quantized
=
lm_head_quantized
verify_awq_marlin_supported
(
num_bits
=
self
.
weight_bits
,
group_size
=
self
.
group_size
,
has_zp
=
self
.
has_zp
)
def
__repr__
(
self
)
->
str
:
return
(
f
"AWQMarlinConfig(weight_bits=
{
self
.
weight_bits
}
, "
f
"group_size=
{
self
.
group_size
}
, "
f
"has_zp=
{
self
.
has_zp
}
, "
f
"lm_head_quantized=
{
self
.
lm_head_quantized
}
)"
)
@
classmethod
def
get_name
(
cls
)
->
str
:
return
"awq_marlin"
@
classmethod
def
get_supported_act_dtypes
(
cls
)
->
List
[
torch
.
dtype
]:
return
[
torch
.
half
,
torch
.
bfloat16
]
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
return
80
@
classmethod
def
get_config_filenames
(
cls
)
->
List
[
str
]:
return
[
"quantize_config.json"
]
@
classmethod
def
from_config
(
cls
,
config
:
Dict
[
str
,
Any
])
->
"AWQMarlinConfig"
:
weight_bits
=
cls
.
get_from_keys
(
config
,
[
"bits"
])
group_size
=
cls
.
get_from_keys
(
config
,
[
"group_size"
])
has_zp
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
lm_head_quantized
=
cls
.
get_from_keys_or
(
config
,
[
"lm_head"
],
default
=
False
)
return
cls
(
weight_bits
,
group_size
,
has_zp
,
lm_head_quantized
)
@
classmethod
def
override_quantization_method
(
cls
,
hf_quant_cfg
,
user_quant
)
->
Optional
[
str
]:
can_convert
=
cls
.
is_awq_marlin_compatible
(
hf_quant_cfg
)
is_valid_user_quant
=
(
user_quant
is
None
or
user_quant
==
"marlin"
)
if
can_convert
and
is_valid_user_quant
:
msg
=
(
"The model is convertible to {} during runtime."
" Using {} kernel."
.
format
(
cls
.
get_name
(),
cls
.
get_name
()))
logger
.
info
(
msg
)
return
cls
.
get_name
()
if
can_convert
and
user_quant
==
"awq"
:
logger
.
info
(
"Detected that the model can run with awq_marlin"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_marlin for"
" faster inference"
)
return
None
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"AWQMarlinLinearMethod"
]:
if
(
isinstance
(
layer
,
LinearBase
)
or
(
isinstance
(
layer
,
ParallelLMHead
)
and
self
.
lm_head_quantized
)):
return
AWQMarlinLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
@
classmethod
def
is_awq_marlin_compatible
(
cls
,
quant_config
:
Dict
[
str
,
Any
]):
# Extract data from quant config.
quant_method
=
quant_config
.
get
(
"quant_method"
,
""
).
lower
()
num_bits
=
quant_config
.
get
(
"bits"
,
None
)
group_size
=
quant_config
.
get
(
"group_size"
,
None
)
has_zp
=
quant_config
.
get
(
"zero_point"
,
None
)
if
quant_method
!=
"awq"
:
return
False
# If we cannot find the info needed in the config, cannot convert.
if
(
num_bits
is
None
or
group_size
is
None
or
has_zp
is
None
):
return
False
return
check_awq_marlin_supported
(
num_bits
=
num_bits
,
group_size
=
group_size
,
has_zp
=
has_zp
,
min_capability
=
cls
.
get_min_capability
())
class
AWQMarlinLinearMethod
(
LinearMethodBase
):
"""Linear method for AWQ Marlin.
Args:
quant_config: The AWQ Marlin quantization config.
"""
def
__init__
(
self
,
quant_config
:
AWQMarlinConfig
)
->
None
:
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
input_size_per_partition
:
int
,
output_partition_sizes
:
List
[
int
],
input_size
:
int
,
output_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
)
->
None
:
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
group_size
=
self
.
quant_config
.
group_size
else
:
group_size
=
input_size
verify_marlin_supports_shape
(
output_size_per_partition
=
output_size_per_partition
,
input_size_per_partition
=
input_size_per_partition
,
input_size
=
input_size
,
group_size
=
group_size
)
qweight
=
Parameter
(
torch
.
empty
(
input_size_per_partition
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qweight
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
num_groups
=
input_size_per_partition
//
group_size
qzeros
=
Parameter
(
torch
.
empty
(
num_groups
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
qzeros
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
})
scales
=
Parameter
(
torch
.
empty
(
num_groups
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
"input_dim"
:
0
,
"output_dim"
:
1
,
})
layer
.
register_parameter
(
"qweight"
,
qweight
)
set_weight_attrs
(
qweight
,
extra_weight_attrs
)
layer
.
register_parameter
(
"qzeros"
,
qzeros
)
set_weight_attrs
(
qzeros
,
extra_weight_attrs
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
layer
.
input_size_per_partition
=
input_size_per_partition
layer
.
output_size_per_partition
=
output_size_per_partition
layer
.
num_groups
=
num_groups
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
# Repack weights from AWQ format to marlin format.
marlin_qweight
=
ops
.
awq_marlin_repack
(
layer
.
qweight
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_tensor
(
layer
,
"qweight"
,
marlin_qweight
)
# Permute scales from AWQ format to marlin format.
marlin_scales
=
marlin_permute_scales
(
layer
.
scales
,
size_k
=
layer
.
input_size_per_partition
,
size_n
=
layer
.
output_size_per_partition
,
group_size
=
self
.
quant_config
.
group_size
)
replace_tensor
(
layer
,
"scales"
,
marlin_scales
)
# Permute zero-points from AWQ format to marlin format.
marlin_zp
=
awq_to_marlin_zero_points
(
layer
.
qzeros
,
size_k
=
layer
.
num_groups
,
size_n
=
layer
.
output_size_per_partition
,
num_bits
=
self
.
quant_config
.
weight_bits
)
replace_tensor
(
layer
,
"qzeros"
,
marlin_zp
)
# Not-used
layer
.
g_idx
=
marlin_make_empty_g_idx
(
device
)
layer
.
g_idx_sort_indices
=
marlin_make_empty_g_idx
(
device
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
return
apply_awq_marlin_linear
(
input
=
x
,
weight
=
layer
.
qweight
,
weight_scale
=
layer
.
scales
,
weight_zp
=
layer
.
qzeros
,
g_idx
=
layer
.
g_idx
,
g_idx_sort_indices
=
layer
.
g_idx_sort_indices
,
workspace
=
layer
.
workspace
,
num_bits
=
self
.
quant_config
.
weight_bits
,
output_size_per_partition
=
layer
.
output_size_per_partition
,
input_size_per_partition
=
layer
.
input_size_per_partition
,
bias
=
bias
)
vllm/model_executor/layers/quantization/base_config.py
View file @
500b93c8
...
...
@@ -97,12 +97,13 @@ class QuantizationConfig(ABC):
return
default
@
abstractmethod
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
QuantizeMethodBase
]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns:
The quantize method. None if the given layer doesn't support quant
method.
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment