Commit ec2e17d8 authored by zhuwenwen's avatar zhuwenwen
Browse files

support numa bind

parent 1d57ec3d
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
make_async)
from vllm.worker.worker_base import WorkerWrapperBase
import numa,os
# 设置当前进程绑定到 NUMA 节点
def bind_to_numa(local_rank):
env_str = f"VLLM_RANK{local_rank}_NUMA"
node_count = numa.get_max_node() + 1
numa_node = int(os.getenv(env_str, -1))
# 未配置环境变量或配置错误则不做绑定,TODO:根据topo自动绑定方案
if numa_node < 0:
logger.warning("%s is unset or set incorrectly, vllm will not bind to numa! %s = %d", env_str, env_str, numa_node)
return
if numa_node > numa.get_max_node():
raise ValueError(f"NUMA node {numa_node} is not available.")
numa.bind([numa_node])
logger = init_logger(__name__)
def create_worker(**kwargs):
vllm_config = kwargs.get("vllm_config")
VLLM_NUMA_BIND = int(os.getenv("VLLM_NUMA_BIND", 1))
if VLLM_NUMA_BIND > 0:
# 绑定当前进程到指定 NUMA 节点
bind_to_numa(kwargs['local_rank'])
pid = os.getpid()
logger.info("########## %d process(rank%s) is running on CPU(s): %s", pid, str(kwargs['local_rank']), str(os.sched_getaffinity(pid)))
logger.info("########## %d process(rank%s) is running on memnode(s): %s", pid, str(kwargs['local_rank']), str(numa.get_membind()))
wrapper = WorkerWrapperBase(vllm_config=vllm_config)
wrapper.init_worker(**kwargs)
return wrapper.worker
class GPUExecutor(ExecutorBase):
uses_ray: bool = False
def _init_executor(self) -> None:
"""Initialize the worker and load the model.
"""
assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")
self.driver_worker = self._create_worker()
self.driver_worker.init_device()
self.driver_worker.load_model()
def _get_worker_kwargs(
self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None) -> Dict[str, Any]:
"""Return worker init args for a given rank."""
if distributed_init_method is None:
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
return dict(
vllm_config=self.vllm_config,
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=(not self.parallel_config)
or (rank % self.parallel_config.tensor_parallel_size == 0),
)
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
return create_worker(**self._get_worker_kwargs(
local_rank=local_rank,
rank=rank,
distributed_init_method=distributed_init_method))
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available KV blocks by invoking the
underlying worker.
"""
return self.driver_worker.determine_num_available_blocks()
def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
"""Initialize the KV cache by invoking the underlying worker.
"""
# NOTE: This is logged in the executor because there can be >1 worker
# with other executors. We could log in the engine level, but work
# remains to abstract away the device for non-GPU configurations.
logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks,
num_cpu_blocks)
max_concurrency = (num_gpu_blocks * self.cache_config.block_size /
self.model_config.max_model_len)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
self.model_config.max_model_len, max_concurrency)
self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output
def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)
def pin_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.pin_lora(lora_id)
def list_loras(self) -> Set[int]:
return self.driver_worker.list_loras()
def add_prompt_adapter(
self, prompt_adapter_request: PromptAdapterRequest) -> bool:
assert prompt_adapter_request.prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.add_prompt_adapter(prompt_adapter_request)
def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.remove_prompt_adapter(prompt_adapter_id)
def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool:
assert prompt_adapter_id > 0, \
"prompt_adapter_id must be greater than 0."
return self.driver_worker.pin_prompt_adapter(prompt_adapter_id)
def list_prompt_adapters(self) -> Set[int]:
return self.driver_worker.list_prompt_adapters()
def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
def start_profile(self) -> None:
self.driver_worker.start_profile()
def stop_profile(self) -> None:
self.driver_worker.stop_profile()
class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
async def execute_model_async(
self,
execute_model_req: ExecuteModelRequest,
) -> List[Union[SamplerOutput, PoolerOutput]]:
output = await make_async(self.driver_worker.execute_model
)(execute_model_req=execute_model_req)
return output
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import dataclasses import dataclasses
import os import os
import numa
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
...@@ -28,6 +29,23 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput, ...@@ -28,6 +29,23 @@ from vllm.worker.model_runner_base import (BroadcastableModelInput,
logger = init_logger(__name__) logger = init_logger(__name__)
# 设置当前进程绑定到 NUMA 节点
def bind_to_numa(local_rank):
env_str = f"VLLM_RANK{local_rank}_NUMA"
node_count = numa.get_max_node() + 1
numa_node = int(os.getenv(env_str, -1))
# 未配置环境变量或配置错误则不做绑定,TODO:根据topo自动绑定方案
if numa_node < 0:
logger.warning("%s is unset or set incorrectly, vllm will not bind to numa! %s = %d", env_str, env_str, numa_node)
return
if numa_node > numa.get_max_node():
raise ValueError(f"NUMA node {numa_node} is not available.")
numa.bind([numa_node])
class WorkerBase(ABC): class WorkerBase(ABC):
"""Worker interface that allows vLLM to cleanly separate implementations for """Worker interface that allows vLLM to cleanly separate implementations for
different hardware. Also abstracts control plane communication, e.g., to different hardware. Also abstracts control plane communication, e.g., to
...@@ -594,6 +612,16 @@ class WorkerWrapperBase: ...@@ -594,6 +612,16 @@ class WorkerWrapperBase:
# To make vLLM config available during worker initialization # To make vLLM config available during worker initialization
self.worker = worker_class(**kwargs) self.worker = worker_class(**kwargs)
assert self.worker is not None assert self.worker is not None
VLLM_NUMA_BIND = int(os.getenv("VLLM_NUMA_BIND", 1))
if VLLM_NUMA_BIND > 0:
# 绑定当前进程到指定 NUMA 节点
bind_to_numa(kwargs['local_rank'])
pid = os.getpid()
logger.info("########## %d process(rank%s) is running on CPU(s): %s", pid, str(kwargs['local_rank']), str(os.sched_getaffinity(pid)))
logger.info("########## %d process(rank%s) is running on memnode(s): %s", pid, str(kwargs['local_rank']), str(numa.get_membind()))
def execute_method(self, method: Union[str, bytes], *args, **kwargs): def execute_method(self, method: Union[str, bytes], *args, **kwargs):
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment