"src/vscode:/vscode.git/clone" did not exist on "2fa1d64841fdd0290d6abf5f0c4129643c089441"
Unverified Commit c76040e3 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support page size > 1 (#4356)

parent 2f6bacee
...@@ -36,7 +36,7 @@ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn ...@@ -36,7 +36,7 @@ fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
import deep_gemm import deep_gemm # `pip install "sgl-kernel>=0.0.4.post3"`
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -77,7 +77,7 @@ class SchedulePolicy: ...@@ -77,7 +77,7 @@ class SchedulePolicy:
self, self,
policy: str, policy: str,
tree_cache: BasePrefixCache, tree_cache: BasePrefixCache,
enable_hierarchical_cache: bool = False, enable_hierarchical_cache: bool,
): ):
self.policy = self._validate_and_adjust_policy(policy, tree_cache) self.policy = self._validate_and_adjust_policy(policy, tree_cache)
self.tree_cache = tree_cache self.tree_cache = tree_cache
...@@ -85,10 +85,17 @@ class SchedulePolicy: ...@@ -85,10 +85,17 @@ class SchedulePolicy:
# It is used to find the matching prefix for in-batch prefix caching. # It is used to find the matching prefix for in-batch prefix caching.
self.waiting_queue_radix_tree = RadixCache( self.waiting_queue_radix_tree = RadixCache(
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False req_to_token_pool=None,
token_to_kv_pool_allocator=None,
page_size=1,
disable=False,
) )
def calc_priority(self, waiting_queue: List[Req]) -> bool: def calc_priority(self, waiting_queue: List[Req]) -> bool:
if self.policy == CacheAgnosticPolicy.FCFS:
# A shortcut for FCFS
return
policy = self._determine_active_policy(waiting_queue) policy = self._determine_active_policy(waiting_queue)
prefix_computed = False prefix_computed = False
...@@ -118,7 +125,7 @@ class SchedulePolicy: ...@@ -118,7 +125,7 @@ class SchedulePolicy:
return prefix_computed return prefix_computed
def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy: def _determine_active_policy(self, waiting_queue: List[Req]) -> Policy:
if len(waiting_queue) > 128 and self.policy == CacheAwarePolicy.LPM: if self.policy == CacheAwarePolicy.LPM and len(waiting_queue) > 128:
# Turn off the expensive prefix matching and sorting when the #queue is large. # Turn off the expensive prefix matching and sorting when the #queue is large.
return CacheAgnosticPolicy.FCFS return CacheAgnosticPolicy.FCFS
return self.policy return self.policy
...@@ -442,7 +449,7 @@ class PrefillAdder: ...@@ -442,7 +449,7 @@ class PrefillAdder:
def add_one_req( def add_one_req(
self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
): ):
if req.sampling_params.ignore_eos and self.tree_cache.disable: if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req) return self.add_one_req_ignore_eos(req, has_chunked_req)
total_tokens = req.extend_input_len + min( total_tokens = req.extend_input_len + min(
......
...@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache ...@@ -93,7 +93,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -103,6 +103,7 @@ from sglang.srt.utils import ( ...@@ -103,6 +103,7 @@ from sglang.srt.utils import (
crash_on_warnings, crash_on_warnings,
get_bool_env_var, get_bool_env_var,
get_zmq_socket, get_zmq_socket,
kill_itself_when_parent_died,
pyspy_dump_schedulers, pyspy_dump_schedulers,
set_gpu_proc_affinity, set_gpu_proc_affinity,
set_random_seed, set_random_seed,
...@@ -159,6 +160,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -159,6 +160,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
) )
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.page_size = server_args.page_size
# Distributed rank info # Distributed rank info
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
...@@ -265,20 +267,23 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -265,20 +267,23 @@ class Scheduler(SchedulerOutputProcessorMixin):
f"context_len={self.model_config.context_len}" f"context_len={self.model_config.context_len}"
) )
# Init memory pool and cache
self.init_memory_pool_and_cache() self.init_memory_pool_and_cache()
# Init running status # Init running status
self.waiting_queue: List[Req] = [] self.waiting_queue: List[Req] = []
# The running decoding batch for continuous batching # The running decoding batch for continuous batching
self.running_batch: Optional[ScheduleBatch] = None self.running_batch: ScheduleBatch = ScheduleBatch(reqs=[], batch_is_full=False)
# The current forward batch # The current forward batch
self.cur_batch: Optional[ScheduleBatch] = None self.cur_batch: Optional[ScheduleBatch] = None
# The current forward batch # The last forward batch
self.last_batch: Optional[ScheduleBatch] = None self.last_batch: Optional[ScheduleBatch] = None
self.forward_ct = 0 self.forward_ct = 0
self.forward_ct_decode = 0 self.forward_ct_decode = 0
self.num_generated_tokens = 0 self.num_generated_tokens = 0
self.num_prefill_tokens = 0
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.time()
self.last_prefill_stats_tic = time.time()
self.return_health_check_ct = 0 self.return_health_check_ct = 0
self.current_stream = torch.get_device_module(self.device).current_stream() self.current_stream = torch.get_device_module(self.device).current_stream()
if self.device == "cpu": if self.device == "cpu":
...@@ -307,7 +312,9 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -307,7 +312,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Init schedule policy and new token estimation # Init schedule policy and new token estimation
self.policy = SchedulePolicy( self.policy = SchedulePolicy(
self.schedule_policy, self.tree_cache, self.enable_hierarchical_cache self.schedule_policy,
self.tree_cache,
self.enable_hierarchical_cache,
) )
assert ( assert (
server_args.schedule_conservativeness >= 0 server_args.schedule_conservativeness >= 0
...@@ -327,11 +334,6 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -327,11 +334,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
) / global_config.default_new_token_ratio_decay_steps ) / global_config.default_new_token_ratio_decay_steps
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
# Tell whether the current running batch is full so that we can skip
# the check of whether to prefill new requests.
# This is an optimization to reduce the overhead of the prefill check.
self.batch_is_full = False
# Init watchdog thread # Init watchdog thread
self.watchdog_timeout = server_args.watchdog_timeout self.watchdog_timeout = server_args.watchdog_timeout
t = threading.Thread(target=self.watchdog_thread, daemon=True) t = threading.Thread(target=self.watchdog_thread, daemon=True)
...@@ -437,6 +439,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -437,6 +439,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.tree_cache = RadixCache( self.tree_cache = RadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache, disable=server_args.disable_radix_cache,
) )
...@@ -458,6 +461,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -458,6 +461,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# The largest context length (prefill + generation) of a single request # The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0 self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0 self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
self.spec_num_total_accepted_tokens = 0 self.spec_num_total_accepted_tokens = 0
self.spec_num_total_forward_ct = 0 self.spec_num_total_forward_ct = 0
...@@ -487,7 +491,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -487,7 +491,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
else: else:
# When the server is idle, so self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
...@@ -527,7 +531,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -527,7 +531,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
) )
self.process_batch_result(tmp_batch, tmp_result) self.process_batch_result(tmp_batch, tmp_result)
elif batch is None: elif batch is None:
# When the server is idle, so self-check and re-init some states # When the server is idle, do self-check and re-init some states
self.check_memory() self.check_memory()
self.new_token_ratio = self.init_new_token_ratio self.new_token_ratio = self.init_new_token_ratio
...@@ -588,7 +592,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -588,7 +592,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
for recv_req in recv_reqs: for recv_req in recv_reqs:
# If it is a health check generation request and there are running requests, ignore it. # If it is a health check generation request and there are running requests, ignore it.
if is_health_check_generate_req(recv_req) and ( if is_health_check_generate_req(recv_req) and (
self.chunked_req is not None or self.running_batch is not None self.chunked_req is not None or not self.running_batch.is_empty()
): ):
self.return_health_check_ct += 1 self.return_health_check_ct += 1
continue continue
...@@ -812,6 +816,11 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -812,6 +816,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
can_run_list: List[Req], can_run_list: List[Req],
running_bs: int, running_bs: int,
): ):
gap_latency = time.time() - self.last_prefill_stats_tic
self.last_prefill_stats_tic = time.time()
self.last_input_throughput = self.num_prefill_tokens / gap_latency
self.num_prefill_tokens = 0
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
...@@ -847,7 +856,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -847,7 +856,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.last_decode_stats_tic = time.time() self.last_decode_stats_tic = time.time()
self.last_gen_throughput = self.num_generated_tokens / gap_latency self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0 self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 num_running_reqs = len(self.running_batch.reqs)
num_used = self.max_total_num_tokens - ( num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
...@@ -911,8 +920,10 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -911,8 +920,10 @@ class Scheduler(SchedulerOutputProcessorMixin):
) )
if memory_leak: if memory_leak:
msg = ( msg = (
"KV cache pool leak detected!" "KV cache pool leak detected! "
f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n" f"{available_size=}, {protected_size=}, {self.max_total_num_tokens=}\n"
f"{self.token_to_kv_pool_allocator.available_size()=}\n"
f"{self.tree_cache.evictable_size()=}\n"
) )
warnings.warn(msg) warnings.warn(msg)
if crash_on_warnings(): if crash_on_warnings():
...@@ -938,7 +949,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -938,7 +949,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.token_to_kv_pool_allocator.available_size() self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size() + self.tree_cache.evictable_size()
) )
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0 num_running_reqs = len(self.running_batch.reqs)
self.stats.num_running_reqs = num_running_reqs self.stats.num_running_reqs = num_running_reqs
self.stats.num_used_tokens = num_used self.stats.num_used_tokens = num_used
self.stats.token_usage = num_used / self.max_total_num_tokens self.stats.token_usage = num_used / self.max_total_num_tokens
...@@ -956,20 +967,20 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -956,20 +967,20 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.tree_cache.cache_unfinished_req(self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx # chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx) self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.batch_is_full = False self.running_batch.batch_is_full = False
# Filter batch # Filter batch
last_bs = self.last_batch.batch_size() last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch() self.last_batch.filter_batch()
if self.last_batch.batch_size() < last_bs: if self.last_batch.batch_size() < last_bs:
self.batch_is_full = False self.running_batch.batch_is_full = False
# Merge the new batch into the running batch # Merge the new batch into the running batch
if not self.last_batch.is_empty(): if not self.last_batch.is_empty():
if self.running_batch is None: if self.running_batch.is_empty():
self.running_batch = self.last_batch self.running_batch = self.last_batch
else: else:
# merge running_batch with prefill batch # Merge running_batch with prefill batch
self.running_batch.merge_batch(self.last_batch) self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill() new_batch = self.get_new_batch_prefill()
...@@ -978,11 +989,11 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -978,11 +989,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
ret = new_batch ret = new_batch
else: else:
# Run decode # Run decode
if self.running_batch is None: if not self.running_batch.is_empty():
ret = None
else:
self.running_batch = self.update_running_batch(self.running_batch) self.running_batch = self.update_running_batch(self.running_batch)
ret = self.running_batch ret = self.running_batch if not self.running_batch.is_empty() else None
else:
ret = None
# Handle DP attention # Handle DP attention
if self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention:
...@@ -997,13 +1008,13 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -997,13 +1008,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Handle the cases where prefill is not allowed # Handle the cases where prefill is not allowed
if ( if (
self.batch_is_full or len(self.waiting_queue) == 0 self.running_batch.batch_is_full or len(self.waiting_queue) == 0
) and self.chunked_req is None: ) and self.chunked_req is None:
return None return None
running_bs = len(self.running_batch.reqs) if self.running_batch else 0 running_bs = len(self.running_batch.reqs)
if running_bs >= self.max_running_requests: if running_bs >= self.max_running_requests:
self.batch_is_full = True self.running_batch.batch_is_full = True
return None return None
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
...@@ -1025,17 +1036,13 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1025,17 +1036,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
running_bs if self.is_mixed_chunk else 0, running_bs if self.is_mixed_chunk else 0,
) )
is_chunked = self.chunked_req is not None if self.chunked_req is not None:
if is_chunked:
self.chunked_req.init_next_round_input() self.chunked_req.init_next_round_input()
self.chunked_req = adder.add_chunked_req(self.chunked_req) self.chunked_req = adder.add_chunked_req(self.chunked_req)
if self.lora_paths: if self.lora_paths:
lora_set = ( lora_set = set([req.lora_path for req in self.running_batch.reqs])
set([req.lora_path for req in self.running_batch.reqs])
if self.running_batch is not None
else set([])
)
# Get requests from the waiting queue to a new prefill batch # Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue: for req in self.waiting_queue:
if ( if (
...@@ -1047,11 +1054,11 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1047,11 +1054,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
) )
> self.max_loras_per_batch > self.max_loras_per_batch
): ):
self.batch_is_full = True self.running_batch.batch_is_full = True
break break
if running_bs + len(adder.can_run_list) >= self.max_running_requests: if running_bs + len(adder.can_run_list) >= self.max_running_requests:
self.batch_is_full = True self.running_batch.batch_is_full = True
break break
req.init_next_round_input( req.init_next_round_input(
...@@ -1066,12 +1073,14 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1066,12 +1073,14 @@ class Scheduler(SchedulerOutputProcessorMixin):
if res == AddReqResult.NO_TOKEN: if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache: if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served # Set batch_is_full after making sure there are requests that can be served
self.batch_is_full = len(adder.can_run_list) > 0 or ( self.running_batch.batch_is_full = len(
adder.can_run_list
) > 0 or (
self.running_batch is not None self.running_batch is not None
and not self.running_batch.is_empty() and not self.running_batch.is_empty()
) )
else: else:
self.batch_is_full = True self.running_batch.batch_is_full = True
break break
# Update waiting queue # Update waiting queue
...@@ -1112,7 +1121,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1112,7 +1121,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Mixed-style chunked prefill # Mixed-style chunked prefill
if ( if (
self.is_mixed_chunk self.is_mixed_chunk
and self.running_batch is not None and not self.running_batch.is_empty()
and not (new_batch.return_logprob or self.running_batch.return_logprob) and not (new_batch.return_logprob or self.running_batch.return_logprob)
): ):
# TODO (lianmin): support return_logprob + mixed chunked prefill # TODO (lianmin): support return_logprob + mixed chunked prefill
...@@ -1121,7 +1130,9 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1121,7 +1130,9 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.running_batch.prepare_for_decode() self.running_batch.prepare_for_decode()
new_batch.mix_with_running(self.running_batch) new_batch.mix_with_running(self.running_batch)
new_batch.decoding_reqs = self.running_batch.reqs new_batch.decoding_reqs = self.running_batch.reqs
self.running_batch = None self.running_batch = ScheduleBatch(
reqs=[], batch_is_full=self.running_batch.batch_is_full
)
else: else:
new_batch.decoding_reqs = None new_batch.decoding_reqs = None
...@@ -1133,8 +1144,8 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1133,8 +1144,8 @@ class Scheduler(SchedulerOutputProcessorMixin):
batch.filter_batch() batch.filter_batch()
if batch.is_empty(): if batch.is_empty():
self.batch_is_full = False batch.batch_is_full = False
return None return batch
# Check if decode out of memory # Check if decode out of memory
if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or ( if not batch.check_decode_mem(self.decode_mem_cache_buf_multiplier) or (
...@@ -1158,7 +1169,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1158,7 +1169,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
) )
if batch.batch_size() < initial_bs: if batch.batch_size() < initial_bs:
self.batch_is_full = False batch.batch_is_full = False
# Update batch tensors # Update batch tensors
batch.prepare_for_decode() batch.prepare_for_decode()
...@@ -1233,8 +1244,6 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1233,8 +1244,6 @@ class Scheduler(SchedulerOutputProcessorMixin):
): ):
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
if batch.is_empty():
self.running_batch = None
elif batch.forward_mode.is_extend(): elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result) self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle(): elif batch.forward_mode.is_idle():
...@@ -1375,9 +1384,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1375,9 +1384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
def flush_cache(self): def flush_cache(self):
"""Flush the memory pool and cache.""" """Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and ( if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
self.running_batch is None or len(self.running_batch.reqs) == 0
):
self.cur_batch = None self.cur_batch = None
self.last_batch = None self.last_batch = None
self.tree_cache.reset() self.tree_cache.reset()
...@@ -1403,7 +1410,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1403,7 +1410,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
logging.warning( logging.warning(
f"Cache not flushed because there are pending requests. " f"Cache not flushed because there are pending requests. "
f"#queue-req: {len(self.waiting_queue)}, " f"#queue-req: {len(self.waiting_queue)}, "
f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}" f"#running-req: {len(self.running_batch.reqs)}"
) )
if_success = False if_success = False
return if_success return if_success
...@@ -1453,24 +1460,24 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1453,24 +1460,24 @@ class Scheduler(SchedulerOutputProcessorMixin):
def abort_request(self, recv_req: AbortReq): def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue # Delete requests in the waiting queue
to_del = None to_del = []
for i, req in enumerate(self.waiting_queue): for i, req in enumerate(self.waiting_queue):
if req.rid == recv_req.rid: if req.rid.startswith(recv_req.rid):
to_del = i to_del.append(i)
break break
if to_del is not None: # Sort in reverse order to avoid index issues when deleting
del self.waiting_queue[to_del] for i in sorted(to_del, reverse=True):
req = self.waiting_queue.pop(i)
logger.debug(f"Abort queued request. {req.rid=}") logger.debug(f"Abort queued request. {req.rid=}")
return return
# Delete requests in the running batch # Delete requests in the running batch
if self.running_batch: for req in self.running_batch.reqs:
for req in self.running_batch.reqs: if req.rid.startswith(recv_req.rid) and not req.finished():
if req.rid == recv_req.rid and not req.finished(): logger.debug(f"Abort running request. {req.rid=}")
logger.debug(f"Abort running request. {req.rid=}") req.to_abort = True
req.to_abort = True return
break
def _pause_engine(self) -> Tuple[List[Req], int]: def _pause_engine(self) -> Tuple[List[Req], int]:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -204,8 +204,17 @@ class SchedulerOutputProcessorMixin: ...@@ -204,8 +204,17 @@ class SchedulerOutputProcessorMixin:
continue continue
if self.enable_overlap and req.finished(): if self.enable_overlap and req.finished():
# Free the one delayed token # Free the one extra delayed token
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1]) if self.page_size == 1:
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
else:
# Only free when the extra token is in a new page
if (
len(req.origin_input_ids) + len(req.output_ids) - 1
) % self.page_size == 0:
self.token_to_kv_pool_allocator.free(
batch.out_cache_loc[i : i + 1]
)
continue continue
if batch.spec_algorithm.is_none(): if batch.spec_algorithm.is_none():
......
...@@ -103,6 +103,9 @@ class TpModelWorkerClient: ...@@ -103,6 +103,9 @@ class TpModelWorkerClient:
self.worker.model_runner.token_to_kv_pool_allocator, self.worker.model_runner.token_to_kv_pool_allocator,
) )
def get_kv_cache(self):
return self.worker.model_runner.token_to_kv_pool
def forward_thread_func(self): def forward_thread_func(self):
try: try:
with torch.get_device_module(self.device).stream(self.forward_stream): with torch.get_device_module(self.device).stream(self.forward_stream):
...@@ -203,7 +206,7 @@ class TpModelWorkerClient: ...@@ -203,7 +206,7 @@ class TpModelWorkerClient:
-(self.future_token_ids_ct + 1), -(self.future_token_ids_ct + 1),
-(self.future_token_ids_ct + 1 + bs), -(self.future_token_ids_ct + 1 + bs),
-1, -1,
dtype=torch.int32, dtype=torch.int64,
device=self.device, device=self.device,
) )
self.future_token_ids_ct = ( self.future_token_ids_ct = (
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, List, Tuple from typing import Any, List, Tuple
class BasePrefixCache(ABC): class BasePrefixCache(ABC):
...@@ -26,24 +26,22 @@ class BasePrefixCache(ABC): ...@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
pass pass
@abstractmethod @abstractmethod
def evict(self, num_tokens: int, evict_callback: Callable): def evict(self, num_tokens: int):
pass pass
@abstractmethod @abstractmethod
def inc_lock_ref(self, node): def inc_lock_ref(self, node: Any):
pass pass
@abstractmethod @abstractmethod
def dec_lock_ref(self, node): def dec_lock_ref(self, node: Any):
pass pass
@abstractmethod
def evictable_size(self): def evictable_size(self):
pass return 0
@abstractmethod
def protected_size(self): def protected_size(self):
raise NotImplementedError() return 0
def total_size(self): def total_size(self):
raise NotImplementedError() raise NotImplementedError()
......
from __future__ import annotations from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled.""" """Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, List, Tuple
import torch import torch
...@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache): ...@@ -24,73 +25,40 @@ class ChunkCache(BasePrefixCache):
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
): ):
self.disable = True
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.entries: Dict[str, ChunkCacheEntry] = {}
self.reset()
def reset(self): def reset(self):
self.entries = {} pass
def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
if rid not in self.entries:
return [], None
entry = self.entries[rid]
max_prefix_len = len(key)
return entry.value[:max_prefix_len], entry
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
if token_ids is None: return [], None
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_id_len = len(token_ids)
def cache_finished_req(self, req: Req):
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
] ]
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.token_to_kv_pool_allocator.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
if req.rid in self.entries:
del self.entries[req.rid]
def cache_unfinished_req(self, req: Req): def cache_unfinished_req(self, req: Req):
token_id_len = len(req.fill_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_id_len req.req_pool_idx, : len(req.fill_ids)
] ]
if req.rid not in self.entries: # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
entry = self.entries[req.rid]
entry.value = kv_indices
req.prefix_indices = kv_indices req.prefix_indices = kv_indices
req.last_node = entry
def insert(self): def insert(self):
raise NotImplementedError() raise NotImplementedError()
def evict(self, num_tokens: int, evict_callback: Callable): def evict(self, num_tokens: int):
pass pass
def inc_lock_ref(self, node): def inc_lock_ref(self, node: Any):
return 0 return 0
def dec_lock_ref(self, node): def dec_lock_ref(self, node: Any):
return 0
def evictable_size(self):
return 0
def pretty_print(self):
return ""
def protected_size(self):
return 0 return 0
def pretty_print(self): def pretty_print(self):
......
...@@ -7,13 +7,13 @@ from typing import List, Optional ...@@ -7,13 +7,13 @@ from typing import List, Optional
import torch import torch
from sglang.srt.managers.cache_controller import HiCacheController from sglang.srt.managers.cache_controller import HiCacheController
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPoolHost, MHATokenToKVPoolHost,
ReqToTokenPool, ReqToTokenPool,
TokenToKVPoolAllocator, TokenToKVPoolAllocator,
) )
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.mem_cache.radix_cache import _key_match_page_size1 as _key_match
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -122,7 +122,7 @@ class HiRadixCache(RadixCache): ...@@ -122,7 +122,7 @@ class HiRadixCache(RadixCache):
def evictable_size(self): def evictable_size(self):
return self.evictable_size_ return self.evictable_size_
def evict(self, num_tokens: int, evict_callback=None): def evict(self, num_tokens: int):
leaves = self._collect_leaves_device() leaves = self._collect_leaves_device()
heapq.heapify(leaves) heapq.heapify(leaves)
......
...@@ -129,6 +129,7 @@ class TokenToKVPoolAllocator: ...@@ -129,6 +129,7 @@ class TokenToKVPoolAllocator:
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.page_size = 1
self.free_slots = None self.free_slots = None
self.is_not_in_free_group = True self.is_not_in_free_group = True
...@@ -149,15 +150,14 @@ class TokenToKVPoolAllocator: ...@@ -149,15 +150,14 @@ class TokenToKVPoolAllocator:
select_index = self.free_slots[:need_size] select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:] self.free_slots = self.free_slots[need_size:]
return select_index
return select_index.to(self.device, non_blocking=True)
def free(self, free_index: torch.Tensor): def free(self, free_index: torch.Tensor):
if free_index.numel() == 0: if free_index.numel() == 0:
return return
if self.is_not_in_free_group: if self.is_not_in_free_group:
self.free_slots = torch.concat((self.free_slots, free_index.cpu())) self.free_slots = torch.concat((self.free_slots, free_index))
else: else:
self.free_group.append(free_index) self.free_group.append(free_index)
...@@ -172,7 +172,9 @@ class TokenToKVPoolAllocator: ...@@ -172,7 +172,9 @@ class TokenToKVPoolAllocator:
def clear(self): def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32) self.free_slots = torch.arange(
1, self.size + 1, dtype=torch.int64, device=self.device
)
self.is_in_free_group = False self.is_in_free_group = False
self.free_group = [] self.free_group = []
...@@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache): ...@@ -182,6 +184,7 @@ class MHATokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
size: int, size: int,
page_size: int,
dtype: torch.dtype, dtype: torch.dtype,
head_num: int, head_num: int,
head_dim: int, head_dim: int,
...@@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache): ...@@ -190,6 +193,7 @@ class MHATokenToKVPool(KVCache):
enable_memory_saver: bool, enable_memory_saver: bool,
): ):
self.size = size self.size = size
self.page_size = page_size
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
...@@ -207,6 +211,8 @@ class MHATokenToKVPool(KVCache): ...@@ -207,6 +211,8 @@ class MHATokenToKVPool(KVCache):
self._create_buffers() self._create_buffers()
self.layer_transfer_counter = None self.layer_transfer_counter = None
self.capture_mode = False
self.alt_stream = torch.cuda.Stream()
k_size, v_size = self.get_kv_size_bytes() k_size, v_size = self.get_kv_size_bytes()
logger.info( logger.info(
...@@ -218,16 +224,16 @@ class MHATokenToKVPool(KVCache): ...@@ -218,16 +224,16 @@ class MHATokenToKVPool(KVCache):
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [ self.k_buffer = [
torch.empty( torch.zeros(
(self.size + 1, self.head_num, self.head_dim), (self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype, dtype=self.store_dtype,
device=self.device, device=self.device,
) )
for _ in range(self.layer_num) for _ in range(self.layer_num)
] ]
self.v_buffer = [ self.v_buffer = [
torch.empty( torch.zeros(
(self.size + 1, self.head_num, self.head_dim), (self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype, dtype=self.store_dtype,
device=self.device, device=self.device,
) )
...@@ -315,14 +321,44 @@ class MHATokenToKVPool(KVCache): ...@@ -315,14 +321,44 @@ class MHATokenToKVPool(KVCache):
cache_v.div_(v_scale) cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype) cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype) cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype: if self.store_dtype != self.dtype:
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype) cache_k = cache_k.view(self.store_dtype)
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype) cache_v = cache_v.view(self.store_dtype)
if self.capture_mode:
self.alt_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.alt_stream):
self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v
torch.cuda.current_stream().wait_stream(self.alt_stream)
else: else:
self.k_buffer[layer_id][loc] = cache_k self.k_buffer[layer_id][loc] = cache_k
self.v_buffer[layer_id][loc] = cache_v self.v_buffer[layer_id][loc] = cache_v
@torch.compile
def fused_downcast(
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
dtype: torch.dtype,
store_dtype: torch.dtype,
max_fp8: float,
min_fp8: float,
):
cache_k = cache_k / k_scale
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
cache_v = cache_v / v_scale
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
cache_k = cache_k.to(dtype)
cache_v = cache_v.to(dtype)
cache_k = cache_k.view(store_dtype)
cache_v = cache_v.view(store_dtype)
return cache_k, cache_v
# This compiled version is slower in the unit test # This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True, backend=get_compiler_backend()) @torch.compile(dynamic=True, backend=get_compiler_backend())
...@@ -335,6 +371,7 @@ class MLATokenToKVPool(KVCache): ...@@ -335,6 +371,7 @@ class MLATokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
size: int, size: int,
page_size: int,
dtype: torch.dtype, dtype: torch.dtype,
kv_lora_rank: int, kv_lora_rank: int,
qk_rope_head_dim: int, qk_rope_head_dim: int,
...@@ -359,8 +396,8 @@ class MLATokenToKVPool(KVCache): ...@@ -359,8 +396,8 @@ class MLATokenToKVPool(KVCache):
with memory_saver_adapter.region(): with memory_saver_adapter.region():
# The padded slot 0 is used for writing dummy outputs from padded tokens. # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [ self.kv_buffer = [
torch.empty( torch.zeros(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim), (size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype, dtype=self.store_dtype,
device=device, device=device,
) )
...@@ -400,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -400,6 +437,7 @@ class DoubleSparseTokenToKVPool(KVCache):
def __init__( def __init__(
self, self,
size: int, size: int,
page_size: int,
dtype: torch.dtype, dtype: torch.dtype,
head_num: int, head_num: int,
head_dim: int, head_dim: int,
...@@ -409,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -409,6 +447,7 @@ class DoubleSparseTokenToKVPool(KVCache):
enable_memory_saver: bool, enable_memory_saver: bool,
): ):
self.size = size self.size = size
self.page_size = page_size
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn): if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
...@@ -423,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -423,17 +462,21 @@ class DoubleSparseTokenToKVPool(KVCache):
with memory_saver_adapter.region(): with memory_saver_adapter.region():
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
self.k_buffer = [ self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.v_buffer = [ self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device) torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num) for _ in range(layer_num)
] ]
# [size, head_num, heavy_channel_num] for each layer # [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [ self.label_buffer = [
torch.empty( torch.zeros(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device (size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
) )
for _ in range(layer_num) for _ in range(layer_num)
...@@ -528,7 +571,7 @@ class MHATokenToKVPoolHost: ...@@ -528,7 +571,7 @@ class MHATokenToKVPoolHost:
f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache." f"Allocating {requested_bytes / 1e9:.2f} GB host memory for hierarchical KV cache."
) )
self.kv_buffer = torch.empty( self.kv_buffer = torch.zeros(
(2, self.layer_num, self.size, self.head_num, self.head_dim), (2, self.layer_num, self.size, self.head_num, self.head_dim),
dtype=self.dtype, dtype=self.dtype,
device=self.device, device=self.device,
...@@ -548,9 +591,6 @@ class MHATokenToKVPoolHost: ...@@ -548,9 +591,6 @@ class MHATokenToKVPoolHost:
def get_flat_data(self, indices): def get_flat_data(self, indices):
return self.kv_buffer[:, :, indices] return self.kv_buffer[:, :, indices]
def get_flat_data_by_layer(self, indices, layer_id):
return self.kv_buffer[:, layer_id, indices]
def assign_flat_data(self, indices, flat_data): def assign_flat_data(self, indices, flat_data):
self.kv_buffer[:, :, indices] = flat_data self.kv_buffer[:, :, indices] = flat_data
......
"""
Copyright 2025 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Page-aligned memory pool.
"""
import torch
import triton
import triton.language as tl
from sglang.srt.mem_cache.memory_pool import KVCache
from sglang.srt.utils import get_bool_env_var, next_power_of_2
@triton.jit
def alloc_extend_kernel(
pre_lens_ptr,
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
max_num_extend_tokens: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.load(pre_lens_ptr + load_offset, mask=load_offset <= pid)
extend_lens = seq_lens - pre_lens
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = tl.load(pre_lens_ptr + pid)
extend_len = seq_len - pre_len
sum_extend_lens = tl.sum(extend_lens)
output_start_loc = sum_extend_lens - extend_len
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
tl.int64
)
tl.store(ret_values, merged_value)
# Part 1: fill the old partial page
last_loc = tl.load(last_loc_ptr + pid)
num_part1 = (
min(seq_len, (pre_len + page_size - 1) // page_size * page_size) - pre_len
)
offset_one_page = tl.arange(0, page_size)
tl.store(
out_indices + output_start_loc + offset_one_page,
last_loc + 1 + offset_one_page,
mask=offset_one_page < num_part1,
)
if pre_len + num_part1 == seq_len:
return
# Part 2: fill the new full pages
num_part2 = (
seq_len // page_size * page_size
- (pre_len + page_size - 1) // page_size * page_size
)
offset_many_page = tl.arange(0, max_num_extend_tokens)
page_start = tl.load(
free_page_ptr + new_page_start_loc + offset_many_page // page_size,
mask=offset_many_page < num_part2,
)
tl.store(
out_indices + output_start_loc + num_part1 + offset_many_page,
page_start * page_size + offset_many_page % page_size,
mask=offset_many_page < num_part2,
)
if pre_len + num_part1 + num_part2 == seq_len:
return
# Part 3: fill the new partial page
num_part3 = seq_len - seq_len // page_size * page_size
start_loc = tl.load(
free_page_ptr + new_page_start_loc + num_page_start_loc_self - 1
)
tl.store(
out_indices + output_start_loc + num_part1 + num_part2 + offset_one_page,
start_loc * page_size + offset_one_page,
mask=offset_one_page < num_part3,
)
@triton.jit
def alloc_decode_kernel(
seq_lens_ptr,
last_loc_ptr,
free_page_ptr,
out_indices,
ret_values,
bs_upper: tl.constexpr,
page_size: tl.constexpr,
):
pid = tl.program_id(0)
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(seq_lens_ptr + load_offset, mask=load_offset <= pid)
pre_lens = tl.where(load_offset <= pid, seq_lens - 1, seq_lens)
seq_len = tl.load(seq_lens_ptr + pid)
pre_len = seq_len - 1
num_pages_after = (seq_lens + page_size - 1) // page_size
num_pages_before = (pre_lens + page_size - 1) // page_size
num_new_pages = num_pages_after - num_pages_before
num_page_start_loc_self = (seq_len + page_size - 1) // page_size - (
pre_len + page_size - 1
) // page_size
sum_num_new_pages = tl.sum(num_new_pages)
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
# Return value
if pid == tl.num_programs(0) - 1:
tl.store(ret_values, sum_num_new_pages)
if num_page_start_loc_self == 0:
last_loc = tl.load(last_loc_ptr + pid)
tl.store(out_indices + pid, last_loc + 1)
else:
page = tl.load(free_page_ptr + new_page_start_loc)
tl.store(out_indices + pid, page * page_size)
class PagedTokenToKVPoolAllocator:
"""
An allocator managing the indices to kv cache data.
This class has the same interface as `TokenToKVPoolAllocator` but the output
of one request is always page-aligned.
TODO: fuse last_loc into the kernel.
"""
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
):
self.size = size
self.dtype = dtype
self.device = device
self.page_size = page_size
self.num_pages = size // page_size
self.free_pages = None
self.is_not_in_free_group = True
self.free_group = []
self.clear()
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self._kvcache = kvcache
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
def available_size(self):
return len(self.free_pages) * self.page_size
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
bs = len(prefix_lens)
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel[(bs,)](
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
next_power_of_2(extend_num_tokens),
)
merged_value = self.ret_values.item()
num_new_pages = merged_value >> 32
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
bs = len(seq_lens)
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
alloc_decode_kernel[(bs,)](
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.ret_values,
next_power_of_2(bs),
self.page_size,
)
num_new_pages = self.ret_values.item()
if num_new_pages > len(self.free_pages):
return None
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
if self.is_not_in_free_group:
free_page_indices = torch.unique(free_index // self.page_size)
self.free_pages = torch.cat((free_page_indices, self.free_pages))
else:
self.free_group.append(free_index)
def free_group_begin(self):
self.is_not_in_free_group = False
self.free_group = []
def free_group_end(self):
self.is_not_in_free_group = True
if self.free_group:
self.free(torch.concat(self.free_group))
def clear(self):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_pages = torch.arange(
1, self.num_pages + 1, dtype=torch.int64, device=self.device
)
self.is_in_free_group = False
self.free_group = []
...@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache. ...@@ -22,7 +22,8 @@ The radix tree data structure for managing the KV cache.
import heapq import heapq
import time import time
from collections import defaultdict from collections import defaultdict
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple from functools import partial
from typing import TYPE_CHECKING, List, Optional, Tuple
import torch import torch
...@@ -67,7 +68,7 @@ class TreeNode: ...@@ -67,7 +68,7 @@ class TreeNode:
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
def _key_match(key0: List, key1: List): def _key_match_page_size1(key0: List, key1: List):
i = 0 i = 0
for k0, k1 in zip(key0, key1): for k0, k1 in zip(key0, key1):
if k0 != k1: if k0 != k1:
...@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List): ...@@ -76,16 +77,42 @@ def _key_match(key0: List, key1: List):
return i return i
def _key_match_paged(key0: List, key1: List, page_size: int):
min_len = min(len(key0), len(key1))
i = 0
while i < min_len:
if key0[i : i + page_size] != key1[i : i + page_size]:
break
i += page_size
return i
class RadixCache(BasePrefixCache): class RadixCache(BasePrefixCache):
def __init__( def __init__(
self, self,
req_to_token_pool: ReqToTokenPool, req_to_token_pool: ReqToTokenPool,
token_to_kv_pool_allocator: TokenToKVPoolAllocator, token_to_kv_pool_allocator: TokenToKVPoolAllocator,
page_size: int,
disable: bool = False, disable: bool = False,
): ):
self.req_to_token_pool = req_to_token_pool self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.page_size = page_size
self.disable = disable self.disable = disable
if self.token_to_kv_pool_allocator:
self.device = self.token_to_kv_pool_allocator.device
else:
self.device = torch.device("cpu")
if self.page_size == 1:
self.key_match_fn = _key_match_page_size1
self.get_child_key_fn = lambda key: key[0]
else:
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
self.get_child_key_fn = lambda key: tuple(key[:page_size])
self.reset() self.reset()
##### Public API ##### ##### Public API #####
...@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache): ...@@ -109,14 +136,25 @@ class RadixCache(BasePrefixCache):
The last node create a new child if the prefix is shorter The last node create a new child if the prefix is shorter
than the last node's value. than the last node's value.
""" """
if self.disable: if self.disable or len(key) == 0:
return [], self.root_node return (
torch.empty(
(0,),
dtype=torch.int32,
device=self.device,
),
self.root_node,
)
if self.page_size != 1:
page_aligned_len = len(key) // self.page_size * self.page_size
key = key[:page_aligned_len]
value, last_node = self._match_prefix_helper(self.root_node, key) value, last_node = self._match_prefix_helper(self.root_node, key)
if value: if value:
value = torch.concat(value) value = torch.concat(value)
else: else:
value = torch.tensor([], dtype=torch.int32) value = torch.empty((0,), dtype=torch.int32, device=self.device)
return value, last_node return value, last_node
def insert(self, key: List, value=None): def insert(self, key: List, value=None):
...@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache): ...@@ -127,29 +165,33 @@ class RadixCache(BasePrefixCache):
value = [x for x in key] value = [x for x in key]
return self._insert_helper(self.root_node, key, value) return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_finished_req(self, req: Req):
"""Cache request when it finishes.""" """Cache request when it finishes."""
if self.disable: if self.disable:
if token_ids is None:
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
else:
token_ids_len = len(token_ids)
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :token_ids_len req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
] ]
self.token_to_kv_pool_allocator.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
return return
if token_ids is None: token_ids = (req.origin_input_ids + req.output_ids)[:-1]
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids) req.req_pool_idx, : len(token_ids)
] ]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone()) new_prefix_len = self.insert(
token_ids[:page_aligned_len], page_aligned_kv_indices
)
self.token_to_kv_pool_allocator.free( self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len] kv_indices[len(req.prefix_indices) : new_prefix_len]
) )
...@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache): ...@@ -158,27 +200,32 @@ class RadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None): def cache_unfinished_req(self, req: Req):
"""Cache request when it is unfinished.""" """Cache request when it is unfinished."""
if self.disable: if self.disable:
return return
if token_ids is None: token_ids = req.fill_ids
token_ids = req.fill_ids
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids) req.req_pool_idx, : len(token_ids)
] ]
if self.page_size != 1:
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
else:
page_aligned_len = len(kv_indices)
page_aligned_kv_indices = kv_indices.clone()
page_aligned_token_ids = token_ids[:page_aligned_len]
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert(token_ids, kv_indices.clone()) new_prefix_len = self.insert(page_aligned_token_ids, page_aligned_kv_indices)
self.token_to_kv_pool_allocator.free( self.token_to_kv_pool_allocator.free(
kv_indices[len(req.prefix_indices) : new_prefix_len] kv_indices[len(req.prefix_indices) : new_prefix_len]
) )
# The prefix indices could be updated, reuse it # The prefix indices could be updated, reuse it
new_indices, new_last_node = self.match_prefix(token_ids) new_indices, new_last_node = self.match_prefix(page_aligned_token_ids)
assert len(new_indices) == len(token_ids)
self.req_to_token_pool.write( self.req_to_token_pool.write(
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
new_indices[len(req.prefix_indices) :], new_indices[len(req.prefix_indices) :],
...@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache): ...@@ -186,7 +233,14 @@ class RadixCache(BasePrefixCache):
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
self.inc_lock_ref(new_last_node) self.inc_lock_ref(new_last_node)
req.prefix_indices = new_indices
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
if self.page_size != 1:
req.prefix_indices = torch.cat(
[new_indices, kv_indices[len(new_indices) :]]
)
else:
req.prefix_indices = new_indices
req.last_node = new_last_node req.last_node = new_last_node
def pretty_print(self): def pretty_print(self):
...@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache): ...@@ -196,7 +250,7 @@ class RadixCache(BasePrefixCache):
def total_size(self): def total_size(self):
return self._total_size_helper() return self._total_size_helper()
def evict(self, num_tokens: int, evict_callback: Callable): def evict(self, num_tokens: int):
if self.disable: if self.disable:
return return
...@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache): ...@@ -212,7 +266,7 @@ class RadixCache(BasePrefixCache):
if x.lock_ref > 0: if x.lock_ref > 0:
continue continue
evict_callback(x.value) self.token_to_kv_pool_allocator.free(x.value)
num_evicted += len(x.value) num_evicted += len(x.value)
self._delete_leaf(x) self._delete_leaf(x)
...@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache): ...@@ -254,15 +308,29 @@ class RadixCache(BasePrefixCache):
# protected size refers to the size of the cache that is locked # protected size refers to the size of the cache that is locked
return self.protected_size_ return self.protected_size_
def all_values_flatten(self):
values = []
def _dfs_helper(node: TreeNode):
for _, child in node.children.items():
values.append(child.value)
_dfs_helper(child)
_dfs_helper(self.root_node)
return torch.concat(values)
##### Internal Helper Functions ##### ##### Internal Helper Functions #####
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.time() node.last_access_time = time.time()
child_key = self.get_child_key_fn(key)
value = [] value = []
while len(key) > 0 and key[0] in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
child = node.children[key[0]] child = node.children[child_key]
child.last_access_time = time.time() child.last_access_time = time.time()
prefix_len = _key_match(child.key, key) prefix_len = self.key_match_fn(child.key, key)
if prefix_len < len(child.key): if prefix_len < len(child.key):
new_node = self._split_node(child.key, child, prefix_len) new_node = self._split_node(child.key, child, prefix_len)
value.append(new_node.value) value.append(new_node.value)
...@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache): ...@@ -272,12 +340,16 @@ class RadixCache(BasePrefixCache):
value.append(child.value) value.append(child.value)
node = child node = child
key = key[prefix_len:] key = key[prefix_len:]
if len(key):
child_key = self.get_child_key_fn(key)
return value, node return value, node
def _split_node(self, key, child: TreeNode, split_len: int): def _split_node(self, key, child: TreeNode, split_len: int):
# new_node -> child # new_node -> child
new_node = TreeNode() new_node = TreeNode()
new_node.children = {key[split_len]: child} new_node.children = {self.get_child_key_fn(key[split_len:]): child}
new_node.parent = child.parent new_node.parent = child.parent
new_node.lock_ref = child.lock_ref new_node.lock_ref = child.lock_ref
new_node.key = child.key[:split_len] new_node.key = child.key[:split_len]
...@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache): ...@@ -285,7 +357,7 @@ class RadixCache(BasePrefixCache):
child.parent = new_node child.parent = new_node
child.key = child.key[split_len:] child.key = child.key[split_len:]
child.value = child.value[split_len:] child.value = child.value[split_len:]
new_node.parent.children[key[0]] = new_node new_node.parent.children[self.get_child_key_fn(key)] = new_node
return new_node return new_node
def _insert_helper(self, node: TreeNode, key: List, value): def _insert_helper(self, node: TreeNode, key: List, value):
...@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache): ...@@ -293,11 +365,13 @@ class RadixCache(BasePrefixCache):
if len(key) == 0: if len(key) == 0:
return 0 return 0
child_key = self.get_child_key_fn(key)
total_prefix_length = 0 total_prefix_length = 0
while len(key) > 0 and key[0] in node.children.keys(): while len(key) > 0 and child_key in node.children.keys():
node = node.children[key[0]] node = node.children[child_key]
node.last_access_time = time.time() node.last_access_time = time.time()
prefix_len = _key_match(node.key, key) prefix_len = self.key_match_fn(node.key, key)
total_prefix_length += prefix_len total_prefix_length += prefix_len
key = key[prefix_len:] key = key[prefix_len:]
value = value[prefix_len:] value = value[prefix_len:]
...@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache): ...@@ -306,12 +380,15 @@ class RadixCache(BasePrefixCache):
new_node = self._split_node(node.key, node, prefix_len) new_node = self._split_node(node.key, node, prefix_len)
node = new_node node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
if len(key): if len(key):
new_node = TreeNode() new_node = TreeNode()
new_node.parent = node new_node.parent = node
new_node.key = key new_node.key = key
new_node.value = value new_node.value = value
node.children[key[0]] = new_node node.children[child_key] = new_node
self.evictable_size_ += len(value) self.evictable_size_ += len(value)
return total_prefix_length return total_prefix_length
...@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache): ...@@ -326,9 +403,13 @@ class RadixCache(BasePrefixCache):
current_node.key[:10], current_node.key[:10],
f"r={current_node.lock_ref}", f"r={current_node.lock_ref}",
) )
for _, child in current_node.children.items(): for key, child in current_node.children.items():
stack.append((child, current_indent + 2)) stack.append((child, current_indent + 2))
assert key == self.get_child_key_fn(
child.key
), f"{key=}, {self.get_child_key_fn(child.key)=}"
def _delete_leaf(self, node): def _delete_leaf(self, node):
for k, v in node.parent.children.items(): for k, v in node.parent.children.items():
if v == node: if v == node:
...@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache): ...@@ -363,7 +444,7 @@ class RadixCache(BasePrefixCache):
if __name__ == "__main__": if __name__ == "__main__":
tree = RadixCache(None, None, False) tree = RadixCache(None, None, page_size=1, disable=False)
tree.insert("Hello") tree.insert("Hello")
tree.insert("Hello") tree.insert("Hello")
......
...@@ -264,11 +264,15 @@ class CudaGraphRunner: ...@@ -264,11 +264,15 @@ class CudaGraphRunner:
def model_capture_mode(self): def model_capture_mode(self):
if hasattr(self.model_runner.model, "capture_mode"): if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = True self.model_runner.model.capture_mode = True
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = True
yield yield
if hasattr(self.model_runner.model, "capture_mode"): if hasattr(self.model_runner.model, "capture_mode"):
self.model_runner.model.capture_mode = False self.model_runner.model.capture_mode = False
if hasattr(self.model_runner.token_to_kv_pool, "capture_mode"):
self.model_runner.token_to_kv_pool.capture_mode = False
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention: if self.enable_dp_attention:
......
...@@ -38,12 +38,12 @@ import triton ...@@ -38,12 +38,12 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend from sglang.srt.utils import get_compiler_backend, next_power_of_2
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
...@@ -51,9 +51,8 @@ if TYPE_CHECKING: ...@@ -51,9 +51,8 @@ if TYPE_CHECKING:
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
PREFILL = auto()
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt). # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
# It is also called "prefill" in common terminology.
EXTEND = auto() EXTEND = auto()
# Decode one token. # Decode one token.
DECODE = auto() DECODE = auto()
...@@ -153,6 +152,12 @@ class ForwardBatch: ...@@ -153,6 +152,12 @@ class ForwardBatch:
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
token_ids_logprobs: Optional[List[List[int]]] = None token_ids_logprobs: Optional[List[List[int]]] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
temperature: torch.Tensor = None
top_p_normalized_logprobs: bool = False
top_p: torch.Tensor = None
# Position information # Position information
positions: torch.Tensor = None positions: torch.Tensor = None
...@@ -189,7 +194,7 @@ class ForwardBatch: ...@@ -189,7 +194,7 @@ class ForwardBatch:
# Attention backend # Attention backend
req_to_token_pool: ReqToTokenPool = None req_to_token_pool: ReqToTokenPool = None
token_to_kv_pool: BaseTokenToKVPool = None token_to_kv_pool: KVCache = None
attn_backend: AttentionBackend = None attn_backend: AttentionBackend = None
# For DP attention # For DP attention
...@@ -229,7 +234,6 @@ class ForwardBatch: ...@@ -229,7 +234,6 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu = ( extend_input_logprob_token_ids_gpu = (
batch.extend_input_logprob_token_ids.to(device, non_blocking=True) batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
) )
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens), batch_size=len(batch.seq_lens),
...@@ -417,8 +421,8 @@ def compute_position_kernel( ...@@ -417,8 +421,8 @@ def compute_position_kernel(
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0 prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
seq_len = tl.load(extend_seq_lens + pid) seq_len = tl.load(extend_seq_lens + pid)
# TODO: optimize this? # NOTE: This can be slow for large bs
cumsum_start = 0 cumsum_start = tl.cast(0, tl.int64)
for i in range(pid): for i in range(pid):
cumsum_start += tl.load(extend_seq_lens + i) cumsum_start += tl.load(extend_seq_lens + i)
......
...@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -53,6 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool, ReqToTokenPool,
TokenToKVPoolAllocator, TokenToKVPoolAllocator,
) )
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader import get_model from sglang.srt.model_loader import get_model
...@@ -430,7 +431,7 @@ class ModelRunner: ...@@ -430,7 +431,7 @@ class ModelRunner:
self.model_config.model_path = model_path self.model_config.model_path = model_path
load_config = LoadConfig(load_format=load_format) load_config = LoadConfig(load_format=load_format)
# Only support the DefaultModelLoader for now # Only support DefaultModelLoader for now
loader = get_model_loader(load_config) loader = get_model_loader(load_config)
if not isinstance(loader, DefaultModelLoader): if not isinstance(loader, DefaultModelLoader):
message = f"Failed to get model loader: {loader}." message = f"Failed to get model loader: {loader}."
...@@ -732,6 +733,7 @@ class ModelRunner: ...@@ -732,6 +733,7 @@ class ModelRunner:
): ):
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank, kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
...@@ -742,6 +744,7 @@ class ModelRunner: ...@@ -742,6 +744,7 @@ class ModelRunner:
elif self.server_args.enable_double_sparsity: elif self.server_args.enable_double_sparsity:
self.token_to_kv_pool = DoubleSparseTokenToKVPool( self.token_to_kv_pool = DoubleSparseTokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
...@@ -753,6 +756,7 @@ class ModelRunner: ...@@ -753,6 +756,7 @@ class ModelRunner:
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
...@@ -762,12 +766,21 @@ class ModelRunner: ...@@ -762,12 +766,21 @@ class ModelRunner:
) )
if self.token_to_kv_pool_allocator is None: if self.token_to_kv_pool_allocator is None:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( if self.page_size == 1:
self.max_total_num_tokens, self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
dtype=self.kv_cache_dtype, self.max_total_num_tokens,
device=self.device, dtype=self.kv_cache_dtype,
kvcache=self.token_to_kv_pool, device=self.device,
) kvcache=self.token_to_kv_pool,
)
else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
)
else: else:
assert self.is_draft_worker assert self.is_draft_worker
......
...@@ -220,6 +220,8 @@ class ServerArgs: ...@@ -220,6 +220,8 @@ class ServerArgs:
else: else:
self.chunked_prefill_size = 8192 self.chunked_prefill_size = 8192
assert self.chunked_prefill_size % self.page_size == 0
# Set cuda graph max batch size # Set cuda graph max batch size
if self.cuda_graph_max_bs is None: if self.cuda_graph_max_bs is None:
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues. # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
......
...@@ -1554,6 +1554,13 @@ def set_cuda_arch(): ...@@ -1554,6 +1554,13 @@ def set_cuda_arch():
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
def next_power_of_2(n: int):
return 1 << (n - 1).bit_length() if n > 0 else 1
setattr(triton, "next_power_of_2", next_power_of_2)
def add_prefix(name: str, prefix: str) -> str: def add_prefix(name: str, prefix: str) -> str:
"""Add a weight path prefix to a module name. """Add a weight path prefix to a module name.
......
...@@ -45,6 +45,7 @@ suites = { ...@@ -45,6 +45,7 @@ suites = {
TestFile("test_no_overlap_scheduler.py", 262), TestFile("test_no_overlap_scheduler.py", 262),
TestFile("test_openai_server.py", 124), TestFile("test_openai_server.py", 124),
TestFile("test_penalty.py", 41), TestFile("test_penalty.py", 41),
TestFile("test_page_size.py", 60),
TestFile("test_pytorch_sampling_backend.py", 66), TestFile("test_pytorch_sampling_backend.py", 66),
TestFile("test_radix_attention.py", 167), TestFile("test_radix_attention.py", 167),
TestFile("test_reasoning_content.py", 89), TestFile("test_reasoning_content.py", 89),
......
...@@ -42,7 +42,8 @@ class TestDPAttention(unittest.TestCase): ...@@ -42,7 +42,8 @@ class TestDPAttention(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.5 print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.5)
def test_mgsm_en(self): def test_mgsm_en(self):
args = SimpleNamespace( args = SimpleNamespace(
...@@ -54,7 +55,8 @@ class TestDPAttention(unittest.TestCase): ...@@ -54,7 +55,8 @@ class TestDPAttention(unittest.TestCase):
) )
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.8 print(f"{metrics=}")
self.assertGreater(metrics["score"], 0.8)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -184,6 +184,7 @@ class TestGPTQModelDynamicWithMarlin(unittest.TestCase): ...@@ -184,6 +184,7 @@ class TestGPTQModelDynamicWithMarlin(unittest.TestCase):
"text": "The capital of France is", "text": "The capital of France is",
"sampling_params": { "sampling_params": {
"max_new_tokens": max_new_tokens, "max_new_tokens": max_new_tokens,
"temperature": 0.001,
}, },
}, },
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment