Unverified Commit 9e426466 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files
parent 2f20f430
......@@ -63,12 +63,12 @@ You can find additional accuracy eval examples in:
## Benchmark the speed
Refer to [Benchmark and Profiling](../developer_guide/benchmark_and_profiling.md).
## Request a Review
## Request a review
You can identify potential reviewers for your code by checking the [code owners](https://github.com/sgl-project/sglang/blob/main/.github/CODEOWNERS) and [reviewers](https://github.com/sgl-project/sglang/blob/main/.github/REVIEWERS.md) files.
Another effective strategy is to review the file modification history and contact individuals who have frequently edited the files.
If you modify files protected by code owners, their approval is required to merge the code.
## General Code Style
## General code style
- Avoid code duplication. If the same code snippet (more than five lines) appears multiple times, extract it into a shared function.
- Minimize device synchronization. Reduce expensive CPU-GPU synchronization operations, such as `tensor.item()` or `tensor.cpu()`, whenever possible. Use vectorized code.
- Keep files concise. If a file exceeds 2,000 lines of code, split it into multiple smaller files.
......
......@@ -267,7 +267,6 @@ def extend(reqs, model_runner):
model_config=model_runner.model_config,
enable_overlap=False,
spec_algorithm=SpeculativeAlgorithm.NONE,
enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
_maybe_prepare_mlp_sync_batch(batch, model_runner)
......
......@@ -864,7 +864,6 @@ class SchedulerDisaggregationDecodeMixin:
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
# construct fake completed prefill
......
......@@ -870,6 +870,8 @@ class FlashInferIndicesUpdaterPrefill:
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
if use_ragged:
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
# and forward_batch.extend_seq_lens_cpu
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = paged_kernel_lens.sum().item()
else:
......
......@@ -57,16 +57,36 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = torch.compiler.disable(decode_attention_fwd)
self.extend_attention_fwd = torch.compiler.disable(extend_attention_fwd)
# Parse args
self.skip_prefill = skip_prefill
max_bs = model_runner.req_to_token_pool.size
self.sliding_window_size = model_runner.sliding_window_size
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.device_core_count = get_device_core_count(model_runner.gpu_id)
self.static_kv_splits = get_bool_env_var(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
)
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
# Check arguments
assert not (
model_runner.sliding_window_size is not None
and model_runner.model_config.is_encoder_decoder
), "Sliding window and cross attention are not supported together"
self.sliding_window_size = model_runner.sliding_window_size
# Initialize buffers
# TODO(Jianan Ji): Make sure it behaves as expected when kv_indptr_buf is provided and sliding window is enabled
if kv_indptr_buf is None:
self.kv_indptr = torch.zeros(
......@@ -87,9 +107,6 @@ class TritonAttnBackend(AttentionBackend):
# When provided a buffer, create a clone for the second buffer
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
......@@ -99,29 +116,9 @@ class TritonAttnBackend(AttentionBackend):
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
)
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
get_attention_tp_size()
)
self.static_kv_splits = get_bool_env_var(
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
)
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
# Initialize forward metadata
self.forward_metadata: ForwardMetadata = None
self.max_context_len = model_runner.model_config.context_len
self.device = model_runner.device
self.device_core_count = get_device_core_count(model_runner.gpu_id)
def get_num_kv_splits(
self,
num_kv_splits: torch.Tensor,
......@@ -333,7 +330,7 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr = None
attn_logits = None
attn_lse = None
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
max_extend_len = max(forward_batch.extend_seq_lens_cpu)
num_kv_splits = None
self.forward_metadata = ForwardMetadata(
......
......@@ -113,6 +113,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_multimodal",
"enable_symm_mem",
"quantization",
"enable_custom_logit_processor",
]
# Put some global args for easy access
......@@ -909,9 +910,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
# Enable custom logit processor
enable_custom_logit_processor: bool = False
# Whether to return hidden states
return_hidden_states: bool = False
......@@ -928,7 +926,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
model_config: ModelConfig,
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
chunked_req: Optional[Req] = None,
):
return_logprob = any(req.return_logprob for req in reqs)
......@@ -955,7 +952,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
has_grammar=any(req.grammar for req in reqs),
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=any(req.return_hidden_states for req in reqs),
chunked_req=chunked_req,
)
......@@ -1009,6 +1005,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_num_tokens: int,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = (
extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
......@@ -1041,8 +1038,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
last_loc: torch.Tensor,
backup_state: bool = False,
):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = len(seq_lens) * self.token_to_kv_pool_allocator.page_size
self._evict_tree_cache_if_needed(num_tokens)
if backup_state:
......@@ -1721,38 +1718,18 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens
if self.forward_mode.is_decode_or_idle():
attention_backend_str = global_server_args_dict["decode_attention_backend"]
else:
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
# Create seq_lens_cpu when needed
if (
attention_backend_str
in [
"fa3",
"flashinfer",
"flashmla",
"cutlass_mla",
"ascend",
"trtllm_mha",
"aiter",
]
or global_server_args_dict["enable_two_batch_overlap"]
):
seq_lens_cpu = (
seq_lens_cpu_cache
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
)
else:
seq_lens_cpu = None
if self.sampling_info:
if self.has_grammar:
self.sampling_info.grammars = [req.grammar for req in self.reqs]
else:
self.sampling_info.grammars = None
seq_lens_cpu = (
seq_lens_cpu_cache
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
)
global bid
bid += 1
return ModelWorkerBatch(
......@@ -1815,18 +1792,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
return_logprob=self.return_logprob,
decoding_reqs=self.decoding_reqs,
spec_algorithm=self.spec_algorithm,
enable_custom_logit_processor=self.enable_custom_logit_processor,
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
is_extend_in_batch=self.is_extend_in_batch,
)
def _evict_tree_cache_if_needed(
self,
num_tokens: int,
) -> None:
if isinstance(self.tree_cache, SWAChunkCache):
def _evict_tree_cache_if_needed(self, num_tokens: int):
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
return
if self.is_hybrid:
......
......@@ -1634,7 +1634,6 @@ class Scheduler(
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
chunked_req=self.chunked_req,
)
if self.enable_hierarchical_cache:
......@@ -2031,7 +2030,6 @@ class Scheduler(
self.model_config,
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
)
idle_batch.prepare_for_idle()
return idle_batch
......
......@@ -20,7 +20,6 @@ Page-aligned memory pool.
"""
import abc
import weakref
from typing import TYPE_CHECKING
import torch
......@@ -81,9 +80,6 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
if self.free_group:
self.free(torch.cat(self.free_group))
def estimated_num_new_pages(self, bs, extend_num_tokens):
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
def merge_and_sort_free(self):
if len(self.release_pages) > 0:
self.free_pages = torch.cat((self.free_pages, self.release_pages))
......@@ -149,6 +145,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def alloc(self, need_size: int):
if self.need_sort and need_size > len(self.free_pages):
self.merge_and_sort_free()
if need_size > len(self.free_pages):
return None
......@@ -437,9 +434,13 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
device: str,
kvcache: KVCache,
need_sort: bool,
max_num_extend_tokens: int,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
self.num_pages = size // page_size
self.max_num_extend_tokens_next_power_of_2 = next_power_of_2(
max_num_extend_tokens
)
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
self.clear()
......@@ -480,7 +481,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
bs = len(prefix_lens)
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
if self.need_sort and extend_num_tokens // self.page_size + bs + 1 > len(
self.free_pages
):
self.merge_and_sort_free()
......@@ -497,7 +498,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
self.ret_values,
next_power_of_2(bs),
self.page_size,
next_power_of_2(extend_num_tokens),
self.max_num_extend_tokens_next_power_of_2,
)
if self.debug_mode:
......@@ -522,9 +523,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
)
bs = len(seq_lens)
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
self.free_pages
):
if self.need_sort and bs > len(self.free_pages):
self.merge_and_sort_free()
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
......@@ -578,151 +577,3 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
def load_cpu_copy(self, kv_cache_cpu, indices):
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
def alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
device,
):
extend_lens = seq_lens - prefix_lens
end_pos = torch.cumsum(extend_lens, 0)
start_pos = end_pos - extend_lens
num_new_pages = (seq_lens + page_size - 1) // page_size - (
prefix_lens + page_size - 1
) // page_size
num_full_new_pages = (seq_lens) // page_size - (
prefix_lens + page_size - 1
) // page_size
need_page = num_new_pages - num_full_new_pages
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
for i in range(len(prefix_lens)):
num1 = (
min(
seq_lens[i],
(prefix_lens[i] + page_size - 1) // page_size * page_size,
)
- prefix_lens[i]
)
if num1:
out_indices[start_pos[i] : start_pos[i] + num1] = (
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
)
num2 = (
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
) * page_size
if num2:
pages = (
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
* page_size
)
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
pages.view(-1, 1) + pos_in_page.view(1, -1)
).view(-1)
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
if num3:
out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1)
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
estimated_num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if estimated_num_new_pages > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[estimated_num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item()
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
return None
end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0:
out_indices = last_loc + 1
else:
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
start_new_pages
] * self.page_size * need_new_pages
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices.int()
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from sglang.srt.mem_cache.allocator import PagedTokenToKVPoolAllocator
if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache
def alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
free_pages,
out_indices,
page_size,
device,
):
extend_lens = seq_lens - prefix_lens
end_pos = torch.cumsum(extend_lens, 0)
start_pos = end_pos - extend_lens
num_new_pages = (seq_lens + page_size - 1) // page_size - (
prefix_lens + page_size - 1
) // page_size
num_full_new_pages = (seq_lens) // page_size - (
prefix_lens + page_size - 1
) // page_size
need_page = num_new_pages - num_full_new_pages
end_new_pages = torch.cumsum(num_new_pages, 0)
start_new_pages = end_new_pages - num_new_pages
pos_in_page = torch.arange(page_size, device=device, dtype=torch.int32)
for i in range(len(prefix_lens)):
num1 = (
min(
seq_lens[i],
(prefix_lens[i] + page_size - 1) // page_size * page_size,
)
- prefix_lens[i]
)
if num1:
out_indices[start_pos[i] : start_pos[i] + num1] = (
last_loc[i] + 1 + pos_in_page[:num1].view(-1)
)
num2 = (
seq_lens[i] // page_size - (prefix_lens[i] + page_size - 1) // page_size
) * page_size
if num2:
pages = (
free_pages[start_new_pages[i] : end_new_pages[i] - need_page[i]]
* page_size
)
out_indices[start_pos[i] + num1 : start_pos[i] + num1 + num2] = (
pages.view(-1, 1) + pos_in_page.view(1, -1)
).view(-1)
num3 = seq_lens[i] - seq_lens[i] // page_size * page_size
if num3:
out_indices[end_pos[i] - num3 : end_pos[i]] = (
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
).view(-1)
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
device: str,
kvcache: KVCache,
need_sort: bool,
):
super().__init__(size, page_size, dtype, device, kvcache, need_sort, 1)
def alloc_extend(
self,
prefix_lens: torch.Tensor,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
):
if self.debug_mode:
assert torch.all(
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
)
num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if self.need_sort and num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
)
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices
def alloc_decode(
self,
seq_lens: torch.Tensor,
last_loc: torch.Tensor,
):
if self.debug_mode:
assert torch.all(
(last_loc + 2) % self.page_size == seq_lens % self.page_size
)
need_new_pages = (seq_lens % self.page_size == 1).int()
num_new_pages = need_new_pages.sum().item()
if num_new_pages > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
return None
end_new_pages = torch.cumsum(need_new_pages, 0)
start_new_pages = end_new_pages - need_new_pages
if num_new_pages == 0:
out_indices = last_loc + 1
else:
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
start_new_pages
] * self.page_size * need_new_pages
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
return out_indices.int()
......@@ -2,7 +2,7 @@ from __future__ import annotations
"""Cache for chunked prefill, used when RadixCache is disabled."""
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional
import torch
......
......@@ -75,12 +75,12 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import (
AscendPagedTokenToKVPoolAllocator,
BaseTokenToKVPoolAllocator,
PagedTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
TokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.allocator_ascend import AscendPagedTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool import (
AscendMLAPagedTokenToKVPool,
AscendTokenToKVPool,
......@@ -176,10 +176,6 @@ class ModelRunner:
self.mem_fraction_static = mem_fraction_static
self.device = server_args.device
self.gpu_id = gpu_id
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
self.tp_rank = tp_rank
self.tp_size = tp_size
self.moe_ep_rank = moe_ep_rank
......@@ -205,15 +201,17 @@ class ModelRunner:
self.is_hybrid = model_config.is_hybrid
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
self.attention_chunk_size = model_config.attention_chunk_size
self.forward_pass_id = 0
# Model-specific adjustment
self.model_specific_adjustment()
# Apply the rank zero filter to logger
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
logger.addFilter(RankZeroFilter(tp_rank == 0))
if server_args.show_time_cost:
enable_show_time_cost()
# Model-specific adjustment
self.model_specific_adjustment()
# Global vars
global_server_args_dict.update(
{k: getattr(server_args, k) for k in GLOBAL_SERVER_ARGS_KEYS}
......@@ -221,8 +219,6 @@ class ModelRunner:
# TODO it is indeed not a "server args"
"use_mla_backend": self.use_mla_backend,
"speculative_algorithm": self.spec_algorithm,
}
| {
"moe_a2a_backend": MoeA2ABackend(server_args.moe_a2a_backend),
"deepep_mode": DeepEPMode(server_args.deepep_mode),
}
......@@ -242,13 +238,15 @@ class ModelRunner:
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
# If it is a draft model, tp_group can be different
# Initialize the model runner
self.initialize(min_per_gpu_memory)
# temporary cached values
# Temporary cached values
self.support_pp = (
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
)
# For weight updates
self._model_update_group = {}
def initialize(self, min_per_gpu_memory: float):
......@@ -277,6 +275,7 @@ class ModelRunner:
)
)
# Expert parallelism
self.eplb_manager = (
EPLBManager(self)
if self.server_args.enable_eplb and (not self.is_draft_worker)
......@@ -1160,6 +1159,7 @@ class ModelRunner:
max_num_reqs: Optional[int] = None,
max_total_tokens: Optional[int] = None,
):
# Determine the kv cache dtype
if self.server_args.kv_cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
......@@ -1178,6 +1178,8 @@ class ModelRunner:
)
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if max_num_reqs is None:
max_num_reqs = min(
......@@ -1190,9 +1192,6 @@ class ModelRunner:
4096,
)
if SGLANG_CI_SMALL_KV_SIZE:
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
if not self.spec_algorithm.is_none():
if self.is_draft_worker:
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
......@@ -1239,6 +1238,7 @@ class ModelRunner:
"Not enough memory. Please try to increase --mem-fraction-static."
)
# Initialize req_to_token_pool
if self.req_to_token_pool is None:
if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
......@@ -1264,6 +1264,7 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
# Initialize token_to_kv_pool
if self.server_args.attention_backend == "ascend":
if self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
......@@ -1349,44 +1350,52 @@ class ModelRunner:
end_layer=self.end_layer,
)
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
max_num_extend_tokens = (
self.server_args.chunked_prefill_size
if self.server_args.chunked_prefill_size > 0
else self.server_args.max_prefill_tokens
)
if self.token_to_kv_pool_allocator is None:
if self.page_size == 1:
if self.is_hybrid:
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
self.full_max_total_num_tokens,
self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
if self.server_args.attention_backend == "ascend":
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
if not _is_npu:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
if self.page_size == 1:
if self.is_hybrid:
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
self.full_max_total_num_tokens,
self.swa_max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
self.max_total_num_tokens,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
assert not self.is_hybrid
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=need_sort,
max_num_extend_tokens=max_num_extend_tokens,
)
else:
assert self.is_draft_worker
......@@ -1554,15 +1563,13 @@ class ModelRunner:
)
return TRTLLMHAAttnBackend(self)
elif backend_str == "intel_amx":
from sglang.srt.layers.attention.intel_amx_backend import (
IntelAMXAttnBackend,
)
logger.info(f"Intel AMX attention backend is enabled.")
return IntelAMXAttnBackend(self)
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
elif backend_str == "dual_chunk_flash_attn":
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
DualChunkFlashAttentionBackend,
)
......@@ -1606,6 +1613,7 @@ class ModelRunner:
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
)
self.cuda_graph_runner = CudaGraphRunner(self)
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
self.cuda_graph_mem_usage = before_mem - after_mem
logger.info(
......
......@@ -68,6 +68,8 @@ class SamplingBatchInfo:
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
from sglang.srt.managers.schedule_batch import global_server_args_dict
reqs = batch.reqs
device = batch.device
temperatures = (
......@@ -97,10 +99,11 @@ class SamplingBatchInfo:
logit_bias[i, int(key)] = value
# Check if any request has custom logit processor
has_custom_logit_processor = (
batch.enable_custom_logit_processor # check the flag first.
and any(r.custom_logit_processor for r in reqs) # then check the requests.
)
has_custom_logit_processor = global_server_args_dict[
"enable_custom_logit_processor"
] and any( # check the flag first.
r.custom_logit_processor for r in reqs
) # then check the requests.
if has_custom_logit_processor:
# Merge the same type of custom logit processors together
......
......@@ -575,6 +575,7 @@ class ServerArgs:
"Pipeline parallelism is incompatible with overlap schedule."
)
# Hicache
if self.hicache_storage_backend == "mooncake":
# to use mooncake storage backend, the following conditions must be met:
self.hicache_io_backend = "kernel"
......@@ -1316,19 +1317,23 @@ class ServerArgs:
# Kernel backend
ATTN_BACKENDS = [
"aiter",
# Common
"triton",
"torch_native",
# NVIDIA specific
"cutlass_mla",
"fa3",
"flashinfer",
"flashmla",
"intel_amx",
"torch_native",
"ascend",
"triton",
"trtllm_mla",
"trtllm_mha",
"dual_chunk_flash_attn",
# AMD specific
"aiter",
"wave",
# Other platforms
"intel_amx",
"ascend",
]
parser.add_argument(
"--attention-backend",
......
......@@ -21,7 +21,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
/*
* From csrc/allreduce
*/
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
m.def("dispose", &dispose);
......@@ -46,6 +45,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def("mscclpp_allreduce(int context, Tensor inp, Tensor! out, int nthreads, int nblocks) -> ()");
m.impl("mscclpp_allreduce", torch::kCUDA, &mscclpp_allreduce);
/*
* From csrc/attention
*/
......@@ -284,6 +284,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"page_size) -> ()");
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
/*
* From csrc/memory
*/
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
......@@ -390,13 +396,13 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, &convert_vertical_slash_indexes_mergehead);
/*
* From XGrammar
* From csrc/grammar
*/
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
/*
* From QServe
* From csrc/gemm (QServe)
*/
m.def(
"qserve_w4a8_per_chn_gemm(Tensor _in_feats, Tensor _kernel, Tensor _wscales, Tensor _ascales, Tensor _w_szs, "
......@@ -413,12 +419,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/
m.def("create_greenctx_stream_by_value(int smA, int smB, int device) -> int[]");
m.impl("create_greenctx_stream_by_value", &create_greenctx_stream_by_value);
/*
* From csrc/memory
*/
m.def("store_kv_cache(Tensor k_cache, Tensor v_cache, Tensor out_loc, Tensor k, Tensor v) -> ()");
m.impl("store_kv_cache", &store_kv_cache);
}
REGISTER_EXTENSION(common_ops)
......@@ -47,7 +47,7 @@ sources = [
"csrc/moe/moe_align_kernel.cu",
"csrc/moe/moe_topk_softmax_kernels.cu",
"csrc/speculative/eagle_utils.cu",
"csrc/torch_extension_rocm.cc",
"csrc/common_extension_rocm.cc",
]
cxx_flags = ["-O3"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment