Unverified Commit c76040e3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support page size > 1 (#4356)

parent 2f6bacee
......@@ -36,7 +36,7 @@ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda()
if _is_cuda:
import deep_gemm
import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
logger = logging.getLogger(__name__)
......
......@@ -77,7 +77,7 @@ class SchedulePolicy:
self,
policy: str,
tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool = False,
enable_hierarchical_cache: bool,
):
self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache
......@@ -85,10 +85,17 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=1,
disable=False,
)
def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS:
# A shortcut for FCFS
return
policy = self._determine_active_policy(waiting_queue)
prefix_computed = False
......@@ -118,7 +125,7 @@ class SchedulePolicy:
return prefix_computed
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM:
if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128:
# Turn off the expensive prefix matching and sorting when the #queue is large.
return CacheAgnosticPolicy.FCFS
return self.policy
......@@ -442,7 +449,7 @@ class PrefillAdder:
def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
):
if req.sampling_params.ignore_eos and self.tree_cache.disable:
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min(
......
......@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
......@@ -103,6 +103,7 @@ from sglang.srt.utils import (
crash_on_warnings,
get_bool_env_var,
get_zmq_socket,
kill_itself_when_parent_died,
pyspy_dump_schedulers,
set_gpu_proc_affinity,
set_random_seed,
......@@ -159,6 +160,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.page_size = server_args.page_size
# Distributed rank info
self.dp_size = server_args.dp_size
......@@ -265,20 +267,23 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"context_len={self.model_config.context_len}"
)
# Init memory pool and cache
self.init_memory_pool_and_cache()
# Init running status
self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None
self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None
# The current forward batch
# The last forward batch
self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0
self.forward_ct_decode = 0
self.num_generated_tokens = 0
self.num_prefill_tokens = 0
self.last_decode_stats_tic = time.time()
self.last_prefill_stats_tic = time.time()
self.return_health_check_ct = 0
self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu":
......@@ -307,7 +312,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Init schedule policy and new token estimation
self.policy = SchedulePolicy(
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache
self.schedule_policy,
self.tree_cache,
self.enable_hierarchical_cache,
)
assert (
server_args.schedule_conservativeness >= 0
......@@ -327,11 +334,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio
# Tell whether the current running batch is full so that we can skip
# the check of whether to prefill new requests.
# This is an optimization to reduce the overhead of the prefill check.
self.batch_is_full = False
# Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True)
......@@ -437,6 +439,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
......@@ -458,6 +461,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0
......@@ -487,7 +491,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
result = self.run_batch(batch)
self.process_batch_result(batch, result)
else:
# When the server is idle, so self-check and re-init some states
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
......@@ -527,7 +531,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# When the server is idle, so self-check and re-init some states
# When the server is idle, do self-check and re-init some states
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
......@@ -588,7 +592,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
for recv_req in recv_reqs:
# If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req(recv_req) and (
self.chunked_req is not None or self.running_batch is not None
self.chunked_req is not None or not self.running_batch.is_empty()
):
self.return_health_check_ct += 1
continue
......@@ -812,6 +816,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
can_run_list: List[Req],
running_bs: int,
):
gap_latency = time.time() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.time()
self.last_input_throughput = self.num_prefill_tokens / gap_latency
self.num_prefill_tokens = 0
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
......@@ -847,7 +856,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.last_decode_stats_tic = time.time()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_running_reqs = len(self.running_batch.reqs)
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
......@@ -911,8 +920,10 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
if memory_leak:
msg = (
"KV cache pool leak detected!"
"KV cache pool leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
)
warnings.warn(msg)
if crash_on_warnings():
......@@ -938,7 +949,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens
......@@ -956,20 +967,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.tree_cache.cache_unfinished_req(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.batch_is_full = False
self.running_batch.batch_is_full = False
# Filter batch
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch()
if self.last_batch.batch_size() < last_bs:
self.batch_is_full = False
self.running_batch.batch_is_full = False
# Merge the new batch into the running batch
if not self.last_batch.is_empty():
if self.running_batch is None:
if self.running_batch.is_empty():
self.running_batch = self.last_batch
else:
# merge running_batch with prefill batch
# Merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill()
......@@ -978,11 +989,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
ret = new_batch
else:
# Run decode
if self.running_batch is None:
ret = None
else:
if not self.running_batch.is_empty():
self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch
ret = self.running_batch if not self.running_batch.is_empty() else None
else:
ret = None
# Handle DP attention
if self.server_args.enable_dp_attention:
......@@ -997,13 +1008,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Handle the cases where prefill is not allowed
if (
self.batch_is_full or len(self.waiting_queue) == 0
self.running_batch.batch_is_full or len(self.waiting_queue) == 0
) and self.chunked_req is None:
return None
running_bs = len(self.running_batch.reqs) if self.running_batch else 0
running_bs = len(self.running_batch.reqs)
if running_bs >= self.max_running_requests:
self.batch_is_full = True
self.running_batch.batch_is_full = True
return None
if self.enable_hierarchical_cache:
......@@ -1025,17 +1036,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
running_bs if self.is_mixed_chunk else 0,
)
is_chunked = self.chunked_req is not None
if is_chunked:
if self.chunked_req is not None:
self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.lora_paths:
lora_set = (
set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None
else set([])
)
lora_set = set([req.lora_path for req in self.running_batch.reqs])
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
......@@ -1047,11 +1054,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
> self.max_loras_per_batch
):
self.batch_is_full = True
self.running_batch.batch_is_full = True
break
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
self.batch_is_full = True
self.running_batch.batch_is_full = True
break
req.init_next_round_input(
......@@ -1066,12 +1073,14 @@ class Scheduler(SchedulerOutputProcessorMixin):
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served
self.batch_is_full = len(adder.can_run_list) > 0 or (
self.running_batch.batch_is_full = len(
adder.can_run_list
) > 0 or (
self.running_batch is not None
and not self.running_batch.is_empty()
)
else:
self.batch_is_full = True
self.running_batch.batch_is_full = True
break
# Update waiting queue
......@@ -1112,7 +1121,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Mixed-style chunked prefill
if (
self.is_mixed_chunk
and self.running_batch is not None
and not self.running_batch.is_empty()
and not (new_batch.return_logprob or self.running_batch.return_logprob)
):
# TODO (lianmin): support return_logprob + mixed chunked prefill
......@@ -1121,7 +1130,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None
self.running_batch = ScheduleBatch(
reqs=[], batch_is_full=self.running_batch.batch_is_full
)
else:
new_batch.decoding_reqs = None
......@@ -1133,8 +1144,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
batch.filter_batch()
if batch.is_empty():
self.batch_is_full = False
return None
batch.batch_is_full = False
return batch
# Check if decode out of memory
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
......@@ -1158,7 +1169,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
)
if batch.batch_size() < initial_bs:
self.batch_is_full = False
batch.batch_is_full = False
# Update batch tensors
batch.prepare_for_decode()
......@@ -1233,8 +1244,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle():
......@@ -1375,9 +1384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
self.cur_batch = None
self.last_batch = None
self.tree_cache.reset()
......@@ -1403,7 +1410,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
logging.warning(
f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
f"#running-req: {len(self.running_batch.reqs)}"
)
if_success = False
return if_success
......@@ -1453,24 +1460,24 @@ class Scheduler(SchedulerOutputProcessorMixin):
def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue
to_del = None
to_del = []
for i, req in enumerate(self.waiting_queue):
if req.rid == recv_req.rid:
to_del = i
if req.rid.startswith(recv_req.rid):
to_del.append(i)
break
if to_del is not None:
del self.waiting_queue[to_del]
# Sort in reverse order to avoid index issues when deleting
for i in sorted(to_del, reverse=True):
req = self.waiting_queue.pop(i)
logger.debug(f"Abort queued request. {req.rid=}")
return
# Delete requests in the running batch
if self.running_batch:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid and not req.finished():
logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True
break
for req in self.running_batch.reqs:
if req.rid.startswith(recv_req.rid) and not req.finished():
logger.debug(f"Abort running request. {req.rid=}")
req.to_abort = True
return
def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError()
......
......@@ -204,8 +204,17 @@ class SchedulerOutputProcessorMixin:
continue
if self.enable_overlap and req.finished():
# Free the one delayed token
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
# Free the one extra delayed token
if self.page_size == 1:
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
else:
# Only free when the extra token is in a new page
if (
len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0:
self.token_to_kv_pool_allocator.free(
batch.out_cache_loc[i : i + 1]
)
continue
if batch.spec_algorithm.is_none():
......
......@@ -103,6 +103,9 @@ class TpModelWorkerClient:
self.worker.model_runner.token_to_kv_pool_allocator,
)
def get_kv_cache(self):
return self.worker.model_runner.token_to_kv_pool
def forward_thread_func(self):
try:
with torch.get_device_module(self.device).stream(self.forward_stream):
......@@ -203,7 +206,7 @@ class TpModelWorkerClient:
-(self.future_token_ids_ct + 1),
-(self.future_token_ids_ct + 1 + bs),
-1,
dtype=torch.int32,
dtype=torch.int64,
device=self.device,
)
self.future_token_ids_ct = (
......
from abc import ABC, abstractmethod
from typing import Callable, List, Tuple
from typing import Any, List, Tuple
class BasePrefixCache(ABC):
......@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
pass
@abstractmethod
def evict(self, num_tokens: int, evict_callback: Callable):
def evict(self, num_tokens: int):
pass
@abstractmethod
def inc_lock_ref(self, node):
def inc_lock_ref(self, node: Any):
pass
@abstractmethod
def dec_lock_ref(self, node):
def dec_lock_ref(self, node: Any):
pass
@abstractmethod
def evictable_size(self):
pass
return 0
@abstractmethod
def protected_size(self):
raise NotImplementedError()
return 0
def total_size(self):
raise NotImplementedError()
......
from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
import torch
......@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
):
self.disable = True
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset()
def reset(self):
self.entries = {}
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
if rid not in self.entries:
return [], None
entry = self.entries[rid]
max_prefix_len = len(key)
return entry.value[:max_prefix_len], entry
pass
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
if token_ids is None:
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_id_len = len(token_ids)
def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
return [], None
def cache_finished_req(self, req: Req):
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
]
self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool_allocator.free(kv_indices)
if req.rid in self.entries:
del self.entries[req.rid]
def cache_unfinished_req(self, req: Req):
token_id_len = len(req.fill_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len
req.req_pool_idx, : len(req.fill_ids)
]
if req.rid not in self.entries:
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
entry = self.entries[req.rid]
entry.value = kv_indices
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
req.prefix_indices = kv_indices
req.last_node = entry
def insert(self):
raise NotImplementedError()
def evict(self, num_tokens: int, evict_callback: Callable):
def evict(self, num_tokens: int):
pass
def inc_lock_ref(self, node):
def inc_lock_ref(self, node: Any):
return 0
def dec_lock_ref(self, node):
return 0
def evictable_size(self):
return 0
def pretty_print(self):
return ""
def protected_size(self):
def dec_lock_ref(self, node: Any):
return 0
def pretty_print(self):
......
......@@ -7,13 +7,13 @@ from typing import List, Optional
import torch
from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost,
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
logger = logging.getLogger(__name__)
......@@ -122,7 +122,7 @@ class HiRadixCache(RadixCache):
def evictable_size(self):
return self.evictable_size_
def evict(self, num_tokens: int, evict_callback=None):
def evict(self, num_tokens: int):
leaves = self._collect_leaves_device()
heapq.heapify(leaves)
......
......@@ -129,6 +129,7 @@ class TokenToKVPoolAllocator:
self.size = size
self.dtype = dtype
self.device = device
self.page_size = 1
self.free_slots = None
self.is_not_in_free_group = True
......@@ -149,15 +150,14 @@ class TokenToKVPoolAllocator:
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index.to(self.device, non_blocking=True)
return select_index
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
self.free_slots = torch.concat((self.free_slots, free_index))
else:
self.free_group.append(free_index)
......@@ -172,7 +172,9 @@ class TokenToKVPoolAllocator:
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
self.free_slots = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_in_free_group = False
self.free_group = []
......@@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
......@@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache):
enable_memory_saver: bool,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
......@@ -207,6 +211,8 @@ class MHATokenToKVPool(KVCache):
self._create_buffers()
self.layer_transfer_counter = None
self.capture_mode = False
self.alt_stream = torch.cuda.Stream()
k_size, v_size = self.get_kv_size_bytes()
logger.info(
......@@ -218,16 +224,16 @@ class MHATokenToKVPool(KVCache):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.empty(
(self.size + 1, self.head_num, self.head_dim),
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
......@@ -315,14 +321,44 @@ class MHATokenToKVPool(KVCache):
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if self.capture_mode:
self.alt_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.alt_stream):
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
torch.cuda.current_stream().wait_stream(self.alt_stream)
else:
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
@torch.compile
def fused_downcast(
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
dtype: torch.dtype,
store_dtype: torch.dtype,
max_fp8: float,
min_fp8: float,
):
cache_k = cache_k / k_scale
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
cache_v = cache_v / v_scale
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
cache_k = cache_k.to(dtype)
cache_v = cache_v.to(dtype)
cache_k = cache_k.view(store_dtype)
cache_v = cache_v.view(store_dtype)
return cache_k, cache_v
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True, backend=get_compiler_backend())
......@@ -335,6 +371,7 @@ class MLATokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
......@@ -359,8 +396,8 @@ class MLATokenToKVPool(KVCache):
with memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype,
device=device,
)
......@@ -400,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
......@@ -409,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
enable_memory_saver: bool,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
......@@ -423,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
with memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.empty(
torch.zeros(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
......@@ -528,7 +571,7 @@ class MHATokenToKVPoolHost:
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
)
self.kv_buffer = torch.empty(
self.kv_buffer = torch.zeros(
(2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype,
device=self.device,
......@@ -548,9 +591,6 @@ class MHATokenToKVPoolHost:
def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data
......
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Page-aligned memory pool.
"""
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_bool_env_var, next_power_of_2
@triton.jit
def alloc_extend_kernel(
pre_lens_ptr,
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
max_num_extend_tokens: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
extend_lens = seq_lens - pre_lens
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = tl.load(pre_lens_ptr + pid)
extend_len = seq_len - pre_len
sum_extend_lens = tl.sum(extend_lens)
output_start_loc = sum_extend_lens - extend_len
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
tl.int64
)
tl.store(ret_values, merged_value)
# Part 1: fill the old partial page
last_loc = tl.load(last_loc_ptr + pid)
num_part1 = (
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
)
offset_one_page = tl.arange(0, page_size)
tl.store(
out_indices + output_start_loc + offset_one_page,
last_loc + 1 + offset_one_page,
mask=offset_one_page < num_part1,
)
if pre_len + num_part1 == seq_len:
return
# Part 2: fill the new full pages
num_part2 = (
seq_len // page_size * page_size
- (pre_len + page_size - 1) // page_size * page_size
)
offset_many_page = tl.arange(0, max_num_extend_tokens)
page_start = tl.load(
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
mask=offset_many_page < num_part2,
)
tl.store(
out_indices + output_start_loc + num_part1 + offset_many_page,
page_start * page_size + offset_many_page % page_size,
mask=offset_many_page < num_part2,
)
if pre_len + num_part1 + num_part2 == seq_len:
return
# Part 3: fill the new partial page
num_part3 = seq_len - seq_len // page_size * page_size
start_loc = tl.load(
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
)
tl.store(
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
start_loc * page_size + offset_one_page,
mask=offset_one_page < num_part3,
)
@triton.jit
def alloc_decode_kernel(
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = seq_len - 1
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
tl.store(ret_values, sum_num_new_pages)
if num_page_start_loc_self == 0:
last_loc = tl.load(last_loc_ptr + pid)
tl.store(out_indices + pid, last_loc + 1)
else:
page = tl.load(free_page_ptr + new_page_start_loc)
tl.store(out_indices + pid, page * page_size)
class PagedTokenToKVPoolAllocator:
"""
An allocator managing the indices to kv cache data.
This class has the same interface as `TokenToKVPoolAllocator` but the output
of one request is always page-aligned.
TODO: fuse last_loc into the kernel.
"""
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = page_size
self.num_pages = size // page_size
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self._kvcache = kvcache
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
def available_size(self):
return len(self.free_pages) * self.page_size
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
next_power_of_2(extend_num_tokens),
)
merged_value = self.ret_values.item()
num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
)
num_new_pages = self.ret_values.item()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
self.free_pages = torch.cat((free_page_indices, self.free_pages))
else:
self.free_group.append(free_index)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.concat(self.free_group))
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.num_pages + 1, dtype=torch.int64, device=self.device
)
self.is_in_free_group = False
self.free_group = []
......@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch
......@@ -67,7 +68,7 @@ class TreeNode:
return self.last_access_time < other.last_access_time
def _key_match(key0: List, key1: List):
def _key_match_page_size1(key0: List, key1: List):
i = 0
for k0, k1 in zip(key0, key1):
if k0 != k1:
......@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List):
return i
def _key_match_paged(key0: List, key1: List, page_size: int):
min_len = min(len(key0), len(key1))
i = 0
while i < min_len:
if key0[i : i + page_size] != key1[i : i + page_size]:
break
i += page_size
return i
class RadixCache(BasePrefixCache):
def __init__(
self,
req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
disable: bool = False,
):
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
else:
self.device = torch.device("cpu")
if self.page_size == 1:
self.key_match_fn = _key_match_page_size1
self.get_child_key_fn = lambda key: key[0]
else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size])
self.reset()
##### Public API #####
......@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache):
The last node create a new child if the prefix is shorter
than the last node's value.
"""
if self.disable:
return [], self.root_node
if self.disable or len(key) == 0:
return (
torch.empty(
(0,),
dtype=torch.int32,
device=self.device,
),
self.root_node,
)
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
value, last_node = self._match_prefix_helper(self.root_node, key)
if value:
value = torch.concat(value)
else:
value = torch.tensor([], dtype=torch.int32)
value = torch.empty((0,), dtype=torch.int32, device=self.device)
return value, last_node
def insert(self, key: List, value=None):
......@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache):
value = [x for x in key]
return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
def cache_finished_req(self, req: Req):
"""Cache request when it finishes."""
if self.disable:
if token_ids is None:
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_ids_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_ids_len
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
]
self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx)
return
if token_ids is None:
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
new_prefix_len = self.insert(
token_ids[:page_aligned_len], page_aligned_kv_indices
)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
......@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
def cache_unfinished_req(self, req: Req):
"""Cache request when it is unfinished."""
if self.disable:
return
if token_ids is None:
token_ids = req.fill_ids
token_ids = req.fill_ids
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone())
new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len]
)
# The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids)
assert len(new_indices) == len(token_ids)
new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :],
......@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache):
self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node)
req.prefix_indices = new_indices
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node
def pretty_print(self):
......@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache):
def total_size(self):
return self._total_size_helper()
def evict(self, num_tokens: int, evict_callback: Callable):
def evict(self, num_tokens: int):
if self.disable:
return
......@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache):
if x.lock_ref > 0:
continue
evict_callback(x.value)
self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
......@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache):
# protected size refers to the size of the cache that is locked
return self.protected_size_
def all_values_flatten(self):
values = []
def _dfs_helper(node: TreeNode):
for _, child in node.children.items():
values.append(child.value)
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.concat(values)
##### Internal Helper Functions #####
def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time()
child_key = self.get_child_key_fn(key)
value = []
while len(key) > 0 and key[0] in node.children.keys():
child = node.children[key[0]]
while len(key) > 0 and child_key in node.children.keys():
child = node.children[child_key]
child.last_access_time = time.time()
prefix_len = _key_match(child.key, key)
prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value)
......@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache):
value.append(child.value)
node = child
key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node
def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child
new_node = TreeNode()
new_node.children = {key[split_len]: child}
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent
new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len]
......@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache):
child.parent = new_node
child.key = child.key[split_len:]
child.value = child.value[split_len:]
new_node.parent.children[key[0]] = new_node
new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node
def _insert_helper(self, node: TreeNode, key: List, value):
......@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache):
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
total_prefix_length = 0
while len(key) > 0 and key[0] in node.children.keys():
node = node.children[key[0]]
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.time()
prefix_len = _key_match(node.key, key)
prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len
key = key[prefix_len:]
value = value[prefix_len:]
......@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = value
node.children[key[0]] = new_node
node.children[child_key] = new_node
self.evictable_size_ += len(value)
return total_prefix_length
......@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache):
current_node.key[:10],
f"r={current_node.lock_ref}",
)
for _, child in current_node.children.items():
for key, child in current_node.children.items():
stack.append((child, current_indent + 2))
assert key == self.get_child_key_fn(
child.key
), f"{key=}, {self.get_child_key_fn(child.key)=}"
def _delete_leaf(self, node):
for k, v in node.parent.children.items():
if v == node:
......@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache):
if __name__ == "__main__":
tree = RadixCache(None, None, False)
tree = RadixCache(None, None, page_size=1, disable=False)
tree.insert("Hello")
tree.insert("Hello")
......
......@@ -264,11 +264,15 @@ class CudaGraphRunner:
def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = True
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = True
yield
if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = False
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
......
......@@ -38,12 +38,12 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend
from sglang.srt.utils import get_compiler_backend, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
......@@ -51,9 +51,8 @@ if TYPE_CHECKING:
class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL = auto()
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
# It is also called "prefill" in common terminology.
EXTEND = auto()
# Decode one token.
DECODE = auto()
......@@ -153,6 +152,12 @@ class ForwardBatch:
top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False
top_p: torch.Tensor = None
# Position information
positions: torch.Tensor = None
......@@ -189,7 +194,7 @@ class ForwardBatch:
# Attention backend
req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None
token_to_kv_pool: KVCache = None
attn_backend: AttentionBackend = None
# For DP attention
......@@ -229,7 +234,6 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
)
ret = cls(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
......@@ -417,8 +421,8 @@ def compute_position_kernel(
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this?
cumsum_start = 0
# NOTE: This can be slow for large bs
cumsum_start = tl.cast(0, tl.int64)
for i in range(pid):
cumsum_start += tl.load(extend_seq_lens + i)
......
......@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model
......@@ -430,7 +431,7 @@ class ModelRunner:
self.model_config.model_path = model_path
load_config = LoadConfig(load_format=load_format)
# Only support the DefaultModelLoader for now
# Only support DefaultModelLoader for now
loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader):
message = f"Failed to get model loader: {loader}."
......@@ -732,6 +733,7 @@ class ModelRunner:
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
......@@ -742,6 +744,7 @@ class ModelRunner:
elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
......@@ -753,6 +756,7 @@ class ModelRunner:
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
......@@ -762,12 +766,21 @@ class ModelRunner:
)
if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
if self.page_size == 1:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else:
assert self.is_draft_worker
......
......@@ -220,6 +220,8 @@ class ServerArgs:
else:
self.chunked_prefill_size = 8192
assert self.chunked_prefill_size % self.page_size == 0
# Set cuda graph max batch size
if self.cuda_graph_max_bs is None:
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
......
......@@ -1554,6 +1554,13 @@ def set_cuda_arch():
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
def next_power_of_2(n: int):
return 1 << (n - 1).bit_length() if n > 0 else 1
setattr(triton, "next_power_of_2", next_power_of_2)
def add_prefix(name: str, prefix: str) -> str:
"""Add a weight path prefix to a module name.
......
......@@ -45,6 +45,7 @@ suites = {
TestFile("test_no_overlap_scheduler.py", 262),
TestFile("test_openai_server.py", 124),
TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 167),
TestFile("test_reasoning_content.py", 89),
......
......@@ -42,7 +42,8 @@ class TestDPAttention(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self):
args = SimpleNamespace(
......@@ -54,7 +55,8 @@ class TestDPAttention(unittest.TestCase):
)
metrics = run_eval(args)
assert metrics["score"] >= 0.8
print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__":
......
......@@ -184,6 +184,7 @@ class TestGPTQModelDynamicWithMarlin(unittest.TestCase):
"text": "The capital of France is",
"sampling_params": {
"max_new_tokens": max_new_tokens,
"temperature": 0.001,
},
},
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment