Commit 22890a8e authored by zhangqha's avatar zhangqha
Browse files

Merge branch 'v0.15.1-dev-pd' into 'v0.15.1-dev'

Merge v0.15.1-dev-pd into v0.15.1-dev

See merge request dcutoolkit/deeplearing/vllm!506
parents b5ca585e be81eaf6
...@@ -4,35 +4,87 @@ ...@@ -4,35 +4,87 @@
import os import os
import socket import socket
import threading import threading
import time
import uuid import uuid
from typing import Any
import aiohttp import aiohttp
import msgpack import msgpack
import zmq import zmq
from typing import Any
from quart import Quart, make_response, request from quart import Quart, make_response, request
from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from collections import deque, defaultdict
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# @dataclass
# class Request:
# request_id: str
# p_http_address: str = ""
# p_dp_rank: int = -1
# d_http_address: str = ""
# d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
zmq_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
count = 0 def count_rank_table_elements(self):
prefill_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp) count = 0
decode_instances: dict[str, Any] = {} # http_address: (zmq_address, stamp) for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
# return all_ranks_ready and self.p_unique_id != b""
return all_ranks_ready
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
prefill_cv = threading.Condition() count = 0
decode_cv = threading.Condition() # prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
DEFAULT_PING_SECONDS = 5 pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
def _remove_oldest_instances(instances: dict[str, Any]) -> None: prefill_cv = threading.Condition()
oldest_key = next(iter(instances), None) decode_cv = threading.Condition()
while oldest_key is not None: instance_cv = threading.Condition()
value = instances[oldest_key]
if value[1] > time.time():
break
print(f"🔴Remove [HTTP:{oldest_key}, ZMQ:{value[0]}, stamp:{value[1]}]")
instances.pop(oldest_key, None)
oldest_key = next(iter(instances), None)
sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket): def _listen_for_register(poller, router_socket):
while True: while True:
...@@ -42,47 +94,81 @@ def _listen_for_register(poller, router_socket): ...@@ -42,47 +94,81 @@ def _listen_for_register(poller, router_socket):
# data: {"type": "P", "http_address": "ip:port", # data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"} # "zmq_address": "ip:port"}
data = msgpack.loads(message) data = msgpack.loads(message)
global prefill_instances
global instance_cv
global decode_instances
if data["type"] == "P": if data["type"] == "P":
global prefill_instances with instance_cv:
global prefill_cv if data["http_address"] not in prefill_instances:
with prefill_cv: prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
node = prefill_instances.get(data["http_address"], None) p_instance = prefill_instances[data["http_address"]]
prefill_instances[data["http_address"]] = ( p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
data["zmq_address"], if p_instance.is_ready():
time.time() + DEFAULT_PING_SECONDS, pending_prefill_ins.append(p_instance.http_address)
) logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
_remove_oldest_instances(prefill_instances) instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D": elif data["type"] == "D":
global decode_instances with instance_cv:
global decode_cv if data["http_address"] not in decode_instances:
with decode_cv: decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
node = decode_instances.get(data["http_address"], None) d_instance = decode_instances[data["http_address"]]
decode_instances[data["http_address"]] = ( d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
data["zmq_address"], if d_instance.is_ready():
time.time() + DEFAULT_PING_SECONDS, pending_decode_ins.append(d_instance.http_address)
) logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
_remove_oldest_instances(decode_instances) instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
prefill_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
p_instance.zmq_address=data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
decode_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
d_instance.zmq_address=data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
else: else:
print( print(
"Unexpected, Received message from %s, data: %s", "Unexpected, Received message from %s, data: %s",
remote_address, remote_address,
data, data,
) )
return
if node is None:
print(f"🔵Add [HTTP:{data['http_address']}, ZMQ:{data['zmq_address']}]")
zmq_context = None
def start_service_discovery(hostname, port): def start_service_discovery(hostname, port):
if not hostname: if not hostname:
hostname = socket.gethostname() hostname = socket.gethostname()
if port == 0: if port == 0:
raise ValueError("Port cannot be 0") raise ValueError("Port cannot be 0")
context = zmq.Context() # context = zmq.Context()
router_socket = context.socket(zmq.ROUTER) # router_socket = context.socket(zmq.ROUTER)
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}") router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller() poller = zmq.Poller()
...@@ -120,8 +206,110 @@ async def forward_request(url, data, request_id): ...@@ -120,8 +206,110 @@ async def forward_request(url, data, request_id):
yield content yield content
def unique_id_dispatch(prefill_instance : str,
decode_instance : str) :
global zmq_context
global sock_cache
global router_nccl
global pd_pair
pd_pair_id = prefill_instance.zmq_address + "_" + decode_instance.zmq_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
rank = 0
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
@app.route("/v1/completions", methods=["POST"]) @app.route("/v1/completions", methods=["POST"])
@app.route("/v1/chat/completions", methods=["POST"])
async def handle_request(): async def handle_request():
try: try:
original_request_data = await request.get_json() original_request_data = await request.get_json()
...@@ -129,45 +317,42 @@ async def handle_request(): ...@@ -129,45 +317,42 @@ async def handle_request():
prefill_request = original_request_data.copy() prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill # change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1 prefill_request["max_tokens"] = 1
if "max_completion_tokens" in prefill_request:
prefill_request["max_completion_tokens"] = 1
global count global count
global prefill_instances global prefill_instances
global prefill_cv global prefill_cv
with prefill_cv: with prefill_cv:
prefill_list = list(prefill_instances.items()) prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)] prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
prefill_zmq_addr = prefill_zmq_addr[0]
global decode_instances global decode_instances
global decode_cv global decode_cv
with decode_cv: with decode_cv:
decode_list = list(decode_instances.items()) decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)] decode_addr, decode_instance = decode_list[count % len(decode_list)]
decode_zmq_addr = decode_zmq_addr[0]
print( print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, " f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, " f"ZMQ:{prefill_instance.zmq_address}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]" f"ZMQ:{decode_instance.zmq_address}]"
) )
count += 1 count += 1
request_id = ( request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_" f"___prefill_addr_{prefill_instance.zmq_address}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}" f"{decode_instance.zmq_address}_{random_uuid()}"
) )
# finish prefill # finish prefill
async for _ in forward_request( async for _ in forward_request(
f"http://{prefill_addr}{request.path}", prefill_request, request_id f"http://{prefill_addr}/v1/completions", prefill_request, request_id
): ):
continue continue
# return decode # return decode
generator = forward_request( generator = forward_request(
f"http://{decode_addr}{request.path}", original_request_data, request_id f"http://{decode_addr}/v1/completions", original_request_data, request_id
) )
response = await make_response(generator) response = await make_response(generator)
response.timeout = None response.timeout = None
...@@ -186,5 +371,7 @@ async def handle_request(): ...@@ -186,5 +371,7 @@ async def handle_request():
if __name__ == "__main__": if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001) t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init()
app.run(host="0.0.0.0", port=10001) app.run(host="0.0.0.0", port=10001)
t.join() t.join()
t_1.join()
...@@ -155,6 +155,11 @@ KVConnectorFactory.register_connector( ...@@ -155,6 +155,11 @@ KVConnectorFactory.register_connector(
"P2pNcclConnector", "P2pNcclConnector",
) )
KVConnectorFactory.register_connector(
"DuSwiftConnector",
"vllm.distributed.kv_transfer.kv_connector.v1.du.du_swift_connector",
"DuSwiftConnector")
KVConnectorFactory.register_connector( KVConnectorFactory.register_connector(
"LMCacheConnectorV1", "LMCacheConnectorV1",
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector", "vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import atexit
import ctypes
import math
from dataclasses import dataclass
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
@dataclass
class MemoryBlock:
size: int
addr: int
"""A memory pool for managing pinned host memory allocations for tensors.
This class implements a buddy allocation system to efficiently manage pinned
host memory for tensor storage. It supports allocation, deallocation, and
tensor storage/retrieval operations.
Key Features:
- Uses power-of-two block sizes for efficient buddy allocation
- Supports splitting and merging of memory blocks
- Provides methods to store CUDA tensors in pinned host memory
- Allows loading tensors from pinned memory back to device
- Automatically cleans up memory on destruction
Attributes:
max_block_size (int): Maximum block size (rounded to nearest power of two)
min_block_size (int): Minimum block size (rounded to nearest power of two)
free_lists (dict): Dictionary of free memory blocks by size
allocated_blocks (dict): Dictionary of currently allocated blocks
base_tensor (torch.Tensor): Base pinned memory tensor
base_address (int): Base memory address of the pinned memory region
Example:
>>> pool = TensorMemoryPool(max_block_size=1024*1024)
>>> tensor = torch.randn(100, device='cuda')
>>> addr = pool.store_tensor(tensor)
>>> loaded_tensor = pool.load_tensor(addr, tensor.dtype,
... tensor.shape, 'cuda')
>>> pool.free(addr)
"""
class TensorMemoryPool:
"""Initializes the memory pool with given size constraints.
Args:
max_block_size (int): Maximum size of memory blocks to manage
min_block_size (int, optional): Minimum size of memory blocks
to manage. Defaults to 512.
Raises:
ValueError: If block sizes are invalid or max_block_size is less
than min_block_size
"""
def __init__(self, max_block_size: int, min_block_size: int = 128):
if max_block_size <= 0 or min_block_size <= 0:
raise ValueError("Block sizes must be positive")
if max_block_size < min_block_size:
raise ValueError(
"Max block size must be greater than min block size")
self.max_block_size = self._round_to_power_of_two(max_block_size)
self.min_block_size = self._round_to_power_of_two(min_block_size)
self.free_lists: dict[int, dict[int, MemoryBlock]] = {}
self.allocated_blocks: dict[int, MemoryBlock] = {}
self._initialize_free_lists()
self._allocate_pinned_memory()
atexit.register(self.cleanup)
def _round_to_power_of_two(self, size: int) -> int:
return 1 << (size - 1).bit_length()
def _initialize_free_lists(self):
size = self.max_block_size
while size >= self.min_block_size:
self.free_lists[size] = {}
size //= 2
def _allocate_pinned_memory(self):
self.base_tensor = torch.empty(self.max_block_size // 4,
dtype=torch.float32,
pin_memory=True)
self.base_address = self.base_tensor.data_ptr()
initial_block = MemoryBlock(size=self.max_block_size,
addr=self.base_address)
self.free_lists[self.max_block_size][
initial_block.addr] = initial_block
logger.debug("TensorMemoryPool, base_address:", self.base_address,
self.base_address % self.max_block_size)
def allocate(self, size: int) -> int:
"""Allocates a memory block of at least the requested size.
Args:
size (int): Minimum size of memory to allocate
Returns:
int: Address of the allocated memory block
Raises:
ValueError: If size is invalid or insufficient memory is available
"""
if size <= 0:
raise ValueError("Allocation size must be positive")
required_size = self._round_to_power_of_two(
max(size, self.min_block_size))
if required_size > self.max_block_size:
raise ValueError("Requested size exceeds maximum block size")
current_size = required_size
while current_size <= self.max_block_size:
if self.free_lists[current_size]:
_, block = self.free_lists[current_size].popitem()
self._split_block(block, required_size)
self.allocated_blocks[block.addr] = block
return block.addr
current_size *= 2
raise ValueError("Insufficient memory")
def _split_block(self, block: MemoryBlock, required_size: int):
while (block.size > required_size
and block.size // 2 >= self.min_block_size):
buddy_size = block.size // 2
buddy_addr = block.addr + buddy_size
buddy = MemoryBlock(size=buddy_size, addr=buddy_addr)
block.size = buddy_size
self.free_lists[buddy_size][buddy.addr] = buddy
def free(self, addr: int):
"""Frees an allocated memory block.
Args:
addr (int): Address of the block to free
Raises:
ValueError: If address is invalid or not allocated
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to free")
block = self.allocated_blocks.pop(addr)
self._merge_buddies(block)
def _merge_buddies(self, block: MemoryBlock):
MAX_MERGE_DEPTH = 30
depth = 0
while depth < MAX_MERGE_DEPTH:
buddy_offset = block.size if (block.addr - self.base_address) % (
2 * block.size) == 0 else -block.size
buddy_addr = block.addr + buddy_offset
buddy = self.free_lists[block.size].get(buddy_addr)
if buddy:
del self.free_lists[buddy.size][buddy.addr]
merged_addr = min(block.addr, buddy.addr)
merged_size = block.size * 2
block = MemoryBlock(size=merged_size, addr=merged_addr)
depth += 1
else:
break
self.free_lists[block.size][block.addr] = block
def store_tensor(self, tensor: torch.Tensor) -> int:
"""Stores a CUDA tensor in pinned host memory.
Args:
tensor (torch.Tensor): CUDA tensor to store
Returns:
int: Address where the tensor is stored
Raises:
ValueError: If tensor is not on CUDA or allocation fails
"""
if not tensor.is_cuda:
raise ValueError("Only CUDA tensors can be stored")
size = tensor.element_size() * tensor.numel()
addr = self.allocate(size)
block = self.allocated_blocks[addr]
if block.size < size:
self.free(addr)
raise ValueError(
f"Allocated block size {block.size} is smaller than "
f"required size {size}")
try:
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer,
dtype=tensor.dtype,
count=tensor.numel()).reshape(
tensor.shape)
except ValueError as err:
self.free(addr)
raise ValueError(f"Failed to create tensor view: {err}") from err
cpu_tensor.copy_(tensor)
return addr
def load_tensor(self, addr: int, dtype: torch.dtype,
shape: tuple[int, ...], device) -> torch.Tensor:
"""Loads a tensor from pinned host memory to the specified device.
Args:
addr (int): Address where tensor is stored
dtype (torch.dtype): Data type of the tensor
shape (tuple[int, ...]): Shape of the tensor
device: Target device for the loaded tensor
Returns:
torch.Tensor: The loaded tensor on the specified device
Raises:
ValueError: If address is invalid or sizes don't match
"""
if addr not in self.allocated_blocks:
raise ValueError("Invalid address to load")
block = self.allocated_blocks[addr]
num_elements = math.prod(shape)
dtype_size = torch.tensor([], dtype=dtype).element_size()
required_size = num_elements * dtype_size
if required_size > block.size:
raise ValueError("Requested tensor size exceeds block size")
buffer = (ctypes.c_byte * block.size).from_address(block.addr)
cpu_tensor = torch.frombuffer(buffer, dtype=dtype,
count=num_elements).reshape(shape)
cuda_tensor = torch.empty(shape, dtype=dtype, device=device)
cuda_tensor.copy_(cpu_tensor)
return cuda_tensor
def cleanup(self):
"""Cleans up all memory resources and resets the pool state."""
self.free_lists.clear()
self.allocated_blocks.clear()
if hasattr(self, 'base_tensor'):
del self.base_tensor
def __del__(self):
self.cleanup()
...@@ -47,6 +47,7 @@ if TYPE_CHECKING: ...@@ -47,6 +47,7 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_USE_FLASHINFER_SAMPLER: bool | None = None VLLM_USE_FLASHINFER_SAMPLER: bool | None = None
VLLM_PP_LAYER_PARTITION: str | None = None VLLM_PP_LAYER_PARTITION: str | None = None
VLLM_PP_LAYER_PARTITION_D: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int | None = 0 VLLM_CPU_KVCACHE_SPACE: int | None = 0
VLLM_CPU_OMP_THREADS_BIND: str = "" VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None VLLM_CPU_NUM_OF_RESERVED_CPU: int | None = None
...@@ -181,6 +182,7 @@ if TYPE_CHECKING: ...@@ -181,6 +182,7 @@ if TYPE_CHECKING:
VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
VLLM_DISABLE_REQUEST_ID_RANDOMIZATION: bool = False
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost" VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600 VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5600
VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998 VLLM_MOONCAKE_BOOTSTRAP_PORT: int = 8998
...@@ -282,6 +284,8 @@ if TYPE_CHECKING: ...@@ -282,6 +284,8 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
USE_FUSED_RMS_QUANT: bool = False USE_FUSED_RMS_QUANT: bool = False
VLLM_P2P_ASYNC: bool = False
VLLM_P2P_BUF_TOKENS: int = 30000
USE_FUSED_SILU_MUL_QUANT: bool = False USE_FUSED_SILU_MUL_QUANT: bool = False
VLLM_USE_PD_SPLIT: bool = False VLLM_USE_PD_SPLIT: bool = False
VLLM_USE_PP_SYNC: bool = False VLLM_USE_PP_SYNC: bool = False
...@@ -759,6 +763,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -759,6 +763,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
else None, else None,
# Pipeline stage partition strategy # Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION_D":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION_D", None),
# (CPU backend only) CPU key-value cache space. # (CPU backend only) CPU key-value cache space.
# default is None and will be set as 4 GB # default is None and will be set as 4 GB
"VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")) "VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0"))
...@@ -1350,6 +1359,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1350,6 +1359,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALLOW_INSECURE_SERIALIZATION": lambda: bool( "VLLM_ALLOW_INSECURE_SERIALIZATION": lambda: bool(
int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0")) int(os.getenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "0"))
), ),
# Temporary: skip adding random suffix to internal request IDs. May be
# needed for KV connectors that match request IDs across instances.
"VLLM_DISABLE_REQUEST_ID_RANDOMIZATION": lambda: bool(
int(os.getenv("VLLM_DISABLE_REQUEST_ID_RANDOMIZATION", "1"))
),
# IP address used for NIXL handshake between remote agents. # IP address used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_HOST": lambda: os.getenv( "VLLM_NIXL_SIDE_CHANNEL_HOST": lambda: os.getenv(
"VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost" "VLLM_NIXL_SIDE_CHANNEL_HOST", "localhost"
...@@ -1813,7 +1827,12 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1813,7 +1827,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vllm will use rmsquant fused op # vllm will use rmsquant fused op
"USE_FUSED_RMS_QUANT": "USE_FUSED_RMS_QUANT":
lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))), lambda: bool(int(os.getenv("USE_FUSED_RMS_QUANT", "0"))),
# vllm pd separation will be used async
"VLLM_P2P_ASYNC":
lambda: bool(int(os.getenv("VLLM_P2P_ASYNC", "0"))),
# pd separation p2p async buf tokens
"VLLM_P2P_BUF_TOKENS":
lambda: int(os.getenv("VLLM_P2P_BUF_TOKENS", "30000")),
# vllm will use silu_mul_quant fused op # vllm will use silu_mul_quant fused op
"USE_FUSED_SILU_MUL_QUANT": "USE_FUSED_SILU_MUL_QUANT":
lambda: (os.getenv("USE_FUSED_SILU_MUL_QUANT", "False").lower() in lambda: (os.getenv("USE_FUSED_SILU_MUL_QUANT", "False").lower() in
......
...@@ -6,6 +6,7 @@ import time ...@@ -6,6 +6,7 @@ import time
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Literal, cast from typing import Any, Literal, cast
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.exceptions import VLLMValidationError from vllm.exceptions import VLLMValidationError
from vllm.inputs import ( from vllm.inputs import (
...@@ -474,7 +475,14 @@ class InputProcessor: ...@@ -474,7 +475,14 @@ class InputProcessor:
" passed to vLLM; use the request_id field." " passed to vLLM; use the request_id field."
) )
request.external_req_id = request.request_id request.external_req_id = request.request_id
request.request_id = f"{request.external_req_id}-{random_uuid():.8}" if envs.VLLM_DISABLE_REQUEST_ID_RANDOMIZATION:
logger.warning_once(
"VLLM_DISABLE_REQUEST_ID_RANDOMIZATION is set and will be "
"removed in a future release. Duplicate externally-provided "
"request IDs may cause failures and/or subtle correctness errors."
)
else:
request.request_id = f"{request.external_req_id}-{random_uuid():.8}"
def process_inputs( def process_inputs(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment