Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
...@@ -11,6 +11,8 @@ logger = init_logger(__name__) ...@@ -11,6 +11,8 @@ logger = init_logger(__name__)
class NeuronExecutor(ExecutorBase): class NeuronExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert (self.lora_config is assert (self.lora_config is
None), "LoRA is not supported for Neuron backend." None), "LoRA is not supported for Neuron backend."
......
...@@ -18,6 +18,8 @@ logger = init_logger(__name__) ...@@ -18,6 +18,8 @@ logger = init_logger(__name__)
class OpenVINOExecutor(ExecutorBase): class OpenVINOExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert self.device_config.device_type == "openvino" assert self.device_config.device_type == "openvino"
assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" assert self.lora_config is None, "OpenVINO backend doesn't support LoRA"
......
import asyncio import asyncio
import os import os
import pickle
from collections import defaultdict from collections import defaultdict
from itertools import islice, repeat from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
...@@ -11,7 +10,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable ...@@ -11,7 +10,8 @@ from vllm.executor.distributed_gpu_executor import ( # yapf: disable
from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.ray_utils import RayWorkerWrapper, ray
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (error_on_invalid_device_count_status, from vllm.utils import (_run_task_with_lock,
error_on_invalid_device_count_status,
get_distributed_init_method, get_ip, get_open_port, get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
...@@ -23,13 +23,33 @@ if TYPE_CHECKING: ...@@ -23,13 +23,33 @@ if TYPE_CHECKING:
logger = init_logger(__name__) logger = init_logger(__name__)
USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayGPUExecutor(DistributedGPUExecutor): class RayGPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert self.parallel_config.distributed_executor_backend == "ray" # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Currently, this requires USE_RAY_SPMD_WORKER=True.
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
# If the env var is set, then we do not distinguish between the
# "driver worker" vs other workers. Also, the rank 0 worker will
# be executed in a remote Ray worker. Currently this requires
# USE_RAY_COMPILED_DAG=True.
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if self.use_ray_compiled_dag:
assert self.use_ray_spmd_worker, (
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
"VLLM_USE_RAY_SPMD_WORKER=1")
if self.use_ray_spmd_worker:
# TODO: Support SPMD worker for non-DAG Ray executor.
assert self.use_ray_compiled_dag, (
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
"VLLM_USE_RAY_COMPILED_DAG=1")
assert self.uses_ray
placement_group = self.parallel_config.placement_group placement_group = self.parallel_config.placement_group
# Disable Ray usage stats collection. # Disable Ray usage stats collection.
...@@ -40,11 +60,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -40,11 +60,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
self.forward_dag = None self.forward_dag: Optional["ray.dag.CompiledDAG"] = None
if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag()
self.extra_execute_model_run_workers_kwargs[
"use_ray_compiled_dag"] = True
def _configure_ray_workers_use_nsight(self, def _configure_ray_workers_use_nsight(self,
ray_remote_kwargs) -> Dict[str, Any]: ray_remote_kwargs) -> Dict[str, Any]:
...@@ -61,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -61,6 +77,20 @@ class RayGPUExecutor(DistributedGPUExecutor):
return ray_remote_kwargs return ray_remote_kwargs
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
return dict(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1 if (self.parallel_config.tensor_parallel_size == 1
...@@ -83,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -83,6 +113,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs): for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0): if not bundle.get("GPU", 0):
continue continue
...@@ -92,39 +123,28 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -92,39 +123,28 @@ class RayGPUExecutor(DistributedGPUExecutor):
placement_group_bundle_index=bundle_id, placement_group_bundle_index=bundle_id,
) )
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote( )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
worker_ip = ray.get(worker.get_node_ip.remote()) if self.use_ray_spmd_worker:
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
else:
# Else, added to the list of workers.
self.workers.append(worker) self.workers.append(worker)
else:
if self.driver_dummy_worker is None: worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper(
**worker_wrapper_kwargs)
else:
# Else, added to the list of workers.
self.workers.append(worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError( raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider " "Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a " "adjusting the Ray placement group or running the driver on a "
...@@ -224,13 +244,14 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -224,13 +244,14 @@ class RayGPUExecutor(DistributedGPUExecutor):
# broadcasted to. # broadcasted to.
self.non_driver_workers: List[RayWorkerWrapper] = [] self.non_driver_workers: List[RayWorkerWrapper] = []
for idx, rank in enumerate(worker_ranks[1:]): # Enforce rank order for correct rank to return final output.
for rank, worker in sorted(zip(worker_ranks[1:], self.workers)):
# We need to skip the driver worker, which we # We need to skip the driver worker, which we
# do by skipping worker_ranks[0] which is always 0. # do by skipping worker_ranks[0] which is always 0.
if rank % self.parallel_config.tensor_parallel_size == 0: if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(self.workers[idx]) self.tp_driver_workers.append(worker)
else: else:
self.non_driver_workers.append(self.workers[idx]) self.non_driver_workers.append(worker)
def _driver_execute_model( def _driver_execute_model(
self, execute_model_req: Optional[ExecuteModelRequest] self, execute_model_req: Optional[ExecuteModelRequest]
...@@ -240,9 +261,23 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -240,9 +261,23 @@ class RayGPUExecutor(DistributedGPUExecutor):
Passing None will cause the driver to stop the model execution Passing None will cause the driver to stop the model execution
loop running in each of the remote workers. loop running in each of the remote workers.
""" """
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
return self.driver_worker.execute_method("execute_model", return self.driver_worker.execute_method("execute_model",
execute_model_req) execute_model_req)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return super().execute_model(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
outputs = ray.get(self.forward_dag.execute(execute_model_req))
return outputs[0]
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
...@@ -252,7 +287,6 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -252,7 +287,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. Can be used in the following """Runs the given method on all workers. Can be used in the following
...@@ -267,6 +301,10 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -267,6 +301,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
- all_args/all_kwargs: args/kwargs for each worker are specified - all_args/all_kwargs: args/kwargs for each worker are specified
individually individually
""" """
if self.use_ray_spmd_worker:
assert not async_run_tensor_parallel_workers_only, (
"async_run_tensor_parallel_workers_only is not supported for "
"spmd mode.")
if max_concurrent_workers: if max_concurrent_workers:
raise NotImplementedError( raise NotImplementedError(
...@@ -275,99 +313,125 @@ class RayGPUExecutor(DistributedGPUExecutor): ...@@ -275,99 +313,125 @@ class RayGPUExecutor(DistributedGPUExecutor):
count = len(self.workers) if not \ count = len(self.workers) if not \
async_run_tensor_parallel_workers_only \ async_run_tensor_parallel_workers_only \
else len(self.non_driver_workers) else len(self.non_driver_workers)
# If using SPMD worker, all workers are the same, so we should execute
# the args on all workers. Otherwise, we skip the first worker's args
# because those args will go to the driver worker.
first_worker_args_index: int = 0 if self.use_ray_spmd_worker else 1
all_worker_args = repeat(args, count) if all_args is None \ all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None) else islice(all_args, first_worker_args_index, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None) else islice(all_kwargs, first_worker_args_index, None)
if use_ray_compiled_dag: # Start the ray workers first.
# Right now, compiled DAG can only accept a single ray_workers = self.workers
# input. TODO(sang): Fix it. if async_run_tensor_parallel_workers_only:
assert self.forward_dag is not None ray_workers = self.non_driver_workers
output_channels = self.forward_dag.execute(1) ray_worker_outputs = [
ray_worker_outputs = [] worker.execute_method.remote(method, *worker_args, **worker_kwargs)
else: for (worker, worker_args, worker_kwargs
# Start the ray workers first. ) in zip(ray_workers, all_worker_args, all_worker_kwargs)
ray_workers = self.workers ]
if async_run_tensor_parallel_workers_only:
ray_workers = self.non_driver_workers
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(ray_workers, all_worker_args, all_worker_kwargs)
]
if async_run_tensor_parallel_workers_only: if async_run_tensor_parallel_workers_only:
# Just return futures # Just return futures
return ray_worker_outputs return ray_worker_outputs
driver_args = args if all_args is None else all_args[0] driver_worker_output = []
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] # In SPMD mode, the driver worker is the same as any other worker,
# so we only explicitly execute on the driver worker if using a
# non-SPMD worker class.
if not self.use_ray_spmd_worker:
driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = [
self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
]
else:
assert self.driver_dummy_worker is not None
driver_worker_output = [
ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
]
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *driver_args, **driver_kwargs)
else:
assert self.driver_dummy_worker is not None
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers. # Get the results of the ray workers.
if self.workers: if self.workers:
if use_ray_compiled_dag: ray_worker_outputs = ray.get(ray_worker_outputs)
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with """Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete.""" async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks) ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self): def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources import pkg_resources
required_version = "2.9" from packaging import version
current_version = pkg_resources.get_distribution("ray").version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version: if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is " raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}") f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.distributed_executor_backend == "ray" assert self.parallel_config.use_ray
# Right now, compiled DAG requires at least 1 arg. We send # Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon. # a dummy value for now. It will be fixed soon.
with InputNode() as input_data: with InputNode() as input_data:
forward_dag = MultiOutputNode([ forward_dag = MultiOutputNode([
worker.execute_model_compiled_dag_remote. worker.execute_model_spmd.bind( # type: ignore[attr-defined]
bind( # type: ignore[attr-defined]
input_data) for worker in self.workers input_data) for worker in self.workers
]) ])
return forward_dag.experimental_compile() return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.driver_exec_method = make_async(self.driver_worker.execute_method) self.pp_locks: Optional[List[asyncio.Lock]] = None
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
if not self.use_ray_compiled_dag:
self.driver_exec_method = make_async(
self.driver_worker.execute_method)
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
if not self.use_ray_spmd_worker:
return await super().execute_model_async(execute_model_req)
if self.forward_dag is None:
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
dag_future = await self.forward_dag.execute_async(execute_model_req)
outputs = await dag_future
return outputs[0]
async def _driver_execute_model_async( async def _driver_execute_model_async(
self, self,
execute_model_req: Optional[ExecuteModelRequest] = None execute_model_req: Optional[ExecuteModelRequest] = None
) -> List[SamplerOutput]: ) -> List[SamplerOutput]:
assert not self.use_ray_spmd_worker, (
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
if not self.tp_driver_workers:
return await self.driver_exec_method("execute_model",
execute_model_req)
if self.pp_locks is None: if self.pp_locks is None:
# This locks each pipeline parallel stage so multiple virtual # This locks each pipeline parallel stage so multiple virtual
# engines can't execute on the same stage at the same time # engines can't execute on the same stage at the same time
...@@ -378,15 +442,11 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -378,15 +442,11 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
for _ in range(self.parallel_config.pipeline_parallel_size) for _ in range(self.parallel_config.pipeline_parallel_size)
] ]
async def _run_task_with_lock(task, lock, *args, **kwargs): tasks = [
async with lock:
return await task(*args, **kwargs)
tasks = []
tasks.append(
asyncio.create_task( asyncio.create_task(
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0], _run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
"execute_model", execute_model_req))) "execute_model", execute_model_req))
]
for pp_rank, driver_worker in enumerate(self.tp_driver_workers, for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
start=1): start=1):
tasks.append( tasks.append(
...@@ -401,8 +461,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): ...@@ -401,8 +461,17 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync):
return results[-1] return results[-1]
async def _start_worker_execution_loop(self): async def _start_worker_execution_loop(self):
assert not self.use_ray_spmd_worker, (
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
coros = [ coros = [
worker.execute_method.remote("start_worker_execution_loop") worker.execute_method.remote("start_worker_execution_loop")
for worker in self.non_driver_workers for worker in self.non_driver_workers
] ]
return await asyncio.gather(*coros) return await asyncio.gather(*coros)
def __del__(self):
if self.forward_dag is not None:
self.forward_dag.teardown()
import ray
for worker in self.workers:
ray.kill(worker)
import pickle
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest
from vllm.utils import get_ip, is_hip, is_xpu from vllm.utils import get_ip, is_hip, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
...@@ -31,16 +31,18 @@ try: ...@@ -31,16 +31,18 @@ try:
gpu_ids = ray.get_gpu_ids() gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids return node_id, gpu_ids
def execute_model_compiled_dag_remote(self, ignored): def execute_model_spmd(self, execute_model_req: ExecuteModelRequest):
"""Used only when compiled DAG is enabled.""" """Used only when SPMD worker and compiled DAG are both
enabled."""
# TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current
# device.
import torch import torch
if not self.compiled_dag_cuda_device_set: if not self.compiled_dag_cuda_device_set:
torch.cuda.set_device(self.worker.device) torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
output = self.worker.execute_model() return self.worker._execute_model_spmd(execute_model_req)
output = pickle.dumps(output)
return output
ray_import_err = None ray_import_err = None
......
import asyncio import asyncio
import os import os
import pickle
from collections import defaultdict from collections import defaultdict
from itertools import islice, repeat from itertools import islice, repeat
from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set, from typing import (TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Set,
Tuple, Union) Tuple, Union)
import vllm.envs as envs
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, MultiModalConfig, ParallelConfig, ModelConfig, MultiModalConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig, PromptAdapterConfig, SchedulerConfig,
...@@ -30,11 +30,13 @@ logger = init_logger(__name__) ...@@ -30,11 +30,13 @@ logger = init_logger(__name__)
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) USE_RAY_COMPILED_DAG = envs.VLLM_USE_RAY_COMPILED_DAG
class RayXPUExecutor(DistributedGPUExecutor): class RayXPUExecutor(DistributedGPUExecutor):
uses_ray: bool = True
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
...@@ -72,10 +74,9 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -72,10 +74,9 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the parallel GPU workers. # Create the parallel GPU workers.
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
# Profile the memory usage and initialize the cache.
self.forward_dag = None self.forward_dag = None
if USE_RAY_COMPILED_DAG: if USE_RAY_COMPILED_DAG:
self.forward_dag = self._compiled_ray_dag() self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
# This is non-None when the execute model loop is running # This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case. # in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
...@@ -108,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -108,6 +109,13 @@ class RayXPUExecutor(DistributedGPUExecutor):
return num_gpu_blocks, num_cpu_blocks return num_gpu_blocks, num_cpu_blocks
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
return dict(
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
if self.parallel_config.tensor_parallel_size == 1: if self.parallel_config.tensor_parallel_size == 1:
...@@ -125,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -125,6 +133,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs): for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0): if not bundle.get("GPU", 0):
continue continue
...@@ -138,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -138,22 +147,14 @@ class RayXPUExecutor(DistributedGPUExecutor):
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=scheduling_strategy, scheduling_strategy=scheduling_strategy,
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerWrapper).remote( )(RayWorkerWrapper).remote(**worker_wrapper_kwargs)
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
worker_ip = ray.get(worker.get_node_ip.remote()) worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None: if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it # If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process. # as the resource holder for the driver process.
self.driver_dummy_worker = worker self.driver_dummy_worker = worker
self.driver_worker = RayWorkerWrapper( self.driver_worker = RayWorkerWrapper(**worker_wrapper_kwargs)
worker_module_name="vllm.worker.xpu_worker",
worker_class_name="XPUWorker",
trust_remote_code=self.model_config.trust_remote_code,
)
else: else:
# Else, added to the list of workers. # Else, added to the list of workers.
self.workers.append(worker) self.workers.append(worker)
...@@ -270,7 +271,6 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -270,7 +271,6 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_kwargs: Optional[List[Dict[str, Any]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False, use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers. Can be used in the following """Runs the given method on all workers. Can be used in the following
...@@ -293,26 +293,20 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -293,26 +293,20 @@ class RayXPUExecutor(DistributedGPUExecutor):
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None) else islice(all_kwargs, 1, None)
if use_ray_compiled_dag: # Start the ray workers first.
# Right now, compiled DAG can only accept a single ray_worker_outputs = [
# input. TODO(sang): Fix it. worker.execute_method.remote(method, *worker_args, **worker_kwargs)
assert self.forward_dag is not None for (worker, worker_args, worker_kwargs
output_channels = self.forward_dag.execute(1) ) in zip(self.workers, all_worker_args, all_worker_kwargs)
else: ]
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if async_run_remote_workers_only: if async_run_remote_workers_only:
# Just return futures # Just return futures
return ray_worker_outputs return ray_worker_outputs
driver_worker_output = []
driver_args = args if all_args is None else all_args[0] driver_args = args if all_args is None else all_args[0]
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
# Start the driver worker after all the ray workers. # Start the driver worker after all the ray workers.
if not use_dummy_driver: if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method( driver_worker_output = self.driver_worker.execute_method(
...@@ -324,36 +318,28 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -324,36 +318,28 @@ class RayXPUExecutor(DistributedGPUExecutor):
method, *driver_args, **driver_kwargs)) method, *driver_args, **driver_kwargs))
# Get the results of the ray workers. # Get the results of the ray workers.
if self.workers: if self.workers:
if use_ray_compiled_dag: ray_worker_outputs = ray.get(ray_worker_outputs)
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs return driver_worker_output + ray_worker_outputs
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
"""Wait for futures returned from _run_workers() with """Wait for futures returned from _run_workers() with
async_run_remote_workers_only to complete.""" async_run_remote_workers_only to complete."""
ray.get(parallel_worker_tasks) ray.get(parallel_worker_tasks)
def _compiled_ray_dag(self): def _compiled_ray_dag(self, enable_asyncio: bool):
import pkg_resources import pkg_resources
required_version = "2.9" from packaging import version
current_version = pkg_resources.get_distribution("ray").version
required_version = version.parse("2.32")
current_version = version.parse(
pkg_resources.get_distribution("ray").version)
if current_version < required_version: if current_version < required_version:
raise ValueError(f"Ray version {required_version} or greater is " raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}") f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.worker_use_ray assert self.parallel_config.use_ray
# Right now, compiled DAG requires at least 1 arg. We send # Right now, compiled DAG requires at least 1 arg. We send
# a dummy value for now. It will be fixed soon. # a dummy value for now. It will be fixed soon.
...@@ -363,7 +349,7 @@ class RayXPUExecutor(DistributedGPUExecutor): ...@@ -363,7 +349,7 @@ class RayXPUExecutor(DistributedGPUExecutor):
bind( # type: ignore[attr-defined] bind( # type: ignore[attr-defined]
input_data) for worker in self.workers input_data) for worker in self.workers
]) ])
return forward_dag.experimental_compile() return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def check_health(self) -> None: def check_health(self) -> None:
"""Raises an error if engine is unhealthy.""" """Raises an error if engine is unhealthy."""
......
...@@ -14,6 +14,8 @@ logger = init_logger(__name__) ...@@ -14,6 +14,8 @@ logger = init_logger(__name__)
class TPUExecutor(ExecutorBase): class TPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None: def _init_executor(self) -> None:
assert not self.scheduler_config.chunked_prefill_enabled, ( assert not self.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend") "Chunked prefill is not yet supported for TPU backend")
......
...@@ -18,6 +18,8 @@ logger = init_logger(__name__) ...@@ -18,6 +18,8 @@ logger = init_logger(__name__)
class XPUExecutor(GPUExecutor): class XPUExecutor(GPUExecutor):
uses_ray: bool = False
def __init__( def __init__(
self, self,
model_config: ModelConfig, model_config: ModelConfig,
......
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs, from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
PromptStrictInputs, TextPrompt, TextTokensPrompt, TextPrompt, TokensPrompt, parse_and_batch_prompt)
TokensPrompt, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry() INPUT_REGISTRY = InputRegistry()
...@@ -14,6 +13,6 @@ See also: ...@@ -14,6 +13,6 @@ See also:
__all__ = [ __all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt", "ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs", "TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry" "InputContext", "InputRegistry"
] ]
...@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict): ...@@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
""" """
class TextTokensPrompt(TypedDict): PromptInputs = Union[str, TextPrompt, TokensPrompt]
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
""" """
The inputs to the LLM, which can take one of the following forms: The inputs to the LLM, which can take one of the following forms:
...@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms: ...@@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
- A tokenized prompt (:class:`TokensPrompt`) - A tokenized prompt (:class:`TokensPrompt`)
""" """
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict): class LLMInputs(TypedDict):
""" """
......
from dataclasses import dataclass import warnings
from dataclasses import dataclass, field
from typing import Optional from typing import Optional
from vllm.adapter_commons.request import AdapterRequest from vllm.adapter_commons.request import AdapterRequest
...@@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest): ...@@ -20,10 +21,25 @@ class LoRARequest(AdapterRequest):
lora_name: str lora_name: str
lora_int_id: int lora_int_id: int
lora_local_path: str lora_path: str = ""
lora_local_path: Optional[str] = field(default=None, repr=False)
long_lora_max_len: Optional[int] = None long_lora_max_len: Optional[int] = None
__hash__ = AdapterRequest.__hash__ __hash__ = AdapterRequest.__hash__
def __post_init__(self):
if 'lora_local_path' in self.__dict__:
warnings.warn(
"The 'lora_local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'lora_path' instead.",
DeprecationWarning,
stacklevel=2)
if not self.lora_path:
self.lora_path = self.lora_local_path or ""
# Ensure lora_path is not empty
assert self.lora_path, "lora_path cannot be empty"
@property @property
def adapter_id(self): def adapter_id(self):
return self.lora_int_id return self.lora_int_id
...@@ -32,6 +48,26 @@ class LoRARequest(AdapterRequest): ...@@ -32,6 +48,26 @@ class LoRARequest(AdapterRequest):
def name(self): def name(self):
return self.lora_name return self.lora_name
@property
def path(self):
return self.lora_path
@property @property
def local_path(self): def local_path(self):
return self.lora_local_path warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2)
return self.lora_path
@local_path.setter
def local_path(self, value):
warnings.warn(
"The 'local_path' attribute is deprecated "
"and will be removed in a future version. "
"Please use 'path' instead.",
DeprecationWarning,
stacklevel=2)
self.lora_path = value
import os
from typing import List, Optional, Set, Tuple, Type from typing import List, Optional, Set, Tuple, Type
import huggingface_hub
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, RepositoryNotFoundError)
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: ...@@ -105,3 +109,46 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
raise ValueError(f"{name} is unsupported LoRA weight") raise ValueError(f"{name} is unsupported LoRA weight")
def get_adapter_absolute_path(lora_path: str) -> str:
"""
Resolves the given lora_path to an absolute local path.
If the lora_path is identified as a Hugging Face model identifier,
it will download the model and return the local snapshot path.
Otherwise, it treats the lora_path as a local file path and
converts it to an absolute path.
Parameters:
lora_path (str): The path to the lora model, which can be an absolute path,
a relative path, or a Hugging Face model identifier.
Returns:
str: The resolved absolute local path to the lora model.
"""
# Check if the path is an absolute path. Return it no matter exists or not.
if os.path.isabs(lora_path):
return lora_path
# If the path starts with ~, expand the user home directory.
if lora_path.startswith('~'):
return os.path.expanduser(lora_path)
# Check if the expanded relative path exists locally.
if os.path.exists(lora_path):
return os.path.abspath(lora_path)
# If the path does not exist locally, assume it's a Hugging Face repo.
try:
local_snapshot_path = huggingface_hub.snapshot_download(
repo_id=lora_path)
except (HfHubHTTPError, RepositoryNotFoundError, EntryNotFoundError,
HFValidationError):
# Handle errors that may occur during the download
# Return original path instead instead of throwing error here
logger.exception("Error downloading the HuggingFace model")
return lora_path
return local_snapshot_path
...@@ -13,6 +13,7 @@ from vllm.logger import init_logger ...@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.lora.models import (LoRAModel, LoRAModelManager, from vllm.lora.models import (LoRAModel, LoRAModelManager,
LRUCacheLoRAModelManager, create_lora_manager) LRUCacheLoRAModelManager, create_lora_manager)
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.lora.utils import get_adapter_absolute_path
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -89,8 +90,9 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -89,8 +90,9 @@ class WorkerLoRAManager(AbstractWorkerManager):
packed_modules_mapping[module]) packed_modules_mapping[module])
else: else:
expected_lora_modules.append(module) expected_lora_modules.append(module)
lora_path = get_adapter_absolute_path(lora_request.lora_path)
lora = self._lora_model_cls.from_local_checkpoint( lora = self._lora_model_cls.from_local_checkpoint(
lora_request.lora_local_path, lora_path,
expected_lora_modules, expected_lora_modules,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
lora_model_id=lora_request.lora_int_id, lora_model_id=lora_request.lora_int_id,
...@@ -102,8 +104,7 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -102,8 +104,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
embedding_padding_modules=self.embedding_padding_modules, embedding_padding_modules=self.embedding_padding_modules,
) )
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(f"Loading lora {lora_path} failed") from e
f"Loading lora {lora_request.lora_local_path} failed") from e
if lora.rank > self.lora_config.max_lora_rank: if lora.rank > self.lora_config.max_lora_rank:
raise ValueError( raise ValueError(
f"LoRA rank {lora.rank} is greater than max_lora_rank " f"LoRA rank {lora.rank} is greater than max_lora_rank "
......
...@@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank, ...@@ -7,7 +7,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -36,7 +36,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
class UnquantizedFusedMoEMethod(FusedMoEMethodBase): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
...@@ -61,19 +61,37 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase): ...@@ -61,19 +61,37 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
router_logits: torch.Tensor, x: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
renormalize: bool = True, top_k: int,
use_grouped_topk: bool = False, renormalize: bool = True,
num_expert_group: Optional[int] = None, use_grouped_topk: bool = False,
topk_group: Optional[int] = None) -> torch.Tensor: num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None,
) -> torch.Tensor:
return self.forward(x, layer.w13_weight, layer.w2_weight,
router_logits, top_k, renormalize,
use_grouped_topk, num_expert_group, topk_group)
def forward_cuda(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
return fused_moe(x, return fused_moe(x,
layer.w13_weight, w1,
layer.w2_weight, w2,
router_logits, router_logits,
top_k, top_k,
renormalize=renormalize, renormalize=renormalize,
...@@ -82,6 +100,28 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase): ...@@ -82,6 +100,28 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
topk_group=topk_group) topk_group=topk_group)
def forward_cpu(self, *args, **kwargs):
raise NotImplementedError(
"The CPU backend currently does not support MoE.")
def forward_tpu(
self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool,
num_expert_group: Optional[int],
topk_group: Optional[int],
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
return fused_moe(x, w1, w2, router_logits, top_k, renormalize)
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
...@@ -118,6 +158,7 @@ class FusedMoE(torch.nn.Module): ...@@ -118,6 +158,7 @@ class FusedMoE(torch.nn.Module):
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -141,7 +182,7 @@ class FusedMoE(torch.nn.Module): ...@@ -141,7 +182,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method: Optional[QuantizeMethodBase] = ( self.quant_method: Optional[QuantizeMethodBase] = (
UnquantizedFusedMoEMethod()) UnquantizedFusedMoEMethod())
else: else:
self.quant_method = quant_config.get_quant_method(self) self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights( self.quant_method.create_weights(
......
import torch
import torch.nn.functional as F
from torch_xla.experimental.custom_kernel import _histogram
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
device = hidden_states.device
dtype = hidden_states.dtype
assert (num_tokens * topk) % 16 == 0, (
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens * topk}")
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
topk_indices = topk_indices.flatten()
topk_argsort_indices = topk_indices.argsort()
topk_argsort_revert_indices = topk_argsort_indices.argsort()
token_indices = torch.arange(num_tokens,
device=device).repeat_interleave(topk)
token_indices = token_indices[topk_argsort_indices]
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
# NOTE(woosuk): The GMM Pallas kernel requires a different weight layout
# from HF Transformers.
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)
x = hidden_states[token_indices]
x = torch.ops.xla.gmm(x, w1, group_sizes)
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
x = torch.ops.xla.gmm(x, w2, group_sizes)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * topk_weights.unsqueeze_(dim=-1)
x = x.sum(dim=-2)
x = x.reshape(orig_shape)
return x
...@@ -160,6 +160,7 @@ class LinearBase(torch.nn.Module): ...@@ -160,6 +160,7 @@ class LinearBase(torch.nn.Module):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -174,7 +175,8 @@ class LinearBase(torch.nn.Module): ...@@ -174,7 +175,8 @@ class LinearBase(torch.nn.Module):
self.quant_method: Optional[ self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod() QuantizeMethodBase] = UnquantizedLinearMethod()
else: else:
self.quant_method = quant_config.get_quant_method(self) self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError raise NotImplementedError
...@@ -190,6 +192,8 @@ class ReplicatedLinear(LinearBase): ...@@ -190,6 +192,8 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: If true, skip adding bias but instead return it. skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
...@@ -198,15 +202,23 @@ class ReplicatedLinear(LinearBase): ...@@ -198,15 +202,23 @@ class ReplicatedLinear(LinearBase):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
super().__init__(input_size, output_size, skip_bias_add, params_dtype, prefix: str = ""):
quant_config) super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix)
# All the linear layer supports quant method. # All the linear layer supports quant method.
assert self.quant_method is not None assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size, self.quant_method.create_weights(self,
[self.output_size], self.input_size, self.input_size, [self.output_size],
self.output_size, self.params_dtype) self.input_size,
self.output_size,
self.params_dtype,
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
...@@ -215,6 +227,15 @@ class ReplicatedLinear(LinearBase): ...@@ -215,6 +227,15 @@ class ReplicatedLinear(LinearBase):
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None assert self.quant_method is not None
...@@ -249,6 +270,8 @@ class ColumnParallelLinear(LinearBase): ...@@ -249,6 +270,8 @@ class ColumnParallelLinear(LinearBase):
quant_config: Quantization configure. quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3. the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
...@@ -259,9 +282,10 @@ class ColumnParallelLinear(LinearBase): ...@@ -259,9 +282,10 @@ class ColumnParallelLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[List[int]] = None): output_sizes: Optional[List[int]] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config, prefix)
self.gather_output = gather_output self.gather_output = gather_output
...@@ -286,7 +310,8 @@ class ColumnParallelLinear(LinearBase): ...@@ -286,7 +310,8 @@ class ColumnParallelLinear(LinearBase):
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader,
prefix=prefix)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
...@@ -358,6 +383,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -358,6 +383,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
...@@ -367,7 +394,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -367,7 +394,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output: bool = False, gather_output: bool = False,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
self.output_sizes = output_sizes self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes) assert all(output_size % tp_size == 0 for output_size in output_sizes)
...@@ -377,7 +405,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -377,7 +405,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output=gather_output, gather_output=gather_output,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config,
prefix=prefix)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -497,6 +526,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -497,6 +526,8 @@ class QKVParallelLinear(ColumnParallelLinear):
skip adding bias but instead return it. skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
""" """
def __init__(self, def __init__(self,
...@@ -507,7 +538,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -507,7 +538,8 @@ class QKVParallelLinear(ColumnParallelLinear):
bias: bool = True, bias: bool = True,
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
...@@ -539,7 +571,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -539,7 +571,8 @@ class QKVParallelLinear(ColumnParallelLinear):
gather_output=False, gather_output=False,
skip_bias_add=skip_bias_add, skip_bias_add=skip_bias_add,
params_dtype=params_dtype, params_dtype=params_dtype,
quant_config=quant_config) quant_config=quant_config,
prefix=prefix)
def weight_loader(self, def weight_loader(self,
param: Parameter, param: Parameter,
...@@ -698,14 +731,16 @@ class RowParallelLinear(LinearBase): ...@@ -698,14 +731,16 @@ class RowParallelLinear(LinearBase):
skip_bias_add: bool = False, skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True, reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None): quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__(input_size, output_size, skip_bias_add, params_dtype, super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config) quant_config, prefix)
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results self.reduce_results = reduce_results
# Divide the weight matrix along the last dimension. # Divide the weight matrix along the last dimension.
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
assert self.quant_method is not None assert self.quant_method is not None
...@@ -716,7 +751,8 @@ class RowParallelLinear(LinearBase): ...@@ -716,7 +751,8 @@ class RowParallelLinear(LinearBase):
input_size=self.input_size, input_size=self.input_size,
output_size=self.output_size, output_size=self.output_size,
params_dtype=self.params_dtype, params_dtype=self.params_dtype,
weight_loader=self.weight_loader) weight_loader=self.weight_loader,
prefix=prefix)
if not reduce_results and (bias and not skip_bias_add): if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the " raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results") "results can lead to incorrect results")
...@@ -760,18 +796,19 @@ class RowParallelLinear(LinearBase): ...@@ -760,18 +796,19 @@ class RowParallelLinear(LinearBase):
# Matrix multiply. # Matrix multiply.
assert self.quant_method is not None assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel) # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
else: else:
output_ = output_parallel output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias return output, output_bias
def extra_repr(self) -> str: def extra_repr(self) -> str:
......
...@@ -2,6 +2,7 @@ from typing import Dict, Type ...@@ -2,6 +2,7 @@ from typing import Dict, Type
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.bitsandbytes import ( from vllm.model_executor.layers.quantization.bitsandbytes import (
...@@ -10,6 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso ...@@ -10,6 +11,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig) CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import ( from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig) DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
...@@ -24,11 +26,13 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { ...@@ -24,11 +26,13 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"awq": AWQConfig, "awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
"fp8": Fp8Config, "fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config,
# The order of gptq methods is important for config.py iteration over # The order of gptq methods is important for config.py iteration over
# override_quantization_method(..) # override_quantization_method(..)
"marlin": MarlinConfig, "marlin": MarlinConfig,
"gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig, "gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"compressed-tensors": CompressedTensorsConfig, "compressed-tensors": CompressedTensorsConfig,
......
...@@ -207,8 +207,8 @@ class AQLMConfig(QuantizationConfig): ...@@ -207,8 +207,8 @@ class AQLMConfig(QuantizationConfig):
return cls(in_group_size, nbits_per_codebook, num_code_books, return cls(in_group_size, nbits_per_codebook, num_code_books,
out_group_size) out_group_size)
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]: prefix: str) -> Optional["AQLMLinearMethod"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return AQLMLinearMethod(self) return AQLMLinearMethod(self)
return None return None
......
...@@ -63,8 +63,8 @@ class AWQConfig(QuantizationConfig): ...@@ -63,8 +63,8 @@ class AWQConfig(QuantizationConfig):
zero_point = cls.get_from_keys(config, ["zero_point"]) zero_point = cls.get_from_keys(config, ["zero_point"])
return cls(weight_bits, group_size, zero_point) return cls(weight_bits, group_size, zero_point)
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]: prefix: str) -> Optional["AWQLinearMethod"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return AWQLinearMethod(self) return AWQLinearMethod(self)
return None return None
......
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points,
check_awq_marlin_supported, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_awq_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__)
class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, has_zp: bool,
lm_head_quantized: bool) -> None:
self.weight_bits = weight_bits
self.pack_factor = 32 // self.weight_bits # packed into int32
self.group_size = group_size
self.has_zp = has_zp
self.lm_head_quantized = lm_head_quantized
verify_awq_marlin_supported(num_bits=self.weight_bits,
group_size=self.group_size,
has_zp=self.has_zp)
def __repr__(self) -> str:
return (f"AWQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"has_zp={self.has_zp}, "
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
return "awq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
has_zp = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
return cls(weight_bits, group_size, has_zp, lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info("Detected that the model can run with awq_marlin"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_marlin for"
" faster inference")
return None
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["AWQMarlinLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return AWQMarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_awq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
has_zp = quant_config.get("zero_point", None)
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or has_zp is None):
return False
return check_awq_marlin_supported(
num_bits=num_bits,
group_size=group_size,
has_zp=has_zp,
min_capability=cls.get_min_capability())
class AWQMarlinLinearMethod(LinearMethodBase):
"""Linear method for AWQ Marlin.
Args:
quant_config: The AWQ Marlin quantization config.
"""
def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)
qweight = Parameter(
torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
num_groups = input_size_per_partition // group_size
qzeros = Parameter(
torch.empty(
num_groups,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
num_groups,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": 0,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("qzeros", qzeros)
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.num_groups = num_groups
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.weight_bits)
replace_tensor(layer, "qzeros", marlin_zp)
# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_awq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.qzeros,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
num_bits=self.quant_config.weight_bits,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
bias=bias)
...@@ -97,12 +97,13 @@ class QuantizationConfig(ABC): ...@@ -97,12 +97,13 @@ class QuantizationConfig(ABC):
return default return default
@abstractmethod @abstractmethod
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: prefix: str) -> Optional[QuantizeMethodBase]:
"""Get the quantize method to use for the quantized layer. """Get the quantize method to use for the quantized layer.
Args: Args:
layer: The layer for the quant method. layer: The layer for the quant method.
prefix: The full name of the layer in the state dict
Returns: Returns:
The quantize method. None if the given layer doesn't support quant The quantize method. None if the given layer doesn't support quant
method. method.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment