Unverified Commit 32eb6e96 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Organize sampling batch info better (#1562)

parent e0b5dbce
......@@ -96,7 +96,9 @@ class Scheduler:
if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}")
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
......@@ -141,9 +143,6 @@ class Scheduler:
nccl_port=port_args.nccl_ports[0],
)
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Get token and memory info from the tp worker
(
......@@ -154,6 +153,9 @@ class Scheduler:
self.random_seed,
) = self.tp_worker.get_token_and_memory_info()
set_random_seed(self.random_seed)
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Print debug info
logger.info(
......
......@@ -87,7 +87,9 @@ class TokenizerManager:
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_scheduler = context.socket(zmq.PUSH)
self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}")
self.send_to_scheduler.connect(
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
# Read model args
self.model_path = server_args.model_path
......
......@@ -30,6 +30,7 @@ class ReqToTokenPool:
def __init__(self, size: int, max_context_len: int, device: str):
self.size = size
self.max_context_len = max_context_len
self.free_slots = list(range(size))
self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device=device
......@@ -54,7 +55,7 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size))
class BaseTokenToKVPool(ABC):
class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""
def __init__(
......@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = np.arange(1, self.size + 1)
@abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abstractmethod
def set_kv_buffer(
self,
layer_id: int,
......
......@@ -411,8 +411,8 @@ class ModelRunner:
device = "cuda"
self.req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1,
self.model_config.context_len + 4,
size=max_num_reqs + 1,
max_context_len=self.model_config.context_len + 4,
device=device,
)
if (
......
......@@ -14,16 +14,17 @@ if TYPE_CHECKING:
@dataclasses.dataclass
class SamplingBatchInfo:
# Basic Info
vocab_size: int
# Batched sampling params
temperatures: torch.Tensor = None
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
temperatures: torch.Tensor
top_ps: torch.Tensor
top_ks: torch.Tensor
min_ps: torch.Tensor
# Dispatch in CUDA graph
need_min_p_sampling: bool
# Bias Tensors
vocab_size: int
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
......@@ -31,9 +32,6 @@ class SamplingBatchInfo:
regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
......@@ -42,25 +40,30 @@ class SamplingBatchInfo:
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs
ret = cls(vocab_size=vocab_size)
with torch.device("cuda"):
ret.temperatures = torch.tensor(
temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs],
dtype=torch.float,
).view(-1, 1)
ret.top_ps = torch.tensor(
top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float
)
ret.top_ks = torch.tensor(
top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int
)
ret.min_ps = torch.tensor(
min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
)
ret = cls(
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
......
......@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
@app.get("/get_model_info")
async def get_model_info():
"""Get the model information."""
result = {
"model_path": tokenizer_manager.model_path,
"is_generation": tokenizer_manager.is_generation,
......@@ -127,11 +128,13 @@ async def get_model_info():
@app.get("/get_server_args")
async def get_server_args():
"""Get the server arguments."""
return dataclasses.asdict(tokenizer_manager.server_args)
@app.get("/flush_cache")
async def flush_cache():
"""Flush the radix cache."""
tokenizer_manager.flush_cache()
return Response(
content="Cache flushed.\nPlease check backend logs for more details. "
......@@ -142,7 +145,7 @@ async def flush_cache():
@app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server."""
success, message = await tokenizer_manager.update_weights(obj, request)
content = {"success": success, "message": message}
if success:
......@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
async def judge_request(obj: RewardReqInput, request: Request):
"""Handle an embedding request."""
"""Handle a reward model request."""
try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
......@@ -307,7 +310,7 @@ def launch_server(
ports = server_args.additional_ports
port_args = PortArgs(
tokenizer_port=ports[0],
scheduler_port=ports[1],
scheduler_input_port=ports[1],
detokenizer_port=ports[2],
nccl_ports=ports[3:],
)
......
......@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
class PortArgs:
# The port for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port: int
# The port for scheduler to receive inputs from tokenizer (zmq)
scheduler_port: int
# The port for scheduler (rank 0) to receive inputs from tokenizer (zmq)
scheduler_input_port: int
# The port for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port: int
# The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports: List[int]
......
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}')
kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment