Commit 6640dc0b authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 44d4d334 83e4e0fe
...@@ -5,7 +5,6 @@ from io import StringIO ...@@ -5,7 +5,6 @@ from io import StringIO
import aiohttp import aiohttp
import vllm
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (BatchRequestInput, from vllm.entrypoints.openai.protocol import (BatchRequestInput,
...@@ -15,6 +14,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat ...@@ -15,6 +14,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import random_uuid from vllm.utils import random_uuid
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -135,7 +135,7 @@ async def main(args): ...@@ -135,7 +135,7 @@ async def main(args):
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
logger.info("vLLM API server version %s", vllm.__version__) logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args) logger.info("args: %s", args)
asyncio.run(main(args)) asyncio.run(main(args))
...@@ -27,6 +27,7 @@ if TYPE_CHECKING: ...@@ -27,6 +27,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_XLA_CACHE_PATH: str = "~/.vllm/xla_cache/"
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_WORKER_MULTIPROC_METHOD: str = "spawn" VLLM_WORKER_MULTIPROC_METHOD: str = "spawn"
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
...@@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -217,6 +218,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# Default is 5 seconds # Default is 5 seconds
"VLLM_IMAGE_FETCH_TIMEOUT": "VLLM_IMAGE_FETCH_TIMEOUT":
lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")), lambda: int(os.getenv("VLLM_IMAGE_FETCH_TIMEOUT", "5")),
# Path to the XLA persistent cache directory.
# Only used for XLA devices such as TPUs.
"VLLM_XLA_CACHE_PATH":
lambda: os.getenv("VLLM_XLA_CACHE_PATH", "~/.vllm/xla_cache/"),
} }
# end-env-vars-definition # end-env-vars-definition
......
...@@ -9,7 +9,8 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ...@@ -9,7 +9,8 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
ResultHandler, WorkerMonitor) ResultHandler, WorkerMonitor)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, from vllm.utils import (cuda_device_count_stateless,
get_distributed_init_method, get_ip, get_open_port,
get_vllm_instance_id, make_async) get_vllm_instance_id, make_async)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -33,8 +34,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): ...@@ -33,8 +34,7 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Disable torch async compiling which won't work with daemonic processes # Disable torch async compiling which won't work with daemonic processes
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
from torch.cuda import device_count assert world_size <= cuda_device_count_stateless(), (
assert world_size <= device_count(), (
"please set tensor_parallel_size to less than max local gpu count") "please set tensor_parallel_size to less than max local gpu count")
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
......
from typing import List, Set, Tuple
import torch
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
logger = init_logger(__name__)
class TPUExecutor(ExecutorBase):
def _init_executor(self) -> None:
assert not self.scheduler_config.chunked_prefill_enabled, (
"Chunked prefill is not yet supported for TPU backend")
assert not self.speculative_config, (
"Speculative decoding is not yet supported for TPU backend")
if self.model_config.dtype in (torch.float16, torch.float32):
logger.warning(
"The TPU backend currently does not support %s. "
"Using bfloat16 instead.", self.model_config.dtype)
self.model_config.dtype = torch.bfloat16
# Instantiate the worker and load the model to the device.
self._init_worker()
def _init_worker(self):
from vllm.worker.tpu_worker import TPUWorker
assert self.parallel_config.world_size == 1, (
"TPUExecutor currently only supports a single TPU chip.")
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = TPUWorker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
self.cache_config,
self.load_config,
self.vision_language_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
)
self.driver_worker.init_device()
self.driver_worker.load_model()
def initialize_cache(
self,
num_gpu_blocks: int,
num_cpu_blocks: int,
) -> None:
"""Initialize the KV cache by invoking the underlying worker."""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# TPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def execute_model(
self,
execute_model_req: ExecuteModelRequest,
) -> List[SamplerOutput]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def list_loras(self) -> Set[int]:
raise NotImplementedError("LoRA is not implemented for TPU backend.")
def check_health(self) -> None:
# TPUExecutor will always be healthy as long as it's running.
return
class TPUExecutorAsync(TPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
sexecute_model_req: ExecuteModelRequest,
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model
)(sexecute_model_req)
return output
...@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence, ...@@ -4,7 +4,7 @@ from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
from typing_extensions import NotRequired from typing_extensions import NotRequired
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.sequence import MultiModalData from vllm.multimodal import MultiModalData
class ParsedText(TypedDict): class ParsedText(TypedDict):
......
import torch.nn as nn import torch.nn as nn
from vllm.utils import is_cpu, is_hip from vllm.utils import is_cpu, is_hip, is_tpu
class CustomOp(nn.Module): class CustomOp(nn.Module):
...@@ -56,5 +56,7 @@ class CustomOp(nn.Module): ...@@ -56,5 +56,7 @@ class CustomOp(nn.Module):
return self.forward_hip return self.forward_hip
elif is_cpu(): elif is_cpu():
return self.forward_cpu return self.forward_cpu
elif is_tpu():
return self.forward_tpu
else: else:
return self.forward_cuda return self.forward_cuda
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 64,
"num_warps": 8,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 5
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 5
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 3
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 4
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 4
},
"8": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 64,
"num_warps": 4,
"num_stages": 3
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"96": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"128": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 3
},
"1024": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 4
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
}
}
...@@ -7,8 +7,8 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase ...@@ -7,8 +7,8 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW8A8DynamicToken, CompressedTensorsScheme, CompressedTensorsW4A16,
CompressedTensorsW8A8StaticTensor) CompressedTensorsW8A8DynamicToken, CompressedTensorsW8A8StaticTensor)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match) QuantizationArgs, QuantizationStrategy, find_first_name_or_class_match)
...@@ -47,16 +47,27 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -47,16 +47,27 @@ class CompressedTensorsConfig(QuantizationConfig):
layer_quant_details: Dict[str, Any] = dict() layer_quant_details: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None) ignore: List[str] = config.get("ignore", None)
# The quant_config has multiple config_groups, each containing
# an input_activations key with details about how the activations are
# quantized, a weights key indicating how the weights are quantized,
# and a list of targets under the `targets` key, dictating which
# layers are impacted by the quantization details. The quantization
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for key, quant_config in config["config_groups"].items(): for key, quant_config in config["config_groups"].items():
targets = quant_config.get("targets") targets = quant_config.get("targets")
for target in targets: for target in targets:
layer_quant_details[target] = {} layer_quant_details[target] = {}
layer_quant_details[target][ layer_quant_details[target][
"weight"] = QuantizationArgs.parse_obj( "weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights")) quant_config.get("weights"))
layer_quant_details[target][ try:
"input"] = QuantizationArgs.parse_obj( layer_quant_details[target][
quant_config.get("input_activations")) "input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
except Exception:
layer_quant_details[target]["input_activations"] = None
return cls(layer_quant_details=layer_quant_details, ignore=ignore) return cls(layer_quant_details=layer_quant_details, ignore=ignore)
...@@ -86,8 +97,23 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -86,8 +97,23 @@ class CompressedTensorsConfig(QuantizationConfig):
return is_8_bits and is_token_tensor and is_symmetric and is_dynamic return is_8_bits and is_token_tensor and is_symmetric and is_dynamic
def _is_w4a16(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
input_quant_none = input_quant is None
is_4_bits = weight_quant.num_bits == 4
is_symmetric = weight_quant.symmetric
is_static = not weight_quant.dynamic
return is_4_bits and input_quant_none and is_symmetric and is_static
def _get_schema(self, weight_quant: BaseModel, def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme": input_quant: BaseModel) -> "CompressedTensorsScheme":
if self._is_w4a16(weight_quant, input_quant):
return CompressedTensorsW4A16(num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
if self._is_static_tensor_w8a8(weight_quant, input_quant): if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8StaticTensor() return CompressedTensorsW8A8StaticTensor()
...@@ -113,8 +139,9 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -113,8 +139,9 @@ class CompressedTensorsConfig(QuantizationConfig):
raise ValueError( raise ValueError(
f"Could not find quantization details for {layer}.") f"Could not find quantization details for {layer}.")
return self._get_schema(weight_quant=layer_quant_details["weight"], return self._get_schema(
input_quant=layer_quant_details["input"]) weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])
class CompressedTensorsLinearMethod(LinearMethodBase): class CompressedTensorsLinearMethod(LinearMethodBase):
...@@ -140,6 +167,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -140,6 +167,7 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
layer=layer, layer=layer,
input_size_per_partition=input_size_per_partition, input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes, output_partition_sizes=output_partition_sizes,
input_size=input_size,
output_size=output_size, output_size=output_size,
params_dtype=params_dtype, params_dtype=params_dtype,
weight_loader=weight_loader) weight_loader=weight_loader)
......
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401 from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401 from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized) CompressedTensorsUnquantized)
from .compressed_tensors_w4a16 import CompressedTensorsW4A16 # noqa: F401
from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501 from .compressed_tensors_w8a8_dynamictoken import ( # noqa: F401, E501
CompressedTensorsW8A8DynamicToken) CompressedTensorsW8A8DynamicToken)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501 from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
......
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, GPTQMarlinState,
marlin_permute_scales)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW4A16"]
class CompressedTensorsW4A16(CompressedTensorsScheme):
def __init__(self,
strategy: str,
num_bits: int,
group_size: Optional[int] = None):
self.num_bits = num_bits
self.strategy = strategy
self.group_size = group_size
if self.strategy == "group" and self.group_size is None:
raise ValueError(
"group_size must be given when using strategy group")
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
pack_factor = 32 // self.num_bits
output_size_per_partition = sum(output_partition_sizes)
if self.group_size is not None:
group_size = self.group_size
else:
group_size = input_size
weight_scale_dim = None
scales_and_zp_size = input_size // group_size
if (input_size != input_size_per_partition
and self.group_size is not None):
weight_scale_dim = 1
scales_and_zp_size = input_size_per_partition // group_size
weight = Parameter(
torch.empty(
output_size_per_partition,
input_size_per_partition // pack_factor,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
weight, {
"input_dim": 1,
"output_dim": 0,
"packed_dim": 1,
"pack_factor": pack_factor
})
set_weight_attrs(weight, {"weight_loader": weight_loader})
layer.register_parameter("weight_packed", weight)
weight_scale = Parameter(
torch.empty(
output_size_per_partition,
scales_and_zp_size,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(weight_scale, {"weight_loader": weight_loader})
set_weight_attrs(weight_scale, {
"input_dim": weight_scale_dim,
"output_dim": 0
})
layer.register_parameter("weight_scale", weight_scale)
# A 2D array defining the original shape of the weights
# before packing
weight_shape = Parameter(torch.empty(2, dtype=torch.int64),
requires_grad=False)
layer.register_parameter("weight_shape", weight_shape)
set_weight_attrs(weight_shape, {"weight_loader": weight_loader})
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.marlin_state = GPTQMarlinState.REPACK
layer.is_k_full = True
layer.group_size = group_size
max_workspace_size = (
output_size_per_partition //
GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.workspace = workspace
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.weight_packed.device
# Reset g_idx related tensors
layer.g_idx = Parameter(torch.empty(0,
dtype=torch.int,
device=cur_device),
requires_grad=False)
layer.g_idx_sort_indices = Parameter(torch.empty(
0, dtype=torch.int, device=cur_device),
requires_grad=False)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed.t().contiguous(), layer.g_idx_sort_indices,
part_size_k, part_size_n, self.num_bits)
replace_tensor("weight_packed", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
marlin_scales = marlin_permute_scales(
layer.weight_scale.squeeze().t().contiguous(), scales_size_k,
scales_size_n, layer.group_size, self.num_bits)
replace_tensor("weight_scale", marlin_scales)
output = ops.gptq_marlin_gemm(reshaped_x, layer.weight_packed,
layer.weight_scale, layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace, self.num_bits, size_m,
part_size_n, part_size_k,
layer.is_k_full)
return output.reshape(out_shape)
...@@ -81,5 +81,5 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme): ...@@ -81,5 +81,5 @@ class CompressedTensorsW8A8DynamicToken(CompressedTensorsScheme):
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
x_q, input_scales = custom_ops.scaled_int8_quant(x) x_q, input_scales = custom_ops.scaled_int8_quant(x)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), input_scales, return custom_ops.cutlass_scaled_mm(x_q, weight.t(), input_scales,
weight_scale, x.dtype) weight_scale, x.dtype)
...@@ -99,5 +99,5 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme): ...@@ -99,5 +99,5 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
# Input quantize # Input quantize
x_q, _ = custom_ops.scaled_int8_quant(x, act_scale) x_q, _ = custom_ops.scaled_int8_quant(x, act_scale)
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale, return custom_ops.cutlass_scaled_mm(x_q, weight.t(), act_scale,
weight_scale, x.dtype) weight_scale, x.dtype)
...@@ -257,11 +257,13 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -257,11 +257,13 @@ class Fp8LinearMethod(LinearMethodBase):
# If dynamic, layer.input_scale is None and x_scale computed from x. # If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale. # If static, layer.input_scale is scalar and x_scale is input_scale.
if bias is None and self.cutlass_fp8_supported: # Temporarily disable CUTLASS kernels due to an illegal memory access
#if bias is None and self.cutlass_fp8_supported:
if False:
qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale) qinput, x_scale = ops.scaled_fp8_quant(x, layer.input_scale)
# Fused GEMM_DQ # Fused GEMM_DQ
output = ops.cutlass_scaled_mm_dq( output = ops.cutlass_scaled_mm(
qinput, qinput,
layer.weight, layer.weight,
out_dtype=x.dtype, out_dtype=x.dtype,
......
...@@ -28,6 +28,7 @@ import torch ...@@ -28,6 +28,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.utils import is_tpu
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: ...@@ -43,6 +44,19 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2) return x.flatten(-2)
def _apply_rotary_emb(
x: torch.Tensor,
freqs_cis: torch.Tensor,
) -> torch.Tensor:
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1), dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out
class RotaryEmbedding(CustomOp): class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
...@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp): ...@@ -64,8 +78,14 @@ class RotaryEmbedding(CustomOp):
self.dtype = dtype self.dtype = dtype
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
cache = cache.to(dtype) self.use_native2 = is_tpu() and is_neox_style
self.register_buffer("cos_sin_cache", cache, persistent=False) if not self.use_native2:
cache = cache.to(dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
else:
cos, sin = cache.chunk(2, dim=-1)
freqs_cis = cos + 1j * sin
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency.""" """Compute the inverse frequency."""
...@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp): ...@@ -100,7 +120,11 @@ class RotaryEmbedding(CustomOp):
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward().""" """A PyTorch-native implementation equivalent to forward().
This method mimics the implementation of the custom CUDA kernel
used in `forward_cuda()`.
"""
query = query.view(*query.shape[:-1], -1, self.head_size) query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size) key = key.view(*key.shape[:-1], -1, self.head_size)
...@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp): ...@@ -138,6 +162,42 @@ class RotaryEmbedding(CustomOp):
key = key.flatten(-2) key = key.flatten(-2)
return query, key return query, key
def forward_native2(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Another PyTorch-native implementation of forward().
This method might perform better than `forward_native()` when compiled.
"""
if positions.dim() == 1:
batch_size = 1
seq_len = positions.shape[0]
else:
batch_size, seq_len = positions.shape
if offsets is not None:
positions = positions + offsets
freqs_cis = self.freqs_cis.index_select(0, positions.flatten())
freqs_cis = freqs_cis.view(batch_size, 1, seq_len, -1)
query_shape = query.shape
query = query.view(batch_size, seq_len, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, freqs_cis)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(batch_size, seq_len, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, freqs_cis)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
def forward_cuda( def forward_cuda(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -161,6 +221,17 @@ class RotaryEmbedding(CustomOp): ...@@ -161,6 +221,17 @@ class RotaryEmbedding(CustomOp):
self.cos_sin_cache, self.is_neox_style) self.cos_sin_cache, self.is_neox_style)
return query, key return query, key
def forward_tpu(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
forward_fn = (self.forward_native2
if self.use_native2 else self.forward_native)
return forward_fn(positions, query, key, offsets)
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}"
s += f", max_position_embeddings={self.max_position_embeddings}" s += f", max_position_embeddings={self.max_position_embeddings}"
......
...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.model_loader.tensorizer import ( from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, TensorizerConfig, is_vllm_tensorized, load_with_tensorizer,
tensorizer_weights_iterator) serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture, from vllm.model_executor.model_loader.utils import (get_model_architecture,
set_default_torch_dtype) set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -34,6 +34,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -34,6 +34,7 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator, safetensors_weights_iterator) pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import is_tpu
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -230,12 +231,26 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -230,12 +231,26 @@ class DefaultModelLoader(BaseModelLoader):
if self.load_config.load_format == LoadFormat.NPCACHE: if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints
assert use_safetensors is False assert use_safetensors is False
return np_cache_weights_iterator(model_name_or_path, weights_iterator = np_cache_weights_iterator(
self.load_config.download_dir, model_name_or_path, self.load_config.download_dir, hf_folder,
hf_folder, hf_weights_files) hf_weights_files)
if use_safetensors: elif use_safetensors:
return safetensors_weights_iterator(hf_weights_files) weights_iterator = safetensors_weights_iterator(hf_weights_files)
return pt_weights_iterator(hf_weights_files) else:
weights_iterator = pt_weights_iterator(hf_weights_files)
if is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
...@@ -380,6 +395,12 @@ class TensorizerLoader(BaseModelLoader): ...@@ -380,6 +395,12 @@ class TensorizerLoader(BaseModelLoader):
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
self._verify_config(model_config, parallel_config) self._verify_config(model_config, parallel_config)
if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = \
self.tensorizer_config.tensorizer_uri \
% get_tensor_model_parallel_rank()
if is_vllm_tensorized(self.tensorizer_config): if is_vllm_tensorized(self.tensorizer_config):
return self._load_model_serialized(model_config, device_config, return self._load_model_serialized(model_config, device_config,
lora_config, lora_config,
...@@ -390,6 +411,16 @@ class TensorizerLoader(BaseModelLoader): ...@@ -390,6 +411,16 @@ class TensorizerLoader(BaseModelLoader):
vision_language_config, vision_language_config,
cache_config) cache_config)
@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: TensorizerConfig,
) -> None:
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
)
class ShardedStateLoader(BaseModelLoader): class ShardedStateLoader(BaseModelLoader):
""" """
......
...@@ -2,11 +2,11 @@ import argparse ...@@ -2,11 +2,11 @@ import argparse
import dataclasses import dataclasses
import io import io
import os import os
import re
import time import time
import typing
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import Generator, Optional, Tuple, Type, Union from typing import BinaryIO, Generator, Optional, Tuple, Type, Union
import torch import torch
from torch import nn from torch import nn
...@@ -14,6 +14,7 @@ from transformers import PretrainedConfig ...@@ -14,6 +14,7 @@ from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig from vllm.config import ModelConfig, ParallelConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -48,8 +49,7 @@ logger = init_logger(__name__) ...@@ -48,8 +49,7 @@ logger = init_logger(__name__)
@dataclass @dataclass
class TensorizerConfig: class TensorizerConfig:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, tensorizer_uri: str
str, bytes, os.PathLike, int]
vllm_tensorized: Optional[bool] = False vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False verify_hash: Optional[bool] = False
num_readers: Optional[int] = None num_readers: Optional[int] = None
...@@ -60,6 +60,12 @@ class TensorizerConfig: ...@@ -60,6 +60,12 @@ class TensorizerConfig:
model_class: Optional[Type[torch.nn.Module]] = None model_class: Optional[Type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None dtype: Optional[Union[str, torch.dtype]] = None
_is_sharded: bool = False
def __post_init__(self):
# check if the configuration is for a sharded vLLM model
self._is_sharded = isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri) is not None
def _construct_tensorizer_args(self) -> "TensorizerArgs": def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = { tensorizer_args = {
...@@ -78,13 +84,12 @@ class TensorizerConfig: ...@@ -78,13 +84,12 @@ class TensorizerConfig:
self, self,
parallel_config: "ParallelConfig", parallel_config: "ParallelConfig",
) -> None: ) -> None:
if (parallel_config.tensor_parallel_size > 1 if parallel_config.tensor_parallel_size > 1 \
and self.tensorizer_uri is not None): and not self._is_sharded:
raise ValueError( raise ValueError(
"Loading to multiple GPUs is not currently supported with " "For a sharded model, tensorizer_uri should include a"
"vLLM-serialized models. Please set tensor_parallel_size=1." " string format template like '%04d' to be formatted"
" or use a non-vLLM-serialized model, such as a " " with the rank of the shard")
"serialized Hugging Face `PretrainedModel`.")
def verify_with_model_config(self, model_config: "ModelConfig") -> None: def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if (model_config.quantization is not None if (model_config.quantization is not None
...@@ -102,8 +107,8 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig, ...@@ -102,8 +107,8 @@ def load_with_tensorizer(tensorizer_config: TensorizerConfig,
@dataclass @dataclass
class TensorizerArgs: class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, typing.BinaryIO, tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
str, bytes, os.PathLike, int] bytes, os.PathLike, int]
vllm_tensorized: Optional[bool] = False vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False verify_hash: Optional[bool] = False
num_readers: Optional[int] = None num_readers: Optional[int] = None
...@@ -332,6 +337,7 @@ class TensorizerAgent: ...@@ -332,6 +337,7 @@ class TensorizerAgent:
) as stream, TensorDeserializer( ) as stream, TensorDeserializer(
stream, stream,
dtype=self.tensorizer_config.dtype, dtype=self.tensorizer_config.dtype,
device=f'cuda:{torch.cuda.current_device()}',
**self.tensorizer_args.deserializer_params) as deserializer: **self.tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(self.model) deserializer.load_into_module(self.model)
end = time.perf_counter() end = time.perf_counter()
...@@ -400,33 +406,70 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool: ...@@ -400,33 +406,70 @@ def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
return False return False
def get_pretensorized_vllm_model(engine: "LLMEngine") -> nn.Module: def serialize_vllm_model(
model = (engine.model_executor.driver_worker.model_runner.model) model: nn.Module,
tensorizer_config: TensorizerConfig,
) -> nn.Module:
model.register_parameter( model.register_parameter(
"vllm_tensorized_marker", "vllm_tensorized_marker",
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False)) nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
return model
def serialize_vllm_model(engine: "LLMEngine",
tensorizer_config : TensorizerConfig,
encryption_key_path: Optional[str] = None) \
-> nn.Module:
model = get_pretensorized_vllm_model(engine)
tensorizer_args = tensorizer_config._construct_tensorizer_args() tensorizer_args = tensorizer_config._construct_tensorizer_args()
encryption_params = None encryption_params = None
if encryption_key_path is not None: if (keyfile := tensorizer_config.encryption_keyfile) is not None:
encryption_params = EncryptionParams.random() with open(keyfile, "rb") as f:
with _write_stream(encryption_key_path, key = f.read()
**tensorizer_args.stream_params) as stream: encryption_params = EncryptionParams(key=key)
stream.write(encryption_params.key)
with _write_stream(tensorizer_args.tensorizer_uri, output_file = tensorizer_args.tensorizer_uri
**tensorizer_args.stream_params) as stream: if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank()
with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params) serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model) serializer.write_module(model)
serializer.close() serializer.close()
logger.info("Successfully serialized model to %s", logger.info("Successfully serialized model to %s", str(output_file))
str(tensorizer_args.tensorizer_uri))
return model return model
def tensorize_vllm_model(engine_args: EngineArgs,
tensorizer_config: TensorizerConfig,
generate_keyfile: bool = True):
"""Utility to load a model and then serialize it with Tensorizer
Intended to be used separately from running a vLLM server since it
creates its own Engine instance.
"""
engine_config = engine_args.create_engine_config()
tensorizer_config.verify_with_model_config(engine_config.model_config)
tensorizer_config.verify_with_parallel_config(
engine_config.parallel_config)
# generate the encryption key before creating the engine to support sharding
if generate_keyfile and (keyfile :=
tensorizer_config.encryption_keyfile) is not None:
encryption_params = EncryptionParams.random()
with _write_stream(
keyfile,
s3_access_key_id=tensorizer_config.s3_access_key_id,
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
s3_endpoint=tensorizer_config.s3_endpoint,
) as stream:
stream.write(encryption_params.key)
engine = LLMEngine.from_engine_args(engine_args)
if tensorizer_config._is_sharded:
# if the engine is a distributed engine (for tensor parallel) then each
# worker shard needs to serialize its part of the model.
engine.model_executor._run_workers(
"save_tensorized_model",
tensorizer_config=tensorizer_config,
)
else:
# with a single worker, we can get to the underlying model directly
serialize_vllm_model(
engine.model_executor.driver_worker.model_runner.model,
tensorizer_config,
)
...@@ -227,7 +227,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): ...@@ -227,7 +227,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
attn_metadata: AttentionMetadata, attn_metadata: AttentionMetadata,
**kwargs: object, **kwargs: object,
) -> SamplerOutput: ) -> SamplerOutput:
"""Run forward pass for Llava 1.5. """Run forward pass for LLaVA-1.5.
One key thing to understand is the `input_ids` already accounts for the One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings. positions of the to-be-inserted image embeddings.
...@@ -247,22 +247,25 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase): ...@@ -247,22 +247,25 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
This way, the `positions` and `attn_metadata` are consistent This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`. with the `input_ids`.
The model takes two types of image inputs: This model has two modes of image inputs:
PIXEL_VALUES and IMAGE_FEATURES. `PIXEL_VALUES` and `IMAGE_FEATURES`.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
Args: Args:
input_ids: Flattened (concatenated) input_ids corresponding to a input_ids: Flattened (concatenated) input_ids corresponding to a
batch. batch.
pixel_values: For PIXEL_VALUES, expects a batch with shape pixel_values: The pixels in each input image.
[1, 3, 336, 336]. Expects a batch with shape `[1, 3, 336, 336]`.
image_features: For IMAGE_FEATURES, expects a batch with shape (Only applicable to `PIXEL_VALUES` mode)
[1, 576, 1024]. image_features: The image features for each input image outputted by
the vision tower before passing to the multi-modal projector.
Expects a batch with shape `[1, 576, 1024]`.
(Only applicable to `IMAGE_FEATURES` mode)
See also:
Each input maps to huggingface implementation, as follows:
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
- `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
""" """
image_input = self._parse_and_validate_image_input(**kwargs) image_input = self._parse_and_validate_image_input(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment