"src/vscode:/vscode.git/clone" did not exist on "303efd2b8d9949758004158c89b6c29a816150a2"
Unverified Commit 05c9bc89 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[minor] simplify the `TokenToKVPoolAllocator` (#7414)

parent b7a2df0a
......@@ -21,13 +21,11 @@ Life cycle of a request in the decode server
from __future__ import annotations
import logging
import os
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.distributed import ProcessGroup
......@@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import (
prepare_abort,
)
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import (
KVCache,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import require_mlp_sync
......@@ -141,7 +136,7 @@ class DecodePreallocQueue:
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: MetadataBuffers,
......
......@@ -25,7 +25,6 @@ from collections import deque
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
......
......@@ -18,12 +18,13 @@ import logging
import math
import threading
from queue import Empty, Full, PriorityQueue, Queue
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
import torch
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__)
......@@ -163,7 +164,7 @@ class HiCacheController:
def __init__(
self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
mem_pool_host: HostKVCache,
page_size: int,
load_cache_event: threading.Event = None,
......
......@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
)
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
......@@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None
# Batch configs
......@@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
cls,
reqs: List[Req],
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tree_cache: BasePrefixCache,
model_config: ModelConfig,
enable_overlap: bool,
......
from __future__ import annotations
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,15 +20,17 @@ import random
from collections import defaultdict
from contextlib import contextmanager
from enum import Enum, auto
from typing import Dict, List, Optional, Set, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop
......@@ -265,7 +269,7 @@ class PrefillAdder:
self,
page_size: int,
tree_cache: BasePrefixCache,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
running_batch: ScheduleBatch,
new_token_ratio: float,
rem_input_tokens: int,
......
......@@ -23,7 +23,6 @@ import time
from collections import defaultdict, deque
from concurrent import futures
from dataclasses import dataclass
from http import HTTPStatus
from pathlib import Path
from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union
......
......@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
......@@ -57,7 +58,7 @@ class TpModelWorker:
nccl_port: int,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
):
# Parse args
self.tp_size = server_args.tp_size
......
from __future__ import annotations
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,13 +19,132 @@ limitations under the License.
Page-aligned memory pool.
"""
import abc
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_bool_env_var, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
class BaseTokenToKVPoolAllocator(abc.ABC):
@abc.abstractmethod
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
self._kvcache = kvcache
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
def debug_print(self) -> str:
return ""
def available_size(self):
return len(self.free_pages) * self.page_size
def get_kvcache(self):
return self._kvcache
def restore_state(self, free_pages):
self.free_pages = free_pages
def backup_state(self):
return self.free_pages
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def get_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise NotImplementedError()
def load_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
raise NotImplementedError()
def alloc_extend(self, *args, **kwargs):
raise NotImplementedError("alloc_extend is only for paged allocator")
def alloc_decode(self, *args, **kwargs):
raise NotImplementedError("alloc_decode is only for paged allocator")
@abc.abstractmethod
def clear(self):
raise NotImplementedError()
@abc.abstractmethod
def alloc(self, need_size: int):
raise NotImplementedError()
@abc.abstractmethod
def free(self, free_index: torch.Tensor):
raise NotImplementedError()
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""An allocator managing the indices to kv cache data."""
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
super().__init__(size, 1, dtype, device, kvcache)
self.clear()
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []
def available_size(self):
# To avoid minor "len(free_pages) * 1" overhead
return len(self.free_pages)
def alloc(self, need_size: int):
if need_size > len(self.free_pages):
return None
select_index = self.free_pages[:need_size]
self.free_pages = self.free_pages[need_size:]
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_pages = torch.cat((self.free_pages, free_index))
else:
self.free_group.append(free_index)
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
@triton.jit
def alloc_extend_kernel(
......@@ -154,7 +275,7 @@ def alloc_decode_kernel(
tl.store(out_indices + pid, page * page_size)
class PagedTokenToKVPoolAllocator:
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""
An allocator managing the indices to kv cache data.
......@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = page_size
super().__init__(size, page_size, dtype, device, kvcache)
self.num_pages = size // page_size
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self._kvcache = kvcache
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
def available_size(self):
return len(self.free_pages) * self.page_size
def get_kvcache(self):
return self._kvcache
self.clear()
def alloc(self, need_size: int):
# page-aligned allocation, returning contiguous indices of pages
......@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
if self.debug_mode:
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_pages
def restore_state(self, free_pages):
self.free_pages = free_pages
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
......
......@@ -2,12 +2,13 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
from typing import TYPE_CHECKING, Any
import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
......@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
):
self.req_to_token_pool = req_to_token_pool
......
......@@ -7,12 +7,12 @@ from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost,
......@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup,
page_size: int,
hicache_ratio: float,
......
......@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
import abc
import logging
import os
from contextlib import nullcontext
from typing import List, Optional, Tuple, Union
......@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
raise NotImplementedError()
class TokenToKVPoolAllocator:
"""An allocator managing the indices to kv cache data."""
def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = 1
self.free_slots = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self._kvcache = kvcache
def available_size(self):
return len(self.free_slots)
def debug_print(self) -> str:
return ""
def get_kvcache(self):
return self._kvcache
def alloc(self, need_size: int):
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_slots = torch.cat((self.free_slots, free_index))
else:
self.free_group.append(free_index)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_slots
def restore_state(self, free_slots):
self.free_slots = free_slots
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class MHATokenToKVPool(KVCache):
def __init__(
......
......@@ -23,7 +23,7 @@ import heapq
import time
from collections import defaultdict
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional
import torch
......@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
AllBlocksCleared,
BlockRemoved,
BlockStored,
KVCacheEvent,
)
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
......@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
enable_kv_cache_events: bool = False,
......
......@@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool import (
DoubleSparseTokenToKVPool,
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
......@@ -152,7 +155,7 @@ class ModelRunner:
server_args: ServerArgs,
is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
):
# Parse args
self.model_config = model_config
......
......@@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import (
get_last_loc,
global_server_args_dict,
)
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
......@@ -315,7 +315,7 @@ class EagleVerifyInput:
self,
batch: ScheduleBatch,
logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int,
vocab_mask: Optional[torch.Tensor] = None, # For grammar
) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment