Commit 294cc23a authored by xiabo's avatar xiabo
Browse files

解决pd分离非对称切分通信组过多问题

parent 84e5aba2
......@@ -9,15 +9,82 @@ import uuid
import aiohttp
import msgpack
import zmq
from typing import Any
from quart import Quart, make_response, request
from dataclasses import dataclass, field
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from collections import deque, defaultdict
import logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
# @dataclass
# class Request:
# request_id: str
# p_http_address: str = ""
# p_dp_rank: int = -1
# d_http_address: str = ""
# d_dp_rank: int = -1
@dataclass
class Instance:
ins_type: str = "P"
http_address: str = ""
zmq_address: str = ""
p_unique_id: bytes = b""
dp_size: int = 0
pp_size: int = 0
tp_size: int = 0
# [dp, pp, tp] : zmq_address
rank_table: dict[int, dict[int, dict[int, str]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
# [dp, pp, tp] : global rank
comm_rank_table: dict[int, dict[int, dict[int, int]]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict(dict))
)
def count_rank_table_elements(self):
count = 0
for first_dict in self.rank_table.values():
for second_dict in first_dict.values():
count += len(second_dict)
return count
def is_ready(self):
world_size = self.dp_size * self.pp_size * self.tp_size
inited_rank = self.count_rank_table_elements()
all_ranks_ready = world_size and inited_rank == world_size
if self.ins_type == "P" :
logger.info(f"""[Router] P is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
# return all_ranks_ready and self.p_unique_id != b""
return all_ranks_ready
else :
logger.info(f"""[Router] D is_ready? : {self.http_address} world_size = {world_size} inited_rank = {inited_rank}""")
return all_ranks_ready
count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
# prefill_instances: dict[str, str] = {} # http_address: zmq_address
# decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_instances: dict[str, Instance] = {}
decode_instances: dict[str, Instance] = {}
pending_prefill_ins: list[str] = []
pending_decode_ins: list[str] = []
ready_prefill_ins: list[str] = []
ready_decode_ins: list[str] = []
pd_pair : dict[str, bytes] = {}
router_nccl = NCCLLibrary()
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
instance_cv = threading.Condition()
sock_cache : dict[str, Any] = {}
def _listen_for_register(poller, router_socket):
while True:
......@@ -27,16 +94,61 @@ def _listen_for_register(poller, router_socket):
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
elif data["type"] == "D":
global instance_cv
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
if data["type"] == "P":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"])
p_instance = prefill_instances[data["http_address"]]
p_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add P rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "D":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"])
d_instance = decode_instances[data["http_address"]]
d_instance.rank_table[int(data["dp_rank"])][int(data["pp_rank"])][int(data["tp_rank"])] = data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
logger.info(f"""[Router] add D rank [{data["dp_rank"]}, {data["pp_rank"]}, {data["tp_rank"]}] : {data["zmq_address"]}""")
elif data["type"] == "P_init":
with instance_cv:
if data["http_address"] not in prefill_instances:
prefill_instances[data["http_address"]] = Instance(http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
prefill_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
p_instance = prefill_instances[data["http_address"]]
p_instance.dp_size=int(data["dp_size"])
p_instance.pp_size=int(data["pp_size"])
p_instance.tp_size=int(data["tp_size"])
p_instance.zmq_address=data["zmq_address"]
if p_instance.is_ready():
pending_prefill_ins.append(p_instance.http_address)
logger.info(f"""[Router] pending_prefill_ins appended {p_instance.http_address} ZMQ:{p_instance.zmq_address}""")
instance_cv.notify()
elif data["type"] == "D_init":
with instance_cv:
if data["http_address"] not in decode_instances:
decode_instances[data["http_address"]] = Instance(ins_type="D", http_address=data["http_address"], dp_size=int(data["dp_size"]), pp_size=int(data["pp_size"]), tp_size=int(data["tp_size"]))
decode_instances[data["http_address"]].zmq_address = data["zmq_address"]
continue
d_instance = decode_instances[data["http_address"]]
d_instance.dp_size=int(data["dp_size"])
d_instance.pp_size=int(data["pp_size"])
d_instance.tp_size=int(data["tp_size"])
d_instance.zmq_address=data["zmq_address"]
if d_instance.is_ready():
pending_decode_ins.append(d_instance.http_address)
logger.info(f"""[Router] pending_decode_ins appended {d_instance.http_address} ZMQ:{d_instance.zmq_address}""")
instance_cv.notify()
else:
print(
"Unexpected, Received message from %s, data: %s",
......@@ -44,6 +156,7 @@ def _listen_for_register(poller, router_socket):
data,
)
zmq_context = None
def start_service_discovery(hostname, port):
if not hostname:
......@@ -51,8 +164,11 @@ def start_service_discovery(hostname, port):
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
# context = zmq.Context()
# router_socket = context.socket(zmq.ROUTER)
global zmq_context
zmq_context = zmq.Context()
router_socket = zmq_context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
......@@ -90,6 +206,109 @@ async def forward_request(url, data, request_id):
yield content
def unique_id_dispatch(prefill_instance : str,
decode_instance : str) :
global zmq_context
global sock_cache
global router_nccl
global pd_pair
pd_pair_id = prefill_instance.zmq_address + "_" + decode_instance.zmq_address
if pd_pair_id in pd_pair:
logger.info(f"""[Router] pd pair {pd_pair_id} already exist""")
return
logger.info(f"""[Router] initing pd pair {pd_pair_id}""")
unique_id = router_nccl.ncclGetUniqueId()
unique_id = bytes(unique_id.internal)
rank = 0
p_rank_num = prefill_instance.dp_size * prefill_instance.pp_size * prefill_instance.tp_size
d_rank_num = decode_instance.dp_size * decode_instance.pp_size * decode_instance.tp_size
world_size = p_rank_num + d_rank_num
for dp_rank in range(prefill_instance.dp_size):
for pp_rank in range(prefill_instance.pp_size):
for tp_rank in range(prefill_instance.tp_size):
if prefill_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[prefill_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
prefill_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [P] [{dp_rank}, {pp_rank}, {tp_rank}]""")
for dp_rank in range(decode_instance.dp_size):
for pp_rank in range(decode_instance.pp_size):
for tp_rank in range(decode_instance.tp_size):
if decode_instance.rank_table[dp_rank][pp_rank][tp_rank] not in sock_cache:
sock = zmq_context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, "router")
sock.connect(f"tcp://{decode_instance.rank_table[dp_rank][pp_rank][tp_rank]}")
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]] = sock
data = {
"cmd": "comm_init",
"pd_pair_id": pd_pair_id,
"unique_id" : unique_id,
"world_size": world_size,
"rank": rank
}
sock_cache[decode_instance.rank_table[dp_rank][pp_rank][tp_rank]].send(msgpack.dumps(data))
decode_instance.comm_rank_table[dp_rank][pp_rank][tp_rank] = rank
rank += 1
logger.info(f"""[Router] dispatch unique_id of pd pair {pd_pair_id} to [D] [{dp_rank}, {pp_rank}, {tp_rank}]""")
pd_pair[pd_pair_id] = unique_id
def pd_pair_init():
global prefill_instances
global decode_instances
global pending_prefill_ins
global pending_decode_ins
global ready_prefill_ins
global ready_decode_ins
global instance_cv
while True:
with instance_cv:
while len(pending_prefill_ins) == 0 and len(pending_decode_ins) == 0:
logger.info(f"""[Router] pd_pair_init: waiting for instance_cv""")
instance_cv.wait()
logger.info(f"""[Router] pd_pair_init: instance_cv finished waiting""")
while pending_prefill_ins:
p_ins = pending_prefill_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {p_ins} from pending_prefill_ins""")
for d_ins in ready_decode_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_prefill_ins.append(p_ins)
pending_prefill_ins.remove(p_ins)
while pending_decode_ins:
d_ins = pending_decode_ins[0]
logger.info(f"""[Router] pd_pair_init: processing {d_ins} from pending_decode_ins""")
for p_ins in ready_prefill_ins:
unique_id_dispatch(prefill_instances[p_ins], decode_instances[d_ins])
ready_decode_ins.append(d_ins)
pending_decode_ins.remove(d_ins)
def start_pd_pair_init():
_thread = threading.Thread(
target=pd_pair_init, daemon=True
)
_thread.start()
return _thread
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
......@@ -104,24 +323,25 @@ async def handle_request():
global prefill_cv
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
prefill_addr, prefill_instance = prefill_list[count % len(prefill_list)]
global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
decode_addr, decode_instance = decode_list[count % len(decode_list)]
print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]"
f"ZMQ:{prefill_instance.zmq_address}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_instance.zmq_address}]"
)
count += 1
request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}"
f"___prefill_addr_{prefill_instance.zmq_address}___decode_addr_"
f"{decode_instance.zmq_address}_{random_uuid()}"
)
# finish prefill
......@@ -151,5 +371,7 @@ async def handle_request():
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
t_1 = start_pd_pair_init()
app.run(host="0.0.0.0", port=10001)
t.join()
t_1.join()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import socket
import threading
import uuid
import aiohttp
import msgpack
import zmq
from quart import Quart, make_response, request
count = 0
prefill_instances: dict[str, str] = {} # http_address: zmq_address
decode_instances: dict[str, str] = {} # http_address: zmq_address
prefill_cv = threading.Condition()
decode_cv = threading.Condition()
def _listen_for_register(poller, router_socket):
while True:
socks = dict(poller.poll())
if router_socket in socks:
remote_address, message = router_socket.recv_multipart()
# data: {"type": "P", "http_address": "ip:port",
# "zmq_address": "ip:port"}
data = msgpack.loads(message)
if data["type"] == "P":
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_instances[data["http_address"]] = data["zmq_address"]
elif data["type"] == "D":
global decode_instances
global decode_cv
with decode_cv:
decode_instances[data["http_address"]] = data["zmq_address"]
else:
print(
"Unexpected, Received message from %s, data: %s",
remote_address,
data,
)
def start_service_discovery(hostname, port):
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
context = zmq.Context()
router_socket = context.socket(zmq.ROUTER)
router_socket.bind(f"tcp://{hostname}:{port}")
poller = zmq.Poller()
poller.register(router_socket, zmq.POLLIN)
_listener_thread = threading.Thread(
target=_listen_for_register, args=[poller, router_socket], daemon=True
)
_listener_thread.start()
return _listener_thread
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
app = Quart(__name__)
def random_uuid() -> str:
return str(uuid.uuid4().hex)
async def forward_request(url, data, request_id):
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id,
}
async with session.post(url=url, json=data, headers=headers) as response:
if response.status == 200:
if True:
async for chunk_bytes in response.content.iter_chunked(1024):
yield chunk_bytes
else:
content = await response.read()
yield content
@app.route("/v1/completions", methods=["POST"])
async def handle_request():
try:
original_request_data = await request.get_json()
prefill_request = original_request_data.copy()
# change max_tokens = 1 to let it only do prefill
prefill_request["max_tokens"] = 1
global count
global prefill_instances
global prefill_cv
with prefill_cv:
prefill_list = list(prefill_instances.items())
prefill_addr, prefill_zmq_addr = prefill_list[count % len(prefill_list)]
global decode_instances
global decode_cv
with decode_cv:
decode_list = list(decode_instances.items())
decode_addr, decode_zmq_addr = decode_list[count % len(decode_list)]
print(
f"handle_request count: {count}, [HTTP:{prefill_addr}, "
f"ZMQ:{prefill_zmq_addr}] 👉 [HTTP:{decode_addr}, "
f"ZMQ:{decode_zmq_addr}]"
)
count += 1
request_id = (
f"___prefill_addr_{prefill_zmq_addr}___decode_addr_"
f"{decode_zmq_addr}_{random_uuid()}"
)
# finish prefill
async for _ in forward_request(
f"http://{prefill_addr}/v1/completions", prefill_request, request_id
):
continue
# return decode
generator = forward_request(
f"http://{decode_addr}/v1/completions", original_request_data, request_id
)
response = await make_response(generator)
response.timeout = None
return response
except Exception as e:
import sys
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server")
print(e)
print("".join(traceback.format_exception(*exc_info)))
if __name__ == "__main__":
t = start_service_discovery("0.0.0.0", 30001)
app.run(host="0.0.0.0", port=10001)
t.join()
......@@ -123,7 +123,8 @@ class P2pNcclEngine:
if self.remote_tp_size % self.tp_size != 0:
logger.error(" the Prefill TP size must be less than or equal to the Decode TP size!!!!")
self.multp = int(self.remote_tp_size / self.tp_size)
self.multiple_machines = self.config.get_from_extra_config(
"enable_multiple_machines", False)
port = int(self.config.kv_port) + port_offset
if port == 0:
raise ValueError("Port cannot be 0")
......@@ -203,12 +204,16 @@ class P2pNcclEngine:
self._listener_thread.start()
self._ping_thread = None
# if port_offset == 0 and self.proxy_address != "":
if self.proxy_address != "":
if self.multiple_machines:
if port_offset == 0 and self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping,
daemon=True)
self._ping_thread.start()
else:
if self.proxy_address != "":
self._ping_thread = threading.Thread(target=self._ping_new,
daemon=True)
self._ping_thread.start()
logger.info(
"💯P2pNcclEngine init, rank:%d, local_rank:%d, http_address:%s, "
"zmq_address:%s, proxy_address:%s, send_type:%s, buffer_size_"
......@@ -267,7 +272,7 @@ class P2pNcclEngine:
p_ip, p_port = self.parse_request_id(request_id, False)
pd_pair_id = p_ip + ":" + str(p_port) + "_" + remote_ip + ":" + str(remote_port)
# remote_port = 22001
if not self.enable_asymmetric_p2p:
remote_address = remote_ip + ":" + str(remote_port + self.rank)
remote_addr = RemoteAddr(pd_pair_id, remote_address, self.rank + self.pp_size * self.tp_size)
......@@ -279,7 +284,7 @@ class P2pNcclEngine:
remote_pp_rank = self.compute_remote_pp_rank(layer_name)
items: list[Any] = []
# remote_tp_rank = self.tp_rank * self.multp
for d_tp_rank in range(self.remote_tp_size):
for mul_tp in range(self.multp):
if self.tp_rank + mul_tp * self.tp_size == d_tp_rank:
......@@ -306,7 +311,7 @@ class P2pNcclEngine:
if self.send_type == "PUT":
return all(
self.send_sync(item) for item in self.get_send_queue_items(
self._send_sync_new(item) for item in self.get_send_queue_items(
request_id, layer_name, tensor, is_mla))
if self.send_type == "PUT_ASYNC":
......@@ -627,6 +632,9 @@ class P2pNcclEngine:
if (envs.VLLM_ENABLE_TBO or envs.VLLM_P2P_ASYNC) and tbo_evt is not None:
self.send_stream.wait_event(tbo_evt)
self._send_kv_p2p_sync(tensor_id, kv_layer, slot_mapping, remote_address)
else:
if self.multiple_machines:
self._send_sync(tensor_id, tensor, remote_address)
else:
# logger.info(f"""=============xiabo tensor_id:{tensor_id} remote_address:{remote_address}""")
self._send_sync_new(tensor_id, tensor, remote_address)
......@@ -734,7 +742,7 @@ class P2pNcclEngine:
"pd_pair_id": remote_address.pd_pair_id,
"comm_rank": rank
}
# logger.info(f"""_send_sync_new:{data}""")
logger.info(f"""_send_sync_new:{data}""")
sock.send(msgpack.dumps(data))
response = sock.recv()
......@@ -830,18 +838,20 @@ class P2pNcclEngine:
return finished_sending or None, finished_recving or None
def _ping(self):
# sock = self.context.socket(zmq.DEALER)
# sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
# logger.debug("ping start, zmq_address:%s", self.zmq_address)
# sock.connect(f"tcp://{self.proxy_address}")
# data = {
# "type": "P" if self.config.is_kv_producer else "D",
# "http_address": self.http_address,
# "zmq_address": self.zmq_address
# }
# while True:
# sock.send(msgpack.dumps(data))
# time.sleep(3)
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
sock.connect(f"tcp://{self.proxy_address}")
data = {
"type": "P" if self.config.is_kv_producer else "D",
"http_address": self.http_address,
"zmq_address": self.zmq_address
}
while True:
sock.send(msgpack.dumps(data))
time.sleep(3)
def _ping_new(self):
sock = self.context.socket(zmq.DEALER)
sock.setsockopt_string(zmq.IDENTITY, self.zmq_address)
logger.debug("ping start, zmq_address:%s", self.zmq_address)
......@@ -856,7 +866,7 @@ class P2pNcclEngine:
"pp_size" : self.pp_size,
"tp_size" : self.tp_size
}
logger.info(f"""_ping data:{data}""")
# logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
data = {
"type": "P" if self.config.is_kv_producer else "D",
......@@ -867,7 +877,7 @@ class P2pNcclEngine:
"zmq_address": self.zmq_address
}
# while True:
logger.info(f"""_ping data:{data}""")
# logger.info(f"""_ping data:{data}""")
sock.send(msgpack.dumps(data))
# time.sleep(3)
......@@ -904,10 +914,44 @@ class P2pNcclEngine:
if self._ping_thread is not None:
self._ping_thread.join()
def get_pp_indices_d(self, num_hidden_layers: int, pp_rank: int,
pp_size: int) -> tuple[int, int]:
partition_list_str = envs.VLLM_PP_LAYER_PARTITION_D
if partition_list_str is not None:
try:
partitions = [
int(layer) for layer in partition_list_str.split(",")
]
except ValueError as err:
raise ValueError("Invalid partition string: {}".format(
partition_list_str)) from err
if len(partitions) != pp_size:
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
if sum(partitions) != num_hidden_layers:
raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
else:
layers_per_partition = num_hidden_layers // pp_size
partitions = [layers_per_partition for _ in range(pp_size)]
if remaining_layers := num_hidden_layers % pp_size:
for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info(
"Hidden layers were unevenly partitioned: [%s]. "
"This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION_D environment variable",
",".join(str(p) for p in partitions))
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer)
def compute_remote_pp_rank(self, layer_name: str) -> int:
current_layer_idx = extract_layer_index(layer_name)
for d_pp_rank in range(self.remote_pp_size):
start, end = get_pp_indices(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
start, end = self.get_pp_indices_d(self.total_num_hidden_layers, d_pp_rank, self.remote_pp_size)
# logger.info(f"""compute_remote_pp_rank : current_layer_idx:{current_layer_idx} start:{start} end:{end}""")
if (current_layer_idx == self.total_num_hidden_layers):
return self.remote_pp_size - 1
......
......@@ -42,6 +42,7 @@ if TYPE_CHECKING:
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_PP_LAYER_PARTITION_D: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0
VLLM_CPU_OMP_THREADS_BIND: str = ""
VLLM_CPU_NUM_OF_RESERVED_CPU: int = 0
......@@ -487,6 +488,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_PP_LAYER_PARTITION":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION_D":
lambda: os.getenv("VLLM_PP_LAYER_PARTITION_D", None),
# (CPU backend only) CPU key-value cache space.
# default is 4 GiB
"VLLM_CPU_KVCACHE_SPACE":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment