# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Common utilities.""" import base64 import ipaddress import itertools import json import logging import os import pickle import random import re import resource import shutil import signal import socket import subprocess import tempfile import time import warnings from importlib.metadata import PackageNotFoundError, version from io import BytesIO from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union import numpy as np import psutil import requests import torch import torch.distributed as dist import triton import zmq from fastapi.responses import ORJSONResponse from packaging import version as pkg_version from starlette.routing import Mount from torch import nn from torch.func import functional_call from torch.library import Library from torch.profiler import ProfilerActivity, profile, record_function from triton.runtime.cache import ( FileCacheManager, default_cache_dir, default_dump_dir, default_override_dir, ) logger = logging.getLogger(__name__) show_time_cost = False time_infos = {} def is_hip() -> bool: """Return whether it is HIP on the AMD ROCm platform.""" return torch.version.hip is not None def is_flashinfer_available(): """ Check whether flashinfer is available. As of Oct. 6, 2024, it is only available on NVIDIA GPUs. """ if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false": return False return torch.cuda.is_available() and not is_hip() def is_ipv6(address): try: ipaddress.IPv6Address(address) return True except ipaddress.AddressValueError: return False def enable_show_time_cost(): global show_time_cost show_time_cost = True class TimeInfo: def __init__(self, name, interval=0.1, color=0, indent=0): self.name = name self.interval = interval self.color = color self.indent = indent self.acc_time = 0 self.last_acc_time = 0 def check(self): if self.acc_time - self.last_acc_time > self.interval: self.last_acc_time = self.acc_time return True return False def pretty_print(self): print(f"\x1b[{self.color}m", end="") print("-" * self.indent * 2, end="") print(f"{self.name}: {self.acc_time:.3f}s\x1b[0m") def mark_start(name, interval=0.1, color=0, indent=0): global time_infos, show_time_cost if not show_time_cost: return torch.cuda.synchronize() if time_infos.get(name, None) is None: time_infos[name] = TimeInfo(name, interval, color, indent) time_infos[name].acc_time -= time.time() def mark_end(name): global time_infos, show_time_cost if not show_time_cost: return torch.cuda.synchronize() time_infos[name].acc_time += time.time() if time_infos[name].check(): time_infos[name].pretty_print() def calculate_time(show=False, min_cost_ms=0.0): def wrapper(func): def inner_func(*args, **kwargs): torch.cuda.synchronize() if show: start_time = time.time() result = func(*args, **kwargs) torch.cuda.synchronize() if show: cost_time = (time.time() - start_time) * 1000 if cost_time > min_cost_ms: print(f"Function {func.__name__} took {cost_time} ms to run.") return result return inner_func return wrapper def get_available_gpu_memory(device, gpu_id, distributed=False): """ Get available memory for cuda:gpu_id device. When distributed is True, the available memory is the minimum available memory of all GPUs. """ if device == "cuda": num_gpus = torch.cuda.device_count() assert gpu_id < num_gpus if torch.cuda.current_device() != gpu_id: print( f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ", "which may cause useless memory allocation for torch CUDA context.", ) torch.cuda.empty_cache() free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id) elif device == "xpu": num_gpus = torch.xpu.device_count() assert gpu_id < num_gpus if torch.xpu.current_device() != gpu_id: print( f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ", "which may cause useless memory allocation for torch XPU context.", ) torch.xpu.empty_cache() used_memory = torch.xpu.memory_allocated() total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory free_gpu_memory = total_gpu_memory - used_memory if distributed: tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to( torch.device(device, gpu_id) ) torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN) free_gpu_memory = tensor.item() return free_gpu_memory / (1 << 30) def is_pin_memory_available() -> bool: return torch.cuda.is_available() _CPU_OFFLOAD_BYTES = 0 _CPU_OFFLOAD_MAX_BYTES = 0 def set_cpu_offload_max_bytes(max_bytes: int) -> None: global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES _CPU_OFFLOAD_BYTES = 0 _CPU_OFFLOAD_MAX_BYTES = max_bytes def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module: device = next(module.parameters()).device if device == torch.device("cpu"): return module global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: return module pin_memory = is_pin_memory_available() # offload parameters to CPU # use pin_memory if possible, which helps cudagraph capture speed offloaded_parameters = False for p in module.parameters(): if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES: # we use per-parameter offloading # one module might have some parameters offloaded and some not break # `torch.empty_like` does not support `pin_memory` argument cpu_data = torch.empty_strided( size=p.data.size(), stride=p.data.stride(), dtype=p.data.dtype, layout=p.data.layout, device="cpu", pin_memory=pin_memory, ) cpu_data.copy_(p.data) p.data = cpu_data _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size() offloaded_parameters = True if offloaded_parameters: original_forward = module.forward def forward(*args, **kwargs): module.forward = original_forward device_state = { # here we blindly call `to(device)` # if the parameter is already on the device, it will be a no-op k: v.to(device, non_blocking=True) for k, v in module.state_dict().items() } output = functional_call(module, device_state, args=args, kwargs=kwargs) module.forward = forward return output module.forward = forward return module class LayerFn(Protocol): def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ... def make_layers( num_hidden_layers: int, layer_fn: LayerFn, prefix: str = "", ) -> Tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function""" modules = torch.nn.ModuleList( [ maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}")) for idx in range(num_hidden_layers) ] ) return modules def set_random_seed(seed: int) -> None: """Set the random seed for all libraries.""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def is_port_available(port): """Return whether a port is available.""" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) s.bind(("", port)) s.listen(1) return True except socket.error: return False def decode_video_base64(video_base64): from PIL import Image # Decode the base64 string video_bytes = base64.b64decode(video_base64) # Placeholder for the start indices of each PNG image img_starts = [] frame_format = "PNG" # str(os.getenv('FRAME_FORMAT', "JPEG")) assert frame_format in [ "PNG", "JPEG", ], "FRAME_FORMAT must be either 'PNG' or 'JPEG'" if frame_format == "PNG": # Find each PNG start signature to isolate images i = 0 while i < len(video_bytes) - 7: # Adjusted for the length of the PNG signature # Check if we found the start of a PNG file if ( video_bytes[i] == 0x89 and video_bytes[i + 1] == 0x50 and video_bytes[i + 2] == 0x4E and video_bytes[i + 3] == 0x47 and video_bytes[i + 4] == 0x0D and video_bytes[i + 5] == 0x0A and video_bytes[i + 6] == 0x1A and video_bytes[i + 7] == 0x0A ): img_starts.append(i) i += 8 # Skip the PNG signature else: i += 1 else: # Find each JPEG start (0xFFD8) to isolate images i = 0 while ( i < len(video_bytes) - 1 ): # Adjusted for the length of the JPEG SOI signature # Check if we found the start of a JPEG file if video_bytes[i] == 0xFF and video_bytes[i + 1] == 0xD8: img_starts.append(i) # Move to the next byte to continue searching for the next image start i += 2 else: i += 1 frames = [] for start_idx in img_starts: # Assuming each image is back-to-back, the end of one image is the start of another # The last image goes until the end of the byte string end_idx = ( img_starts[img_starts.index(start_idx) + 1] if img_starts.index(start_idx) + 1 < len(img_starts) else len(video_bytes) ) img_bytes = video_bytes[start_idx:end_idx] # Convert bytes to a PIL Image img = Image.open(BytesIO(img_bytes)) # Convert PIL Image to a NumPy array frame = np.array(img) # Append the frame to the list of frames frames.append(frame) # Ensure there's at least one frame to avoid errors with np.stack if frames: return np.stack(frames, axis=0), img.size else: return np.array([]), ( 0, 0, ) # Return an empty array and size tuple if no frames were found def load_image(image_file: Union[str, bytes]): from PIL import Image image = image_size = None if isinstance(image_file, bytes): image = Image.open(BytesIO(image_file)) elif image_file.startswith("http://") or image_file.startswith("https://"): timeout = int(os.getenv("REQUEST_TIMEOUT", "3")) response = requests.get(image_file, timeout=timeout) image = Image.open(BytesIO(response.content)) elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")): image = Image.open(image_file) elif image_file.startswith("data:"): image_file = image_file.split(",")[1] image = Image.open(BytesIO(base64.b64decode(image_file))) elif image_file.startswith("video:"): image_file = image_file.replace("video:", "") image, image_size = decode_video_base64(image_file) elif isinstance(image_file, str): image = Image.open(BytesIO(base64.b64decode(image_file))) else: raise ValueError(f"Invalid image: {image}") return image, image_size def suppress_other_loggers(): from vllm.logger import logger as vllm_default_logger vllm_default_logger.setLevel(logging.WARN) logging.getLogger("vllm.config").setLevel(logging.ERROR) logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel( logging.WARN ) logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel( logging.WARN ) logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.ERROR) logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR) warnings.filterwarnings( "ignore", category=UserWarning, message="The given NumPy array is not writable" ) def assert_pkg_version(pkg: str, min_version: str, message: str): try: installed_version = version(pkg) if pkg_version.parse(installed_version) < pkg_version.parse(min_version): raise Exception( f"{pkg} is installed with version {installed_version}, which " f"is less than the minimum required version {min_version}. " + message ) except PackageNotFoundError: raise Exception( f"{pkg} with minimum required version {min_version} is not installed. " + message ) def kill_parent_process(): """Kill the parent process and all children of the parent process.""" current_process = psutil.Process() parent_process = current_process.parent() kill_child_process( parent_process.pid, include_self=True, skip_pid=current_process.pid ) try: current_process.kill() except psutil.NoSuchProcess: pass def kill_child_process(pid=None, include_self=False, skip_pid=None): """Kill the process and all its children process.""" if pid is None: pid = os.getpid() try: itself = psutil.Process(pid) except psutil.NoSuchProcess: return children = itself.children(recursive=True) for child in children: if child.pid == skip_pid: continue try: child.kill() except psutil.NoSuchProcess: pass if include_self: try: itself.kill() # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes), # so we send an additional signal to kill them. itself.send_signal(signal.SIGINT) except psutil.NoSuchProcess: pass def monkey_patch_vllm_model_config(): from vllm.config import ModelConfig if not hasattr(ModelConfig, "_resolve_task"): return def _resolve_task( self, task_option, hf_config, ): supported_tasks = { "generate": True, "embedding": False, } selected_task = "generate" return supported_tasks, selected_task setattr(ModelConfig, "_resolve_task", _resolve_task) def monkey_patch_vllm_p2p_access_check(gpu_id: int): """ Monkey patch the slow p2p access check in vllm. NOTE: We assume the p2p access is always allowed, which can be wrong for some setups. """ import vllm.distributed.device_communicators.custom_all_reduce_utils as tgt setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True) vllm_all_gather_backup = None def monkey_patch_vllm_all_gather(reverse: bool = False): """Monkey patch all-gather to remove in-place operations.""" from torch.distributed import _functional_collectives as funcol from vllm.distributed.parallel_state import GroupCoordinator global vllm_all_gather_backup if vllm_all_gather_backup is None: vllm_all_gather_backup = GroupCoordinator.all_gather def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ assert ( -input_.dim() <= dim < input_.dim() ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if dim < 0: # Convert negative dim to positive. dim += input_.dim() input_size = input_.size() # Allocate output tensor. output_tensor = torch.empty( (world_size,) + input_size, dtype=input_.dtype, device=input_.device ) output_tensor = funcol.all_gather_tensor( input_, gather_dim=0, group=self.device_group ).view((world_size,) + input_size) # Reshape output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape( input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :] ) return output_tensor if reverse: setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup) else: setattr(GroupCoordinator, "all_gather", all_gather) def maybe_set_triton_cache_manager() -> None: """Set environment variable to tell Triton to use a custom cache manager""" cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None) if cache_manger is None: manager = "sglang.srt.utils:CustomCacheManager" logger.debug("Setting Triton cache manager to: %s", manager) os.environ["TRITON_CACHE_MANAGER"] = manager class CustomCacheManager(FileCacheManager): # Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None if dump: self.cache_dir = default_dump_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) elif override: self.cache_dir = default_override_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) else: # create cache directory if it doesn't exist self.cache_dir = ( os.getenv("TRITON_CACHE_DIR", "").strip() or default_cache_dir() ) if self.cache_dir: self.cache_dir = f"{self.cache_dir}_{os.getpid()}" self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) else: raise RuntimeError("Could not create or locate cache dir") def set_ulimit(target_soft_limit=65535): resource_type = resource.RLIMIT_NOFILE current_soft, current_hard = resource.getrlimit(resource_type) if current_soft < target_soft_limit: try: resource.setrlimit(resource_type, (target_soft_limit, current_hard)) except ValueError as e: logger.warning(f"Fail to set RLIMIT_NOFILE: {e}") def add_api_key_middleware(app, api_key: str): @app.middleware("http") async def authentication(request, call_next): if request.method == "OPTIONS": return await call_next(request) if request.url.path.startswith("/health"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + api_key: return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401) return await call_next(request) def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str): if "SGLANG_USE_MODELSCOPE" in os.environ: if not os.path.exists(model_path): from modelscope import snapshot_download model_path = snapshot_download(model_path) tokenizer_path = snapshot_download( tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"] ) return model_path, tokenizer_path def configure_logger(server_args, prefix: str = ""): format = f"[%(asctime)s{prefix}] %(message)s" # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s" logging.basicConfig( level=getattr(logging, server_args.log_level.upper()), format=format, datefmt="%Y-%m-%d %H:%M:%S", force=True, ) # source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9 def replace_submodule( model: nn.Module, module_name: str, new_module: nn.Module ) -> nn.Module: """Replace a submodule in a model with a new module.""" parent = model.get_submodule(".".join(module_name.split(".")[:-1])) target_name = module_name.split(".")[-1] setattr(parent, target_name, new_module) return new_module def set_weight_attrs( weight: torch.Tensor, weight_attrs: Optional[Dict[str, Any]], ): """Set attributes on a weight tensor. This method is used to set attributes on a weight tensor. This method will not overwrite existing attributes. Args: weight: The weight tensor. weight_attrs: A dictionary of attributes to set on the weight tensor. """ if weight_attrs is None: return for key, value in weight_attrs.items(): assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}" setattr(weight, key, value) def broadcast_pyobj( data: List[Any], rank: int, dist_group: Optional[torch.distributed.ProcessGroup] = None, ): """Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" if rank == 0: if len(data) == 0: tensor_size = torch.tensor([0], dtype=torch.long) dist.broadcast(tensor_size, src=0, group=dist_group) else: serialized_data = pickle.dumps(data) size = len(serialized_data) tensor_data = torch.ByteTensor( np.frombuffer(serialized_data, dtype=np.uint8) ) tensor_size = torch.tensor([size], dtype=torch.long) dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_data, src=0, group=dist_group) return data else: tensor_size = torch.tensor([0], dtype=torch.long) dist.broadcast(tensor_size, src=0, group=dist_group) size = tensor_size.item() if size == 0: return [] tensor_data = torch.empty(size, dtype=torch.uint8) dist.broadcast(tensor_data, src=0, group=dist_group) serialized_data = bytes(tensor_data.cpu().numpy()) data = pickle.loads(serialized_data) return data step_counter = 0 def pytorch_profile(name, func, *args, data_size=-1): """ Args: name (string): the name of recorded function. func: the function to be profiled. args: the arguments of the profiled function. data_size (int): some measurement of the computation complexity. Usually, it could be the batch size. """ global step_counter os.makedirs("trace", exist_ok=True) with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), # on_trace_ready=tensorboard_trace_handler('./log_dir'), record_shapes=True, profile_memory=True, with_stack=True, ) as prof: with record_function(name): with open(f"trace/size_{step_counter}.json", "w") as f: json.dump({"size": data_size}, f) result = func(*args) prof.export_chrome_trace(f"trace/{name}_{step_counter}.json") step_counter += 1 return result def first_rank_print(*args, **kwargs): if torch.cuda.current_device() == 0: print(*args, **kwargs) else: pass def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: str): mem = psutil.virtual_memory() total_mem = mem.total / 1024**3 available_mem = mem.available / 1024**3 if total_mem > 32 and available_mem > 16: buf_size = int(0.5 * 1024**3) else: buf_size = -1 socket = context.socket(socket_type) if socket_type == zmq.PUSH: socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) socket.connect(f"ipc://{endpoint}") elif socket_type == zmq.PULL: socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVBUF, buf_size) socket.bind(f"ipc://{endpoint}") else: raise ValueError(f"Unsupported socket type: {socket_type}") return socket def dump_to_file(dirpath, name, value): from vllm.distributed import get_tensor_model_parallel_rank if get_tensor_model_parallel_rank() != 0: return os.makedirs(dirpath, exist_ok=True) if value.dtype is torch.bfloat16: value = value.float() value = value.cpu().numpy() output_filename = os.path.join(dirpath, f"pytorch_dump_{name}.npy") logger.info(f"Dump a tensor to {output_filename}. Shape = {value.shape}") np.save(output_filename, value) def is_triton_3(): return triton.__version__.startswith("3.") def maybe_torch_compile(*args, **kwargs): """ torch.compile does not work for triton 2.2.0, which is needed in xlm1's jax. Therefore, we disable it here. """ def decorator(func): if is_triton_3(): return torch.compile(*args, **kwargs)(func) return func return decorator def delete_directory(dirpath): try: # This will remove the directory and all its contents shutil.rmtree(dirpath) except OSError as e: print(f"Warning: {dirpath} : {e.strerror}") # Temporary directory for prometheus multiprocess mode # Cleaned up automatically when this object is garbage collected prometheus_multiproc_dir: tempfile.TemporaryDirectory def set_prometheus_multiproc_dir(): # Set prometheus multiprocess directory # sglang uses prometheus multiprocess mode # we need to set this before importing prometheus_client # https://prometheus.github.io/client_python/multiprocess/ global prometheus_multiproc_dir if "PROMETHEUS_MULTIPROC_DIR" in os.environ: logger.debug("User set PROMETHEUS_MULTIPROC_DIR detected.") prometheus_multiproc_dir = tempfile.TemporaryDirectory( dir=os.environ["PROMETHEUS_MULTIPROC_DIR"] ) else: prometheus_multiproc_dir = tempfile.TemporaryDirectory() os.environ["PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name logger.debug(f"PROMETHEUS_MULTIPROC_DIR: {os.environ['PROMETHEUS_MULTIPROC_DIR']}") def add_prometheus_middleware(app): # We need to import prometheus_client after setting the env variable `PROMETHEUS_MULTIPROC_DIR` from prometheus_client import CollectorRegistry, make_asgi_app, multiprocess registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) metrics_route = Mount("/metrics", make_asgi_app(registry=registry)) # Workaround for 307 Redirect for /metrics metrics_route.path_regex = re.compile("^/metrics(?P.*)$") app.routes.append(metrics_route) def bind_port(port): """Bind to a specific port, assuming it's available.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse sock.bind(("", port)) sock.listen(1) return sock def get_amdgpu_memory_capacity(): try: # Run rocm-smi and capture the output result = subprocess.run( ["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, text=True, ) if result.returncode != 0: raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}") # Parse the output to extract memory values in MiB memory_values = [ float(mem) / 1024 / 1024 for mem in result.stdout.strip().split("\n") if re.match(r"^\d+(\.\d+)?$", mem.strip()) ] if not memory_values: raise ValueError("No GPU memory values found.") # Return the minimum memory value return min(memory_values) except FileNotFoundError: raise RuntimeError( "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible." ) def get_nvgpu_memory_capacity(): try: # Run nvidia-smi and capture the output result = subprocess.run( ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, ) if result.returncode != 0: raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}") # Parse the output to extract memory values memory_values = [ float(mem) for mem in result.stdout.strip().split("\n") if re.match(r"^\d+(\.\d+)?$", mem.strip()) ] if not memory_values: raise ValueError("No GPU memory values found.") # Return the minimum memory value return min(memory_values) except FileNotFoundError: raise RuntimeError( "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible." ) def crash_on_warnings(): # Crash on warning if we are running CI tests return os.getenv("SGLANG_IS_IN_CI", "false").lower() == "true" def get_device_name(device_id: int = 0) -> str: if hasattr(torch, "cuda") and torch.cuda.is_available(): return torch.cuda.get_device_name(device_id) if hasattr(torch, "hip") and torch.hip.is_available(): return torch.hip.get_device_name(device_id) if hasattr(torch, "xpu") and torch.xpu.is_available(): return torch.xpu.get_device_name(device_id) if hasattr(torch, "hpu") and torch.hpu.is_available(): return torch.hpu.get_device_name(device_id) sglang_lib = Library("sglang", "FRAGMENT") # noqa def direct_register_custom_op( op_name: str, op_func: Callable, mutates_args: List[str], fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, ): """ `torch.library.custom_op` can have significant overhead because it needs to consider complicated dispatching logic. This function directly registers a custom op and dispatches it to the CUDA backend. See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 for more details. By default, the custom op is registered to the vLLM library. If you want to register it to a different library, you can pass the library object to the `target_lib` argument. IMPORTANT: the lifetime of the operator is tied to the lifetime of the library object. If you want to bind the operator to a different library, make sure the library object is alive when the operator is used. """ import torch.library if hasattr(torch.library, "infer_schema"): schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) else: # for pytorch 2.4 import torch._custom_op.impl schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or sglang_lib my_lib.define(op_name + schema_str) my_lib.impl(op_name, op_func, "CUDA") if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) def gpu_proc_affinity( tp_size: int, nnodes: int, gpu_id: int, ): # current process pid = os.getpid() p = psutil.Process(pid) tp_size_per_node = tp_size // nnodes # total physical cores total_pcores = psutil.cpu_count(logical=False) # physical cores per TP (N.B. more Cores than GPUs on node) num_cores_bind = total_pcores // tp_size_per_node # able to handle multiple DP per node start_cpu_id = (gpu_id * num_cores_bind) % total_pcores end_cpu_id = start_cpu_id + num_cores_bind if psutil.cpu_count() != psutil.cpu_count(logical=False): # HT on upper_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)] lower_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)] bind_cpu_ids = list(itertools.chain(upper_cpu_ids, lower_cpu_ids)) else: # HT off bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)] # set cpu_affinity to current process p.cpu_affinity(bind_cpu_ids) logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")