Unverified Commit 9dae4078 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve type annotation (#1029)

parent fcc0f5ed
...@@ -18,13 +18,15 @@ limitations under the License. ...@@ -18,13 +18,15 @@ limitations under the License.
import random import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from typing import List from typing import Dict, List
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
class PolicyScheduler: class PolicyScheduler:
def __init__(self, policy, tree_cache): def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]: if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled. # LPM and DFS-weight is meaningless when the tree cache is disabled.
policy = "fcfs" policy = "fcfs"
...@@ -72,12 +74,18 @@ class PolicyScheduler: ...@@ -72,12 +74,18 @@ class PolicyScheduler:
else: else:
raise ValueError(f"Unknown schedule_policy: {self.policy}") raise ValueError(f"Unknown schedule_policy: {self.policy}")
def calc_weight(self, cur_node, node_to_weight): def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
for child in cur_node.children.values(): for child in cur_node.children.values():
self.calc_weight(child, node_to_weight) self.calc_weight(child, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child] node_to_weight[cur_node] += node_to_weight[child]
def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q): def get_dfs_priority(
self,
cur_node: TreeNode,
node_to_priority: Dict,
last_node_to_reqs: Dict,
q: List,
):
childs = [child for child in cur_node.children.values()] childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x]) childs.sort(key=lambda x: -node_to_priority[x])
for child in childs: for child in childs:
...@@ -88,10 +96,10 @@ class PolicyScheduler: ...@@ -88,10 +96,10 @@ class PolicyScheduler:
class PrefillAdder: class PrefillAdder:
def __init__( def __init__(
self, self,
tree_cache, tree_cache: BasePrefixCache,
rem_total_tokens, rem_total_tokens: int,
rem_input_tokens, rem_input_tokens: int,
rem_chunk_tokens, rem_chunk_tokens: int,
): ):
self.tree_cache = tree_cache self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens self.rem_total_tokens = rem_total_tokens
...@@ -151,7 +159,7 @@ class PrefillAdder: ...@@ -151,7 +159,7 @@ class PrefillAdder:
return req if truncated else None return req if truncated else None
@contextmanager @contextmanager
def _lock_node(self, last_node): def _lock_node(self, last_node: TreeNode):
try: try:
delta = self.tree_cache.inc_lock_ref(last_node) delta = self.tree_cache.inc_lock_ref(last_node)
self.rem_total_tokens += delta self.rem_total_tokens += delta
......
...@@ -21,15 +21,17 @@ import os ...@@ -21,15 +21,17 @@ import os
import pickle import pickle
import time import time
import warnings import warnings
from typing import List, Optional, Union from typing import Any, List, Optional, Union
import torch import torch
import torch.distributed
import torch.distributed as dist import torch.distributed as dist
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchEmbeddingOut, BatchEmbeddingOut,
...@@ -62,6 +64,10 @@ from sglang.utils import get_exception_traceback ...@@ -62,6 +64,10 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO: Rename "CI" to "SGLANG_IS_IN_CI".
crash_on_warning = os.getenv("CI", "false") == "true"
class ModelTpServer: class ModelTpServer:
def __init__( def __init__(
self, self,
...@@ -198,7 +204,7 @@ class ModelTpServer: ...@@ -198,7 +204,7 @@ class ModelTpServer:
self.new_token_ratio = self.min_new_token_ratio self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay self.new_token_ratio_decay = global_config.new_token_ratio_decay
def exposed_step(self, recv_reqs): def exposed_step(self, recv_reqs: List):
try: try:
# Recv requests # Recv requests
for recv_req in recv_reqs: for recv_req in recv_reqs:
...@@ -247,7 +253,7 @@ class ModelTpServer: ...@@ -247,7 +253,7 @@ class ModelTpServer:
# Print stats # Print stats
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_stats() self.print_decode_stats()
if self.running_batch.is_empty(): if self.running_batch.is_empty():
self.running_batch = None self.running_batch = None
...@@ -259,7 +265,7 @@ class ModelTpServer: ...@@ -259,7 +265,7 @@ class ModelTpServer:
self.check_memory() self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio self.new_token_ratio = global_config.init_new_token_ratio
def print_stats(self): def print_decode_stats(self):
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
...@@ -276,7 +282,6 @@ class ModelTpServer: ...@@ -276,7 +282,6 @@ class ModelTpServer:
) )
def check_memory(self): def check_memory(self):
crash = os.getenv("CI", "false") == "true"
available_size = ( available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
) )
...@@ -286,7 +291,7 @@ class ModelTpServer: ...@@ -286,7 +291,7 @@ class ModelTpServer:
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n" f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!" "KV cache pool leak detected!"
) )
exit(1) if crash else None exit(1) if crash_on_warning else None
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn( warnings.warn(
...@@ -295,7 +300,7 @@ class ModelTpServer: ...@@ -295,7 +300,7 @@ class ModelTpServer:
f"total slots={self.req_to_token_pool.size}\n" f"total slots={self.req_to_token_pool.size}\n"
"Memory pool leak detected!" "Memory pool leak detected!"
) )
exit(1) if crash else None exit(1) if crash_on_warning else None
def handle_generate_request( def handle_generate_request(
self, self,
...@@ -511,7 +516,14 @@ class ModelTpServer: ...@@ -511,7 +516,14 @@ class ModelTpServer:
self.handle_finished_requests(batch) self.handle_finished_requests(batch)
def add_logprob_return_values(self, i, req: Req, pt, next_token_ids, output): def add_logprob_return_values(
self,
i,
req: Req,
pt: int,
next_token_ids: List[int],
output: LogitProcessorOutput,
):
if req.normalized_prompt_logprob is None: if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
...@@ -786,7 +798,11 @@ def run_tp_server( ...@@ -786,7 +798,11 @@ def run_tp_server(
def launch_tp_servers( def launch_tp_servers(
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args gpu_ids: List[int],
tp_rank_range: List[int],
server_args: ServerArgs,
nccl_port: int,
model_overide_args: dict,
): ):
"""Launch multiple tensor parallel servers.""" """Launch multiple tensor parallel servers."""
procs = [] procs = []
...@@ -801,7 +817,9 @@ def launch_tp_servers( ...@@ -801,7 +817,9 @@ def launch_tp_servers(
return procs return procs
def broadcast_recv_input(data, rank, dist_group): def broadcast_recv_input(
data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0: if rank == 0:
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable
class BasePrefixCache(ABC): class BasePrefixCache(ABC):
...@@ -25,7 +26,7 @@ class BasePrefixCache(ABC): ...@@ -25,7 +26,7 @@ class BasePrefixCache(ABC):
pass pass
@abstractmethod @abstractmethod
def evict(self, num_tokens, evict_callback): def evict(self, num_tokens: int, evict_callback: Callable):
pass pass
@abstractmethod @abstractmethod
...@@ -41,7 +42,7 @@ class BasePrefixCache(ABC): ...@@ -41,7 +42,7 @@ class BasePrefixCache(ABC):
pass pass
def total_size(self): def total_size(self):
raise NotImplementedError raise NotImplementedError()
def pretty_print(self): def pretty_print(self):
raise NotImplementedError raise NotImplementedError()
from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Callable
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -15,7 +18,9 @@ class ChunkCacheEntry: ...@@ -15,7 +18,9 @@ class ChunkCacheEntry:
class ChunkCache(BasePrefixCache): class ChunkCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool): def __init__(
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
):
self.disable = True self.disable = True
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool = token_to_kv_pool
...@@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache): ...@@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache):
entry = self.entries[rid] entry = self.entries[rid]
return entry.value, entry return entry.value, entry
def cache_finished_req(self, req: "Req", token_ids=None): def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None: if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1] token_ids = (req.origin_input_ids + req.output_ids)[:-1]
...@@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache): ...@@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache):
if req.rid in self.entries: if req.rid in self.entries:
del self.entries[req.rid] del self.entries[req.rid]
def cache_unfinished_req(self, req: "Req", token_ids=None): def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None: if token_ids is None:
token_ids = req.fill_ids token_ids = req.fill_ids
...@@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache): ...@@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache):
def insert(self): def insert(self):
raise NotImplementedError raise NotImplementedError
def evict(self, num_tokens, evict_callback): def evict(self, num_tokens: int, evict_callback: Callable):
pass pass
def inc_lock_ref(self, node): def inc_lock_ref(self, node):
......
...@@ -16,7 +16,7 @@ limitations under the License. ...@@ -16,7 +16,7 @@ limitations under the License.
"""Memory pool.""" """Memory pool."""
import logging import logging
from typing import List from typing import List, Union
import torch import torch
...@@ -42,7 +42,7 @@ class ReqToTokenPool: ...@@ -42,7 +42,7 @@ class ReqToTokenPool:
return select_index return select_index
def free(self, free_index): def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)): if isinstance(free_index, (int,)):
self.free_slots.append(free_index) self.free_slots.append(free_index)
else: else:
......
from __future__ import annotations
""" """
Copyright 2023-2024 SGLang Team Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -25,6 +27,7 @@ from typing import TYPE_CHECKING ...@@ -25,6 +27,7 @@ from typing import TYPE_CHECKING
import torch import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -43,7 +46,7 @@ class TreeNode: ...@@ -43,7 +46,7 @@ class TreeNode:
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
def _key_match(key0, key1): def _key_match(key0: List, key1: List):
i = 0 i = 0
for k0, k1 in zip(key0, key1): for k0, k1 in zip(key0, key1):
if k0 != k1: if k0 != k1:
...@@ -53,7 +56,12 @@ def _key_match(key0, key1): ...@@ -53,7 +56,12 @@ def _key_match(key0, key1):
class RadixCache(BasePrefixCache): class RadixCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False): def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
disable: bool = False,
):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool self.token_to_kv_pool = token_to_kv_pool
self.disable = disable self.disable = disable
...@@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache): ...@@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache):
self.root_node.lock_ref = 1 self.root_node.lock_ref = 1
self.evictable_size_ = 0 self.evictable_size_ = 0
def match_prefix(self, key, **kwargs): def match_prefix(self, key: List, **kwargs):
if self.disable: if self.disable:
return [], self.root_node return [], self.root_node
...@@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache): ...@@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache):
value = torch.tensor([], dtype=torch.int32) value = torch.tensor([], dtype=torch.int32)
return value, last_node[0] return value, last_node[0]
def insert(self, key, value=None): def insert(self, key: List, value=None):
if self.disable: if self.disable:
return 0 return 0
...@@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache): ...@@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache):
value = [x for x in key] value = [x for x in key]
return self._insert_helper(self.root_node, key, value) return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: "Req", token_ids=None): def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
"""Cache request when it finishes.""" """Cache request when it finishes."""
if token_ids is None: if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1] token_ids = (req.origin_input_ids + req.output_ids)[:-1]
...@@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache): ...@@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: "Req", token_ids=None): def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
"""Cache request when it is unfinished.""" """Cache request when it is unfinished."""
if self.disable: if self.disable:
return return
...@@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache): ...@@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache):
def total_size(self): def total_size(self):
return self._total_size_helper(self.root_node) return self._total_size_helper(self.root_node)
def evict(self, num_tokens, evict_callback): def evict(self, num_tokens: int, evict_callback: Callable):
if self.disable: if self.disable:
return return
...@@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache): ...@@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node): def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
node.last_access_time = time.time() node.last_access_time = time.time()
if len(key) == 0: if len(key) == 0:
return return
...@@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache): ...@@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache):
last_node[0] = child last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node) self._match_prefix_helper(child, key[prefix_len:], value, last_node)
def _split_node(self, key, child: TreeNode, split_len): def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child # new_node -> child
new_node = TreeNode() new_node = TreeNode()
new_node.children = {key[split_len:][0]: child} new_node.children = {key[split_len:][0]: child}
...@@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache): ...@@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache):
new_node.parent.children[key[:split_len][0]] = new_node new_node.parent.children[key[:split_len][0]] = new_node
return new_node return new_node
def _insert_helper(self, node, key, value): def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.time() node.last_access_time = time.time()
if len(key) == 0: if len(key) == 0:
return 0 return 0
...@@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache): ...@@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache):
self.evictable_size_ += len(value) self.evictable_size_ += len(value)
return 0 return 0
def _print_helper(self, node: TreeNode, indent): def _print_helper(self, node: TreeNode, indent: int):
for _, child in node.children.items(): for _, child in node.children.items():
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}") print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2) self._print_helper(child, indent=indent + 2)
...@@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache): ...@@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache):
del node.parent.children[k] del node.parent.children[k]
self.evictable_size_ -= len(node.key) self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node): def _total_size_helper(self, node: TreeNode):
x = len(node.value) x = len(node.value)
for child in node.children.values(): for child in node.children.values():
x += self._total_size_helper(child) x += self._total_size_helper(child)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment