Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
...@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x.dtype, x.dtype,
True, # is_vnni True, # is_vnni
) )
x_q, x_scale = per_token_quant_int8(x)
x_q_2d = x_q.view(-1, x_q.shape[-1]) x_q, x_scale = per_token_quant_int8(x)
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
output = int8_scaled_mm( return int8_scaled_mm(
x_q_2d, x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
layer.weight,
x_scale_2d,
layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
) )
return output.view(output_shape)
class W8A8Int8MoEMethod(FusedMoEMethodBase): class W8A8Int8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8. """MoE method for INT8.
...@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl: ...@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8LinearMethodMTImpl: class NPU_W8A8LinearMethodMTImpl:
...@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl: ...@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten() layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8DynamicLinearMethod(LinearMethodBase): class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
......
...@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp ...@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
get_compiler_backend,
is_cpu, is_cpu,
is_cuda, is_cuda,
is_hip, is_hip,
...@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support() ...@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu() _is_cpu = is_cpu()
if _is_cuda: if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
else:
FusedSetKVBufferArg = None
if _use_aiter: if _use_aiter:
from aiter.rotary_embedding import get_rope as aiter_get_rope from aiter.rotary_embedding import get_rope as aiter_get_rope
if is_npu(): if is_npu():
import torch_npu import torch_npu
NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
...@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp): ...@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward().""" """A PyTorch-native implementation of forward()."""
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for native implementation"
if offsets is not None: if offsets is not None:
positions = positions + offsets positions = positions + offsets
positions = positions.flatten() positions = positions.flatten()
...@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp): ...@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-npu implementation of forward().""" """A PyTorch-npu implementation of forward()."""
assert ( import os
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for npu implementation"
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"): if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
return self.forward_native( return self.forward_native(positions, query, key, offsets)
positions, query, key, offsets, fused_set_kv_buffer_arg
)
else: else:
rotary_mode = "half" rotary_mode = "half"
if self.is_neox_style: if self.is_neox_style:
...@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp): ...@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for cpu implementation"
positions = torch.add(positions, offsets) if offsets is not None else positions positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available: if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu( return torch.ops.sgl_kernel.rotary_embedding_cpu(
...@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp): ...@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
self.is_neox_style, self.is_neox_style,
) )
else: else:
return self.forward_native( return self.forward_native(positions, query, key, offsets)
positions, query, key, offsets, fused_set_kv_buffer_arg
)
def forward_cuda( def forward_cuda(
self, self,
...@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp): ...@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]): if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace( apply_rope_with_cos_sin_cache_inplace(
...@@ -1059,7 +1035,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1059,7 +1035,7 @@ class MRotaryEmbedding(RotaryEmbedding):
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})" f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
) )
@torch.compile(dynamic=True, backend=get_compiler_backend()) @torch.compile(dynamic=True)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -1207,7 +1183,7 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -1207,7 +1183,7 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long() time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten() t_index = time_tensor_long.flatten()
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"): elif model_type == "qwen2_vl":
t_index = ( t_index = (
torch.arange(llm_grid_t) torch.arange(llm_grid_t)
.view(-1, 1) .view(-1, 1)
...@@ -1918,30 +1894,17 @@ def apply_rotary_pos_emb_npu( ...@@ -1918,30 +1894,17 @@ def apply_rotary_pos_emb_npu(
sin: torch.Tensor, sin: torch.Tensor,
unsqueeze_dim=1, unsqueeze_dim=1,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
"""Ascend implementation equivalent to apply_rotary_pos_emb_native. if q.shape[1] != 128:
Args:
q: [num_tokens, num_heads, head_size]
k: [num_tokens, num_kv_heads, head_size]
cos: [num_tokens, head_size]
sin: [num_tokens, head_size]
"""
if (
cos.dim() != 2
or q.dim() != 3
or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
):
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim) return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0) cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0) cos = torch.transpose(cos, 1, 2)
q = q.unsqueeze(0) sin = sin.unsqueeze(unsqueeze_dim)
k = k.unsqueeze(0) sin = torch.transpose(sin, 1, 2)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin) q = torch.transpose(q, 1, 2)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin) k = torch.transpose(k, 1, 2)
q_embed = q_embed.squeeze(0) q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
k_embed = k_embed.squeeze(0) q_embed = torch.transpose(q_embed, 1, 2)
k_embed = torch.transpose(k_embed, 1, 2)
return q_embed, k_embed return q_embed, k_embed
......
...@@ -15,29 +15,6 @@ def get_layer_id(weight_name): ...@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
return None return None
def pad_or_narrow_weight(
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
) -> torch.Tensor:
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
if valid_size > 0:
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
pad_shape = list(loaded_weight.shape)
pad_shape[input_dim] = shard_size - valid_size
pad = torch.zeros(
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
)
return torch.cat([loaded_slice, pad], dim=input_dim)
# All padding
pad_shape = list(loaded_weight.shape)
pad_shape[input_dim] = shard_size
return torch.zeros(
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
)
class PPMissingLayer(torch.nn.Identity): class PPMissingLayer(torch.nn.Identity):
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1 # https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
......
...@@ -5,7 +5,7 @@ import triton ...@@ -5,7 +5,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import cached_triton_kernel from sglang.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
......
...@@ -3,7 +3,7 @@ import triton ...@@ -3,7 +3,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import cached_triton_kernel from sglang.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"])) @cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
......
...@@ -275,17 +275,43 @@ class HiCacheController: ...@@ -275,17 +275,43 @@ class HiCacheController:
and self.storage_config.tp_rank != 0 and self.storage_config.tp_rank != 0
) )
# Use storage backend factory for dynamic backend creation if storage_backend == "file":
from sglang.srt.mem_cache.storage import StorageBackendFactory from sglang.srt.mem_cache.hicache_storage import HiCacheFile
try: self.storage_backend = HiCacheFile(self.storage_config)
self.storage_backend = StorageBackendFactory.create_backend( elif storage_backend == "nixl":
storage_backend, self.storage_config, self.mem_pool_host from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
self.storage_backend = HiCacheNixl()
elif storage_backend == "mooncake":
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
MooncakeStore,
)
self.storage_backend = MooncakeStore(self.storage_config)
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs":
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS,
) )
except ValueError as e:
raise ValueError(f"Failed to create storage backend: {e}") from e
self.storage_backend.register_mem_pool_host(self.mem_pool_host) if self.mem_pool_host.layout == "page_first":
bytes_per_page = (
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
)
elif self.mem_pool_host.layout == "layer_first":
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config(
bytes_per_page, dtype, self.storage_config
)
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
)
self.enable_storage = True self.enable_storage = True
# todo: threshold policy for prefetching # todo: threshold policy for prefetching
...@@ -309,10 +335,18 @@ class HiCacheController: ...@@ -309,10 +335,18 @@ class HiCacheController:
# Select the get and set functions # Select the get and set functions
self.page_get_func = self._generic_page_get self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set self.page_set_func = self._generic_page_set
self.batch_exists_func = self.storage_backend.batch_exists
if self.storage_backend_type in ["hf3fs", "mooncake"]: self.is_3fs_zerocopy = (
self.page_get_func = self._page_get_zero_copy self.storage_backend_type == "hf3fs"
self.page_set_func = self._page_set_zero_copy and self.mem_pool_host.layout == "page_first"
)
if self.storage_backend_type == "mooncake":
self.page_get_func = self._mooncake_page_get
self.page_set_func = self._mooncake_page_set
elif self.is_3fs_zerocopy:
self.page_get_func = self._3fs_zero_copy_page_get
self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.device = self.mem_pool_device.device self.device = self.mem_pool_device.device
self.layer_num = self.mem_pool_device.layer_num self.layer_num = self.mem_pool_device.layer_num
...@@ -436,6 +470,7 @@ class HiCacheController: ...@@ -436,6 +470,7 @@ class HiCacheController:
host_indices = self.mem_pool_host.alloc(len(device_indices)) host_indices = self.mem_pool_host.alloc(len(device_indices))
if host_indices is None: if host_indices is None:
return None return None
self.mem_pool_host.protect_write(host_indices)
self.write_queue.append( self.write_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
...@@ -459,6 +494,7 @@ class HiCacheController: ...@@ -459,6 +494,7 @@ class HiCacheController:
self.mem_pool_host.backup_from_device_all_layer( self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend self.mem_pool_device, host_indices, device_indices, self.io_backend
) )
self.mem_pool_host.complete_io(op.host_indices)
finish_event.record() finish_event.record()
# NOTE: We must save the host indices and device indices here, # NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are # this is because we need to guarantee that these tensors are
...@@ -482,6 +518,7 @@ class HiCacheController: ...@@ -482,6 +518,7 @@ class HiCacheController:
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices)) device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None: if device_indices is None:
return None return None
self.mem_pool_host.protect_load(host_indices)
self.load_queue.append( self.load_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority) CacheOperation(host_indices, device_indices, node_id, priority)
) )
...@@ -526,6 +563,7 @@ class HiCacheController: ...@@ -526,6 +563,7 @@ class HiCacheController:
self.io_backend, self.io_backend,
) )
producer_event.complete(i) producer_event.complete(i)
self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here, # NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are # this is because we need to guarantee that these tensors are
# still alive when the load stream is executing. # still alive when the load stream is executing.
...@@ -543,16 +581,29 @@ class HiCacheController: ...@@ -543,16 +581,29 @@ class HiCacheController:
) )
return producer_id return producer_id
def evict_device(self, device_indices: torch.Tensor) -> int: def evict_device(
self.mem_pool_device_allocator.free(device_indices) self, device_indices: torch.Tensor, host_indices: torch.Tensor
return len(device_indices) ) -> int:
if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices)
return len(device_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int: def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
if not backup_only: if not backup_only:
raise ValueError("Other eviction policies are not supported yet.") raise ValueError("Other eviction policies are not supported yet.")
self.mem_pool_host.free(host_indices) if self.mem_pool_host.is_backup(host_indices):
return len(host_indices) self.mem_pool_host.free(host_indices)
return len(host_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def prefetch( def prefetch(
self, self,
...@@ -579,19 +630,42 @@ class HiCacheController: ...@@ -579,19 +630,42 @@ class HiCacheController:
for chunk in chunks: for chunk in chunks:
self.host_mem_release_queue.put(chunk) self.host_mem_release_queue.put(chunk)
def _page_get_zero_copy(self, operation, hash_values, host_indices): def _3fs_zero_copy_batch_exists(self, batch_hashes):
results = self.storage_backend.batch_get_v1(hash_values, host_indices) _batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
inc = 0 hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
for i in range(len(hash_values)): return hit_page_num
if not results[i]:
logger.warning( def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}." hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
) hash_values, host_indices
break )
inc += self.page_size page_data = self.storage_backend.batch_get(hashes, dsts)
operation.increment(inc) if page_data:
inc = self.page_size * len(hashes) // factor
operation.increment(inc)
else:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
)
def _mooncake_page_get(self, operation, hash_values, host_indices):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
get_result = self.storage_backend.batch_get(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
if get_result != len(hash_values):
logger.warning(
f"Prefetch operation {operation.request_id} failed or partially failed."
)
if get_result != 0:
operation.increment(get_result * self.page_size)
# todo: deprecate
def _generic_page_get(self, operation, hash_values, host_indices): def _generic_page_get(self, operation, hash_values, host_indices):
dummy_page_dst = [ dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
...@@ -681,7 +755,7 @@ class HiCacheController: ...@@ -681,7 +755,7 @@ class HiCacheController:
batch_tokens[i : i + self.page_size], last_hash batch_tokens[i : i + self.page_size], last_hash
) )
batch_hashes.append(last_hash) batch_hashes.append(last_hash)
hit_page_num = self.storage_backend.batch_exists(batch_hashes) hit_page_num = self.batch_exists_func(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num]) hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes): if hit_page_num < len(batch_hashes):
...@@ -750,16 +824,34 @@ class HiCacheController: ...@@ -750,16 +824,34 @@ class HiCacheController:
self.backup_queue.put(operation) self.backup_queue.put(operation)
return operation.id return operation.id
# todo: deprecate # non-zero copy
def _generic_page_set(self, hash_values, host_indices) -> bool: def _generic_page_set(self, hash_values, host_indices) -> bool:
data = [ data = [
self.mem_pool_host.get_data_page(host_indices[i * self.page_size]) self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values)) for i in range(len(hash_values))
] ]
return self.storage_backend.batch_set(hash_values, data) return self.storage_backend.batch_set(hash_values, data)
def _page_set_zero_copy(self, hash_values, host_indices) -> bool: # zero copy
return all(self.storage_backend.batch_set_v1(hash_values, host_indices)) def _mooncake_page_set(self, hash_values, host_indices) -> bool:
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
success = self.storage_backend.batch_set(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
return success
# zero copy
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
return self.storage_backend.batch_set(hashes, dsts)
# Backup batch by batch # Backup batch by batch
def _page_backup(self, operation): def _page_backup(self, operation):
......
...@@ -35,7 +35,6 @@ else: ...@@ -35,7 +35,6 @@ else:
Image = Any Image = Any
# Parameters for a session
@dataclass @dataclass
class SessionParams: class SessionParams:
id: Optional[str] = None id: Optional[str] = None
...@@ -133,23 +132,18 @@ class GenerateReqInput: ...@@ -133,23 +132,18 @@ class GenerateReqInput:
# Conversation id used for tracking requests # Conversation id used for tracking requests
conversation_id: Optional[str] = None conversation_id: Optional[str] = None
# Label for the request
label: Optional[str] = None
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
# Extra key for classifying the request (e.g. cache_salt) # Image gen grpc migration
extra_key: Optional[Union[List[str], str]] = None
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs: bool = False
# For custom metric labels
custom_labels: Optional[Dict[str, str]] = None
# (Deprecated, please use custom_labels) Label for the request
label: Optional[str] = None
# (Internal) Whether to return bytes for image generation
return_bytes: bool = False return_bytes: bool = False
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
def contains_mm_input(self) -> bool: def contains_mm_input(self) -> bool:
return ( return (
has_valid_data(self.image_data) has_valid_data(self.image_data)
...@@ -548,11 +542,8 @@ class GenerateReqInput: ...@@ -548,11 +542,8 @@ class GenerateReqInput:
self.data_parallel_rank if self.data_parallel_rank is not None else None self.data_parallel_rank if self.data_parallel_rank is not None else None
), ),
conversation_id=self.conversation_id, conversation_id=self.conversation_id,
priority=self.priority,
extra_key=self.extra_key,
no_logs=self.no_logs,
custom_labels=self.custom_labels,
label=self.label, label=self.label,
priority=self.priority,
return_bytes=self.return_bytes, return_bytes=self.return_bytes,
) )
...@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput: ...@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
# For dp balance # For dp balance
dp_balance_id: int = -1 dp_balance_id: int = -1
# Label for the request
label: Optional[str] = None
# Priority for the request # Priority for the request
priority: Optional[int] = None priority: Optional[int] = None
# Extra key for classifying the request (e.g. cache_salt) # Image gen grpc migration
extra_key: Optional[str] = None return_bytes: bool = False
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs: bool = False
# tracing context # tracing context
trace_context: Optional[Dict] = None trace_context: Optional[Dict] = None
# (Deprecated, please use custom_labels) Label for the request
label: Optional[str] = None
# (Internal) Whether to return bytes for image generation
return_bytes: bool = False
@dataclass @dataclass
class BatchTokenizedGenerateReqInput: class BatchTokenizedGenerateReqInput:
......
...@@ -507,7 +507,6 @@ def embed_mm_inputs( ...@@ -507,7 +507,6 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
placeholder_tokens: dict[Modality, List[int]] = None, placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
""" """
Embed multimodal inputs and integrate them with text token embeddings. Embed multimodal inputs and integrate them with text token embeddings.
...@@ -523,7 +522,7 @@ def embed_mm_inputs( ...@@ -523,7 +522,7 @@ def embed_mm_inputs(
Returns: Returns:
Combined embedding tensor with multimodal content integrated Combined embedding tensor with multimodal content integrated
""" """
other_info = {}
if mm_inputs_list is None: if mm_inputs_list is None:
return None return None
...@@ -533,7 +532,7 @@ def embed_mm_inputs( ...@@ -533,7 +532,7 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list: for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None] item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks, deepstack_embeddings = [], [], [] embeddings, masks = [], []
# 2. Get multimodal embedding separately # 2. Get multimodal embedding separately
# Try get mm embedding if any # Try get mm embedding if any
for modality in Modality.all(): for modality in Modality.all():
...@@ -579,12 +578,6 @@ def embed_mm_inputs( ...@@ -579,12 +578,6 @@ def embed_mm_inputs(
extend_length=extend_seq_lens, extend_length=extend_seq_lens,
items_offset_list=items_offsets, items_offset_list=items_offsets,
) )
if use_deepstack and embedding is not None:
embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding)
)
deepstack_embeddings += [deepstack_embedding]
embeddings += [embedding] embeddings += [embedding]
masks += [mask] masks += [mask]
...@@ -598,37 +591,13 @@ def embed_mm_inputs( ...@@ -598,37 +591,13 @@ def embed_mm_inputs(
inputs_embeds = input_embedding(input_ids) inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding # 4. scatter embeddings into input embedding
for embedding, mask in zip(embeddings, masks):
# deepstack embedding
if use_deepstack:
num_deepstack_embeddings = (
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings,
)
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)
other_info["input_deepstack_embeds"] = input_deepstack_embeds
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
if embedding is None or mask is None: if embedding is None or mask is None:
continue continue
# in-place update # in-place update
indices = torch.where(mask.squeeze(dim=-1))[0] indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype) inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
return inputs_embeds
if use_deepstack:
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype
)
return inputs_embeds, other_info
def general_mm_embed_routine( def general_mm_embed_routine(
...@@ -640,7 +609,6 @@ def general_mm_embed_routine( ...@@ -640,7 +609,6 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor] Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None, ] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None, placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
...@@ -652,7 +620,6 @@ def general_mm_embed_routine( ...@@ -652,7 +620,6 @@ def general_mm_embed_routine(
language_model: Base language model to use language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function. data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
**kwargs: Additional arguments passed to language model **kwargs: Additional arguments passed to language model
Returns: Returns:
...@@ -678,20 +645,16 @@ def general_mm_embed_routine( ...@@ -678,20 +645,16 @@ def general_mm_embed_routine(
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu) for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None if forward_batch.mm_inputs[i] is not None
] ]
inputs_embeds, other_info = embed_mm_inputs( inputs_embeds = embed_mm_inputs(
mm_inputs_list=mm_inputs_list, mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
input_ids=input_ids, input_ids=input_ids,
multimodal_model=multimodal_model,
input_embedding=embed_tokens, input_embedding=embed_tokens,
multimodal_model=multimodal_model,
data_embedding_func_mapping=data_embedding_funcs, data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens, placeholder_tokens=placeholder_tokens,
use_deepstack=use_deepstack,
) )
# add for qwen3_vl deepstack
if use_deepstack:
kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here # just being defensive here
forward_batch.mm_inputs = None forward_batch.mm_inputs = None
......
...@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__) ...@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {} PROCESSOR_MAPPING = {}
def import_processors(package_name: str): def import_processors():
package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."): for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg: if not ispkg:
......
import torch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.utils import get_compiler_backend
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class FutureMap:
def __init__(
self,
max_running_requests: int,
device: torch.device,
):
self.future_ct = 0
# A factor of 3 is used to avoid collision in the circular buffer.
self.future_limit = max_running_requests * 3
# A factor of 5 is used to ensure the buffer is large enough.
self.future_buffer_len = max_running_requests * 5
self.device = device
self.token_ids_buf = torch.empty(
(self.future_buffer_len,), dtype=torch.int64, device=self.device
)
def update_ct(self, bs: int) -> int:
"""Update the circular buffer pointer and return the current pointer."""
cur_future_ct = self.future_ct
self.future_ct = (cur_future_ct + bs) % self.future_limit
return cur_future_ct
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
input_ids = model_worker_batch.input_ids
_resolve_future_token_ids(input_ids, self.token_ids_buf)
def update_next_future(self, future_ct: int, bs: int):
return torch.arange(
-(future_ct + 1),
-(future_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
)
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
...@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache ...@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather", "disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache", "disable_radix_cache",
"enable_dp_lm_head", "enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision", "flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion", "enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size", "moe_dense_tp_size",
...@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_custom_logit_processor", "enable_custom_logit_processor",
"disaggregation_mode", "disaggregation_mode",
"enable_deterministic_inference", "enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
] ]
# Put some global args for easy access # Put some global args for easy access
...@@ -492,7 +493,7 @@ class Req: ...@@ -492,7 +493,7 @@ class Req:
self.custom_logit_processor = custom_logit_processor self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states self.return_hidden_states = return_hidden_states
# extra key for classifying the request (e.g. cache_salt) # extra key for classifying the request (e.g. lora_id, cache_salt)
if lora_id is not None: if lora_id is not None:
extra_key = ( extra_key = (
extra_key or "" extra_key or ""
...@@ -608,8 +609,6 @@ class Req: ...@@ -608,8 +609,6 @@ class Req:
) = None ) = None
self.hidden_states: List[List[float]] = [] self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
self.output_topk_p = None
self.output_topk_index = None
# Embedding (return values) # Embedding (return values)
self.embedding = None self.embedding = None
...@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = ( spec_info: Optional[
None Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
) ] = None
# Whether to return hidden states # Whether to return hidden states
return_hidden_states: bool = False return_hidden_states: bool = False
...@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if ( if (
self.spec_algorithm.is_eagle() self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone() or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_ngram() or self.spec_algorithm.is_lookahead()
): ):
# if spec decoding is used, the decode batch is prepared inside # if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models. # `forward_batch_speculative_generation` after running draft models.
...@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.filter_batch(keep_indices, keep_indices_device) self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info: if self.spec_info:
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0: self.spec_info.filter_batch(keep_indices_device)
has_been_filtered = False
else:
has_been_filtered = True
self.spec_info.filter_batch(
new_indices=keep_indices_device,
has_been_filtered=has_been_filtered,
)
def merge_batch(self, other: "ScheduleBatch"): def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because # Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
...@@ -1992,9 +1984,9 @@ class ModelWorkerBatch: ...@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
# Speculative decoding # Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = ( spec_info: Optional[
None Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
) ] = None
# If set, the output of the batch contains the hidden states of the run. # If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1 hicache_consumer_index: int = -1
......
...@@ -318,6 +318,7 @@ class PrefillAdder: ...@@ -318,6 +318,7 @@ class PrefillAdder:
new_token_ratio: float, new_token_ratio: float,
rem_input_tokens: int, rem_input_tokens: int,
rem_chunk_tokens: Optional[int], rem_chunk_tokens: Optional[int],
max_prefill_bs: Optional[int],
mixed_with_decode_tokens: int = 0, mixed_with_decode_tokens: int = 0,
priority_scheduling_preemption_threshold: int = 0, priority_scheduling_preemption_threshold: int = 0,
): ):
...@@ -358,6 +359,10 @@ class PrefillAdder: ...@@ -358,6 +359,10 @@ class PrefillAdder:
priority_scheduling_preemption_threshold priority_scheduling_preemption_threshold
) )
self.max_prefill_bs = (
max_prefill_bs if max_prefill_bs is not None else 2147483647
)
def _get_running_request_total_token_offset(self, req: Req) -> int: def _get_running_request_total_token_offset(self, req: Req) -> int:
return ( return (
min( min(
...@@ -549,6 +554,9 @@ class PrefillAdder: ...@@ -549,6 +554,9 @@ class PrefillAdder:
def add_one_req( def add_one_req(
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int] self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
): ):
if len(self.can_run_list) >= self.max_prefill_bs:
return AddReqResult.OTHER
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True): if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req) return self.add_one_req_ignore_eos(req, has_chunked_req)
......
...@@ -44,9 +44,6 @@ from sglang.srt.disaggregation.decode import ( ...@@ -44,9 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue, DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin, SchedulerDisaggregationDecodeMixin,
) )
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
DecodeKVCacheOffloadManager,
)
from sglang.srt.disaggregation.prefill import ( from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue, PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin, SchedulerDisaggregationPrefillMixin,
...@@ -262,7 +259,7 @@ class Scheduler( ...@@ -262,7 +259,7 @@ class Scheduler(
self.enable_metrics_for_all_schedulers = ( self.enable_metrics_for_all_schedulers = (
server_args.enable_metrics_for_all_schedulers server_args.enable_metrics_for_all_schedulers
) )
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0 self.enable_kv_cache_events = server_args.kv_events_config is not None
self.stream_interval = server_args.stream_interval self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string( self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
...@@ -388,10 +385,10 @@ class Scheduler( ...@@ -388,10 +385,10 @@ class Scheduler(
target_worker=self.tp_worker, target_worker=self.tp_worker,
dp_rank=dp_rank, dp_rank=dp_rank,
) )
elif self.spec_algorithm.is_ngram(): elif self.spec_algorithm.is_lookahead():
from sglang.srt.speculative.ngram_worker import NGRAMWorker from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker
self.draft_worker = NGRAMWorker( self.draft_worker = LOOKAHEADWorker(
gpu_id=gpu_id, gpu_id=gpu_id,
tp_rank=tp_rank, tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank, moe_ep_rank=moe_ep_rank,
...@@ -556,11 +553,9 @@ class Scheduler( ...@@ -556,11 +553,9 @@ class Scheduler(
# Init metrics stats # Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_kv_events(server_args.kv_events_config)
self.init_dp_balance(dp_balance_meta) self.init_dp_balance(dp_balance_meta)
if self.enable_kv_cache_events:
self.init_kv_events(server_args.kv_events_config)
# Init disaggregation # Init disaggregation
self.disaggregation_mode = DisaggregationMode( self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode self.server_args.disaggregation_mode
...@@ -618,6 +613,8 @@ class Scheduler( ...@@ -618,6 +613,8 @@ class Scheduler(
] ]
) )
self.max_prefill_bs = server_args.max_prefill_bs
def init_deterministic_inference_config(self): def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends.""" """Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference: if not self.server_args.enable_deterministic_inference:
...@@ -758,24 +755,6 @@ class Scheduler( ...@@ -758,24 +755,6 @@ class Scheduler(
eviction_policy=server_args.radix_eviction_policy, eviction_policy=server_args.radix_eviction_policy,
) )
if (
server_args.disaggregation_mode == "decode"
and server_args.disaggregation_decode_enable_offload_kvcache
):
self.decode_offload_manager = DecodeKVCacheOffloadManager(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_group=(
self.attn_tp_cpu_group
if self.server_args.enable_dp_attention
else self.tp_cpu_group
),
tree_cache=self.tree_cache,
server_args=self.server_args,
)
else:
self.decode_offload_manager = None
self.decode_mem_cache_buf_multiplier = ( self.decode_mem_cache_buf_multiplier = (
1 1
if self.spec_algorithm.is_none() if self.spec_algorithm.is_none()
...@@ -806,7 +785,7 @@ class Scheduler( ...@@ -806,7 +785,7 @@ class Scheduler(
self.disagg_metadata_buffers = MetadataBuffers( self.disagg_metadata_buffers = MetadataBuffers(
buffer_size, buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size, hidden_size=self.model_config.hf_text_config.hidden_size,
hidden_states_dtype=self.model_config.dtype, dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
) )
...@@ -826,7 +805,7 @@ class Scheduler( ...@@ -826,7 +805,7 @@ class Scheduler(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=( draft_token_to_kv_pool=(
None None
if self.draft_worker is None or self.spec_algorithm.is_ngram() if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool else self.draft_worker.model_runner.token_to_kv_pool
), ),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
...@@ -855,7 +834,7 @@ class Scheduler( ...@@ -855,7 +834,7 @@ class Scheduler(
self.disagg_metadata_buffers = MetadataBuffers( self.disagg_metadata_buffers = MetadataBuffers(
buffer_size, buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size, hidden_size=self.model_config.hf_text_config.hidden_size,
hidden_states_dtype=self.model_config.dtype, dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(), custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
) )
...@@ -863,7 +842,7 @@ class Scheduler( ...@@ -863,7 +842,7 @@ class Scheduler(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=( draft_token_to_kv_pool=(
None None
if self.draft_worker is None or self.spec_algorithm.is_ngram() if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool else self.draft_worker.model_runner.token_to_kv_pool
), ),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
...@@ -1832,6 +1811,7 @@ class Scheduler( ...@@ -1832,6 +1811,7 @@ class Scheduler(
self.new_token_ratio, self.new_token_ratio,
self.max_prefill_tokens, self.max_prefill_tokens,
self.chunked_prefill_size, self.chunked_prefill_size,
self.max_prefill_bs,
running_bs if self.is_mixed_chunk else 0, running_bs if self.is_mixed_chunk else 0,
self.priority_scheduling_preemption_threshold, self.priority_scheduling_preemption_threshold,
) )
......
...@@ -250,13 +250,7 @@ class SchedulerOutputProcessorMixin: ...@@ -250,13 +250,7 @@ class SchedulerOutputProcessorMixin:
req.check_finished() req.check_finished()
if req.finished(): if req.finished():
if self.server_args.disaggregation_decode_enable_offload_kvcache: self.tree_cache.cache_finished_req(req)
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
if not self.decode_offload_manager.offload_kv_cache(req):
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time() req.time_stats.completion_time = time.time()
if req.return_logprob and batch.spec_algorithm.is_none(): if req.return_logprob and batch.spec_algorithm.is_none():
......
...@@ -97,7 +97,7 @@ class SchedulerProfilerMixin: ...@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
def start_profile( def start_profile(
self, stage: Optional[ForwardMode] = None self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None: ) -> ProfileReqOutput | None:
stage_str = f" for {stage.name}" if stage else "" stage_str = f" for {stage.__str__()}" if stage else ""
logger.info( logger.info(
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})", f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
) )
...@@ -181,7 +181,7 @@ class SchedulerProfilerMixin: ...@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
if not Path(self.torch_profiler_output_dir).exists(): if not Path(self.torch_profiler_output_dir).exists():
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True) Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
stage_suffix = f"-{stage.name}" if stage else "" stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...") logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None: if self.torch_profiler is not None:
self.torch_profiler.stop() self.torch_profiler.stop()
...@@ -247,7 +247,7 @@ class SchedulerProfilerMixin: ...@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
if self.profiler_decode_ct == 0: if self.profiler_decode_ct == 0:
if self.profile_in_progress: if self.profile_in_progress:
# force trace flush # force trace flush
self.stop_profile(stage=ForwardMode.EXTEND) self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode) self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1 self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct: if self.profiler_decode_ct > self.profiler_target_decode_ct:
...@@ -294,6 +294,6 @@ class SchedulerProfilerMixin: ...@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
recv_req.profile_by_stage, recv_req.profile_by_stage,
recv_req.profile_id, recv_req.profile_id,
) )
return self.start_profile() return self.start_profile(True)
else: else:
return self.stop_profile() return self.stop_profile()
...@@ -185,7 +185,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -185,7 +185,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
) )
if self.model_config.is_multimodal: if self.model_config.is_multimodal:
import_processors("sglang.srt.multimodal.processors") import_processors()
try: try:
_processor = get_processor( _processor = get_processor(
server_args.tokenizer_path, server_args.tokenizer_path,
...@@ -320,8 +320,8 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -320,8 +320,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"model_name": self.server_args.served_model_name, "model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future, # TODO: Add lora name/path in the future,
} }
if server_args.tokenizer_metrics_allowed_custom_labels: if server_args.tokenizer_metrics_allowed_customer_labels:
for label in server_args.tokenizer_metrics_allowed_custom_labels: for label in server_args.tokenizer_metrics_allowed_customer_labels:
labels[label] = "" labels[label] = ""
self.metrics_collector = TokenizerMetricsCollector( self.metrics_collector = TokenizerMetricsCollector(
server_args=server_args, server_args=server_args,
...@@ -750,7 +750,6 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -750,7 +750,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return_hidden_states=obj.return_hidden_states, return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank, data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority, priority=obj.priority,
extra_key=obj.extra_key,
) )
elif isinstance(obj, EmbeddingReqInput): elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
...@@ -1633,10 +1632,10 @@ class TokenizerManager(TokenizerCommunicatorMixin): ...@@ -1633,10 +1632,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else 0 else 0
) )
custom_labels = getattr(state.obj, "custom_labels", None) customer_labels = getattr(state.obj, "customer_labels", None)
labels = ( labels = (
{**self.metrics_collector.labels, **custom_labels} {**self.metrics_collector.labels, **customer_labels}
if custom_labels if customer_labels
else self.metrics_collector.labels else self.metrics_collector.labels
) )
if ( if (
......
...@@ -91,6 +91,7 @@ class TpModelWorker: ...@@ -91,6 +91,7 @@ class TpModelWorker:
else server_args.speculative_draft_model_revision else server_args.speculative_draft_model_revision
), ),
is_draft_model=is_draft_worker, is_draft_model=is_draft_worker,
tp_rank=tp_rank,
) )
self.model_runner = ModelRunner( self.model_runner = ModelRunner(
......
...@@ -36,11 +36,10 @@ from sglang.srt.managers.io_struct import ( ...@@ -36,11 +36,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
) )
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode from sglang.srt.utils import DynamicGradMode, get_compiler_backend
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -49,6 +48,15 @@ if TYPE_CHECKING: ...@@ -49,6 +48,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class TpModelWorkerClient: class TpModelWorkerClient:
"""A tensor parallel model worker.""" """A tensor parallel model worker."""
...@@ -71,7 +79,11 @@ class TpModelWorkerClient: ...@@ -71,7 +79,11 @@ class TpModelWorkerClient:
self.gpu_id = gpu_id self.gpu_id = gpu_id
# Init future mappings # Init future mappings
self.future_map = FutureMap(self.max_running_requests, self.device) self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
)
# Launch threads # Launch threads
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]() self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
...@@ -141,7 +153,7 @@ class TpModelWorkerClient: ...@@ -141,7 +153,7 @@ class TpModelWorkerClient:
batch_lists: List = [None] * 2 batch_lists: List = [None] * 2
while True: while True:
model_worker_batch, future_map_ct, sync_event = self.input_queue.get() model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
if not model_worker_batch: if not model_worker_batch:
break break
...@@ -157,7 +169,8 @@ class TpModelWorkerClient: ...@@ -157,7 +169,8 @@ class TpModelWorkerClient:
copy_done = torch.get_device_module(self.device).Event() copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input # Resolve future tokens in the input
self.future_map.resolve_future(model_worker_batch) input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward # Run forward
logits_output, next_token_ids, can_run_cuda_graph = ( logits_output, next_token_ids, can_run_cuda_graph = (
...@@ -174,9 +187,9 @@ class TpModelWorkerClient: ...@@ -174,9 +187,9 @@ class TpModelWorkerClient:
if model_worker_batch.is_prefill_only: if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU # For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long) next_token_ids = torch.zeros(bs, dtype=torch.long)
self.future_token_ids_map[
# store the future indices into future map future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
self.future_map.store_to_map(future_map_ct, bs, next_token_ids) ] = next_token_ids
# Copy results to the CPU # Copy results to the CPU
if model_worker_batch.return_logprob: if model_worker_batch.return_logprob:
...@@ -242,14 +255,20 @@ class TpModelWorkerClient: ...@@ -242,14 +255,20 @@ class TpModelWorkerClient:
sync_event.record(self.scheduler_stream) sync_event.record(self.scheduler_stream)
# Push a new batch to the queue # Push a new batch to the queue
bs = len(model_worker_batch.seq_lens) self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
cur_future_map_ct = self.future_map.update_ct(bs)
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
# get this forward batch's future token ids # Allocate output future objects
future_next_token_ids = self.future_map.update_next_future( bs = len(model_worker_batch.seq_lens)
cur_future_map_ct, bs future_next_token_ids = torch.arange(
-(self.future_token_ids_ct + 1),
-(self.future_token_ids_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
) )
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
return None, future_next_token_ids, False return None, future_next_token_ids, False
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
......
...@@ -79,37 +79,48 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator): ...@@ -79,37 +79,48 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
) )
num_new_pages = ( num_new_pages = (
( (seq_lens + self.page_size - 1) // self.page_size
(seq_lens + self.page_size - 1) // self.page_size - (prefix_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size ).sum()
) num_new_pages_item = num_new_pages.item()
.sum() if self.need_sort and num_new_pages_item > len(self.free_pages):
.item()
)
if self.need_sort and num_new_pages > len(self.free_pages):
self.merge_and_sort_free() self.merge_and_sort_free()
if num_new_pages > len(self.free_pages): if num_new_pages_item > len(self.free_pages):
return None return None
out_indices = torch.empty( out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device (extend_num_tokens,), dtype=torch.int64, device=self.device
) )
alloc_extend_kernel_ascend( if num_new_pages_item < 200:
prefix_lens, import sgl_kernel_npu
seq_lens,
last_loc, torch.ops.npu.alloc_extend(
self.free_pages, prefix_lens,
out_indices, seq_lens,
self.page_size, last_loc,
self.device, self.free_pages,
) self.page_size,
out_indices,
num_new_pages,
)
else:
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode: if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices) assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:] self.free_pages = self.free_pages[num_new_pages_item:]
return out_indices return out_indices
def alloc_decode( def alloc_decode(
......
...@@ -7,8 +7,6 @@ from typing import Any, List, Optional ...@@ -7,8 +7,6 @@ from typing import Any, List, Optional
import torch import torch
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -34,46 +32,15 @@ class HiCacheStorageConfig: ...@@ -34,46 +32,15 @@ class HiCacheStorageConfig:
extra_config: Optional[dict] = None extra_config: Optional[dict] = None
@dataclass
class HiCacheStorageExtraInfo:
extra_info: Optional[dict] = None
class HiCacheStorage(ABC): class HiCacheStorage(ABC):
""" """
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache. HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used. It abstracts the underlying storage mechanism, allowing different implementations to be used.
""" """
# todo, potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool # todo, the page size of storage backend does not have to be the same as the same as host memory pool
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
self.mem_pool_host = mem_pool_host
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@abstractmethod @abstractmethod
def get( def get(
self, self,
...@@ -87,7 +54,6 @@ class HiCacheStorage(ABC): ...@@ -87,7 +54,6 @@ class HiCacheStorage(ABC):
""" """
pass pass
# TODO: Deprecate
@abstractmethod @abstractmethod
def batch_get( def batch_get(
self, self,
...@@ -115,7 +81,6 @@ class HiCacheStorage(ABC): ...@@ -115,7 +81,6 @@ class HiCacheStorage(ABC):
""" """
pass pass
# TODO: Deprecate
@abstractmethod @abstractmethod
def batch_set( def batch_set(
self, self,
...@@ -138,7 +103,6 @@ class HiCacheStorage(ABC): ...@@ -138,7 +103,6 @@ class HiCacheStorage(ABC):
""" """
pass pass
# TODO: Use a finer-grained return type (e.g., List[bool])
def batch_exists(self, keys: List[str]) -> int: def batch_exists(self, keys: List[str]) -> int:
""" """
Check if the keys exist in the storage. Check if the keys exist in the storage.
...@@ -150,9 +114,6 @@ class HiCacheStorage(ABC): ...@@ -150,9 +114,6 @@ class HiCacheStorage(ABC):
return i return i
return len(keys) return len(keys)
def clear(self) -> None:
pass
def get_stats(self): def get_stats(self):
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment