"G" did not exist on "b89700812254c0d6a68e848362a13ff15a28ade2"
Unverified Commit 9dae4078 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve type annotation (#1029)

parent fcc0f5ed
......@@ -18,13 +18,15 @@ limitations under the License.
import random
from collections import defaultdict
from contextlib import contextmanager
from typing import List
from typing import Dict, List
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.radix_cache import TreeNode
class PolicyScheduler:
def __init__(self, policy, tree_cache):
def __init__(self, policy: str, tree_cache: BasePrefixCache):
if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
# LPM and DFS-weight is meaningless when the tree cache is disabled.
policy = "fcfs"
......@@ -72,12 +74,18 @@ class PolicyScheduler:
else:
raise ValueError(f"Unknown schedule_policy: {self.policy}")
def calc_weight(self, cur_node, node_to_weight):
def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
for child in cur_node.children.values():
self.calc_weight(child, node_to_weight)
node_to_weight[cur_node] += node_to_weight[child]
def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
def get_dfs_priority(
self,
cur_node: TreeNode,
node_to_priority: Dict,
last_node_to_reqs: Dict,
q: List,
):
childs = [child for child in cur_node.children.values()]
childs.sort(key=lambda x: -node_to_priority[x])
for child in childs:
......@@ -88,10 +96,10 @@ class PolicyScheduler:
class PrefillAdder:
def __init__(
self,
tree_cache,
rem_total_tokens,
rem_input_tokens,
rem_chunk_tokens,
tree_cache: BasePrefixCache,
rem_total_tokens: int,
rem_input_tokens: int,
rem_chunk_tokens: int,
):
self.tree_cache = tree_cache
self.rem_total_tokens = rem_total_tokens
......@@ -151,7 +159,7 @@ class PrefillAdder:
return req if truncated else None
@contextmanager
def _lock_node(self, last_node):
def _lock_node(self, last_node: TreeNode):
try:
delta = self.tree_cache.inc_lock_ref(last_node)
self.rem_total_tokens += delta
......
......@@ -21,15 +21,17 @@ import os
import pickle
import time
import warnings
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
import torch
import torch.distributed
import torch.distributed as dist
from sglang.global_config import global_config
from sglang.srt.constrained.fsm_cache import FSMCache
from sglang.srt.constrained.jump_forward import JumpForwardCache
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
BatchEmbeddingOut,
......@@ -62,6 +64,10 @@ from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
# TODO: Rename "CI" to "SGLANG_IS_IN_CI".
crash_on_warning = os.getenv("CI", "false") == "true"
class ModelTpServer:
def __init__(
self,
......@@ -198,7 +204,7 @@ class ModelTpServer:
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
def exposed_step(self, recv_reqs):
def exposed_step(self, recv_reqs: List):
try:
# Recv requests
for recv_req in recv_reqs:
......@@ -247,7 +253,7 @@ class ModelTpServer:
# Print stats
if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
self.print_stats()
self.print_decode_stats()
if self.running_batch.is_empty():
self.running_batch = None
......@@ -259,7 +265,7 @@ class ModelTpServer:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
def print_stats(self):
def print_decode_stats(self):
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
......@@ -276,7 +282,6 @@ class ModelTpServer:
)
def check_memory(self):
crash = os.getenv("CI", "false") == "true"
available_size = (
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
)
......@@ -286,7 +291,7 @@ class ModelTpServer:
f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
"KV cache pool leak detected!"
)
exit(1) if crash else None
exit(1) if crash_on_warning else None
if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
warnings.warn(
......@@ -295,7 +300,7 @@ class ModelTpServer:
f"total slots={self.req_to_token_pool.size}\n"
"Memory pool leak detected!"
)
exit(1) if crash else None
exit(1) if crash_on_warning else None
def handle_generate_request(
self,
......@@ -511,7 +516,14 @@ class ModelTpServer:
self.handle_finished_requests(batch)
def add_logprob_return_values(self, i, req: Req, pt, next_token_ids, output):
def add_logprob_return_values(
self,
i,
req: Req,
pt: int,
next_token_ids: List[int],
output: LogitProcessorOutput,
):
if req.normalized_prompt_logprob is None:
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
......@@ -786,7 +798,11 @@ def run_tp_server(
def launch_tp_servers(
gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
gpu_ids: List[int],
tp_rank_range: List[int],
server_args: ServerArgs,
nccl_port: int,
model_overide_args: dict,
):
"""Launch multiple tensor parallel servers."""
procs = []
......@@ -801,7 +817,9 @@ def launch_tp_servers(
return procs
def broadcast_recv_input(data, rank, dist_group):
def broadcast_recv_input(
data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0:
......
from abc import ABC, abstractmethod
from typing import Callable
class BasePrefixCache(ABC):
......@@ -25,7 +26,7 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def evict(self, num_tokens, evict_callback):
def evict(self, num_tokens: int, evict_callback: Callable):
pass
@abstractmethod
......@@ -41,7 +42,7 @@ class BasePrefixCache(ABC):
pass
def total_size(self):
raise NotImplementedError
raise NotImplementedError()
def pretty_print(self):
raise NotImplementedError
raise NotImplementedError()
from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
......@@ -15,7 +18,9 @@ class ChunkCacheEntry:
class ChunkCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool):
def __init__(
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
......@@ -32,7 +37,7 @@ class ChunkCache(BasePrefixCache):
entry = self.entries[rid]
return entry.value, entry
def cache_finished_req(self, req: "Req", token_ids=None):
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
......@@ -45,7 +50,7 @@ class ChunkCache(BasePrefixCache):
if req.rid in self.entries:
del self.entries[req.rid]
def cache_unfinished_req(self, req: "Req", token_ids=None):
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None:
token_ids = req.fill_ids
......@@ -64,7 +69,7 @@ class ChunkCache(BasePrefixCache):
def insert(self):
raise NotImplementedError
def evict(self, num_tokens, evict_callback):
def evict(self, num_tokens: int, evict_callback: Callable):
pass
def inc_lock_ref(self, node):
......
......@@ -16,7 +16,7 @@ limitations under the License.
"""Memory pool."""
import logging
from typing import List
from typing import List, Union
import torch
......@@ -42,7 +42,7 @@ class ReqToTokenPool:
return select_index
def free(self, free_index):
def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
......
from __future__ import annotations
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -25,6 +27,7 @@ from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
......@@ -43,7 +46,7 @@ class TreeNode:
return self.last_access_time < other.last_access_time
def _key_match(key0, key1):
def _key_match(key0: List, key1: List):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
......@@ -53,7 +56,12 @@ def _key_match(key0, key1):
class RadixCache(BasePrefixCache):
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool: BaseTokenToKVPool,
disable: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool = token_to_kv_pool
self.disable = disable
......@@ -68,7 +76,7 @@ class RadixCache(BasePrefixCache):
self.root_node.lock_ref = 1
self.evictable_size_ = 0
def match_prefix(self, key, **kwargs):
def match_prefix(self, key: List, **kwargs):
if self.disable:
return [], self.root_node
......@@ -81,7 +89,7 @@ class RadixCache(BasePrefixCache):
value = torch.tensor([], dtype=torch.int32)
return value, last_node[0]
def insert(self, key, value=None):
def insert(self, key: List, value=None):
if self.disable:
return 0
......@@ -89,7 +97,7 @@ class RadixCache(BasePrefixCache):
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: "Req", token_ids=None):
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
"""Cache request when it finishes."""
if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
......@@ -110,7 +118,7 @@ class RadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: "Req", token_ids=None):
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
"""Cache request when it is unfinished."""
if self.disable:
return
......@@ -145,7 +153,7 @@ class RadixCache(BasePrefixCache):
def total_size(self):
return self._total_size_helper(self.root_node)
def evict(self, num_tokens, evict_callback):
def evict(self, num_tokens: int, evict_callback: Callable):
if self.disable:
return
......@@ -199,7 +207,9 @@ class RadixCache(BasePrefixCache):
##### Internal Helper Functions #####
def _match_prefix_helper(self, node, key, value, last_node):
def _match_prefix_helper(
self, node: TreeNode, key: List, value, last_node: TreeNode
):
node.last_access_time = time.time()
if len(key) == 0:
return
......@@ -216,7 +226,7 @@ class RadixCache(BasePrefixCache):
last_node[0] = child
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
def _split_node(self, key, child: TreeNode, split_len):
def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len:][0]: child}
......@@ -230,7 +240,7 @@ class RadixCache(BasePrefixCache):
new_node.parent.children[key[:split_len][0]] = new_node
return new_node
def _insert_helper(self, node, key, value):
def _insert_helper(self, node: TreeNode, key: List, value):
node.last_access_time = time.time()
if len(key) == 0:
return 0
......@@ -261,7 +271,7 @@ class RadixCache(BasePrefixCache):
self.evictable_size_ += len(value)
return 0
def _print_helper(self, node: TreeNode, indent):
def _print_helper(self, node: TreeNode, indent: int):
for _, child in node.children.items():
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
self._print_helper(child, indent=indent + 2)
......@@ -273,7 +283,7 @@ class RadixCache(BasePrefixCache):
del node.parent.children[k]
self.evictable_size_ -= len(node.key)
def _total_size_helper(self, node):
def _total_size_helper(self, node: TreeNode):
x = len(node.value)
for child in node.children.values():
x += self._total_size_helper(child)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment