Unverified Commit 32eb6e96 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Organize sampling batch info better (#1562)

parent e0b5dbce
...@@ -96,7 +96,9 @@ class Scheduler: ...@@ -96,7 +96,9 @@ class Scheduler:
if self.tp_rank == 0: if self.tp_rank == 0:
self.recv_from_tokenizer = context.socket(zmq.PULL) self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}") self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect( self.send_to_detokenizer.connect(
...@@ -141,9 +143,6 @@ class Scheduler: ...@@ -141,9 +143,6 @@ class Scheduler:
nccl_port=port_args.nccl_ports[0], nccl_port=port_args.nccl_ports[0],
) )
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Get token and memory info from the tp worker # Get token and memory info from the tp worker
( (
...@@ -154,6 +153,9 @@ class Scheduler: ...@@ -154,6 +153,9 @@ class Scheduler:
self.random_seed, self.random_seed,
) = self.tp_worker.get_token_and_memory_info() ) = self.tp_worker.get_token_and_memory_info()
set_random_seed(self.random_seed) set_random_seed(self.random_seed)
self.pad_input_ids_func = getattr(
self.tp_worker.model_runner.model, "pad_input_ids", None
)
# Print debug info # Print debug info
logger.info( logger.info(
......
...@@ -87,7 +87,9 @@ class TokenizerManager: ...@@ -87,7 +87,9 @@ class TokenizerManager:
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}") self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
self.send_to_scheduler = context.socket(zmq.PUSH) self.send_to_scheduler = context.socket(zmq.PUSH)
self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}") self.send_to_scheduler.connect(
f"tcp://127.0.0.1:{port_args.scheduler_input_port}"
)
# Read model args # Read model args
self.model_path = server_args.model_path self.model_path = server_args.model_path
......
...@@ -30,6 +30,7 @@ class ReqToTokenPool: ...@@ -30,6 +30,7 @@ class ReqToTokenPool:
def __init__(self, size: int, max_context_len: int, device: str): def __init__(self, size: int, max_context_len: int, device: str):
self.size = size self.size = size
self.max_context_len = max_context_len
self.free_slots = list(range(size)) self.free_slots = list(range(size))
self.req_to_token = torch.empty( self.req_to_token = torch.empty(
(size, max_context_len), dtype=torch.int32, device=device (size, max_context_len), dtype=torch.int32, device=device
...@@ -54,7 +55,7 @@ class ReqToTokenPool: ...@@ -54,7 +55,7 @@ class ReqToTokenPool:
self.free_slots = list(range(self.size)) self.free_slots = list(range(self.size))
class BaseTokenToKVPool(ABC): class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations""" """A memory pool that maps a token to its kv cache locations"""
def __init__( def __init__(
...@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC): ...@@ -92,19 +93,15 @@ class BaseTokenToKVPool(ABC):
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = np.arange(1, self.size + 1) self.free_slots = np.arange(1, self.size + 1)
@abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor: def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor: def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError() raise NotImplementedError()
@abstractmethod
def set_kv_buffer( def set_kv_buffer(
self, self,
layer_id: int, layer_id: int,
......
...@@ -411,8 +411,8 @@ class ModelRunner: ...@@ -411,8 +411,8 @@ class ModelRunner:
device = "cuda" device = "cuda"
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1, size=max_num_reqs + 1,
self.model_config.context_len + 4, max_context_len=self.model_config.context_len + 4,
device=device, device=device,
) )
if ( if (
......
...@@ -14,16 +14,17 @@ if TYPE_CHECKING: ...@@ -14,16 +14,17 @@ if TYPE_CHECKING:
@dataclasses.dataclass @dataclasses.dataclass
class SamplingBatchInfo: class SamplingBatchInfo:
# Basic Info
vocab_size: int
# Batched sampling params # Batched sampling params
temperatures: torch.Tensor = None temperatures: torch.Tensor
top_ps: torch.Tensor = None top_ps: torch.Tensor
top_ks: torch.Tensor = None top_ks: torch.Tensor
min_ps: torch.Tensor = None min_ps: torch.Tensor
# Dispatch in CUDA graph
need_min_p_sampling: bool
# Bias Tensors # Bias Tensors
vocab_size: int
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None vocab_mask: torch.Tensor = None
...@@ -31,9 +32,6 @@ class SamplingBatchInfo: ...@@ -31,9 +32,6 @@ class SamplingBatchInfo:
regex_fsms: List[RegexGuide] = None regex_fsms: List[RegexGuide] = None
regex_fsm_states: List[int] = None regex_fsm_states: List[int] = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Penalizer # Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None linear_penalties: torch.Tensor = None
...@@ -42,25 +40,30 @@ class SamplingBatchInfo: ...@@ -42,25 +40,30 @@ class SamplingBatchInfo:
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
reqs = batch.reqs reqs = batch.reqs
ret = cls(vocab_size=vocab_size)
with torch.device("cuda"): with torch.device("cuda"):
ret.temperatures = torch.tensor( temperatures = torch.tensor(
[r.sampling_params.temperature for r in reqs], [r.sampling_params.temperature for r in reqs],
dtype=torch.float, dtype=torch.float,
).view(-1, 1) ).view(-1, 1)
ret.top_ps = torch.tensor( top_ps = torch.tensor(
[r.sampling_params.top_p for r in reqs], dtype=torch.float [r.sampling_params.top_p for r in reqs], dtype=torch.float
) )
ret.top_ks = torch.tensor( top_ks = torch.tensor(
[r.sampling_params.top_k for r in reqs], dtype=torch.int [r.sampling_params.top_k for r in reqs], dtype=torch.int
) )
ret.min_ps = torch.tensor( min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float [r.sampling_params.min_p for r in reqs], dtype=torch.float
) )
ret = cls(
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
min_ps=min_ps,
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
vocab_size=vocab_size,
)
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge. # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at # Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
......
...@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response: ...@@ -118,6 +118,7 @@ async def health_generate(request: Request) -> Response:
@app.get("/get_model_info") @app.get("/get_model_info")
async def get_model_info(): async def get_model_info():
"""Get the model information."""
result = { result = {
"model_path": tokenizer_manager.model_path, "model_path": tokenizer_manager.model_path,
"is_generation": tokenizer_manager.is_generation, "is_generation": tokenizer_manager.is_generation,
...@@ -127,11 +128,13 @@ async def get_model_info(): ...@@ -127,11 +128,13 @@ async def get_model_info():
@app.get("/get_server_args") @app.get("/get_server_args")
async def get_server_args(): async def get_server_args():
"""Get the server arguments."""
return dataclasses.asdict(tokenizer_manager.server_args) return dataclasses.asdict(tokenizer_manager.server_args)
@app.get("/flush_cache") @app.get("/flush_cache")
async def flush_cache(): async def flush_cache():
"""Flush the radix cache."""
tokenizer_manager.flush_cache() tokenizer_manager.flush_cache()
return Response( return Response(
content="Cache flushed.\nPlease check backend logs for more details. " content="Cache flushed.\nPlease check backend logs for more details. "
...@@ -142,7 +145,7 @@ async def flush_cache(): ...@@ -142,7 +145,7 @@ async def flush_cache():
@app.post("/update_weights") @app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server."""
success, message = await tokenizer_manager.update_weights(obj, request) success, message = await tokenizer_manager.update_weights(obj, request)
content = {"success": success, "message": message} content = {"success": success, "message": message}
if success: if success:
...@@ -205,7 +208,7 @@ app.put("/encode")(encode_request) ...@@ -205,7 +208,7 @@ app.put("/encode")(encode_request)
async def judge_request(obj: RewardReqInput, request: Request): async def judge_request(obj: RewardReqInput, request: Request):
"""Handle an embedding request.""" """Handle a reward model request."""
try: try:
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
...@@ -307,7 +310,7 @@ def launch_server( ...@@ -307,7 +310,7 @@ def launch_server(
ports = server_args.additional_ports ports = server_args.additional_ports
port_args = PortArgs( port_args = PortArgs(
tokenizer_port=ports[0], tokenizer_port=ports[0],
scheduler_port=ports[1], scheduler_input_port=ports[1],
detokenizer_port=ports[2], detokenizer_port=ports[2],
nccl_ports=ports[3:], nccl_ports=ports[3:],
) )
......
...@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: ...@@ -627,10 +627,11 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
class PortArgs: class PortArgs:
# The port for tokenizer to receive inputs from detokenizer (zmq) # The port for tokenizer to receive inputs from detokenizer (zmq)
tokenizer_port: int tokenizer_port: int
# The port for scheduler to receive inputs from tokenizer (zmq) # The port for scheduler (rank 0) to receive inputs from tokenizer (zmq)
scheduler_port: int scheduler_input_port: int
# The port for detokenizer to receive inputs from scheduler (zmq) # The port for detokenizer to receive inputs from scheduler (zmq)
detokenizer_port: int detokenizer_port: int
# The port for nccl initialization for multiple TP groups (torch.dist) # The port for nccl initialization for multiple TP groups (torch.dist)
nccl_ports: List[int] nccl_ports: List[int]
......
kill -9 $(ps aux | grep 'sglang.launch_server' | grep -v 'grep' | awk '{print $2}') kill -9 $(ps aux | grep 'multiprocessing.spawn' | grep -v 'grep' | awk '{print $2}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment