Unverified Commit 4c31ae9f authored by Teng Ma's avatar Teng Ma Committed by GitHub
Browse files

[PD] Support KV transfer with mooncake (#4880)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
Co-authored-by: default avatarXuchun Shang <xuchun.shang@linux.alibaba.com>
Co-authored-by: default avatarshangmingc <csmthu@gmail.com>
parent f730362e
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from enum import Enum import struct
from typing import Optional import threading
from functools import cache
from typing import Dict, List, Optional, Tuple
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import zmq
from aiohttp import web
from sglang.srt.disaggregation.transfer_engine.mooncake import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
src_groups = []
dst_groups = []
current_src = [src_indices[0]]
current_dst = [dst_indices[0]]
for i in range(1, len(src_indices)):
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
if src_contiguous and dst_contiguous:
current_src.append(src_indices[i])
current_dst.append(dst_indices[i])
else:
src_groups.append(current_src)
dst_groups.append(current_dst)
current_src = [src_indices[i]]
current_dst = [dst_indices[i]]
src_groups.append(current_src)
dst_groups.append(current_dst)
return src_groups, dst_groups
class KVArgs: class KVArgs:
engine_rank: int engine_rank: int
kv_data_ptrs: list[int] kv_data_ptrs: list[int]
...@@ -21,10 +55,6 @@ class KVArgs: ...@@ -21,10 +55,6 @@ class KVArgs:
ib_device: str ib_device: str
class KVManager:
def __init__(self, args: KVArgs): ...
class KVPoll: class KVPoll:
Failed = 0 Failed = 0
Bootstrapping = 1 Bootstrapping = 1
...@@ -33,49 +63,434 @@ class KVPoll: ...@@ -33,49 +63,434 @@ class KVPoll:
Success = 4 Success = 4
RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]]
WaitingPoolType = Dict[
int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int]
]
KVSENDER_POLLING_PORT = 17788
KVRECEIVER_POLLING_PORT = 27788
class KVManager:
# TODO: make it general and support multiple transfer backend before merging
def __init__(self, args: KVArgs, disaggregation_mode: DisaggregationMode):
self.engine = MooncakeTransferEngine()
self.kv_args = args
self.disaggregation_mode = disaggregation_mode
self.request_pool: RequestPoolType = {}
self.request_status: Dict[int, KVPoll] = {}
self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.waiting_pool: WaitingPoolType = {}
self.transfer_event = threading.Event()
self.start_prefill_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread()
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def register_buffer_to_engine(self):
for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
):
self.engine.register(kv_data_ptr, kv_data_len)
for aux_data_ptr, aux_data_len in zip(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
):
self.engine.register(aux_data_ptr, aux_data_len)
@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int64],
dst_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
):
layer_num = int(len(self.kv_args.kv_data_ptrs) / 2)
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)
for layer_id in range(layer_num):
prefill_key_layer_ptr = self.kv_args.kv_data_ptrs[layer_id]
key_item_len = self.kv_args.kv_item_lens[layer_id]
prefill_value_layer_ptr = self.kv_args.kv_data_ptrs[layer_num + layer_id]
value_item_len = self.kv_args.kv_item_lens[layer_num + layer_id]
decode_key_layer_ptr = dst_ptrs[layer_id]
decode_value_layer_ptr = dst_ptrs[layer_num + layer_id]
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
prefill_key_addr = (
prefill_key_layer_ptr + int(prefill_index[0]) * key_item_len
)
decode_key_addr = (
decode_key_layer_ptr + int(decode_index[0]) * key_item_len
)
# TODO: mooncake transfer engine can do async transfer. Do async later
status = self.engine.transfer_sync(
mooncake_session_id,
prefill_key_addr,
decode_key_addr,
key_item_len * len(prefill_index),
)
if status != 0:
return status
prefill_value_addr = (
prefill_value_layer_ptr + int(prefill_index[0]) * value_item_len
)
decode_value_addr = (
decode_value_layer_ptr + int(decode_index[0]) * value_item_len
)
# TODO: mooncake transfer engine can do async transfer. Do async later
status = self.engine.transfer_sync(
mooncake_session_id,
prefill_value_addr,
decode_value_addr,
value_item_len * len(prefill_index),
)
if status != 0:
return status
return 0
def send_aux(
self,
mooncake_session_id: str,
prefill_aux_index: int,
dst_aux_ptrs: list[int],
dst_aux_index: int,
):
aux_item_len = self.kv_args.aux_item_lens[0]
prefill_aux_addr = (
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
)
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
# TODO: mooncake transfer engine can do async transfer. Do async later
# Not sure about the amount of aux data, maybe transfer it by zmq is more effective
status = self.engine.transfer_sync(
mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
)
return status
def sync_status_to_decode_endpoint(self, remote: str, room: int):
if ":" in remote:
remote = remote.split(":")[0]
self._connect(
"tcp://"
+ remote
+ ":"
+ str(KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank)
).send_multipart(
[
str(room).encode("ascii"),
str(self.request_status[room]).encode("ascii"),
]
)
def start_prefill_thread(self):
sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(sender_rank_port))
def prefill_thread():
while True:
(
endpoint,
mooncake_session_id,
bootstrap_room,
dst_ptrs,
dst_kv_indices,
dst_aux_ptrs,
dst_aux_index,
) = self.server_socket.recv_multipart()
if bootstrap_room.decode("ascii") == "None":
continue
endpoint = endpoint.decode("ascii")
mooncake_session_id = mooncake_session_id.decode("ascii")
bootstrap_room = int(bootstrap_room.decode("ascii"))
dst_ptrs = list(struct.unpack(f"{len(dst_ptrs)//8}Q", dst_ptrs))
dst_kv_indices = np.frombuffer(dst_kv_indices, dtype=np.int64)
dst_aux_ptrs = list(
struct.unpack(f"{len(dst_aux_ptrs)//8}Q", dst_aux_ptrs)
)
dst_aux_index = int(dst_aux_index.decode("ascii"))
self.waiting_pool[bootstrap_room] = (
endpoint,
mooncake_session_id,
dst_ptrs,
dst_kv_indices,
dst_aux_ptrs,
dst_aux_index,
)
self.transfer_event.set()
threading.Thread(target=prefill_thread).start()
def transfer_thread():
while True:
self.transfer_event.wait()
self.transfer_event.clear()
bootstrap_room_ready = self.request_pool.keys()
bootstrap_room_request = self.waiting_pool.keys()
for room in list(bootstrap_room_request):
if room not in list(bootstrap_room_ready):
continue
status = KVPoll.Transferring
self.request_status[room] = status
(
endpoint,
mooncake_session_id,
dst_ptrs,
dst_kv_indices,
dst_aux_ptrs,
dst_aux_index,
) = self.waiting_pool.pop(room)
self.sync_status_to_decode_endpoint(endpoint, room)
(
prefill_kv_indices,
prefill_aux_index,
) = self.request_pool.pop(room)
ret = self.send_kvcache(
mooncake_session_id,
prefill_kv_indices,
dst_ptrs,
dst_kv_indices,
)
if ret != 0:
status = KVPoll.Failed
self.sync_status_to_decode_endpoint(endpoint, room)
continue
ret = self.send_aux(
mooncake_session_id,
prefill_aux_index,
dst_aux_ptrs,
dst_aux_index,
)
if ret != 0:
status = KVPoll.Failed
else:
status = KVPoll.Success
self.request_status[room] = status
self.sync_status_to_decode_endpoint(endpoint, room)
threading.Thread(target=transfer_thread).start()
def start_decode_thread(self):
receiver_rank_port = KVRECEIVER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(receiver_rank_port))
def decode_thread():
while True:
(bootstrap_room, status) = self.server_socket.recv_multipart()
status = int(status.decode("ascii"))
bootstrap_room = int(bootstrap_room.decode("ascii"))
self.request_status[bootstrap_room] = status
threading.Thread(target=decode_thread).start()
def enqueue_request(
self,
bootstrap_room: int,
kv_indices: npt.NDArray[np.int64],
aux_index: Optional[int],
):
self.request_pool[bootstrap_room] = (kv_indices, aux_index)
self.request_status[bootstrap_room] = KVPoll.WaitingForInput
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_event.set()
def check_status(self, bootstrap_room: int):
if (
self.disaggregation_mode == DisaggregationMode.DECODE
and self.request_status[bootstrap_room] == KVPoll.Success
):
if bootstrap_room in self.request_pool:
self.request_pool.pop(bootstrap_room)
return self.request_status[bootstrap_room]
def set_status(self, bootstrap_room: int, status: KVPoll):
self.request_status[bootstrap_room] = status
def get_localhost(self):
return self.engine.get_localhost()
def get_session_id(self):
return self.engine.get_session_id()
class KVSender: class KVSender:
def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int): def __init__(self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: int):
self.has_sent = False self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput)
self.aux_index = None
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): ... def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.aux_index = aux_index
self.num_kv_indices = num_kv_indices
def send(self, kv_indices: npt.NDArray[np.int32]): def send(self, kv_indices: npt.NDArray[np.int64]):
self.has_sent = True self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, self.aux_index)
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
if self.has_sent is False: return self.kv_mgr.check_status(self.bootstrap_room)
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
class KVReceiver: class KVReceiver:
def __init__( def __init__(
self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None self, mgr: KVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None
): ):
self.has_init = False self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.prefill_server_url = (
bootstrap_addr.split(":")[0]
+ ":"
+ str(KVSENDER_POLLING_PORT + self.kv_mgr.kv_args.engine_rank)
)
self.decode_ip = self.kv_mgr.get_localhost()
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): @cache
self.has_init = True def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, aux_index)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
self.decode_ip.encode("ascii"),
self.session_id.encode("ascii"),
str(self.bootstrap_room).encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
]
)
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
if self.has_init is False: return self.kv_mgr.check_status(self.bootstrap_room)
# Assume handshake completed instantly
return KVPoll.WaitingForInput
else:
# Assume transfer completed instantly
return KVPoll.Success
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVReceiver Exception") raise Exception("Fake KVReceiver Exception")
class KVBootstrapServer: class KVBootstrapServer:
def __init__(self, port: int): ... def __init__(self, port: int):
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata)
async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "")
if request.method == "GET":
return await self._handle_get(key)
elif request.method == "PUT":
return await self._handle_put(key, request)
elif request.method == "DELETE":
return await self._handle_delete(key)
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_get(self, key):
async with self.lock:
value = self.store.get(key)
if value is None:
return web.Response(
text="metadata not found", status=404, content_type="application/json"
)
return web.Response(body=value, status=200, content_type="application/json")
async def _handle_put(self, key, request):
data = await request.read()
async with self.lock:
self.store[key] = data
return web.Response(
text="metadata updated", status=200, content_type="application/json"
)
async def _handle_delete(self, key):
async with self.lock:
if key not in self.store:
return web.Response(
text="metadata not found",
status=404,
content_type="application/json",
)
del self.store[key]
return web.Response(
text="metadata deleted", status=200, content_type="application/json"
)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ... def poll(self) -> KVPoll: ...
...@@ -24,11 +24,13 @@ import logging ...@@ -24,11 +24,13 @@ import logging
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
import numpy as np
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
poll_and_all_reduce, poll_and_all_reduce,
) )
...@@ -115,7 +117,7 @@ class DecodePreallocQueue: ...@@ -115,7 +117,7 @@ class DecodePreallocQueue:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
] ]
kv_args.ib_device = "mock-ib-device" kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args) kv_manager = KVManager(kv_args, DisaggregationMode("decode"))
return kv_manager return kv_manager
def add(self, req: Req) -> None: def add(self, req: Req) -> None:
...@@ -186,6 +188,7 @@ class DecodePreallocQueue: ...@@ -186,6 +188,7 @@ class DecodePreallocQueue:
] ]
.cpu() .cpu()
.numpy() .numpy()
.astype(np.int64)
) )
decode_req.metadata_buffer_index = ( decode_req.metadata_buffer_index = (
......
...@@ -26,6 +26,7 @@ import torch ...@@ -26,6 +26,7 @@ import torch
from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVSender
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
DisaggregationMode,
ReqToMetadataIdxAllocator, ReqToMetadataIdxAllocator,
poll_and_all_reduce, poll_and_all_reduce,
) )
...@@ -95,7 +96,7 @@ class PrefillBootstrapQueue: ...@@ -95,7 +96,7 @@ class PrefillBootstrapQueue:
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
] ]
kv_args.ib_device = "mock-ib-device" kv_args.ib_device = "mock-ib-device"
kv_manager = KVManager(kv_args) kv_manager = KVManager(kv_args, DisaggregationMode("prefill"))
return kv_manager return kv_manager
def add(self, req: Req) -> None: def add(self, req: Req) -> None:
......
import json
import logging
import os
import uuid
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class MooncakeTransferEngineConfig:
local_hostname: str
metadata_server: str
protocol: str
device_name: str
@staticmethod
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
"""Load the config from a JSON file."""
with open(file_path) as fin:
config = json.load(fin)
return MooncakeTransferEngineConfig(
local_hostname=config.get("local_hostname", None),
metadata_server=config.get("metadata_server"),
protocol=config.get("protocol", "rdma"),
device_name=config.get("device_name", ""),
)
@staticmethod
def load_from_env() -> "MooncakeTransferEngineConfig":
"""Load config from a file specified in the environment variable."""
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
if config_file_path is None:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
)
return MooncakeTransferEngineConfig.from_file(config_file_path)
class MooncakeTransferEngine:
def __init__(self):
try:
from mooncake.engine import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run SGLang with MooncakeTransferEngine."
) from e
self.engine = TransferEngine()
try:
self.config = MooncakeTransferEngineConfig.load_from_env()
logger.info("Mooncake Configuration loaded successfully.")
except ValueError as e:
logger.error(e)
raise
except Exception as exc:
logger.error("An error occurred while loading the configuration: %s", exc)
raise
self.config = MooncakeTransferEngineConfig.load_from_env()
session_suffix = "_" + str(uuid.uuid4())
self.session_id = self.config.local_hostname + session_suffix
self.initialize(
self.session_id,
self.config.metadata_server,
self.config.protocol,
self.config.device_name,
)
def register(self, ptr, length):
self.engine.register_memory(ptr, length)
def deregister(self, ptr):
self.engine.unregister_memory(ptr)
def initialize(
self,
local_hostname: str,
metadata_server: str,
protocol: str,
device_name: str,
) -> None:
"""Initialize the mooncake instance."""
self.engine.initialize(local_hostname, metadata_server, protocol, device_name)
def transfer_sync(
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
) -> int:
"""Synchronously transfer data to the specified address."""
ret = self.engine.transfer_sync_write(
session_id, buffer, peer_buffer_address, length
)
if ret < 0:
logger.error("Transfer Return Error")
raise Exception("Transfer Return Error")
return ret
def get_localhost(self):
return self.config.local_hostname
def get_session_id(self):
return self.session_id
...@@ -95,6 +95,10 @@ class GenerateReqInput: ...@@ -95,6 +95,10 @@ class GenerateReqInput:
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_room: Optional[int] = None
def normalize_batch_and_arguments(self): def normalize_batch_and_arguments(self):
""" """
Normalize the batch size and arguments for the request. Normalize the batch size and arguments for the request.
...@@ -435,6 +439,10 @@ class TokenizedGenerateReqInput: ...@@ -435,6 +439,10 @@ class TokenizedGenerateReqInput:
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
# For disaggregated inference
bootstrap_host: Optional[str] = None
bootstrap_room: Optional[int] = None
@dataclass @dataclass
class EmbeddingReqInput: class EmbeddingReqInput:
......
...@@ -390,6 +390,8 @@ class Req: ...@@ -390,6 +390,8 @@ class Req:
custom_logit_processor: Optional[str] = None, custom_logit_processor: Optional[str] = None,
return_hidden_states: bool = False, return_hidden_states: bool = False,
eos_token_ids: Optional[Set[int]] = None, eos_token_ids: Optional[Set[int]] = None,
bootstrap_host: Optional[str] = None,
bootstrap_room: Optional[int] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
...@@ -523,8 +525,8 @@ class Req: ...@@ -523,8 +525,8 @@ class Req:
self.lora_path = lora_path self.lora_path = lora_path
# For disaggregation # For disaggregation
self.bootstrap_host: str = "0.0.0.0" self.bootstrap_host: str = bootstrap_host
self.bootstrap_room: Optional[int] = None self.bootstrap_room: Optional[int] = bootstrap_room
self.disagg_kv_sender: Optional[KVSender] = None self.disagg_kv_sender: Optional[KVSender] = None
# used for warmup because we don't have a pair yet when init # used for warmup because we don't have a pair yet when init
......
...@@ -836,6 +836,8 @@ class Scheduler( ...@@ -836,6 +836,8 @@ class Scheduler(
custom_logit_processor=custom_logit_processor, custom_logit_processor=custom_logit_processor,
return_hidden_states=recv_req.return_hidden_states, return_hidden_states=recv_req.return_hidden_states,
eos_token_ids=self.model_config.hf_eos_token_id, eos_token_ids=self.model_config.hf_eos_token_id,
bootstrap_host=recv_req.bootstrap_host,
bootstrap_room=recv_req.bootstrap_room,
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
req.queue_time_start = time.time() req.queue_time_start = time.time()
......
...@@ -452,6 +452,8 @@ class TokenizerManager: ...@@ -452,6 +452,8 @@ class TokenizerManager:
top_logprobs_num, top_logprobs_num,
token_ids_logprob, token_ids_logprob,
obj.stream, obj.stream,
bootstrap_host=obj.bootstrap_host,
bootstrap_room=obj.bootstrap_room,
lora_path=obj.lora_path, lora_path=obj.lora_path,
input_embeds=input_embeds, input_embeds=input_embeds,
session_params=session_params, session_params=session_params,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment