Unverified Commit b36afed4 authored by cctry's avatar cctry Committed by GitHub
Browse files

Separate allocation logic from scheduler (#11313)

parent 9aa4502d
...@@ -51,6 +51,7 @@ import logging ...@@ -51,6 +51,7 @@ import logging
import multiprocessing import multiprocessing
import os import os
import time import time
from types import SimpleNamespace
from typing import Tuple from typing import Tuple
import numpy as np import numpy as np
...@@ -257,11 +258,18 @@ def prepare_synthetic_inputs_for_latency_test( ...@@ -257,11 +258,18 @@ def prepare_synthetic_inputs_for_latency_test(
@torch.no_grad @torch.no_grad
def extend(reqs, model_runner): def extend(reqs, model_runner):
# Create dummy tree_cache for benchmarks (no prefix caching, just allocation)
dummy_tree_cache = SimpleNamespace(
page_size=1,
device=model_runner.device,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
)
batch = ScheduleBatch.init_new( batch = ScheduleBatch.init_new(
reqs=reqs, reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator, token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
tree_cache=None, tree_cache=dummy_tree_cache,
model_config=model_runner.model_config, model_config=model_runner.model_config,
enable_overlap=False, enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE, spec_algorithm=SpeculativeAlgorithm.NONE,
......
...@@ -45,8 +45,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union ...@@ -45,8 +45,6 @@ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import triton
import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
...@@ -62,6 +60,7 @@ from sglang.srt.mem_cache.allocator import ( ...@@ -62,6 +60,7 @@ from sglang.srt.mem_cache.allocator import (
) )
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
...@@ -70,7 +69,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw ...@@ -70,7 +69,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton from sglang.srt.utils import flatten_nested_list
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
...@@ -1001,158 +1000,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1001,158 +1000,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self): def is_empty(self):
return len(self.reqs) == 0 return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
def alloc_token_slots(self, num_tokens: int, backup_state: bool = False):
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
if out_cache_loc is None:
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
error_msg = (
f"{phase_str} out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
if self.tree_cache is not None:
self.tree_cache.pretty_print()
raise RuntimeError(error_msg)
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_extend(
self,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = (
extend_num_tokens
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
)
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
if out_cache_loc is None:
error_msg = (
f"Prefill out of memory. Try to lower your batch size.\n"
f"Try to allocate {extend_num_tokens} tokens.\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def alloc_paged_token_slots_decode(
self,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
seq_lens, seq_lens_cpu, last_loc
)
if out_cache_loc is None:
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
f"Try to allocate {len(seq_lens)} tokens.\n"
f"{self._available_and_evictable_str()}"
)
logger.error(error_msg)
raise RuntimeError(error_msg)
if backup_state:
return out_cache_loc, state
else:
return out_cache_loc
def write_cache_indices(
self,
req_pool_indices: List[int],
prefix_lens: List[int],
seq_lens: List[int],
extend_lens: List[int],
out_cache_loc: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
seq_lens_tensor: torch.Tensor,
extend_lens_tensor: torch.Tensor,
prefix_tensors: list[torch.Tensor],
):
if support_triton(global_server_args_dict.get("attention_backend")):
prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors], device=self.device
)
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton[(len(req_pool_indices),)](
self.req_to_token_pool.req_to_token,
req_pool_indices_tensor,
prefix_pointers,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
out_cache_loc,
self.req_to_token_pool.req_to_token.shape[1],
)
else:
pt = 0
for i in range(len(req_pool_indices)):
self.req_to_token_pool.write(
(req_pool_indices[i], slice(0, prefix_lens[i])),
prefix_tensors[i],
)
self.req_to_token_pool.write(
(req_pool_indices[i], slice(prefix_lens[i], seq_lens[i])),
out_cache_loc[pt : pt + extend_lens[i]],
)
pt += extend_lens[i]
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]): def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
self.encoder_lens_cpu = [] self.encoder_lens_cpu = []
self.encoder_cached = [] self.encoder_cached = []
...@@ -1253,10 +1100,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1253,10 +1100,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to( orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True self.device, non_blocking=True
) )
prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device
)
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
token_type_ids_tensor = None token_type_ids_tensor = None
if len(token_type_ids) > 0: if len(token_type_ids) > 0:
...@@ -1264,48 +1107,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1264,48 +1107,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
sum(token_type_ids, []), dtype=torch.int64 sum(token_type_ids, []), dtype=torch.int64
).to(self.device, non_blocking=True) ).to(self.device, non_blocking=True)
extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor # Set batch fields needed by alloc_for_extend
self.prefix_lens = prefix_lens
# Allocate req slots self.extend_lens = extend_lens
bs = len(self.reqs) self.seq_lens = seq_lens_tensor
req_pool_indices = self.alloc_req_slots(bs, self.reqs) self.seq_lens_cpu = seq_lens_cpu
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to( self.extend_num_tokens = extend_num_tokens
self.device, non_blocking=True
)
# Allocate memory # Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1: out_cache_loc, req_pool_indices_tensor, req_pool_indices = alloc_for_extend(
out_cache_loc = self.alloc_token_slots(extend_num_tokens) self
else:
last_loc = [
(
r.prefix_indices[-1:]
if len(r.prefix_indices) > 0
else torch.tensor([-1], device=self.device)
)
for r in self.reqs
]
out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu,
torch.cat(last_loc),
extend_num_tokens,
)
# Write allocated tokens to req_to_token_pool
self.write_cache_indices(
req_pool_indices,
prefix_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_pool_indices_tensor,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
[r.prefix_indices for r in reqs],
) )
# Set fields # Set fields
...@@ -1317,12 +1128,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1317,12 +1128,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
req.req_pool_idx = req_pool_indices[i] req.req_pool_idx = req_pool_indices[i]
assert seq_len - pre_len == req.extend_input_len assert seq_len - pre_len == req.extend_input_len
if pre_len > 0:
if isinstance(self.tree_cache, SWAChunkCache):
self.tree_cache.evict_swa(
req, pre_len, self.model_config.attention_chunk_size
)
# If input_embeds are available, store them # If input_embeds are available, store them
if req.input_embeds is not None: if req.input_embeds is not None:
# If req.input_embeds is already a list, append its content directly # If req.input_embeds is already a list, append its content directly
...@@ -1414,8 +1219,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1414,8 +1219,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.input_ids = input_ids_tensor self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu
self.orig_seq_lens = orig_seq_lens_tensor self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc self.out_cache_loc = out_cache_loc
self.input_embeds = ( self.input_embeds = (
...@@ -1439,9 +1242,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1439,9 +1242,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs] self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
self.extend_num_tokens = extend_num_tokens
self.prefix_lens = prefix_lens
self.extend_lens = extend_lens
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
...@@ -1681,11 +1481,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1681,11 +1481,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.output_ids = None self.output_ids = None
if self.model_config.is_encoder_decoder: if self.model_config.is_encoder_decoder:
locs = self.encoder_lens + self.seq_lens
self.prepare_encoder_info_decode() self.prepare_encoder_info_decode()
else:
locs = self.seq_lens.clone()
# Allocate memory
self.out_cache_loc = alloc_for_decode(self, token_per_req=1)
# Update seq_lens after allocation
if self.enable_overlap: if self.enable_overlap:
# Do not use in-place operations in the overlap mode # Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1 self.seq_lens = self.seq_lens + 1
...@@ -1698,28 +1499,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1698,28 +1499,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.orig_seq_lens.add_(1) self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs self.seq_lens_sum += bs
# free memory
if isinstance(self.tree_cache, SWAChunkCache):
for req in self.reqs:
self.tree_cache.evict_swa(
req, req.seqlen - 1, self.model_config.attention_chunk_size
)
# Allocate memory
if self.token_to_kv_pool_allocator.page_size == 1:
self.out_cache_loc = self.alloc_token_slots(bs)
else:
last_loc = self.req_to_token_pool.req_to_token[
self.req_pool_indices, self.seq_lens - 2
]
self.out_cache_loc = self.alloc_paged_token_slots_decode(
self.seq_lens, self.seq_lens_cpu, last_loc
)
self.req_to_token_pool.write(
(self.req_pool_indices, locs), self.out_cache_loc.to(torch.int32)
)
def filter_batch( def filter_batch(
self, self,
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
...@@ -1940,23 +1719,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1940,23 +1719,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
else: else:
return self.token_to_kv_pool_allocator.available_size() >= num_tokens return self.token_to_kv_pool_allocator.available_size() >= num_tokens
def _available_and_evictable_str(self) -> str:
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
swa_evictable_size = self.tree_cache.swa_evictable_size()
return (
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
f"Full LRU list evictable size: {self.tree_cache.full_lru_list_evictable_size()}\n"
f"SWA LRU list evictable size: {self.tree_cache.swa_lru_list_evictable_size()}\n"
)
else:
available_size = self.token_to_kv_pool_allocator.available_size()
evictable_size = self.tree_cache.evictable_size()
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
def __str__(self): def __str__(self):
return ( return (
f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, " f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
...@@ -2038,128 +1800,3 @@ class ModelWorkerBatch: ...@@ -2038,128 +1800,3 @@ class ModelWorkerBatch:
# Whether this batch is prefill-only (no token generation needed) # Whether this batch is prefill-only (no token generation needed)
is_prefill_only: bool = False is_prefill_only: bool = False
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
prefix_tensors,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
# write prefix
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < pre_len
value = tl.load(prefix_tensor + offset, mask=mask)
tl.store(
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
value,
mask=mask,
)
# NOTE: This can be slow for large bs
cumsum_start = tl.cast(0, tl.int64)
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
def get_last_loc(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if (
global_server_args_dict["attention_backend"] != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
def get_last_loc_torch(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
return torch.where(
prefix_lens_tensor > 0,
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
torch.full_like(prefix_lens_tensor, -1),
)
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 256
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import support_triton
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
logger = logging.getLogger(__name__)
GLOBAL_SERVER_ARGS_KEYS = ["attention_backend"]
global_server_args_dict = {k: getattr(ServerArgs, k) for k in GLOBAL_SERVER_ARGS_KEYS}
@triton.jit
def write_req_to_token_pool_triton(
req_to_token_ptr, # [max_batch, max_context_len]
req_pool_indices,
prefix_tensors,
pre_lens,
seq_lens,
extend_lens,
out_cache_loc,
req_to_token_ptr_stride: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 512
pid = tl.program_id(0)
req_pool_index = tl.load(req_pool_indices + pid)
pre_len = tl.load(pre_lens + pid)
seq_len = tl.load(seq_lens + pid)
prefix_tensor = tl.load(prefix_tensors + pid).to(tl.pointer_type(tl.int64))
# write prefix
num_loop = tl.cdiv(pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < pre_len
value = tl.load(prefix_tensor + offset, mask=mask)
tl.store(
req_to_token_ptr + req_pool_index * req_to_token_ptr_stride + offset,
value,
mask=mask,
)
# NOTE: This can be slow for large bs
cumsum_start = tl.cast(0, tl.int64)
for i in range(pid):
cumsum_start += tl.load(extend_lens + i)
num_loop = tl.cdiv(seq_len - pre_len, BLOCK_SIZE)
for i in range(num_loop):
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = offset < (seq_len - pre_len)
value = tl.load(out_cache_loc + cumsum_start + offset, mask=mask)
tl.store(
req_to_token_ptr
+ req_pool_index * req_to_token_ptr_stride
+ offset
+ pre_len,
value,
mask=mask,
)
def write_cache_indices(
out_cache_loc: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
req_pool_indices_cpu: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
extend_lens_tensor: torch.Tensor,
extend_lens_cpu: torch.Tensor,
prefix_tensors: list[torch.Tensor],
req_to_token_pool: ReqToTokenPool,
):
if support_triton(global_server_args_dict.get("attention_backend")):
prefix_pointers = torch.tensor(
[t.data_ptr() for t in prefix_tensors],
device=req_to_token_pool.device,
)
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
write_req_to_token_pool_triton[(req_pool_indices_tensor.shape[0],)](
req_to_token_pool.req_to_token,
req_pool_indices_tensor,
prefix_pointers,
prefix_lens_tensor,
seq_lens_tensor,
extend_lens_tensor,
out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
)
else:
pt = 0
for i in range(req_pool_indices_cpu.shape[0]):
req_idx = req_pool_indices_cpu[i].item()
prefix_len = prefix_lens_cpu[i].item()
seq_len = seq_lens_cpu[i].item()
extend_len = extend_lens_cpu[i].item()
req_to_token_pool.write(
(req_idx, slice(0, prefix_len)),
prefix_tensors[i],
)
req_to_token_pool.write(
(req_idx, slice(prefix_len, seq_len)),
out_cache_loc[pt : pt + extend_len],
)
pt += extend_len
def get_last_loc(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
if (
global_server_args_dict["attention_backend"] != "ascend"
and global_server_args_dict["attention_backend"] != "torch_native"
):
impl = get_last_loc_triton
else:
impl = get_last_loc_torch
return impl(req_to_token, req_pool_indices_tensor, prefix_lens_tensor)
def get_last_loc_torch(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
return torch.where(
prefix_lens_tensor > 0,
req_to_token[req_pool_indices_tensor, prefix_lens_tensor - 1],
torch.full_like(prefix_lens_tensor, -1),
)
@triton.jit
def get_last_loc_kernel(
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
offset = tl.arange(0, BLOCK_SIZE) + pid * BLOCK_SIZE
mask = offset < num_tokens
prefix_lens = tl.load(prefix_lens_tensor + offset, mask=mask, other=0)
req_pool_indices = tl.load(req_pool_indices_tensor + offset, mask=mask, other=0)
token_mask = prefix_lens > 0
token_index = req_pool_indices * req_to_token_stride + (prefix_lens - 1)
tokens = tl.load(req_to_token + token_index, mask=token_mask, other=-1)
tl.store(result + offset, tokens, mask=mask)
def get_last_loc_triton(
req_to_token: torch.Tensor,
req_pool_indices_tensor: torch.Tensor,
prefix_lens_tensor: torch.Tensor,
) -> torch.Tensor:
BLOCK_SIZE = 256
num_tokens = prefix_lens_tensor.shape[0]
result = torch.empty_like(prefix_lens_tensor)
grid = (triton.cdiv(num_tokens, BLOCK_SIZE),)
get_last_loc_kernel[grid](
req_to_token,
req_pool_indices_tensor,
prefix_lens_tensor,
result,
num_tokens,
req_to_token.stride(0),
BLOCK_SIZE,
)
return result
def alloc_token_slots(
tree_cache: BasePrefixCache,
num_tokens: int,
backup_state: bool = False,
):
allocator = tree_cache.token_to_kv_pool_allocator
evict_from_tree_cache(tree_cache, num_tokens)
state = None
if backup_state:
state = allocator.backup_state()
out_cache_loc = allocator.alloc(num_tokens)
if out_cache_loc is None:
error_msg = (
f"Out of memory. Try to lower your batch size.\n"
f"Try to allocate {num_tokens} tokens.\n"
f"{available_and_evictable_str(tree_cache)}"
)
logger.error(error_msg)
if tree_cache is not None:
tree_cache.pretty_print()
raise RuntimeError(error_msg)
return (out_cache_loc, state) if backup_state else out_cache_loc
def evict_from_tree_cache(tree_cache: BasePrefixCache | None, num_tokens: int):
if tree_cache is None:
return
if isinstance(tree_cache, (SWAChunkCache, ChunkCache)):
return
allocator = tree_cache.token_to_kv_pool_allocator
# Check if this is a hybrid allocator
if hasattr(allocator, "full_available_size"):
# Hybrid allocator
full_available_size = allocator.full_available_size()
swa_available_size = allocator.swa_available_size()
if full_available_size < num_tokens or swa_available_size < num_tokens:
full_num_tokens = max(0, num_tokens - full_available_size)
swa_num_tokens = max(0, num_tokens - swa_available_size)
tree_cache.evict(full_num_tokens, swa_num_tokens)
else:
# Standard allocator
if allocator.available_size() < num_tokens:
tree_cache.evict(num_tokens)
def alloc_paged_token_slots_extend(
tree_cache: BasePrefixCache,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
allocator = tree_cache.token_to_kv_pool_allocator
num_tokens = extend_num_tokens + len(seq_lens_cpu) * allocator.page_size
evict_from_tree_cache(tree_cache, num_tokens)
state = None
if backup_state:
state = allocator.backup_state()
out_cache_loc = allocator.alloc_extend(
prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
if out_cache_loc is None:
error_msg = (
f"Prefill out of memory. Try to lower your batch size.\n"
f"Try to allocate {extend_num_tokens} tokens.\n"
f"{available_and_evictable_str(tree_cache)}"
)
logger.error(error_msg)
if tree_cache is not None:
tree_cache.pretty_print()
raise RuntimeError(error_msg)
return (out_cache_loc, state) if backup_state else out_cache_loc
def alloc_req_slots(
req_to_token_pool: ReqToTokenPool,
num_reqs: int,
reqs: list[Req] | None,
) -> list[int]:
"""Allocate request slots from the pool."""
if isinstance(req_to_token_pool, HybridReqToTokenPool):
req_pool_indices = req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
def alloc_for_extend(
batch: ScheduleBatch,
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
"""
Allocate KV cache for extend batch and write to req_to_token_pool.
Returns:
out_cache_loc: allocated cache locations
req_pool_indices_device: request pool indices at a device tensor
req_pool_indices: request pool indices as list
"""
# free out-of-window swa tokens
if isinstance(batch.tree_cache, SWAChunkCache):
for req, pre_len in zip(batch.reqs, batch.prefix_lens):
batch.tree_cache.evict_swa(
req, pre_len, batch.model_config.attention_chunk_size
)
bs = len(batch.reqs)
prefix_tensors = [r.prefix_indices for r in batch.reqs]
# Create tensors for allocation
prefix_lens_cpu = torch.tensor(batch.prefix_lens, dtype=torch.int64)
extend_lens_cpu = torch.tensor(batch.extend_lens, dtype=torch.int64)
prefix_lens_device = prefix_lens_cpu.to(batch.device, non_blocking=True)
extend_lens_device = extend_lens_cpu.to(batch.device, non_blocking=True)
# Allocate req slots
req_pool_indices = alloc_req_slots(batch.req_to_token_pool, bs, batch.reqs)
req_pool_indices_cpu = torch.tensor(req_pool_indices, dtype=torch.int64)
req_pool_indices_device = req_pool_indices_cpu.to(batch.device, non_blocking=True)
# Allocate KV cache (throws exception on failure)
if batch.tree_cache.page_size == 1:
out_cache_loc = alloc_token_slots(batch.tree_cache, batch.extend_num_tokens)
else:
# Paged allocation - build last_loc
last_loc = [
(
t[-1:]
if len(t) > 0
else torch.tensor([-1], device=batch.tree_cache.device)
)
for t in prefix_tensors
]
out_cache_loc = alloc_paged_token_slots_extend(
tree_cache=batch.tree_cache,
prefix_lens=prefix_lens_device,
prefix_lens_cpu=prefix_lens_cpu,
seq_lens=batch.seq_lens,
seq_lens_cpu=batch.seq_lens_cpu,
last_loc=torch.cat(last_loc),
extend_num_tokens=batch.extend_num_tokens,
)
# Write to req_to_token_pool
write_cache_indices(
out_cache_loc,
req_pool_indices_device,
req_pool_indices_cpu,
prefix_lens_device,
prefix_lens_cpu,
batch.seq_lens,
batch.seq_lens_cpu,
extend_lens_device,
extend_lens_cpu,
prefix_tensors,
batch.req_to_token_pool,
)
return out_cache_loc, req_pool_indices_device, req_pool_indices
def alloc_paged_token_slots_decode(
tree_cache: BasePrefixCache,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
token_per_req: int = 1,
) -> torch.Tensor:
"""Allocate paged KV cache for decode batch."""
allocator = tree_cache.token_to_kv_pool_allocator
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = len(seq_lens) * allocator.page_size
evict_from_tree_cache(tree_cache, num_tokens)
out_cache_loc = allocator.alloc_decode(seq_lens, seq_lens_cpu, last_loc)
if out_cache_loc is None:
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
f"Try to allocate {len(seq_lens) * token_per_req} tokens.\n"
f"{available_and_evictable_str(tree_cache)}"
)
logger.error(error_msg)
if tree_cache is not None:
tree_cache.pretty_print()
raise RuntimeError(error_msg)
return out_cache_loc
def alloc_for_decode(batch: ScheduleBatch, token_per_req: int) -> torch.Tensor:
"""
Allocate KV cache for decode batch and write to req_to_token_pool.
Returns:
out_cache_loc: allocated cache locations
"""
if isinstance(batch.tree_cache, SWAChunkCache):
for req in batch.reqs:
batch.tree_cache.evict_swa(
req, req.seqlen - 1, batch.model_config.attention_chunk_size
)
bs = batch.seq_lens.shape[0]
if batch.tree_cache.page_size == 1:
# Non-paged allocation
out_cache_loc = alloc_token_slots(batch.tree_cache, bs * token_per_req)
else:
# Paged allocation
last_loc = batch.req_to_token_pool.req_to_token[
batch.req_pool_indices, batch.seq_lens - 1
]
seq_lens_next = batch.seq_lens + token_per_req
out_cache_loc = alloc_paged_token_slots_decode(
tree_cache=batch.tree_cache,
seq_lens=seq_lens_next,
seq_lens_cpu=batch.seq_lens_cpu + token_per_req,
last_loc=last_loc,
token_per_req=token_per_req,
)
# Write to req_to_token_pool
if batch.model_config.is_encoder_decoder:
locs = batch.encoder_lens + batch.seq_lens
else:
locs = batch.seq_lens.clone()
batch.req_to_token_pool.write(
(batch.req_pool_indices, locs), out_cache_loc.to(torch.int32)
)
return out_cache_loc
def available_and_evictable_str(tree_cache) -> str:
token_to_kv_pool_allocator = tree_cache.token_to_kv_pool_allocator
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
full_available_size = token_to_kv_pool_allocator.full_available_size()
swa_available_size = token_to_kv_pool_allocator.swa_available_size()
full_evictable_size = tree_cache.full_evictable_size()
swa_evictable_size = tree_cache.swa_evictable_size()
return (
f"Available full tokens: {full_available_size + full_evictable_size} ({full_available_size=} + {full_evictable_size=})\n"
f"Available swa tokens: {swa_available_size + swa_evictable_size} ({swa_available_size=} + {swa_evictable_size=})\n"
f"Full LRU list evictable size: {tree_cache.full_lru_list_evictable_size()}\n"
f"SWA LRU list evictable size: {tree_cache.swa_lru_list_evictable_size()}\n"
)
else:
available_size = token_to_kv_pool_allocator.available_size()
evictable_size = tree_cache.evictable_size()
return f"Available tokens: {available_size + evictable_size} ({available_size=} + {evictable_size=})\n"
...@@ -10,12 +10,13 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject ...@@ -10,12 +10,13 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
ScheduleBatch, from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend,
alloc_token_slots,
get_last_loc, get_last_loc,
global_server_args_dict,
) )
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import ( from sglang.srt.speculative.spec_utils import (
...@@ -100,7 +101,10 @@ class EagleVerifyInput(SpecInput): ...@@ -100,7 +101,10 @@ class EagleVerifyInput(SpecInput):
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
if page_size == 1: if page_size == 1:
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) batch.out_cache_loc = alloc_token_slots(
batch.tree_cache,
len(batch.input_ids),
)
end_offset = batch.seq_lens + self.draft_token_num end_offset = batch.seq_lens + self.draft_token_num
else: else:
prefix_lens = batch.seq_lens prefix_lens = batch.seq_lens
...@@ -112,7 +116,8 @@ class EagleVerifyInput(SpecInput): ...@@ -112,7 +116,8 @@ class EagleVerifyInput(SpecInput):
batch.req_pool_indices, batch.req_pool_indices,
prefix_lens, prefix_lens,
) )
batch.out_cache_loc = batch.alloc_paged_token_slots_extend( batch.out_cache_loc = alloc_paged_token_slots_extend(
batch.tree_cache,
prefix_lens, prefix_lens,
prefix_lens_cpu, prefix_lens_cpu,
end_offset, end_offset,
......
...@@ -14,13 +14,14 @@ from sglang.srt.distributed import ( ...@@ -14,13 +14,14 @@ from sglang.srt.distributed import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
ScheduleBatch,
get_last_loc,
global_server_args_dict,
)
from sglang.srt.managers.scheduler import GenerationBatchResult from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend,
alloc_token_slots,
get_last_loc,
)
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
ForwardBatch, ForwardBatch,
...@@ -541,8 +542,10 @@ class EAGLEWorker(TpModelWorker): ...@@ -541,8 +542,10 @@ class EAGLEWorker(TpModelWorker):
# [ topk 0 ] [ topk 1 ] # [ topk 0 ] [ topk 1 ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2] # [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
if self.page_size == 1: if self.page_size == 1:
out_cache_loc, token_to_kv_pool_state_backup = batch.alloc_token_slots( out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots(
num_seqs * self.speculative_num_steps * self.topk, backup_state=True batch.tree_cache,
num_seqs * self.speculative_num_steps * self.topk,
backup_state=True,
) )
else: else:
if self.topk == 1: if self.topk == 1:
...@@ -601,7 +604,8 @@ class EAGLEWorker(TpModelWorker): ...@@ -601,7 +604,8 @@ class EAGLEWorker(TpModelWorker):
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item() extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
out_cache_loc, token_to_kv_pool_state_backup = ( out_cache_loc, token_to_kv_pool_state_backup = (
batch.alloc_paged_token_slots_extend( alloc_paged_token_slots_extend(
batch.tree_cache,
prefix_lens, prefix_lens,
prefix_lens_cpu, prefix_lens_cpu,
seq_lens, seq_lens,
......
...@@ -16,10 +16,11 @@ import torch.nn.functional as F ...@@ -16,10 +16,11 @@ import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import ( from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
ScheduleBatch, from sglang.srt.mem_cache.common import (
alloc_paged_token_slots_extend,
alloc_token_slots,
get_last_loc, get_last_loc,
global_server_args_dict,
) )
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
...@@ -74,7 +75,10 @@ class NgramVerifyInput(SpecInput): ...@@ -74,7 +75,10 @@ class NgramVerifyInput(SpecInput):
batch.input_ids = self.draft_token batch.input_ids = self.draft_token
if page_size == 1: if page_size == 1:
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) batch.out_cache_loc = alloc_token_slots(
batch.tree_cache,
len(batch.input_ids),
)
end_offset = batch.seq_lens + self.draft_token_num end_offset = batch.seq_lens + self.draft_token_num
else: else:
# TODO(lsyin): add prefix lens cpu here to support page size > 1 # TODO(lsyin): add prefix lens cpu here to support page size > 1
...@@ -87,7 +91,8 @@ class NgramVerifyInput(SpecInput): ...@@ -87,7 +91,8 @@ class NgramVerifyInput(SpecInput):
batch.req_pool_indices, batch.req_pool_indices,
prefix_lens, prefix_lens,
) )
batch.out_cache_loc = batch.alloc_paged_token_slots_extend( batch.out_cache_loc = alloc_paged_token_slots_extend(
batch.tree_cache,
prefix_lens, prefix_lens,
prefix_lens_cpu, prefix_lens_cpu,
end_offset, end_offset,
......
...@@ -8,6 +8,7 @@ python3 test_forward_split_prefill.py ...@@ -8,6 +8,7 @@ python3 test_forward_split_prefill.py
""" """
import unittest import unittest
from types import SimpleNamespace
import numpy as np import numpy as np
import torch import torch
...@@ -95,11 +96,18 @@ class TestForwardSplitPrefill(CustomTestCase): ...@@ -95,11 +96,18 @@ class TestForwardSplitPrefill(CustomTestCase):
req.logprob_start_len = len(req.origin_input_ids) - 1 req.logprob_start_len = len(req.origin_input_ids) - 1
reqs.append(req) reqs.append(req)
# Create dummy tree_cache for tests (no prefix caching, just allocation)
dummy_tree_cache = SimpleNamespace(
page_size=1,
device=self.model_runner.device,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
)
batch = ScheduleBatch.init_new( batch = ScheduleBatch.init_new(
reqs=reqs, reqs=reqs,
req_to_token_pool=self.model_runner.req_to_token_pool, req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.model_runner.token_to_kv_pool_allocator,
tree_cache=None, tree_cache=dummy_tree_cache,
model_config=self.model_config, model_config=self.model_config,
enable_overlap=False, enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE, spec_algorithm=SpeculativeAlgorithm.NONE,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment