Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
......@@ -393,23 +393,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x.dtype,
True, # is_vnni
)
x_q, x_scale = per_token_quant_int8(x)
x_q_2d = x_q.view(-1, x_q.shape[-1])
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
x_q, x_scale = per_token_quant_int8(x)
output = int8_scaled_mm(
x_q_2d,
layer.weight,
x_scale_2d,
layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
return int8_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
return output.view(output_shape)
class W8A8Int8MoEMethod(FusedMoEMethodBase):
"""MoE method for INT8.
......@@ -648,7 +638,6 @@ class NPU_W8A8LinearMethodImpl:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8LinearMethodMTImpl:
......@@ -841,7 +830,6 @@ class NPU_W8A8DynamicLinearMethodImpl:
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
......
......@@ -12,7 +12,6 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import (
cpu_has_amx_support,
get_bool_env_var,
get_compiler_backend,
is_cpu,
is_cuda,
is_hip,
......@@ -27,19 +26,13 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace
else:
FusedSetKVBufferArg = None
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
if _use_aiter:
from aiter.rotary_embedding import get_rope as aiter_get_rope
if is_npu():
import torch_npu
NPU_ROTARY_MUL_MAX_NUM_HEADS = 1000
NPU_ROTARY_MUL_MAX_HEAD_SIZE = 896
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., : x.shape[-1] // 2]
......@@ -149,13 +142,8 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for native implementation"
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
......@@ -184,17 +172,12 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-npu implementation of forward()."""
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for npu implementation"
import os
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)
return self.forward_native(positions, query, key, offsets)
else:
rotary_mode = "half"
if self.is_neox_style:
......@@ -219,12 +202,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
assert (
fused_set_kv_buffer_arg is None
), "fused_set_kv_buffer_arg is not supported for cpu implementation"
positions = torch.add(positions, offsets) if offsets is not None else positions
if _is_cpu_amx_available:
return torch.ops.sgl_kernel.rotary_embedding_cpu(
......@@ -236,9 +214,7 @@ class RotaryEmbedding(CustomOp):
self.is_neox_style,
)
else:
return self.forward_native(
positions, query, key, offsets, fused_set_kv_buffer_arg
)
return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
......@@ -246,7 +222,7 @@ class RotaryEmbedding(CustomOp):
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
fused_set_kv_buffer_arg=None, # Optional[FusedSetKVBufferArg]
) -> Tuple[torch.Tensor, torch.Tensor]:
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
apply_rope_with_cos_sin_cache_inplace(
......@@ -1059,7 +1035,7 @@ class MRotaryEmbedding(RotaryEmbedding):
f"Corrected mrope_section: {self.mrope_section} (sum={sum(self.mrope_section)})"
)
@torch.compile(dynamic=True, backend=get_compiler_backend())
@torch.compile(dynamic=True)
def forward(
self,
positions: torch.Tensor,
......@@ -1207,7 +1183,7 @@ class MRotaryEmbedding(RotaryEmbedding):
time_tensor_long = time_tensor.long()
t_index = time_tensor_long.flatten()
elif model_type in ("qwen2_vl", "qwen3_vl", "qwen3_vl_moe"):
elif model_type == "qwen2_vl":
t_index = (
torch.arange(llm_grid_t)
.view(-1, 1)
......@@ -1918,30 +1894,17 @@ def apply_rotary_pos_emb_npu(
sin: torch.Tensor,
unsqueeze_dim=1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Ascend implementation equivalent to apply_rotary_pos_emb_native.
Args:
q: [num_tokens, num_heads, head_size]
k: [num_tokens, num_kv_heads, head_size]
cos: [num_tokens, head_size]
sin: [num_tokens, head_size]
"""
if (
cos.dim() != 2
or q.dim() != 3
or q.shape[1] >= NPU_ROTARY_MUL_MAX_NUM_HEADS
or q.shape[2] >= NPU_ROTARY_MUL_MAX_HEAD_SIZE
):
# Note: num_heads and head_size of q must be less than 1000 and 896, respectively
if q.shape[1] != 128:
return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
cos = cos.unsqueeze(unsqueeze_dim).unsqueeze(0)
sin = sin.unsqueeze(unsqueeze_dim).unsqueeze(0)
q = q.unsqueeze(0)
k = k.unsqueeze(0)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
q_embed = q_embed.squeeze(0)
k_embed = k_embed.squeeze(0)
cos = cos.unsqueeze(unsqueeze_dim)
cos = torch.transpose(cos, 1, 2)
sin = sin.unsqueeze(unsqueeze_dim)
sin = torch.transpose(sin, 1, 2)
q = torch.transpose(q, 1, 2)
k = torch.transpose(k, 1, 2)
q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
q_embed = torch.transpose(q_embed, 1, 2)
k_embed = torch.transpose(k_embed, 1, 2)
return q_embed, k_embed
......
......@@ -15,29 +15,6 @@ def get_layer_id(weight_name):
return None
def pad_or_narrow_weight(
loaded_weight: torch.Tensor, input_dim: int, start_idx: int, shard_size: int
) -> torch.Tensor:
# Padding with zeros for special case such as qwen2_5_VL's mlp which is not 8-aligned
valid_size = max(loaded_weight.shape[input_dim] - start_idx, 0)
if valid_size > 0:
loaded_slice = loaded_weight.narrow(input_dim, start_idx, valid_size)
pad_shape = list(loaded_weight.shape)
pad_shape[input_dim] = shard_size - valid_size
pad = torch.zeros(
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
)
return torch.cat([loaded_slice, pad], dim=input_dim)
# All padding
pad_shape = list(loaded_weight.shape)
pad_shape[input_dim] = shard_size
return torch.zeros(
pad_shape, dtype=loaded_weight.dtype, device=loaded_weight.device
)
class PPMissingLayer(torch.nn.Identity):
# Adapted from
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
......
......@@ -5,7 +5,7 @@ import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import cached_triton_kernel
from sglang.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
......
......@@ -3,7 +3,7 @@ import triton
import triton.language as tl
from sglang.srt.lora.utils import LoRABatchInfo
from sglang.srt.utils import cached_triton_kernel
from sglang.utils import cached_triton_kernel
@cached_triton_kernel(lambda _, kwargs: (kwargs["NUM_SLICES"], kwargs["BLOCK_M"]))
......
......@@ -275,17 +275,43 @@ class HiCacheController:
and self.storage_config.tp_rank != 0
)
# Use storage backend factory for dynamic backend creation
from sglang.srt.mem_cache.storage import StorageBackendFactory
if storage_backend == "file":
from sglang.srt.mem_cache.hicache_storage import HiCacheFile
try:
self.storage_backend = StorageBackendFactory.create_backend(
storage_backend, self.storage_config, self.mem_pool_host
self.storage_backend = HiCacheFile(self.storage_config)
elif storage_backend == "nixl":
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
self.storage_backend = HiCacheNixl()
elif storage_backend == "mooncake":
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
MooncakeStore,
)
self.storage_backend = MooncakeStore(self.storage_config)
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
assert self.mem_pool_host.layout == "page_first"
elif storage_backend == "hf3fs":
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
HiCacheHF3FS,
)
except ValueError as e:
raise ValueError(f"Failed to create storage backend: {e}") from e
self.storage_backend.register_mem_pool_host(self.mem_pool_host)
if self.mem_pool_host.layout == "page_first":
bytes_per_page = (
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
)
elif self.mem_pool_host.layout == "layer_first":
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
dtype = mem_pool_host.dtype
self.storage_backend = HiCacheHF3FS.from_env_config(
bytes_per_page, dtype, self.storage_config
)
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
)
self.enable_storage = True
# todo: threshold policy for prefetching
......@@ -309,10 +335,18 @@ class HiCacheController:
# Select the get and set functions
self.page_get_func = self._generic_page_get
self.page_set_func = self._generic_page_set
if self.storage_backend_type in ["hf3fs", "mooncake"]:
self.page_get_func = self._page_get_zero_copy
self.page_set_func = self._page_set_zero_copy
self.batch_exists_func = self.storage_backend.batch_exists
self.is_3fs_zerocopy = (
self.storage_backend_type == "hf3fs"
and self.mem_pool_host.layout == "page_first"
)
if self.storage_backend_type == "mooncake":
self.page_get_func = self._mooncake_page_get
self.page_set_func = self._mooncake_page_set
elif self.is_3fs_zerocopy:
self.page_get_func = self._3fs_zero_copy_page_get
self.page_set_func = self._3fs_zero_copy_page_set
self.batch_exists_func = self._3fs_zero_copy_batch_exists
self.device = self.mem_pool_device.device
self.layer_num = self.mem_pool_device.layer_num
......@@ -436,6 +470,7 @@ class HiCacheController:
host_indices = self.mem_pool_host.alloc(len(device_indices))
if host_indices is None:
return None
self.mem_pool_host.protect_write(host_indices)
self.write_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
......@@ -459,6 +494,7 @@ class HiCacheController:
self.mem_pool_host.backup_from_device_all_layer(
self.mem_pool_device, host_indices, device_indices, self.io_backend
)
self.mem_pool_host.complete_io(op.host_indices)
finish_event.record()
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
......@@ -482,6 +518,7 @@ class HiCacheController:
device_indices = self.mem_pool_device_allocator.alloc(len(host_indices))
if device_indices is None:
return None
self.mem_pool_host.protect_load(host_indices)
self.load_queue.append(
CacheOperation(host_indices, device_indices, node_id, priority)
)
......@@ -526,6 +563,7 @@ class HiCacheController:
self.io_backend,
)
producer_event.complete(i)
self.mem_pool_host.complete_io(op.host_indices)
# NOTE: We must save the host indices and device indices here,
# this is because we need to guarantee that these tensors are
# still alive when the load stream is executing.
......@@ -543,16 +581,29 @@ class HiCacheController:
)
return producer_id
def evict_device(self, device_indices: torch.Tensor) -> int:
self.mem_pool_device_allocator.free(device_indices)
return len(device_indices)
def evict_device(
self, device_indices: torch.Tensor, host_indices: torch.Tensor
) -> int:
if self.mem_pool_host.is_synced(host_indices):
self.mem_pool_device_allocator.free(device_indices)
self.mem_pool_host.update_backup(host_indices)
return len(device_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def evict_host(self, host_indices: torch.Tensor, backup_only: bool = True) -> int:
if not backup_only:
raise ValueError("Other eviction policies are not supported yet.")
self.mem_pool_host.free(host_indices)
return len(host_indices)
if self.mem_pool_host.is_backup(host_indices):
self.mem_pool_host.free(host_indices)
return len(host_indices)
else:
raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
)
def prefetch(
self,
......@@ -579,19 +630,42 @@ class HiCacheController:
for chunk in chunks:
self.host_mem_release_queue.put(chunk)
def _page_get_zero_copy(self, operation, hash_values, host_indices):
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
inc = 0
for i in range(len(hash_values)):
if not results[i]:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hash_values[i]}."
)
break
inc += self.page_size
operation.increment(inc)
def _3fs_zero_copy_batch_exists(self, batch_hashes):
_batch_hashes, _, factor = self.mem_pool_host.get_buffer_with_hash(batch_hashes)
hit_page_num = self.storage_backend.batch_exists(_batch_hashes) // factor
return hit_page_num
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
hashes, dsts, factor = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
page_data = self.storage_backend.batch_get(hashes, dsts)
if page_data:
inc = self.page_size * len(hashes) // factor
operation.increment(inc)
else:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
)
def _mooncake_page_get(self, operation, hash_values, host_indices):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
get_result = self.storage_backend.batch_get(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
if get_result != len(hash_values):
logger.warning(
f"Prefetch operation {operation.request_id} failed or partially failed."
)
if get_result != 0:
operation.increment(get_result * self.page_size)
# todo: deprecate
def _generic_page_get(self, operation, hash_values, host_indices):
dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
......@@ -681,7 +755,7 @@ class HiCacheController:
batch_tokens[i : i + self.page_size], last_hash
)
batch_hashes.append(last_hash)
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
hit_page_num = self.batch_exists_func(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes):
......@@ -750,16 +824,34 @@ class HiCacheController:
self.backup_queue.put(operation)
return operation.id
# todo: deprecate
# non-zero copy
def _generic_page_set(self, hash_values, host_indices) -> bool:
data = [
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
self.mem_pool_host.get_flat_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values))
]
return self.storage_backend.batch_set(hash_values, data)
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
# zero copy
def _mooncake_page_set(self, hash_values, host_indices) -> bool:
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values,
host_indices,
self.storage_config.tp_rank,
)
success = self.storage_backend.batch_set(
key_strs,
target_locations=buffer_ptrs,
target_sizes=buffer_sizes,
)
return success
# zero copy
def _3fs_zero_copy_page_set(self, hash_values, host_indices) -> bool:
hashes, dsts, _ = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices
)
return self.storage_backend.batch_set(hashes, dsts)
# Backup batch by batch
def _page_backup(self, operation):
......
......@@ -35,7 +35,6 @@ else:
Image = Any
# Parameters for a session
@dataclass
class SessionParams:
id: Optional[str] = None
......@@ -133,23 +132,18 @@ class GenerateReqInput:
# Conversation id used for tracking requests
conversation_id: Optional[str] = None
# Label for the request
label: Optional[str] = None
# Priority for the request
priority: Optional[int] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[Union[List[str], str]] = None
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs: bool = False
# For custom metric labels
custom_labels: Optional[Dict[str, str]] = None
# (Deprecated, please use custom_labels) Label for the request
label: Optional[str] = None
# (Internal) Whether to return bytes for image generation
# Image gen grpc migration
return_bytes: bool = False
# For customer metric labels
customer_labels: Optional[Dict[str, str]] = None
def contains_mm_input(self) -> bool:
return (
has_valid_data(self.image_data)
......@@ -548,11 +542,8 @@ class GenerateReqInput:
self.data_parallel_rank if self.data_parallel_rank is not None else None
),
conversation_id=self.conversation_id,
priority=self.priority,
extra_key=self.extra_key,
no_logs=self.no_logs,
custom_labels=self.custom_labels,
label=self.label,
priority=self.priority,
return_bytes=self.return_bytes,
)
......@@ -609,23 +600,18 @@ class TokenizedGenerateReqInput:
# For dp balance
dp_balance_id: int = -1
# Label for the request
label: Optional[str] = None
# Priority for the request
priority: Optional[int] = None
# Extra key for classifying the request (e.g. cache_salt)
extra_key: Optional[str] = None
# Whether to disallow logging for this request (e.g. due to ZDR)
no_logs: bool = False
# Image gen grpc migration
return_bytes: bool = False
# tracing context
trace_context: Optional[Dict] = None
# (Deprecated, please use custom_labels) Label for the request
label: Optional[str] = None
# (Internal) Whether to return bytes for image generation
return_bytes: bool = False
@dataclass
class BatchTokenizedGenerateReqInput:
......
......@@ -507,7 +507,6 @@ def embed_mm_inputs(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: dict[Modality, List[int]] = None,
use_deepstack: bool = False,
) -> Optional[torch.Tensor]:
"""
Embed multimodal inputs and integrate them with text token embeddings.
......@@ -523,7 +522,7 @@ def embed_mm_inputs(
Returns:
Combined embedding tensor with multimodal content integrated
"""
other_info = {}
if mm_inputs_list is None:
return None
......@@ -533,7 +532,7 @@ def embed_mm_inputs(
for mm_inputs in mm_inputs_list:
item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
embeddings, masks, deepstack_embeddings = [], [], []
embeddings, masks = [], []
# 2. Get multimodal embedding separately
# Try get mm embedding if any
for modality in Modality.all():
......@@ -579,12 +578,6 @@ def embed_mm_inputs(
extend_length=extend_seq_lens,
items_offset_list=items_offsets,
)
if use_deepstack and embedding is not None:
embedding, deepstack_embedding = (
multimodal_model.separate_deepstack_embeds(embedding)
)
deepstack_embeddings += [deepstack_embedding]
embeddings += [embedding]
masks += [mask]
......@@ -598,37 +591,13 @@ def embed_mm_inputs(
inputs_embeds = input_embedding(input_ids)
# 4. scatter embeddings into input embedding
# deepstack embedding
if use_deepstack:
num_deepstack_embeddings = (
len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
)
deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
inputs_embeds.shape[-1] * num_deepstack_embeddings,
)
input_deepstack_embeds = torch.zeros(
deepstack_embedding_shape,
device=inputs_embeds.device,
dtype=inputs_embeds.dtype,
)
other_info["input_deepstack_embeds"] = input_deepstack_embeds
for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
for embedding, mask in zip(embeddings, masks):
if embedding is None or mask is None:
continue
# in-place update
indices = torch.where(mask.squeeze(dim=-1))[0]
inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
if use_deepstack:
input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
inputs_embeds.device, inputs_embeds.dtype
)
return inputs_embeds, other_info
return inputs_embeds
def general_mm_embed_routine(
......@@ -640,7 +609,6 @@ def general_mm_embed_routine(
Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
] = None,
placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
use_deepstack: bool = False,
**kwargs,
) -> torch.Tensor:
"""
......@@ -652,7 +620,6 @@ def general_mm_embed_routine(
language_model: Base language model to use
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
placeholder_tokens: Token IDs for multimodal placeholders
use_deepstack: Whether to use deepstack embeddings
**kwargs: Additional arguments passed to language model
Returns:
......@@ -678,20 +645,16 @@ def general_mm_embed_routine(
for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
if forward_batch.mm_inputs[i] is not None
]
inputs_embeds, other_info = embed_mm_inputs(
inputs_embeds = embed_mm_inputs(
mm_inputs_list=mm_inputs_list,
extend_prefix_lens=extend_prefix_lens,
extend_seq_lens=extend_seq_lens,
input_ids=input_ids,
multimodal_model=multimodal_model,
input_embedding=embed_tokens,
multimodal_model=multimodal_model,
data_embedding_func_mapping=data_embedding_funcs,
placeholder_tokens=placeholder_tokens,
use_deepstack=use_deepstack,
)
# add for qwen3_vl deepstack
if use_deepstack:
kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
# once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
# just being defensive here
forward_batch.mm_inputs = None
......
......@@ -12,7 +12,8 @@ logger = logging.getLogger(__name__)
PROCESSOR_MAPPING = {}
def import_processors(package_name: str):
def import_processors():
package_name = "sglang.srt.multimodal.processors"
package = importlib.import_module(package_name)
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
if not ispkg:
......
import torch
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.utils import get_compiler_backend
@torch.compile(dynamic=True, backend=get_compiler_backend())
def _resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class FutureMap:
def __init__(
self,
max_running_requests: int,
device: torch.device,
):
self.future_ct = 0
# A factor of 3 is used to avoid collision in the circular buffer.
self.future_limit = max_running_requests * 3
# A factor of 5 is used to ensure the buffer is large enough.
self.future_buffer_len = max_running_requests * 5
self.device = device
self.token_ids_buf = torch.empty(
(self.future_buffer_len,), dtype=torch.int64, device=self.device
)
def update_ct(self, bs: int) -> int:
"""Update the circular buffer pointer and return the current pointer."""
cur_future_ct = self.future_ct
self.future_ct = (cur_future_ct + bs) % self.future_limit
return cur_future_ct
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
input_ids = model_worker_batch.input_ids
_resolve_future_token_ids(input_ids, self.token_ids_buf)
def update_next_future(self, future_ct: int, bs: int):
return torch.arange(
-(future_ct + 1),
-(future_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
)
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
......@@ -67,14 +67,14 @@ from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.sampling.sampling_params import DEFAULT_SAMPLING_SEED, SamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......@@ -90,7 +90,6 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather",
"disable_radix_cache",
"enable_dp_lm_head",
"enable_fp32_lm_head",
"flashinfer_mxfp4_moe_precision",
"enable_flashinfer_allreduce_fusion",
"moe_dense_tp_size",
......@@ -113,6 +112,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
"enable_custom_logit_processor",
"disaggregation_mode",
"enable_deterministic_inference",
"nsa_prefill",
"nsa_decode",
]
# Put some global args for easy access
......@@ -492,7 +493,7 @@ class Req:
self.custom_logit_processor = custom_logit_processor
self.return_hidden_states = return_hidden_states
# extra key for classifying the request (e.g. cache_salt)
# extra key for classifying the request (e.g. lora_id, cache_salt)
if lora_id is not None:
extra_key = (
extra_key or ""
......@@ -608,8 +609,6 @@ class Req:
) = None
self.hidden_states: List[List[float]] = []
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
self.output_topk_p = None
self.output_topk_index = None
# Embedding (return values)
self.embedding = None
......@@ -954,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
None
)
spec_info: Optional[
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
] = None
# Whether to return hidden states
return_hidden_states: bool = False
......@@ -1609,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if (
self.spec_algorithm.is_eagle()
or self.spec_algorithm.is_standalone()
or self.spec_algorithm.is_ngram()
or self.spec_algorithm.is_lookahead()
):
# if spec decoding is used, the decode batch is prepared inside
# `forward_batch_speculative_generation` after running draft models.
......@@ -1736,14 +1735,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
if self.spec_info:
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
has_been_filtered = False
else:
has_been_filtered = True
self.spec_info.filter_batch(
new_indices=keep_indices_device,
has_been_filtered=has_been_filtered,
)
self.spec_info.filter_batch(keep_indices_device)
def merge_batch(self, other: "ScheduleBatch"):
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
......@@ -1992,9 +1984,9 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
None
)
spec_info: Optional[
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1
......
......@@ -318,6 +318,7 @@ class PrefillAdder:
new_token_ratio: float,
rem_input_tokens: int,
rem_chunk_tokens: Optional[int],
max_prefill_bs: Optional[int],
mixed_with_decode_tokens: int = 0,
priority_scheduling_preemption_threshold: int = 0,
):
......@@ -358,6 +359,10 @@ class PrefillAdder:
priority_scheduling_preemption_threshold
)
self.max_prefill_bs = (
max_prefill_bs if max_prefill_bs is not None else 2147483647
)
def _get_running_request_total_token_offset(self, req: Req) -> int:
return (
min(
......@@ -549,6 +554,9 @@ class PrefillAdder:
def add_one_req(
self, req: Req, has_chunked_req: bool, truncation_align_size: Optional[int]
):
if len(self.can_run_list) >= self.max_prefill_bs:
return AddReqResult.OTHER
if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
return self.add_one_req_ignore_eos(req, has_chunked_req)
......
......@@ -44,9 +44,6 @@ from sglang.srt.disaggregation.decode import (
DecodeTransferQueue,
SchedulerDisaggregationDecodeMixin,
)
from sglang.srt.disaggregation.decode_kvcache_offload_manager import (
DecodeKVCacheOffloadManager,
)
from sglang.srt.disaggregation.prefill import (
PrefillBootstrapQueue,
SchedulerDisaggregationPrefillMixin,
......@@ -262,7 +259,7 @@ class Scheduler(
self.enable_metrics_for_all_schedulers = (
server_args.enable_metrics_for_all_schedulers
)
self.enable_kv_cache_events = server_args.kv_events_config and tp_rank == 0
self.enable_kv_cache_events = server_args.kv_events_config is not None
self.stream_interval = server_args.stream_interval
self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm
......@@ -388,10 +385,10 @@ class Scheduler(
target_worker=self.tp_worker,
dp_rank=dp_rank,
)
elif self.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_worker import NGRAMWorker
elif self.spec_algorithm.is_lookahead():
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker
self.draft_worker = NGRAMWorker(
self.draft_worker = LOOKAHEADWorker(
gpu_id=gpu_id,
tp_rank=tp_rank,
moe_ep_rank=moe_ep_rank,
......@@ -556,11 +553,9 @@ class Scheduler(
# Init metrics stats
self.init_metrics(tp_rank, pp_rank, dp_rank)
self.init_kv_events(server_args.kv_events_config)
self.init_dp_balance(dp_balance_meta)
if self.enable_kv_cache_events:
self.init_kv_events(server_args.kv_events_config)
# Init disaggregation
self.disaggregation_mode = DisaggregationMode(
self.server_args.disaggregation_mode
......@@ -618,6 +613,8 @@ class Scheduler(
]
)
self.max_prefill_bs = server_args.max_prefill_bs
def init_deterministic_inference_config(self):
"""Initialize deterministic inference configuration for different attention backends."""
if not self.server_args.enable_deterministic_inference:
......@@ -758,24 +755,6 @@ class Scheduler(
eviction_policy=server_args.radix_eviction_policy,
)
if (
server_args.disaggregation_mode == "decode"
and server_args.disaggregation_decode_enable_offload_kvcache
):
self.decode_offload_manager = DecodeKVCacheOffloadManager(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
tp_group=(
self.attn_tp_cpu_group
if self.server_args.enable_dp_attention
else self.tp_cpu_group
),
tree_cache=self.tree_cache,
server_args=self.server_args,
)
else:
self.decode_offload_manager = None
self.decode_mem_cache_buf_multiplier = (
1
if self.spec_algorithm.is_none()
......@@ -806,7 +785,7 @@ class Scheduler(
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
hidden_states_dtype=self.model_config.dtype,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
......@@ -826,7 +805,7 @@ class Scheduler(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
draft_token_to_kv_pool=(
None
if self.draft_worker is None or self.spec_algorithm.is_ngram()
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
......@@ -855,7 +834,7 @@ class Scheduler(
self.disagg_metadata_buffers = MetadataBuffers(
buffer_size,
hidden_size=self.model_config.hf_text_config.hidden_size,
hidden_states_dtype=self.model_config.dtype,
dtype=self.model_config.dtype,
custom_mem_pool=self.token_to_kv_pool_allocator.get_kvcache().maybe_get_custom_mem_pool(),
)
......@@ -863,7 +842,7 @@ class Scheduler(
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
draft_token_to_kv_pool=(
None
if self.draft_worker is None or self.spec_algorithm.is_ngram()
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
else self.draft_worker.model_runner.token_to_kv_pool
),
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
......@@ -1832,6 +1811,7 @@ class Scheduler(
self.new_token_ratio,
self.max_prefill_tokens,
self.chunked_prefill_size,
self.max_prefill_bs,
running_bs if self.is_mixed_chunk else 0,
self.priority_scheduling_preemption_threshold,
)
......
......@@ -250,13 +250,7 @@ class SchedulerOutputProcessorMixin:
req.check_finished()
if req.finished():
if self.server_args.disaggregation_decode_enable_offload_kvcache:
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
if not self.decode_offload_manager.offload_kv_cache(req):
self.tree_cache.cache_finished_req(req)
else:
self.tree_cache.cache_finished_req(req)
self.tree_cache.cache_finished_req(req)
req.time_stats.completion_time = time.time()
if req.return_logprob and batch.spec_algorithm.is_none():
......
......@@ -97,7 +97,7 @@ class SchedulerProfilerMixin:
def start_profile(
self, stage: Optional[ForwardMode] = None
) -> ProfileReqOutput | None:
stage_str = f" for {stage.name}" if stage else ""
stage_str = f" for {stage.__str__()}" if stage else ""
logger.info(
f"Profiling starts{stage_str}. Traces will be saved to: {self.torch_profiler_output_dir} (with profile id: {self.profile_id})",
)
......@@ -181,7 +181,7 @@ class SchedulerProfilerMixin:
if not Path(self.torch_profiler_output_dir).exists():
Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
stage_suffix = f"-{stage.name}" if stage else ""
stage_suffix = f"-{stage.__str__()}" if stage else ""
logger.info("Stop profiling" + stage_suffix + "...")
if self.torch_profiler is not None:
self.torch_profiler.stop()
......@@ -247,7 +247,7 @@ class SchedulerProfilerMixin:
if self.profiler_decode_ct == 0:
if self.profile_in_progress:
# force trace flush
self.stop_profile(stage=ForwardMode.EXTEND)
self.stop_profile(ForwardMode.EXTEND)
self.start_profile(batch.forward_mode)
self.profiler_decode_ct += 1
if self.profiler_decode_ct > self.profiler_target_decode_ct:
......@@ -294,6 +294,6 @@ class SchedulerProfilerMixin:
recv_req.profile_by_stage,
recv_req.profile_id,
)
return self.start_profile()
return self.start_profile(True)
else:
return self.stop_profile()
......@@ -185,7 +185,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
)
if self.model_config.is_multimodal:
import_processors("sglang.srt.multimodal.processors")
import_processors()
try:
_processor = get_processor(
server_args.tokenizer_path,
......@@ -320,8 +320,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
"model_name": self.server_args.served_model_name,
# TODO: Add lora name/path in the future,
}
if server_args.tokenizer_metrics_allowed_custom_labels:
for label in server_args.tokenizer_metrics_allowed_custom_labels:
if server_args.tokenizer_metrics_allowed_customer_labels:
for label in server_args.tokenizer_metrics_allowed_customer_labels:
labels[label] = ""
self.metrics_collector = TokenizerMetricsCollector(
server_args=server_args,
......@@ -750,7 +750,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
return_hidden_states=obj.return_hidden_states,
data_parallel_rank=obj.data_parallel_rank,
priority=obj.priority,
extra_key=obj.extra_key,
)
elif isinstance(obj, EmbeddingReqInput):
tokenized_obj = TokenizedEmbeddingReqInput(
......@@ -1633,10 +1632,10 @@ class TokenizerManager(TokenizerCommunicatorMixin):
else 0
)
custom_labels = getattr(state.obj, "custom_labels", None)
customer_labels = getattr(state.obj, "customer_labels", None)
labels = (
{**self.metrics_collector.labels, **custom_labels}
if custom_labels
{**self.metrics_collector.labels, **customer_labels}
if customer_labels
else self.metrics_collector.labels
)
if (
......
......@@ -91,6 +91,7 @@ class TpModelWorker:
else server_args.speculative_draft_model_revision
),
is_draft_model=is_draft_worker,
tp_rank=tp_rank,
)
self.model_runner = ModelRunner(
......
......@@ -36,11 +36,10 @@ from sglang.srt.managers.io_struct import (
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.overlap_utils import FutureMap
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import DynamicGradMode
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
from sglang.utils import get_exception_traceback
if TYPE_CHECKING:
......@@ -49,6 +48,15 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@torch.compile(dynamic=True, backend=get_compiler_backend())
def resolve_future_token_ids(input_ids, future_token_ids_map):
input_ids[:] = torch.where(
input_ids < 0,
future_token_ids_map[torch.clamp(-input_ids, min=0)],
input_ids,
)
class TpModelWorkerClient:
"""A tensor parallel model worker."""
......@@ -71,7 +79,11 @@ class TpModelWorkerClient:
self.gpu_id = gpu_id
# Init future mappings
self.future_map = FutureMap(self.max_running_requests, self.device)
self.future_token_ids_ct = 0
self.future_token_ids_limit = self.max_running_requests * 3
self.future_token_ids_map = torch.empty(
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
)
# Launch threads
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
......@@ -141,7 +153,7 @@ class TpModelWorkerClient:
batch_lists: List = [None] * 2
while True:
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
if not model_worker_batch:
break
......@@ -157,7 +169,8 @@ class TpModelWorkerClient:
copy_done = torch.get_device_module(self.device).Event()
# Resolve future tokens in the input
self.future_map.resolve_future(model_worker_batch)
input_ids = model_worker_batch.input_ids
resolve_future_token_ids(input_ids, self.future_token_ids_map)
# Run forward
logits_output, next_token_ids, can_run_cuda_graph = (
......@@ -174,9 +187,9 @@ class TpModelWorkerClient:
if model_worker_batch.is_prefill_only:
# For prefill-only requests, create dummy token IDs on CPU
next_token_ids = torch.zeros(bs, dtype=torch.long)
# store the future indices into future map
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
self.future_token_ids_map[
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
] = next_token_ids
# Copy results to the CPU
if model_worker_batch.return_logprob:
......@@ -242,14 +255,20 @@ class TpModelWorkerClient:
sync_event.record(self.scheduler_stream)
# Push a new batch to the queue
bs = len(model_worker_batch.seq_lens)
cur_future_map_ct = self.future_map.update_ct(bs)
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
# get this forward batch's future token ids
future_next_token_ids = self.future_map.update_next_future(
cur_future_map_ct, bs
# Allocate output future objects
bs = len(model_worker_batch.seq_lens)
future_next_token_ids = torch.arange(
-(self.future_token_ids_ct + 1),
-(self.future_token_ids_ct + 1 + bs),
-1,
dtype=torch.int64,
device=self.device,
)
self.future_token_ids_ct = (
self.future_token_ids_ct + bs
) % self.future_token_ids_limit
return None, future_next_token_ids, False
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
......
......@@ -79,37 +79,48 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
)
num_new_pages = (
(
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
)
.sum()
.item()
)
if self.need_sort and num_new_pages > len(self.free_pages):
(seq_lens + self.page_size - 1) // self.page_size
- (prefix_lens + self.page_size - 1) // self.page_size
).sum()
num_new_pages_item = num_new_pages.item()
if self.need_sort and num_new_pages_item > len(self.free_pages):
self.merge_and_sort_free()
if num_new_pages > len(self.free_pages):
if num_new_pages_item > len(self.free_pages):
return None
out_indices = torch.empty(
(extend_num_tokens,), dtype=torch.int32, device=self.device
(extend_num_tokens,), dtype=torch.int64, device=self.device
)
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if num_new_pages_item < 200:
import sgl_kernel_npu
torch.ops.npu.alloc_extend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
self.page_size,
out_indices,
num_new_pages,
)
else:
alloc_extend_kernel_ascend(
prefix_lens,
seq_lens,
last_loc,
self.free_pages,
out_indices,
self.page_size,
self.device,
)
if self.debug_mode:
assert len(torch.unique(out_indices)) == len(out_indices)
self.free_pages = self.free_pages[num_new_pages:]
self.free_pages = self.free_pages[num_new_pages_item:]
return out_indices
def alloc_decode(
......
......@@ -7,8 +7,6 @@ from typing import Any, List, Optional
import torch
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__)
......@@ -34,46 +32,15 @@ class HiCacheStorageConfig:
extra_config: Optional[dict] = None
@dataclass
class HiCacheStorageExtraInfo:
extra_info: Optional[dict] = None
class HiCacheStorage(ABC):
"""
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
# todo, potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
self.mem_pool_host = mem_pool_host
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@abstractmethod
def get(
self,
......@@ -87,7 +54,6 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Deprecate
@abstractmethod
def batch_get(
self,
......@@ -115,7 +81,6 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Deprecate
@abstractmethod
def batch_set(
self,
......@@ -138,7 +103,6 @@ class HiCacheStorage(ABC):
"""
pass
# TODO: Use a finer-grained return type (e.g., List[bool])
def batch_exists(self, keys: List[str]) -> int:
"""
Check if the keys exist in the storage.
......@@ -150,9 +114,6 @@ class HiCacheStorage(ABC):
return i
return len(keys)
def clear(self) -> None:
pass
def get_stats(self):
return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment