Unverified Commit 05c9bc89 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

[minor] simplify the `TokenToKVPoolAllocator` (#7414)

parent b7a2df0a
...@@ -21,13 +21,11 @@ Life cycle of a request in the decode server ...@@ -21,13 +21,11 @@ Life cycle of a request in the decode server
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import ( ...@@ -47,12 +45,9 @@ from sglang.srt.disaggregation.utils import (
prepare_abort, prepare_abort,
) )
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
KVCache,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import require_mlp_sync from sglang.srt.utils import require_mlp_sync
...@@ -141,7 +136,7 @@ class DecodePreallocQueue: ...@@ -141,7 +136,7 @@ class DecodePreallocQueue:
def __init__( def __init__(
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
draft_token_to_kv_pool: Optional[KVCache], draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: MetadataBuffers, metadata_buffers: MetadataBuffers,
......
...@@ -25,7 +25,6 @@ from collections import deque ...@@ -25,7 +25,6 @@ from collections import deque
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, List, Optional
import numpy as np
import torch import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
......
...@@ -18,12 +18,13 @@ import logging ...@@ -18,12 +18,13 @@ import logging
import math import math
import threading import threading
from queue import Empty, Full, PriorityQueue, Queue from queue import Empty, Full, PriorityQueue, Queue
from typing import List, Optional from typing import TYPE_CHECKING, List, Optional
import torch import torch
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -163,7 +164,7 @@ class HiCacheController: ...@@ -163,7 +164,7 @@ class HiCacheController:
def __init__( def __init__(
self, self,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
mem_pool_host: HostKVCache, mem_pool_host: HostKVCache,
page_size: int, page_size: int,
load_cache_event: threading.Event = None, load_cache_event: threading.Event = None,
......
...@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( ...@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
) )
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.layers.multimodal import gpu_tensor_hash
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.metrics.collector import TimeStats from sglang.srt.metrics.collector import TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
...@@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -810,7 +811,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Request, memory pool, and cache # Request, memory pool, and cache
reqs: List[Req] reqs: List[Req]
req_to_token_pool: ReqToTokenPool = None req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
tree_cache: BasePrefixCache = None tree_cache: BasePrefixCache = None
# Batch configs # Batch configs
...@@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -907,7 +908,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
cls, cls,
reqs: List[Req], reqs: List[Req],
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
model_config: ModelConfig, model_config: ModelConfig,
enable_overlap: bool, enable_overlap: bool,
......
from __future__ import annotations
# Copyright 2023-2024 SGLang Team # Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,15 +20,17 @@ import random ...@@ -18,15 +20,17 @@ import random
from collections import defaultdict from collections import defaultdict
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum, auto from enum import Enum, auto
from typing import Dict, List, Optional, Set, Union from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
import torch import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large. # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
# This can prevent the server from being too conservative. # This can prevent the server from being too conservative.
# Note that this only clips the estimation in the scheduler but does not change the stop # Note that this only clips the estimation in the scheduler but does not change the stop
...@@ -265,7 +269,7 @@ class PrefillAdder: ...@@ -265,7 +269,7 @@ class PrefillAdder:
self, self,
page_size: int, page_size: int,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
running_batch: ScheduleBatch, running_batch: ScheduleBatch,
new_token_ratio: float, new_token_ratio: float,
rem_input_tokens: int, rem_input_tokens: int,
......
...@@ -23,7 +23,6 @@ import time ...@@ -23,7 +23,6 @@ import time
from collections import defaultdict, deque from collections import defaultdict, deque
from concurrent import futures from concurrent import futures
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
......
...@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
...@@ -57,7 +58,7 @@ class TpModelWorker: ...@@ -57,7 +58,7 @@ class TpModelWorker:
nccl_port: int, nccl_port: int,
is_draft_worker: bool = False, is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None, req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
): ):
# Parse args # Parse args
self.tp_size = server_args.tp_size self.tp_size = server_args.tp_size
......
from __future__ import annotations
""" """
Copyright 2025 SGLang Team Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
...@@ -17,13 +19,132 @@ limitations under the License. ...@@ -17,13 +19,132 @@ limitations under the License.
Page-aligned memory pool. Page-aligned memory pool.
""" """
import abc
from typing import TYPE_CHECKING
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_bool_env_var, next_power_of_2 from sglang.srt.utils import get_bool_env_var, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
class BaseTokenToKVPoolAllocator(abc.ABC):
@abc.abstractmethod
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
self._kvcache = kvcache
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
def debug_print(self) -> str:
return ""
def available_size(self):
return len(self.free_pages) * self.page_size
def get_kvcache(self):
return self._kvcache
def restore_state(self, free_pages):
self.free_pages = free_pages
def backup_state(self):
return self.free_pages
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def get_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
raise NotImplementedError()
def load_cpu_copy(self, *args, **kwargs):
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
raise NotImplementedError()
def alloc_extend(self, *args, **kwargs):
raise NotImplementedError("alloc_extend is only for paged allocator")
def alloc_decode(self, *args, **kwargs):
raise NotImplementedError("alloc_decode is only for paged allocator")
@abc.abstractmethod
def clear(self):
raise NotImplementedError()
@abc.abstractmethod
def alloc(self, need_size: int):
raise NotImplementedError()
@abc.abstractmethod
def free(self, free_index: torch.Tensor):
raise NotImplementedError()
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
"""An allocator managing the indices to kv cache data."""
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
super().__init__(size, 1, dtype, device, kvcache)
self.clear()
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []
def available_size(self):
# To avoid minor "len(free_pages) * 1" overhead
return len(self.free_pages)
def alloc(self, need_size: int):
if need_size > len(self.free_pages):
return None
select_index = self.free_pages[:need_size]
self.free_pages = self.free_pages[need_size:]
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_pages = torch.cat((self.free_pages, free_index))
else:
self.free_group.append(free_index)
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
@triton.jit @triton.jit
def alloc_extend_kernel( def alloc_extend_kernel(
...@@ -154,7 +275,7 @@ def alloc_decode_kernel( ...@@ -154,7 +275,7 @@ def alloc_decode_kernel(
tl.store(out_indices + pid, page * page_size) tl.store(out_indices + pid, page * page_size)
class PagedTokenToKVPoolAllocator: class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
""" """
An allocator managing the indices to kv cache data. An allocator managing the indices to kv cache data.
...@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator: ...@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
device: str, device: str,
kvcache: KVCache, kvcache: KVCache,
): ):
self.size = size super().__init__(size, page_size, dtype, device, kvcache)
self.dtype = dtype
self.device = device
self.page_size = page_size
self.num_pages = size // page_size self.num_pages = size // page_size
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL") self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self._kvcache = kvcache
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device) self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
self.clear()
def available_size(self):
return len(self.free_pages) * self.page_size
def get_kvcache(self):
return self._kvcache
def alloc(self, need_size: int): def alloc(self, need_size: int):
# page-aligned allocation, returning contiguous indices of pages # page-aligned allocation, returning contiguous indices of pages
...@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator: ...@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(self.free_pages)) == len(self.free_pages) assert len(torch.unique(self.free_pages)) == len(self.free_pages)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_pages
def restore_state(self, free_pages):
self.free_pages = free_pages
def clear(self): def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange( self.free_pages = torch.arange(
......
...@@ -2,12 +2,13 @@ from __future__ import annotations ...@@ -2,12 +2,13 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Any, Callable, List, Tuple from typing import TYPE_CHECKING, Any
import torch import torch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache): ...@@ -17,7 +18,7 @@ class ChunkCache(BasePrefixCache):
def __init__( def __init__(
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int, page_size: int,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
......
...@@ -7,12 +7,12 @@ from typing import List, Optional ...@@ -7,12 +7,12 @@ from typing import List, Optional
import torch import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import MatchResult from sglang.srt.mem_cache.base_prefix_cache import MatchResult
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
TokenToKVPoolAllocator,
) )
from sglang.srt.mem_cache.memory_pool_host import ( from sglang.srt.mem_cache.memory_pool_host import (
MHATokenToKVPoolHost, MHATokenToKVPoolHost,
...@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache): ...@@ -28,7 +28,7 @@ class HiRadixCache(RadixCache):
def __init__( def __init__(
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
tp_cache_group: torch.distributed.ProcessGroup, tp_cache_group: torch.distributed.ProcessGroup,
page_size: int, page_size: int,
hicache_ratio: float, hicache_ratio: float,
......
...@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache. ...@@ -26,7 +26,6 @@ KVCache actually holds the physical kv cache.
import abc import abc
import logging import logging
import os
from contextlib import nullcontext from contextlib import nullcontext
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -167,84 +166,6 @@ class KVCache(abc.ABC): ...@@ -167,84 +166,6 @@ class KVCache(abc.ABC):
raise NotImplementedError() raise NotImplementedError()
class TokenToKVPoolAllocator:
"""An allocator managing the indices to kv cache data."""
def __init__(
self,
size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = 1
self.free_slots = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self._kvcache = kvcache
def available_size(self):
return len(self.free_slots)
def debug_print(self) -> str:
return ""
def get_kvcache(self):
return self._kvcache
def alloc(self, need_size: int):
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_slots = torch.cat((self.free_slots, free_index))
else:
self.free_group.append(free_index)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.cat(self.free_group))
def backup_state(self):
return self.free_slots
def restore_state(self, free_slots):
self.free_slots = free_slots
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_not_in_free_group = True
self.free_group = []
def get_cpu_copy(self, indices):
return self._kvcache.get_cpu_copy(indices)
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
class MHATokenToKVPool(KVCache): class MHATokenToKVPool(KVCache):
def __init__( def __init__(
......
...@@ -23,7 +23,7 @@ import heapq ...@@ -23,7 +23,7 @@ import heapq
import time import time
from collections import defaultdict from collections import defaultdict
from functools import partial from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional
import torch import torch
...@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import ( ...@@ -31,10 +31,10 @@ from sglang.srt.disaggregation.kv_events import (
AllBlocksCleared, AllBlocksCleared,
BlockRemoved, BlockRemoved,
BlockStored, BlockStored,
KVCacheEvent,
) )
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req from sglang.srt.managers.schedule_batch import Req
...@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache): ...@@ -98,7 +98,7 @@ class RadixCache(BasePrefixCache):
def __init__( def __init__(
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int, page_size: int,
disable: bool = False, disable: bool = False,
enable_kv_cache_events: bool = False, enable_kv_cache_events: bool = False,
......
...@@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -71,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
GLOBAL_SERVER_ARGS_KEYS, GLOBAL_SERVER_ARGS_KEYS,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
DoubleSparseTokenToKVPool, DoubleSparseTokenToKVPool,
MHATokenToKVPool, MHATokenToKVPool,
MLATokenToKVPool, MLATokenToKVPool,
ReqToTokenPool, ReqToTokenPool,
TokenToKVPoolAllocator,
) )
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
...@@ -152,7 +155,7 @@ class ModelRunner: ...@@ -152,7 +155,7 @@ class ModelRunner:
server_args: ServerArgs, server_args: ServerArgs,
is_draft_worker: bool = False, is_draft_worker: bool = False,
req_to_token_pool: Optional[ReqToTokenPool] = None, req_to_token_pool: Optional[ReqToTokenPool] = None,
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
): ):
# Parse args # Parse args
self.model_config = model_config self.model_config = model_config
......
...@@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -21,7 +21,7 @@ from sglang.srt.managers.schedule_batch import (
get_last_loc, get_last_loc,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
...@@ -315,7 +315,7 @@ class EagleVerifyInput: ...@@ -315,7 +315,7 @@ class EagleVerifyInput:
self, self,
batch: ScheduleBatch, batch: ScheduleBatch,
logits_output: torch.Tensor, logits_output: torch.Tensor,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
page_size: int, page_size: int,
vocab_mask: Optional[torch.Tensor] = None, # For grammar vocab_mask: Optional[torch.Tensor] = None, # For grammar
) -> torch.Tensor: ) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment