"vscode:/vscode.git/clone" did not exist on "9fafa62db7d489f1b7910174373183cee9e65c3d"
Commit 852a49c5 authored by maxiao's avatar maxiao
Browse files

adapt to dsv32 on dcu

parent 8f7453e3
......@@ -48,9 +48,9 @@ class HiRadixCache(RadixCache):
if hicache_io_backend == "direct":
if hicache_mem_layout == "page_first":
hicache_mem_layout = "page_first_direct"
hicache_mem_layout = "layer_first"
logger.warning(
"Page first layout is not supported with direct IO backend, switching to page first direct layout"
"Page first layout is not supported with direct IO backend, switching to layer first layout"
)
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
......@@ -305,7 +305,7 @@ class HiRadixCache(RadixCache):
def _evict_backuped(self, node: TreeNode):
# evict a node already written to host
num_evicted = self.cache_controller.evict_device(node.value)
num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
assert num_evicted > 0
self.evictable_size_ -= num_evicted
node.value = None
......@@ -576,6 +576,8 @@ class HiRadixCache(RadixCache):
written_indices,
hash_value[: min_completed_tokens // self.page_size],
)
if len(written_indices):
self.cache_controller.mem_pool_host.update_prefetch(written_indices)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.append_host_mem_release(
......@@ -773,6 +775,7 @@ class HiRadixCache(RadixCache):
# change the reference if the node is evicted
# this often happens in the case of KV cache recomputation
node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(node.host_value)
self.evictable_size_ += len(node.value)
else:
self._inc_hit_count(node, chunked)
......@@ -782,6 +785,7 @@ class HiRadixCache(RadixCache):
new_node = self._split_node(node.key, node, prefix_len)
if new_node.evicted:
new_node.value = value[:prefix_len]
self.token_to_kv_pool_host.update_synced(new_node.host_value)
self.evictable_size_ += len(new_node.value)
else:
self._inc_hit_count(new_node, chunked)
......
......@@ -15,6 +15,8 @@ limitations under the License.
from __future__ import annotations
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
......@@ -37,6 +39,7 @@ import triton
import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.attention.nsa.utils import NSA_KV_CACHE_STORE_FP8
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
......@@ -1030,6 +1033,8 @@ class MLATokenToKVPool(KVCache):
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
use_nsa: bool = False,
override_kv_cache_dim: Optional[int] = None,
):
super().__init__(
size,
......@@ -1044,6 +1049,16 @@ class MLATokenToKVPool(KVCache):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.use_nsa = use_nsa
# TODO do not hardcode
self.kv_cache_dim = (
656
if use_nsa and NSA_KV_CACHE_STORE_FP8
else (kv_lora_rank + qk_rope_head_dim)
)
if use_nsa and NSA_KV_CACHE_STORE_FP8:
assert self.dtype == torch.float8_e4m3fn, f"{self.dtype=}"
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
......@@ -1067,7 +1082,7 @@ class MLATokenToKVPool(KVCache):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
(size + page_size, 1, self.kv_cache_dim),
dtype=self.store_dtype,
device=device,
)
......@@ -1130,6 +1145,7 @@ class MLATokenToKVPool(KVCache):
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
assert not (self.use_nsa and NSA_KV_CACHE_STORE_FP8)
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
......@@ -1147,16 +1163,28 @@ class MLATokenToKVPool(KVCache):
cache_k_rope: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
)
if self.use_nsa and NSA_KV_CACHE_STORE_FP8:
# original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
# TODO no need to cat
cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
cache_k = cache_k.view(self.store_dtype)
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
else:
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer],
loc,
cache_k_nope,
cache_k_rope,
)
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
......@@ -1186,6 +1214,103 @@ class MLATokenToKVPool(KVCache):
torch.cuda.synchronize()
class NSATokenToKVPool(MLATokenToKVPool):
def __init__(
self,
size: int,
page_size: int,
kv_lora_rank: int,
dtype: torch.dtype,
qk_rope_head_dim: int,
layer_num: int,
device: str,
index_head_dim: int,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
super().__init__(
size,
page_size,
dtype,
kv_lora_rank,
qk_rope_head_dim,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
use_nsa=True,
)
# self.index_k_dtype = torch.float8_e4m3fn
# self.index_k_scale_dtype = torch.float32
self.index_head_dim = index_head_dim
# num head == 1 and head dim == 128 for index_k in NSA
assert index_head_dim == 128
self.quant_block_size = 128
assert self.page_size == 64
self.index_k_with_scale_buffer = [
torch.zeros(
# Layout:
# ref: test_attention.py :: kv_cache_cast_to_fp8
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
# data: for page i,
# * buf[i, :page_size * head_dim] for fp8 data
# * buf[i, page_size * head_dim:].view(float32) for scale
(
(size + page_size + 1) // self.page_size,
self.page_size
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
),
dtype=torch.uint8,
device=device,
)
for _ in range(layer_num)
]
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self.index_k_with_scale_buffer[layer_id - self.start_layer]
def get_index_k_continuous(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetK.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)
def get_index_k_scale_continuous(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetS.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)
# TODO rename later (currently use diff name to avoid confusion)
def set_index_k_and_scale_buffer(
self,
layer_id: int,
loc: torch.Tensor,
index_k: torch.Tensor,
index_k_scale: torch.Tensor,
) -> None:
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
index_buf_accessor.SetKAndS.execute(
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
)
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__(
self,
......@@ -1194,6 +1319,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
index_head_dim: Optional[int],
layer_num: int,
device: str,
enable_memory_saver: bool,
......@@ -1213,6 +1339,7 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.index_head_dim = index_head_dim
self.custom_mem_pool = None
......@@ -1240,6 +1367,18 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
dtype=self.store_dtype,
device=self.device,
)
if self.index_head_dim is not None:
self.index_k_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
1,
self.index_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
self._finalize_allocation_log(size)
......@@ -1251,6 +1390,10 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
kv_size_bytes += get_tensor_size_bytes(k_cache)
for v_cache in self.v_buffer:
kv_size_bytes += get_tensor_size_bytes(v_cache)
if self.index_head_dim is not None:
assert hasattr(self, "index_k_buffer")
for index_k_cache in self.index_k_buffer:
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
return kv_size_bytes
def get_kv_buffer(self, layer_id: int):
......@@ -1277,6 +1420,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
def get_index_k_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.index_k_buffer[layer_id - self.start_layer]
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
......@@ -1289,6 +1440,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
]
if self.index_head_dim is not None:
kv_data_ptrs += [
self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
]
kv_data_lens += [
self.index_k_buffer[i].nbytes for i in range(self.layer_num)
]
kv_item_lens += [
self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
......@@ -1325,6 +1486,26 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
cache_v.view(-1, 1, self.qk_rope_head_dim),
)
def set_index_k_buffer(
self,
layer_id: int,
loc: torch.Tensor,
index_k: torch.Tensor,
):
if index_k.dtype != self.dtype:
index_k = index_k.to(self.dtype)
if self.store_dtype != self.dtype:
index_k = index_k.view(self.store_dtype)
torch_npu.npu_scatter_nd_update_(
self.index_k_buffer[layer_id - self.start_layer].view(
-1, 1, self.index_head_dim
),
loc.view(-1, 1),
index_k.view(-1, 1, self.index_head_dim),
)
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
......
......@@ -31,13 +31,27 @@ if not (_is_npu or _is_xpu):
logger = logging.getLogger(__name__)
def synchronized(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
with self.lock:
return func(self, *args, **kwargs)
class MemoryStateInt(IntEnum):
IDLE = 0
RESERVED = 1
PROTECTED = 2
SYNCED = 3
BACKUP = 4
def synchronized(debug_only=False):
def _decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if (not debug_only) or self.debug:
with self.lock:
return func(self, *args, **kwargs)
else:
return True
return wrapper
return wrapper
return _decorator
class HostKVCache(abc.ABC):
......@@ -96,6 +110,7 @@ class HostKVCache(abc.ABC):
# A lock for synchronized operations on memory allocation and state transitions.
self.lock = threading.RLock()
self.debug = logger.isEnabledFor(logging.DEBUG)
self.clear()
@abc.abstractmethod
......@@ -125,7 +140,7 @@ class HostKVCache(abc.ABC):
raise NotImplementedError()
@abc.abstractmethod
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
def get_flat_data_page(self, index) -> torch.Tensor:
"""
Get a flat data page from the host memory pool.
"""
......@@ -146,7 +161,7 @@ class HostKVCache(abc.ABC):
"""
raise NotImplementedError()
@synchronized
@synchronized()
def clear(self):
# Initialize memory states and tracking structures.
self.mem_state = torch.zeros(
......@@ -157,7 +172,7 @@ class HostKVCache(abc.ABC):
def available_size(self):
return len(self.free_slots)
@synchronized
@synchronized()
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
assert (
need_size % self.page_size == 0
......@@ -168,13 +183,92 @@ class HostKVCache(abc.ABC):
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
if self.debug:
self.mem_state[select_index] = MemoryStateInt.RESERVED
return select_index
@synchronized
@synchronized()
def free(self, indices: torch.Tensor) -> int:
self.free_slots = torch.cat([self.free_slots, indices])
if self.debug:
self.mem_state[indices] = MemoryStateInt.IDLE
return len(indices)
@synchronized(debug_only=True)
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
assert len(indices) > 0, "The indices should not be empty"
states = self.mem_state[indices]
assert (
states == states[0]
).all(), "The memory slots should have the same state {}".format(states)
return MemoryStateInt(states[0].item())
@synchronized(debug_only=True)
def is_reserved(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.RESERVED
@synchronized(debug_only=True)
def is_protected(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def is_synced(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def is_backup(self, indices: torch.Tensor) -> bool:
return self.get_state(indices) == MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_backup(self, indices: torch.Tensor):
if not self.is_synced(indices):
raise ValueError(
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_prefetch(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be in RESERVED state before turning into BACKUP. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.BACKUP
@synchronized(debug_only=True)
def update_synced(self, indices: torch.Tensor):
self.mem_state[indices] = MemoryStateInt.SYNCED
@synchronized(debug_only=True)
def protect_write(self, indices: torch.Tensor):
if not self.is_reserved(indices):
raise ValueError(
f"The host memory slots should be RESERVED before write operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def protect_load(self, indices: torch.Tensor):
if not self.is_backup(indices):
raise ValueError(
f"The host memory slots should be in BACKUP state before load operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.PROTECTED
@synchronized(debug_only=True)
def complete_io(self, indices: torch.Tensor):
if not self.is_protected(indices):
raise ValueError(
f"The host memory slots should be PROTECTED during I/O operations. "
f"Current state: {self.get_state(indices)}"
)
self.mem_state[indices] = MemoryStateInt.SYNCED
class MHATokenToKVPoolHost(HostKVCache):
device_pool: MHATokenToKVPool
......@@ -367,19 +461,16 @@ class MHATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
def get_flat_data_page(self, index) -> torch.Tensor:
if self.layout == "layer_first":
data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
elif self.layout == "page_first":
data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
return self.kv_buffer[:, index : index + self.page_size, :, :, :].flatten()
elif self.layout == "page_first_direct":
real_index = index // self.page_size
data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
return self.kv_buffer[:, real_index : real_index + 1, :, :, :, :].flatten()
else:
raise ValueError(f"Unsupported layout: {self.layout}")
if flat:
data_page = data_page.flatten()
return data_page
def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros(
......@@ -416,12 +507,9 @@ class MHATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def get_page_buffer_meta(self, indices):
""" "
meta data for zero copy
"""
assert len(indices) % self.page_size == 0
def get_buffer_meta(self, keys, indices, local_rank):
ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
indices = indices.tolist()
v_offset = (
......@@ -431,52 +519,48 @@ class MHATokenToKVPoolHost(HostKVCache):
* self.head_dim
* self.dtype.itemsize
)
if self.layout == "layer_first":
for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.head_num
* self.head_dim
* self.dtype.itemsize
+ layer_id
* self.size
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
element_size = (
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
)
element_size_list = [element_size] * len(ptr_list)
elif self.layout in ["page_first", "page_first_direct"]:
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* self.head_num
* self.head_dim
* self.dtype.itemsize
)
element_size_list = [element_size] * len(ptr_list)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
return ptr_list, element_size_list
v_ptr = k_ptr + v_offset
ptr_list.append(k_ptr)
ptr_list.append(v_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_{local_rank}_k")
key_list.append(f"{key_}_{local_rank}_v")
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* self.head_num
* self.head_dim
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first"
assert indices is None or (len(keys) == (len(indices) // self.page_size))
key_list = []
buf_list = []
for i in range(len(keys)):
key = keys[i]
key_list.append(f"{key}-k")
key_list.append(f"{key}-v")
if indices is not None:
index = indices[i * self.page_size]
buf_list.append(self.k_buffer[index : index + self.page_size])
buf_list.append(self.v_buffer[index : index + self.page_size])
return key_list, buf_list, 2
class MLATokenToKVPoolHost(HostKVCache):
......@@ -652,19 +736,16 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
def get_flat_data_page(self, index) -> torch.Tensor:
if self.layout == "layer_first":
data_page = self.kv_buffer[:, index : index + self.page_size, :, :]
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
elif self.layout == "page_first":
data_page = self.kv_buffer[index : index + self.page_size, :, :, :]
return self.kv_buffer[index : index + self.page_size, :, :, :].flatten()
elif self.layout == "page_first_direct":
real_index = index // self.page_size
data_page = self.kv_buffer[real_index : real_index + 1, :, :, :, :]
return self.kv_buffer[real_index : real_index + 1, :, :, :, :].flatten()
else:
raise ValueError(f"Unsupported layout: {self.layout}")
if flat:
data_page = data_page.flatten()
return data_page
def get_dummy_flat_data_page(self) -> torch.Tensor:
return torch.zeros(
......@@ -706,51 +787,40 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def get_page_buffer_meta(self, indices):
""" "
meta data for zero copy
"""
assert len(indices) % self.page_size == 0
def get_buffer_meta(self, keys, indices, local_rank):
ptr_list = []
key_list = []
kv_buffer_data_ptr = self.kv_buffer.data_ptr()
indices = indices.tolist()
if self.layout == "layer_first":
for index in range(0, len(indices), self.page_size):
for layer_id in range(self.layer_num):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
+ layer_id
* self.size
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
element_size = (
self.dtype.itemsize
* self.page_size
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(ptr_list)
elif self.layout in ["page_first", "page_first_direct"]:
for index in range(0, len(indices), self.page_size):
k_ptr = (
kv_buffer_data_ptr
+ indices[index]
* self.layer_num
* (self.kv_lora_rank + self.qk_rope_head_dim)
* self.dtype.itemsize
)
ptr_list.append(k_ptr)
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(ptr_list)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
return ptr_list, element_size_list
ptr_list.append(k_ptr)
key_ = keys[index // self.page_size]
key_list.append(f"{key_}_k")
element_size = (
self.layer_num
* self.dtype.itemsize
* self.page_size
* (self.kv_lora_rank + self.qk_rope_head_dim)
)
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first"
assert indices is None or (len(keys) == (len(indices) // self.page_size))
buf_list = []
if indices is not None:
for i in range(len(keys)):
index = indices[i * self.page_size]
buf_list.append(self.kv_buffer[index : index + self.page_size])
return keys, buf_list, 1
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to SGLang project
"""Storage backend module for SGLang HiCache."""
from .backend_factory import StorageBackendFactory
__all__ = [
"StorageBackendFactory",
]
# AIBrix KVCache as L3 KV Cache
This document provides brief instructions for setting up a AIBrixKVCache storage backend + AIBrixKVCache + SGLang runtime environment from scratch, describing how to utilize AIBrixKVCache as the L3 KV cache for SGLang.
The process consists of three main steps:
## Step1:Install AIbrix KVCache
Refer to the [AIBrix KVCache documentation](https://github.com/vllm-project/aibrix/blob/main/python/aibrix_kvcache/README.md) to install AIBrix KVCache.
## Step2: Deploy AIBrix Distributed KVCache Storage
AIBrix KVCache currently supports multiple distributed KVCache backends, including ByteDance's open-source Infinistore and the not-yet-open source PrisKV incubated by ByteDance's PrisDB & IAAS & DMI team.
For the Infinistore installation process, please refer to [this link](https://github.com/bytedance/InfiniStore).
PrisKV for AIBrix KVCache is currently in the open-source preparation stage, and no public documentation is available yet.
## Step3: Deploy Model Serving
For information on configuring a distributed KVCache backend for AIBrixKVCache, please refer to [this link](https://aibrix.readthedocs.io/latest/designs/aibrix-kvcache-offloading-framework.html)
Using PrisKV as an example, the startup command is as follows:
```bash
export AIBRIX_KV_CACHE_OL_L1_CACHE_ENABLED="0"
export AIBRIX_KV_CACHE_OL_L2_CACHE_BACKEND="PRIS"
export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_ADDR="127.0.0.1"
export AIBRIX_KV_CACHE_OL_PRIS_REMOTE_PORT="6379"
export AIBRIX_KV_CACHE_OL_PRIS_PASSWORD="kvcache-redis"
MODEL_LENGTH=32768&&NCCL_MIN_NCHANNELS=24&&NCCL_IB_QPS_PER_CONNECTION=8&&NCCL_DEBUG=INFO \
python3 -m sglang.launch_server \
--model-path /code/models/Qwen3-32B \
--host 0.0.0.0 --port 8080 \
--enable-hierarchical-cache \
--hicache-storage-backend aibrix \
--page-size 16 \
--hicache-write-policy write_back \
--enable-metrics --hicache-ratio=2
```
import logging
from typing import Any, List, Optional
import torch
from aibrix_kvcache import (
BaseKVCacheManager,
BlockHashes,
KVCacheBlockLayout,
KVCacheBlockSpec,
KVCacheConfig,
KVCacheTensorSpec,
ModelSpec,
)
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__)
class AibrixKVCacheStorage(HiCacheStorage):
def __init__(self, storage_config: HiCacheStorageConfig, mem_pool: HostKVCache):
if storage_config is not None:
self.is_mla_backend = storage_config.is_mla_model
self.local_rank = storage_config.tp_rank
else:
self.is_mla_backend = False
self.local_rank = 0
kv_cache = mem_pool.device_pool
self.page_size = mem_pool.page_size
self.kv_cache_dtype = kv_cache.dtype
self.layer_num = kv_cache.layer_num
self.kv_head_ids = [
self.local_rank * kv_cache.head_num + i for i in range(kv_cache.head_num)
]
if not self.is_mla_backend:
self.layer_ids = range(
kv_cache.start_layer, kv_cache.end_layer
) # for pipeline parallel
self.block_spec = KVCacheBlockSpec(
block_ntokens=self.page_size,
block_dtype=self.kv_cache_dtype,
block_layout=KVCacheBlockLayout(KVCacheBlockLayout.NCLD),
tensor_spec=KVCacheTensorSpec(
heads=self.kv_head_ids,
layers=self.layer_ids,
head_size=kv_cache.head_dim,
),
)
logger.info(self.block_spec)
config = KVCacheConfig(
block_spec=self.block_spec, model_spec=ModelSpec(102400)
)
self.kv_cache_manager = BaseKVCacheManager(config)
else:
raise NotImplementedError(
"MLA is not supported by AibrixKVCacheStorage yet."
)
def _aibrix_kvcache_metrics_report(self):
self.kv_cache_manager.metrics.summary()
self.kv_cache_manager.metrics.reset()
def batch_get(
self,
keys: List[str],
target_locations: List[torch.Tensor],
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]:
block_hash = BlockHashes(keys, self.page_size)
status = self.kv_cache_manager.acquire(None, block_hash)
log_every_n_seconds(
logger, logging.INFO, self._aibrix_kvcache_metrics_report(), 1
)
if status.is_ok():
num_fetched_tokens, handle = status.value
kv_blocks = handle.to_tensors()
assert len(kv_blocks) == len(target_locations)
for i in range(len(kv_blocks)):
assert (
target_locations[i].nbytes == kv_blocks[i].nbytes
), f"{target_locations[i].nbytes}, {kv_blocks[i].nbytes}"
target_locations[i].copy_(kv_blocks[i].flatten())
handle.release()
return target_locations
return [None] * len(keys)
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_size: Optional[Any] = None,
) -> torch.Tensor | None:
return self.batch_get([key], [target_location], [target_size])[0]
def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
block_hash = BlockHashes(keys, self.page_size)
status = self.kv_cache_manager.allocate_for(None, block_hash)
if not status.is_ok():
logger.warning(
f"aibrix_kvcache set allocate failed, error_code {status.error_code}"
)
return False
handle = status.value
tensors = handle.to_tensors()
if len(tensors) != len(values):
logger.warning("aibrix_kvcache set allocate not enough")
return False
for i in range(len(tensors)):
assert (
tensors[i].nbytes == values[i].nbytes
), f"{tensors[i].nbytes}, {values[i].nbytes}"
tensors[i].reshape(values[i].shape).copy_(values[i]).reshape(
tensors[i].shape
)
status = self.kv_cache_manager.put(None, block_hash, handle)
if not status.is_ok():
logger.info(
f"AIBrix KVCache Storage set failed, error_code {status.error_code}"
)
return False
completed = status.value
return completed == len(keys) * self.page_size
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_size: Optional[Any] = None,
) -> bool:
return self.batch_set([key], [value], [target_location], [target_size])
def batch_exists(self, keys: List[str]) -> int:
block_hash = BlockHashes(keys, self.page_size)
status = self.kv_cache_manager.exists(None, block_hash)
if status.is_ok():
return status.value // self.page_size
return 0
def exists(self, key: str) -> bool | dict:
return self.batch_exists([key]) > 0
import logging
import os
import torch
import torch.distributed
from aibrix_kvcache import (
BaseKVCacheManager,
GroupAwareKVCacheManager,
KVCacheBlockLayout,
KVCacheBlockSpec,
KVCacheConfig,
KVCacheMetrics,
KVCacheTensorSpec,
ModelSpec,
TokenListView,
)
from aibrix_kvcache.common.absl_logging import getLogger, log_every_n_seconds, log_if
from aibrix_kvcache_storage import AibrixKVCacheStorage
from torch.distributed import Backend, ProcessGroup
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
from sglang.srt.mem_cache.memory_pool_host import MHATokenToKVPoolHost
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
def setup():
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "63886"
class AIBrixKVCacheStorageTest:
def test_with_page_size(self):
config = HiCacheStorageConfig(
tp_rank=0,
tp_size=1,
is_mla_model=False,
is_page_first_layout=True,
model_name="test",
)
for page_size in range(1, 3):
logger.info(f"page_size: {page_size}")
batch_size = 2
head_num = 1
layer_num = 64
head_dim = 128
kv_cache = MHATokenToKVPool(
1024,
page_size,
torch.float16,
head_num,
head_dim,
layer_num,
"cpu",
False,
0,
layer_num,
)
mem_pool = MHATokenToKVPoolHost(kv_cache, 2, 0, page_size, "layer_first")
query_length = batch_size * 2
partial = batch_size
self.aibrix_kvcache = AibrixKVCacheStorage(config, mem_pool)
target_shape = (2, layer_num, page_size, head_num, head_dim)
rand_tensor = [
torch.rand(target_shape, dtype=torch.float16)
for _ in range(query_length)
]
keys = ["hash" + str(i) for i in range(query_length)]
partial_keys = keys[batch_size:query_length]
assert self.aibrix_kvcache.batch_exists(keys) == 0
assert self.aibrix_kvcache.batch_set(keys, rand_tensor)
get_tensor = [
torch.rand(target_shape, dtype=torch.float16).flatten()
for _ in range(query_length)
]
self.aibrix_kvcache.batch_get(keys, get_tensor)
for i in range(query_length):
assert torch.equal(get_tensor[i], rand_tensor[i].flatten())
ret = self.aibrix_kvcache.batch_exists(keys)
assert self.aibrix_kvcache.batch_exists(keys) == query_length
assert self.aibrix_kvcache.batch_exists(partial_keys) == partial
partial_get_tensor = [
torch.rand(target_shape, dtype=torch.float16).flatten()
for _ in range(partial)
]
self.aibrix_kvcache.batch_get(partial_keys, partial_get_tensor)
for i in range(partial):
assert torch.equal(
partial_get_tensor[i], rand_tensor[i + partial].flatten()
)
log_every_n_seconds(
logger,
logging.INFO,
self.aibrix_kvcache.kv_cache_manager.metrics.summary(),
1,
)
if __name__ == "__main__":
setup()
test = AIBrixKVCacheStorageTest()
test.test_with_page_size()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to SGLang project
import importlib
import logging
from typing import TYPE_CHECKING, Any, Dict
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
class StorageBackendFactory:
"""Factory for creating storage backend instances with support for dynamic loading."""
_registry: Dict[str, Dict[str, Any]] = {}
@staticmethod
def _load_backend_class(
module_path: str, class_name: str, backend_name: str
) -> type[HiCacheStorage]:
"""Load and validate a backend class from module path."""
try:
module = importlib.import_module(module_path)
backend_class = getattr(module, class_name)
if not issubclass(backend_class, HiCacheStorage):
raise TypeError(
f"Backend class {class_name} must inherit from HiCacheStorage"
)
return backend_class
except ImportError as e:
raise ImportError(
f"Failed to import backend '{backend_name}' from '{module_path}': {e}"
) from e
except AttributeError as e:
raise AttributeError(
f"Class '{class_name}' not found in module '{module_path}': {e}"
) from e
@classmethod
def register_backend(cls, name: str, module_path: str, class_name: str) -> None:
"""Register a storage backend with lazy loading.
Args:
name: Backend identifier
module_path: Python module path containing the backend class
class_name: Name of the backend class
"""
if name in cls._registry:
logger.warning(f"Backend '{name}' is already registered, overwriting")
def loader() -> type[HiCacheStorage]:
"""Lazy loader function to import the backend class."""
return cls._load_backend_class(module_path, class_name, name)
cls._registry[name] = {
"loader": loader,
"module_path": module_path,
"class_name": class_name,
}
@classmethod
def create_backend(
cls,
backend_name: str,
storage_config: HiCacheStorageConfig,
mem_pool_host: Any,
**kwargs,
) -> HiCacheStorage:
"""Create a storage backend instance.
Args:
backend_name: Name of the backend to create
storage_config: Storage configuration
mem_pool_host: Memory pool host object
**kwargs: Additional arguments passed to external backends
Returns:
Initialized storage backend instance
Raises:
ValueError: If backend is not registered and cannot be dynamically loaded
ImportError: If backend module cannot be imported
Exception: If backend initialization fails
"""
# First check if backend is already registered
if backend_name in cls._registry:
registry_entry = cls._registry[backend_name]
backend_class = registry_entry["loader"]()
logger.info(
f"Creating storage backend '{backend_name}' "
f"({registry_entry['module_path']}.{registry_entry['class_name']})"
)
return cls._create_builtin_backend(
backend_name, backend_class, storage_config, mem_pool_host
)
# Try to dynamically load backend from extra_config
if backend_name == "dynamic" and storage_config.extra_config is not None:
backend_config = storage_config.extra_config
return cls._create_dynamic_backend(
backend_config, storage_config, mem_pool_host, **kwargs
)
# Backend not found
available_backends = list(cls._registry.keys())
raise ValueError(
f"Unknown storage backend '{backend_name}'. "
f"Registered backends: {available_backends}. "
)
@classmethod
def _create_dynamic_backend(
cls,
backend_config: Dict[str, Any],
storage_config: HiCacheStorageConfig,
mem_pool_host: Any,
**kwargs,
) -> HiCacheStorage:
"""Create a backend dynamically from configuration."""
required_fields = ["backend_name", "module_path", "class_name"]
for field in required_fields:
if field not in backend_config:
raise ValueError(
f"Missing required field '{field}' in backend config for 'dynamic' backend"
)
backend_name = backend_config["backend_name"]
module_path = backend_config["module_path"]
class_name = backend_config["class_name"]
try:
# Import the backend class
backend_class = cls._load_backend_class(
module_path, class_name, backend_name
)
logger.info(
f"Creating dynamic storage backend '{backend_name}' "
f"({module_path}.{class_name})"
)
# Create the backend instance with storage_config
return backend_class(storage_config, kwargs)
except Exception as e:
logger.error(
f"Failed to create dynamic storage backend '{backend_name}': {e}"
)
raise
@classmethod
def _create_builtin_backend(
cls,
backend_name: str,
backend_class: type[HiCacheStorage],
storage_config: HiCacheStorageConfig,
mem_pool_host: Any,
) -> HiCacheStorage:
"""Create built-in backend with original initialization logic."""
if backend_name == "file":
return backend_class(storage_config)
elif backend_name == "nixl":
return backend_class()
elif backend_name == "mooncake":
backend = backend_class(storage_config)
return backend
elif backend_name == "aibrix":
backend = backend_class(storage_config, mem_pool_host)
return backend
elif backend_name == "hf3fs":
# Calculate bytes_per_page based on memory pool layout
if mem_pool_host.layout == "page_first":
bytes_per_page = (
mem_pool_host.get_ksize_per_token() * mem_pool_host.page_size
)
elif mem_pool_host.layout == "layer_first":
bytes_per_page = (
mem_pool_host.get_size_per_token() * mem_pool_host.page_size
)
dtype = mem_pool_host.dtype
return backend_class.from_env_config(bytes_per_page, dtype, storage_config)
else:
raise ValueError(f"Unknown built-in backend: {backend_name}")
# Register built-in storage backends
StorageBackendFactory.register_backend(
"file", "sglang.srt.mem_cache.hicache_storage", "HiCacheFile"
)
StorageBackendFactory.register_backend(
"nixl",
"sglang.srt.mem_cache.storage.nixl.hicache_nixl",
"HiCacheNixl",
)
StorageBackendFactory.register_backend(
"mooncake",
"sglang.srt.mem_cache.storage.mooncake_store.mooncake_store",
"MooncakeStore",
)
StorageBackendFactory.register_backend(
"hf3fs",
"sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs",
"HiCacheHF3FS",
)
StorageBackendFactory.register_backend(
"aibrix",
"sglang.srt.mem_cache.storage.aibrix_kvcache.aibrix_kvcache_storage",
"AibrixKVCacheStorage",
)
......@@ -12,12 +12,7 @@ from typing import Any, List, Optional, Tuple
import torch
from sglang.srt.mem_cache.hicache_storage import (
HiCacheStorage,
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
)
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.storage.hf3fs.hf3fs_client import Hf3fsClient
from sglang.srt.metrics.collector import StorageMetrics
......@@ -183,14 +178,11 @@ class HiCacheHF3FS(HiCacheStorage):
self.skip_backup = True
self.rank = 0
self.is_zero_copy = False
logger.info(
f"[Rank {self.rank}] HiCacheHF3FS Client Initializing: "
f"file_path={self.file_path}, "
f"file_size={self.file_size / (2 ** 30):.2f} GB, "
f"num_pages={self.num_pages}, "
f"is_mla_model={self.is_mla_model}"
f"num_pages={self.num_pages}"
)
self.ac = AtomicCounter(self.numjobs)
......@@ -331,12 +323,25 @@ class HiCacheHF3FS(HiCacheStorage):
use_mock_client=use_mock_client,
)
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
return self.batch_get(
[key],
[target_location] if target_location is not None else None,
[target_sizes] if target_sizes is not None else None,
)[0]
@synchronized()
def _batch_get(
def batch_get(
self,
keys: List[str],
values: List[torch.Tensor],
) -> List[bool]:
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None]:
page_indices = self.metadata_client.get_page_indices(self.rank, keys)
batch_indices, file_offsets = [], []
......@@ -345,9 +350,15 @@ class HiCacheHF3FS(HiCacheStorage):
batch_indices.append(i)
file_offsets.append(page_index * self.bytes_per_page)
for target_location in values:
assert target_location.is_contiguous()
file_results = values
if target_locations is not None:
for target_location in target_locations:
assert target_location.is_contiguous()
file_results = target_locations
else:
file_results = [
torch.empty(self.numel, dtype=self.dtype)
for _ in range(len(batch_indices))
]
start_time = time.perf_counter()
......@@ -368,10 +379,12 @@ class HiCacheHF3FS(HiCacheStorage):
ionum / (end_time - start_time) * self.gb_per_page
)
results = [False] * len(keys)
for batch_index, read_result in zip(batch_indices, read_results):
results = [None] * len(keys)
for batch_index, file_result, read_result in zip(
batch_indices, file_results, read_results
):
if read_result == self.bytes_per_page:
results[batch_index] = True
results[batch_index] = file_result
else:
logger.error(
f"[Rank {self.rank}] HiCacheHF3FS get {keys[batch_index]} failed"
......@@ -379,12 +392,28 @@ class HiCacheHF3FS(HiCacheStorage):
return results
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
return self.batch_set(
[key],
[value] if value is not None else None,
[target_location] if target_location is not None else None,
[target_sizes] if target_sizes is not None else None,
)
@synchronized()
def _batch_set(
def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
) -> List[bool]:
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
# In MLA backend, only one rank needs to backup the KV cache
if self.skip_backup:
return True
......@@ -445,7 +474,7 @@ class HiCacheHF3FS(HiCacheStorage):
self.rank, written_keys_to_confirm, pages_to_release
)
return results
return all(results)
def delete(self, key: str) -> None:
self.metadata_client.delete_keys(self.rank, [key])
......@@ -455,25 +484,21 @@ class HiCacheHF3FS(HiCacheStorage):
return result[0] if result else False
def batch_exists(self, keys: List[str]) -> int:
factor = 1
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
factor = 2
results = self.metadata_client.exists(self.rank, keys)
for i in range(len(keys)):
if not results[i]:
return i
i = 0
while i < len(keys) and results[i]:
i += 1
return i // factor
return len(keys)
def clear(self) -> None:
def clear(self) -> bool:
try:
self.metadata_client.clear(self.rank)
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
return True
except Exception as e:
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
return False
def close(self) -> None:
try:
......@@ -496,139 +521,3 @@ class HiCacheHF3FS(HiCacheStorage):
self.prefetch_bandwidth.clear()
self.backup_bandwidth.clear()
return storage_metrics
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
super().register_mem_pool_host(mem_pool_host)
self.is_zero_copy = self.mem_pool_host.layout == "page_first"
logger.info(f"{self.is_zero_copy=}")
def _get_mha_zero_copy_keys(self, keys: List[str]) -> List[str]:
_keys = []
for k in keys:
_keys.append(f"{k}-k")
_keys.append(f"{k}-v")
return _keys
def _get_mha_zero_copy_values(
self, values: List[torch.Tensor]
) -> List[torch.Tensor]:
_values = []
for value in values:
_values.append(value[0])
_values.append(value[1])
return _values
def _batch_get_preprocess(self, keys, host_indices):
page_num = len(host_indices) // self.mem_pool_host.page_size
# host_indices to kv_buffer
flat = not self.is_zero_copy
values = (
[
self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
for i in range(page_num)
]
if self.is_zero_copy
else [
self.mem_pool_host.get_dummy_flat_data_page() for _ in range(page_num)
]
)
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
values = self._get_mha_zero_copy_values(values)
return keys, values
def _batch_get_postprocess(self, host_indices, values, results):
page_num = len(host_indices) // self.mem_pool_host.page_size
if self.is_zero_copy:
if not self.is_mla_model:
results = [
(results[2 * i] and results[2 * i + 1]) for i in range(page_num)
]
results = results[:page_num]
return results
for i in range(page_num):
if not results[i]:
break
self.mem_pool_host.set_from_flat_data_page(
host_indices[i * self.mem_pool_host.page_size], values[i]
)
return results
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
keys, values = self._batch_get_preprocess(keys, host_indices)
results = self._batch_get(keys, values)
return self._batch_get_postprocess(host_indices, values, results)
def _batch_set_preprocess(self, keys, host_indices):
page_num = len(host_indices) // self.mem_pool_host.page_size
# host_indices to kv_buffer
flat = not self.is_zero_copy
values = [
self.mem_pool_host.get_data_page(host_indices[i * page_num], flat=flat)
for i in range(page_num)
]
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
values = self._get_mha_zero_copy_values(values)
return keys, values
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
len_keys = len(keys)
keys, values = self._batch_set_preprocess(keys, host_indices)
results = self._batch_set(keys, values)
return results
# Deprecated
def get(
self,
key: str,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> torch.Tensor | None:
pass
# Deprecated
def batch_get(
self,
keys: List[str],
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> List[torch.Tensor | None] | int:
pass
# Deprecated
def set(
self,
key: str,
value: Optional[Any] = None,
target_location: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
pass
# Deprecated
def batch_set(
self,
keys: List[str],
values: Optional[Any] = None,
target_locations: Optional[Any] = None,
target_sizes: Optional[Any] = None,
) -> bool:
pass
......@@ -7,12 +7,7 @@ from typing import Any, List, Optional
import torch
from sglang.srt.mem_cache.hicache_storage import (
HiCacheStorage,
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
)
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
DEFAULT_GLOBAL_SEGMENT_SIZE = 4 * 1024 * 1024 * 1024 # 4 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 16 * 1024 * 1024 # 16 MB
......@@ -188,13 +183,7 @@ class MooncakeStore(HiCacheStorage):
assert self.store.is_exist(warmup_key) == 1
assert self.store.get(warmup_key) == warmup_value
def register_mem_pool_host(self, mem_pool_host: HostKVCache):
super().register_mem_pool_host(mem_pool_host)
assert self.mem_pool_host.layout in [
"page_first",
"page_first_direct",
], "mooncake store storage backend only support page first or page first direct layout"
buffer = self.mem_pool_host.kv_buffer
def register_buffer(self, buffer: torch.Tensor) -> None:
try:
buffer_ptr = buffer.data_ptr()
buffer_size = buffer.numel() * buffer.element_size()
......@@ -205,97 +194,6 @@ class MooncakeStore(HiCacheStorage):
logger.error("Failed to register buffer to Mooncake Store: %s", err)
raise TypeError("Mooncake Store Register Buffer Error.") from err
def _get_mha_buffer_meta(self, keys, indices):
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
key_list = []
for key_ in keys:
key_list.append(f"{key_}_{self.local_rank}_k")
key_list.append(f"{key_}_{self.local_rank}_v")
assert len(key_list) == len(ptr_list)
return key_list, ptr_list, element_size_list
def _get_mla_buffer_meta(self, keys, indices):
ptr_list, element_size_list = self.mem_pool_host.get_page_buffer_meta(indices)
key_list = []
for key_ in keys:
key_list.append(f"{key_}_k")
assert len(key_list) == len(ptr_list)
return key_list, ptr_list, element_size_list
def _batch_preprocess(self, keys, host_indices):
assert len(keys) > 0
assert len(keys) == len(host_indices) // self.mem_pool_host.page_size
if self.is_mla_backend:
return self._get_mla_buffer_meta(keys, host_indices)
else:
return self._get_mha_buffer_meta(keys, host_indices)
def _batch_postprocess(self, results: List[int], is_set_operate=False):
"""
refer to https://github.com/kvcache-ai/Mooncake/blob/main/mooncake-store/include/pybind_client.h
for batch_get_into, results is Vector of integers,
where each element is the number of bytes read on success, or a negative value on error
for batch_put_from, results is Vector of integers,
where each element is 0 on success, or a negative value on error
"""
if self.is_mla_backend:
return [k_res == 0 if is_set_operate else k_res > 0 for k_res in results]
else:
kv_pairs = zip(results[::2], results[1::2])
return [
(
(k_res == 0 and v_res == 0)
if is_set_operate
else (k_res > 0 and v_res > 0)
)
for k_res, v_res in kv_pairs
]
def batch_get_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
get_results = self._get_batch_zero_copy_impl(
key_strs, buffer_ptrs, buffer_sizes
)
return self._batch_postprocess(get_results, is_set_operate=False)
def batch_set_v1(
self,
keys: List[str],
host_indices: torch.Tensor,
extra_info: Optional[HiCacheStorageExtraInfo] = None,
) -> List[bool]:
key_strs, buffer_ptrs, buffer_sizes = self._batch_preprocess(keys, host_indices)
exist_result = self._batch_exist(key_strs)
set_keys = []
set_buffer_ptrs = []
set_buffer_sizes = []
set_indices = []
set_results = [-1] * len(key_strs)
for i in range(len(key_strs)):
if exist_result[i] != 1:
set_keys.append(key_strs[i])
set_buffer_ptrs.append(buffer_ptrs[i])
set_buffer_sizes.append(buffer_sizes[i])
set_indices.append(i)
else:
set_results[i] = 0
# Only set non-existing keys to storage
if len(set_keys) > 0:
put_results = self._put_batch_zero_copy_impl(
set_keys, set_buffer_ptrs, set_buffer_sizes
)
for i in range(len(set_indices)):
set_results[set_indices[i]] = put_results[i]
return self._batch_postprocess(set_results, is_set_operate=True)
def set(
self,
key,
......
......@@ -44,7 +44,7 @@ def generate_buckets(
return two_sides_exponential_buckets(float(middle), float(base), int(count))
if rule == "default":
return sorted(set(default_buckets))
assert rule == "custom"
assert rule == "customer"
return sorted(set([float(x) for x in buckets_rule[1:]]))
......
......@@ -167,6 +167,29 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
else:
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
else:
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs = (
list(range(1, 9))
+ list(range(10, 33, 2))
+ list(range(40, 64, 8))
+ list(range(80, 161, 16))
)
gpu_mem = get_device_memory_capacity()
if gpu_mem is not None:
if gpu_mem > 90 * 1024: # H200, H20
capture_bs += list(range(160, 257, 8))
if gpu_mem > 160 * 1000: # B200, MI300
capture_bs += list(range(256, 513, 16))
if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
......@@ -182,6 +205,12 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
if server_args.cuda_graph_max_bs:
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
if max(capture_bs) < server_args.cuda_graph_max_bs:
capture_bs += list(
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
)
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
capture_bs = list(sorted(set(capture_bs)))
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
......@@ -246,7 +275,7 @@ class CudaGraphRunner:
if (
model_runner.spec_algorithm.is_eagle()
or model_runner.spec_algorithm.is_standalone()
or model_runner.spec_algorithm.is_ngram()
or model_runner.spec_algorithm.is_lookahead()
):
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
......@@ -413,12 +442,12 @@ class CudaGraphRunner:
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
)
is_ngram_supported = (
is_lookahead_supported = (
(
forward_batch.batch_size * self.num_tokens_per_bs
== forward_batch.input_ids.numel()
)
if self.model_runner.spec_algorithm.is_ngram()
if self.model_runner.spec_algorithm.is_lookahead()
else True
)
......@@ -427,7 +456,7 @@ class CudaGraphRunner:
and is_encoder_lens_supported
and is_tbo_supported
and capture_hidden_mode_matches
and is_ngram_supported
and is_lookahead_supported
)
def capture(self) -> None:
......@@ -437,7 +466,6 @@ class CudaGraphRunner:
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
)
torch.cuda.memory._record_memory_history()
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
......@@ -486,8 +514,6 @@ class CudaGraphRunner:
save_gemlite_cache()
if self.enable_profile_cuda_graph:
torch.cuda.memory._dump_snapshot(f"cuda_graph_runner_memory_usage.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
log_message = (
"Sorted by CUDA Time:\n"
+ prof.key_averages(group_by_input_shape=True).table(
......@@ -497,7 +523,6 @@ class CudaGraphRunner:
+ prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
+ "\n\nMemory Usage is saved to cuda_graph_runner_memory_usage.pickle\n"
)
logger.info(log_message)
......@@ -518,6 +543,9 @@ class CudaGraphRunner:
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
seq_lens_cpu = self.seq_lens_cpu[
:bs
] # TODO: Remove this after changing to real indexer
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
if self.is_encoder_decoder:
......@@ -588,6 +616,7 @@ class CudaGraphRunner:
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_cpu=seq_lens_cpu, # TODO: Remove this after changing to real indexer
next_token_logits_buffer=next_token_logits_buffer,
orig_seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
......@@ -842,10 +871,10 @@ class CudaGraphRunner:
seq_lens_cpu=None,
)
elif self.model_runner.spec_algorithm.is_ngram():
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
elif self.model_runner.spec_algorithm.is_lookahead():
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
spec_info = NgramVerifyInput(
spec_info = LookaheadVerifyInput(
draft_token=None,
tree_mask=self.custom_mask,
positions=None,
......
......@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import (
flatten_nested_list,
get_compiler_backend,
is_npu,
support_triton,
)
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -300,6 +294,7 @@ class ForwardBatch:
# For padding
padded_static_len: int = -1 # -1 if not padded
num_token_non_padded: Optional[torch.Tensor] = None # scalar tensor
num_token_non_padded_cpu: int = None
# For Qwen2-VL
mrope_positions: torch.Tensor = None
......@@ -361,6 +356,7 @@ class ForwardBatch:
ret.num_token_non_padded = torch.tensor(
len(batch.input_ids), dtype=torch.int32
).to(device, non_blocking=True)
ret.num_token_non_padded_cpu = len(batch.input_ids)
# For MLP sync
if batch.global_num_tokens is not None:
......
......@@ -33,7 +33,12 @@ import torch.distributed as dist
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.configs.model_config import (
AttentionArch,
ModelConfig,
get_nsa_index_head_dim,
is_deepseek_nsa,
)
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
from sglang.srt.connector import ConnectorType
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
......@@ -60,10 +65,7 @@ from sglang.srt.eplb.expert_location import (
set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
from sglang.srt.layers.attention.attention_registry import (
ATTENTION_BACKENDS,
attn_backend_wrapper,
)
from sglang.srt.layers.attention.attention_registry import ATTENTION_BACKENDS
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
......@@ -98,6 +100,7 @@ from sglang.srt.mem_cache.memory_pool import (
HybridReqToTokenPool,
MHATokenToKVPool,
MLATokenToKVPool,
NSATokenToKVPool,
ReqToTokenPool,
SWAKVPool,
)
......@@ -107,9 +110,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.model_loader.utils import set_default_torch_dtype
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.offloader import (
......@@ -118,6 +118,9 @@ from sglang.srt.offloader import (
set_offloader,
)
from sglang.srt.patch_torch import monkey_patch_torch_reductions
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_init_weights_send_group_for_remote_instance_request,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
......@@ -160,6 +163,7 @@ MLA_ATTENTION_BACKENDS = [
"cutlass_mla",
"trtllm_mla",
"ascend",
"nsa",
]
......@@ -182,13 +186,6 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
logger = logging.getLogger(__name__)
if _is_npu:
import torch_npu
torch.npu.config.allow_internal_format = True
torch_npu.npu.set_compile_mode(jit_compile=False)
class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
......@@ -350,6 +347,7 @@ class ModelRunner:
if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache")
self.server_args.disable_radix_cache = True
self.server_args.attention_backend = "hybrid_linear_attn"
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
self.server_args.max_mamba_cache_size = (
......@@ -745,10 +743,6 @@ class ModelRunner:
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
model_loader_extra_config=self.server_args.model_loader_extra_config,
tp_rank=self.tp_rank,
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(
......@@ -1484,8 +1478,7 @@ class ModelRunner:
if self.max_total_num_tokens <= 0:
raise RuntimeError(
f"Not enough memory. Please try to increase --mem-fraction-static. "
f"Current value: {self.server_args.mem_fraction_static=}"
"Not enough memory. Please try to increase --mem-fraction-static."
)
# Initialize req_to_token_pool
......@@ -1544,6 +1537,7 @@ class ModelRunner:
assert self.is_draft_worker
# Initialize token_to_kv_pool
is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
if self.server_args.attention_backend == "ascend":
if self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
......@@ -1552,6 +1546,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
index_head_dim=self.model_config.index_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
......@@ -1571,7 +1566,22 @@ class ModelRunner:
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.use_mla_backend and is_nsa_model:
self.token_to_kv_pool = NSATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
)
elif self.use_mla_backend:
assert not is_nsa_model
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
......@@ -1650,9 +1660,10 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if _is_npu and (
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
):
if _is_npu and self.server_args.attention_backend in [
"ascend",
"hybrid_linear_attn",
]:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
......@@ -1765,8 +1776,7 @@ class ModelRunner:
def _get_attention_backend_from_str(self, backend_str: str):
if backend_str not in ATTENTION_BACKENDS:
raise ValueError(f"Invalid attention backend: {backend_str}")
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
return attn_backend_wrapper(self, full_attention_backend)
return ATTENTION_BACKENDS[backend_str](self)
def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"
......
......@@ -19,10 +19,8 @@ import logging
import threading
from typing import TYPE_CHECKING, Optional, Union
import numpy as np
import torch
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
logger = logging.getLogger(__name__)
......@@ -75,11 +73,16 @@ class NPUGraphRunner(CudaGraphRunner):
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (self.bs - self.raw_bs)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
thread.start()
self.graphs[self.bs].replay()
thread.join()
if self.model_runner.model_config.index_head_dim is None:
seq_lens = forward_batch.seq_lens.cpu().tolist() + [0] * (
self.bs - self.raw_bs
)
thread = threading.Thread(target=self._update_inputs, args=(seq_lens,))
thread.start()
self.graphs[self.bs].replay()
thread.join()
else:
self.graphs[self.bs].replay()
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
......
......@@ -54,9 +54,6 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.model_loader.utils import (
get_model_architecture,
post_load_weights,
......@@ -80,6 +77,9 @@ from sglang.srt.model_loader.weight_utils import (
safetensors_weights_iterator,
set_runai_streamer_env,
)
from sglang.srt.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
......@@ -206,10 +206,7 @@ def _initialize_model(
if _is_npu:
packed_modules_mapping.update(
{
"visual": {
"qkv_proj": ["qkv"],
"gate_up_proj": ["gate_proj", "up_proj"],
},
"visual": {"qkv_proj": ["qkv"]},
"vision_model": {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"proj": ["out_proj"],
......@@ -1420,7 +1417,7 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f"load format {load_config.load_format}"
)
model_weights = f"instance://{load_config.remote_instance_weight_loader_seed_instance_ip}:{load_config.remote_instance_weight_loader_send_weights_group_ports[load_config.tp_rank]}"
model_weights = f"instance://{model_config.remote_instance_weight_loader_seed_instance_ip}:{model_config.remote_instance_weight_loader_send_weights_group_ports[model_config.tp_rank]}"
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
......@@ -1442,12 +1439,11 @@ class RemoteInstanceModelLoader(BaseModelLoader):
def load_model_from_remote_instance(
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
load_config = self.load_config
instance_ip = socket.gethostbyname(socket.gethostname())
start_build_group_tic = time.time()
client.build_group(
gpu_id=device_config.gpu_id,
tp_rank=load_config.tp_rank,
tp_rank=model_config.tp_rank,
instance_ip=instance_ip,
)
torch.cuda.synchronize()
......@@ -1456,13 +1452,13 @@ class RemoteInstanceModelLoader(BaseModelLoader):
f"finish building group for remote instance, time used: {(end_build_group_tic - start_build_group_tic):.4f}s"
)
if load_config.tp_rank == 0:
if model_config.tp_rank == 0:
t = threading.Thread(
target=trigger_transferring_weights_request,
args=(
load_config.remote_instance_weight_loader_seed_instance_ip,
load_config.remote_instance_weight_loader_seed_instance_service_port,
load_config.remote_instance_weight_loader_send_weights_group_ports,
model_config.remote_instance_weight_loader_seed_instance_ip,
model_config.remote_instance_weight_loader_seed_instance_service_port,
model_config.remote_instance_weight_loader_send_weights_group_ports,
instance_ip,
),
)
......
......@@ -8,6 +8,7 @@ import hashlib
import json
import logging
import os
import queue
import tempfile
from collections import defaultdict
from typing import (
......@@ -37,8 +38,7 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
from sglang.srt.utils import find_local_repo_dir, print_warning_once
from sglang.utils import is_in_ci
from sglang.srt.utils import print_warning_once
logger = logging.getLogger(__name__)
......@@ -236,89 +236,6 @@ def get_quant_config(
return quant_cls.from_config(config)
def find_local_hf_snapshot_dir(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
) -> Optional[str]:
"""If the weights are already local, skip downloading and returns the path."""
if os.path.isdir(model_name_or_path):
return None
found_local_snapshot_dir = None
# Check custom cache_dir (if provided)
if cache_dir:
try:
repo_folder = os.path.join(
cache_dir,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *model_name_or_path.split("/")]
),
)
rev_to_use = revision
if not rev_to_use:
ref_main = os.path.join(repo_folder, "refs", "main")
if os.path.isfile(ref_main):
with open(ref_main) as f:
rev_to_use = f.read().strip()
if rev_to_use:
rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
if os.path.isdir(rev_dir):
found_local_snapshot_dir = rev_dir
except Exception as e:
logger.warning(
"Failed to find local snapshot in custom cache_dir %s: %s",
cache_dir,
e,
)
# Check default HF cache as well
if not found_local_snapshot_dir:
try:
rev_dir = find_local_repo_dir(model_name_or_path, revision)
if rev_dir and os.path.isdir(rev_dir):
found_local_snapshot_dir = rev_dir
except Exception as e:
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
# If local snapshot exists, validate it contains at least one weight file
# matching allow_patterns before skipping download.
if found_local_snapshot_dir is None:
return None
local_weight_files: List[str] = []
try:
for pattern in allow_patterns:
local_weight_files.extend(
glob.glob(os.path.join(found_local_snapshot_dir, pattern))
)
except Exception as e:
logger.warning(
"Failed to scan local snapshot %s with patterns %s: %s",
found_local_snapshot_dir,
allow_patterns,
e,
)
local_weight_files = []
if len(local_weight_files) > 0:
logger.info(
"Found local HF snapshot for %s at %s; skipping download.",
model_name_or_path,
found_local_snapshot_dir,
)
return found_local_snapshot_dir
else:
logger.info(
"Local HF snapshot at %s has no files matching %s; will attempt download.",
found_local_snapshot_dir,
allow_patterns,
)
return None
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
......@@ -343,16 +260,6 @@ def download_weights_from_hf(
Returns:
str: The path to the downloaded model weights.
"""
if is_in_ci():
# If the weights are already local, skip downloading and returns the path.
# This is used to skip too-many Huggingface API calls in CI.
path = find_local_hf_snapshot_dir(
model_name_or_path, cache_dir, allow_patterns, revision
)
if path is not None:
return path
if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available:
fs = HfFileSystem()
......
......@@ -45,12 +45,12 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
is_dp_attention_enabled,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
......@@ -72,10 +72,6 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.utils import (
create_fused_set_kv_buffer_arg,
enable_fused_set_kv_buffer,
)
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None
......@@ -559,27 +555,8 @@ class BailingMoEAttention(nn.Module):
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(
positions,
q,
k,
fused_set_kv_buffer_arg=(
create_fused_set_kv_buffer_arg(
value=v,
layer=self.attn,
forward_batch=forward_batch,
)
if enable_fused_set_kv_buffer(forward_batch)
else None
),
)
context_layer = self.attn(
q,
k,
v,
forward_batch,
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
)
q, k = self.rotary_emb(positions, q, k)
context_layer = self.attn(q, k, v, forward_batch)
attn_output, _ = self.dense(context_layer)
return attn_output
......@@ -725,7 +702,7 @@ class BailingMoEModel(nn.Module):
self.embed_dim,
quant_config=quant_config,
prefix=add_prefix("word_embeddings", prefix),
enable_tp=not is_dp_attention_enabled(),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
else:
self.word_embeddings = PPMissingLayer()
......
......@@ -15,6 +15,7 @@
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
from __future__ import annotations
import concurrent.futures
import logging
......@@ -25,9 +26,15 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from transformers import PretrainedConfig
from sglang.srt.configs.model_config import (
get_nsa_index_head_dim,
get_nsa_index_n_heads,
get_nsa_index_topk,
is_deepseek_nsa,
)
from sglang.srt.debug_utils.dumper import dumper
from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_pp_group,
......@@ -47,6 +54,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
NPUFusedMLAPreprocess,
is_mla_preprocess_enabled,
)
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
......@@ -175,6 +183,11 @@ if _is_hip:
decode_attention_fwd_grouped_rope,
)
if _is_npu:
import custom_ops
import sgl_kernel_npu
import torch_npu
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
......@@ -183,6 +196,7 @@ logger = logging.getLogger(__name__)
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
"fa3",
"nsa",
"flashinfer",
"cutlass_mla",
"trtllm_mla",
......@@ -203,6 +217,9 @@ class AttnForwardMethod(IntEnum):
# Use absorbed multi-latent attention
MLA = auto()
# Use Deepseek V3.2 sparse multi-latent attention
NPU_MLA_SPARSE = auto()
# Use multi-head attention, but with KV cache chunked.
# This method can avoid OOM when prefix lengths are long.
MHA_CHUNKED_KV = auto()
......@@ -245,9 +262,15 @@ def handle_ascend(attn, forward_batch):
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
if hasattr(attn, "indexer"):
return AttnForwardMethod.NPU_MLA_SPARSE
else:
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
if hasattr(attn, "indexer"):
return AttnForwardMethod.NPU_MLA_SPARSE
else:
return AttnForwardMethod.MLA
def _get_sum_extend_prefix_lens(forward_batch):
......@@ -266,7 +289,7 @@ def _is_extend_without_speculative(forward_batch):
)
def _handle_backend(attn, forward_batch, backend_name):
def _handle_backend(attn: DeepseekV2AttentionMLA, forward_batch, backend_name):
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
disable_ragged = (
backend_name in ["flashinfer", "flashmla"]
......@@ -332,6 +355,10 @@ def handle_aiter(attn, forward_batch):
return AttnForwardMethod.MLA
def handle_nsa(attn, forward_batch):
return AttnForwardMethod.MLA
def handle_triton(attn, forward_batch):
if (
_is_extend_without_speculative(forward_batch)
......@@ -996,6 +1023,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
# For tensor parallel attention
if self.q_lora_rank is not None:
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
......@@ -1033,6 +1064,26 @@ class DeepseekV2AttentionMLA(nn.Module):
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
)
self.use_nsa = is_deepseek_nsa(config)
if self.use_nsa:
self.indexer = Indexer(
hidden_size=hidden_size,
index_n_heads=get_nsa_index_n_heads(config),
index_head_dim=get_nsa_index_head_dim(config),
rope_head_dim=qk_rope_head_dim,
index_topk=get_nsa_index_topk(config),
q_lora_rank=q_lora_rank,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
scale_fmt="ue8m0",
block_size=128,
rope_scaling=rope_scaling,
prefix=add_prefix("indexer", prefix),
quant_config=quant_config,
layer_id=layer_id,
alt_stream=alt_stream,
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
......@@ -1055,9 +1106,6 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
......@@ -1184,8 +1232,8 @@ class DeepseekV2AttentionMLA(nn.Module):
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
if self.is_mla_preprocess_enabled:
assert (
quant_config.get_name() == "w8a8_int8"
), "MLA Preprocess only works with W8A8Int8"
quant_config is None or quant_config.get_name() == "w8a8_int8"
), "MLA Preprocess only works with Unquant or W8A8Int8"
self.mla_preprocess = None
def dispatch_attn_forward_method(
......@@ -1263,7 +1311,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return hidden_states, None, forward_batch, None
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
if attn_forward_method == AttnForwardMethod.MHA:
inner_state = self.forward_normal_prepare(
positions, hidden_states, forward_batch, zero_allocator
......@@ -1295,6 +1342,10 @@ class DeepseekV2AttentionMLA(nn.Module):
inner_state = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
inner_state = self.forward_npu_sparse_prepare(
positions, hidden_states, forward_batch, zero_allocator
)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
inner_state = self.forward_absorb_fused_mla_rope_prepare(
positions, hidden_states, forward_batch, zero_allocator
......@@ -1320,6 +1371,8 @@ class DeepseekV2AttentionMLA(nn.Module):
return self.forward_normal_chunked_kv_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA:
return self.forward_absorb_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
return self.forward_npu_sparse_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
return self.forward_absorb_fused_mla_rope_core(*inner_state)
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
......@@ -1412,6 +1465,7 @@ class DeepseekV2AttentionMLA(nn.Module):
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
q_lora = None
if self.q_lora_rank is not None:
if (
(not isinstance(hidden_states, tuple))
......@@ -1450,6 +1504,10 @@ class DeepseekV2AttentionMLA(nn.Module):
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
# q_lora needed by indexer
if self.use_nsa:
q_lora = q
k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
......@@ -1519,10 +1577,37 @@ class DeepseekV2AttentionMLA(nn.Module):
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
topk_indices = None
if q_lora is not None:
topk_indices = self.indexer(
x=hidden_states,
q_lora=q_lora,
positions=positions,
forward_batch=forward_batch,
layer_id=self.layer_id,
)
return (
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
topk_indices,
)
def forward_absorb_core(
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
self,
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
topk_indices,
):
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
extra_args = {}
......@@ -1531,6 +1616,7 @@ class DeepseekV2AttentionMLA(nn.Module):
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
"is_neox": self.rotary_emb.is_neox_style,
}
attn_output = self.attn_mqa(
q_nope_out,
k_nope,
......@@ -1539,6 +1625,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_rope=q_pe,
k_rope=k_pe,
**extra_args,
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
)
else:
if _use_aiter_gfx95:
......@@ -1558,7 +1645,13 @@ class DeepseekV2AttentionMLA(nn.Module):
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
attn_output = self.attn_mqa(
q,
k,
k_nope,
forward_batch,
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.use_deep_gemm_bmm:
......@@ -1640,6 +1733,221 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
def forward_npu_sparse_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
"""
Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
"""
if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
if self.mla_preprocess is None:
self.mla_preprocess = NPUFusedMLAPreprocess(
self.fused_qkv_a_proj_with_mqa,
self.q_a_layernorm,
self.kv_a_layernorm,
self.q_b_proj,
self.w_kc,
self.rotary_emb,
self.layer_id,
self.num_local_heads,
self.qk_nope_head_dim,
self.qk_rope_head_dim,
)
(
q_pe,
k_pe,
q_nope_out,
k_nope,
forward_batch,
zero_allocator,
positions,
) = self.mla_preprocess.forward(
positions, hidden_states, forward_batch, zero_allocator
)
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
q, _ = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
q_lora = self.q_a_layernorm(q)
else:
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if (
(not isinstance(hidden_states, tuple))
and hidden_states.shape[0] <= 16
and self.use_min_latency_fused_a_gemm
):
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
)
else:
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
q, latent_cache = fused_qkv_a_proj_out.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
)
k_nope = latent_cache[..., : self.kv_lora_rank]
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q = self.q_a_layernorm(q)
with torch.cuda.stream(self.alt_stream):
k_nope = self.kv_a_layernorm(k_nope)
current_stream.wait_stream(self.alt_stream)
else:
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
q, k_nope = fused_rms_mxfp4_quant(
q,
self.q_a_layernorm.weight,
self.q_a_layernorm.variance_epsilon,
k_nope,
self.kv_a_layernorm.weight,
self.kv_a_layernorm.variance_epsilon,
)
else:
q = self.q_a_layernorm(q)
k_nope = self.kv_a_layernorm(k_nope)
q_lora = q.clone() # required for topk_indices
k_nope = k_nope.unsqueeze(1)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
q_nope, q_pe = q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
if self.use_deep_gemm_bmm:
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
per_token_group_quant_mla_deep_gemm_masked_fp8(
q_nope.transpose(0, 1)
)
)
q_nope_out = q_nope.new_empty(
(self.num_local_heads, aligned_m, self.kv_lora_rank)
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
(q_nope_val, q_nope_scale),
(self.w_kc, self.w_scale_k),
q_nope_out,
masked_m,
expected_m,
)
q_nope_out = q_nope_out[:, :expected_m, :]
elif _is_hip:
# TODO(haishaw): add bmm_fp8 to ROCm
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
x = q_nope.transpose(0, 1)
q_nope_out = torch.empty(
x.shape[0],
x.shape[1],
self.w_kc.shape[2],
device=x.device,
dtype=torch.bfloat16,
)
batched_gemm_afp4wfp4_pre_quant(
x,
self.w_kc.transpose(-2, -1),
self.w_scale_k.transpose(-2, -1),
torch.bfloat16,
q_nope_out,
)
else:
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1),
zero_allocator.allocate(1),
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_nope_out = q_nope_out.transpose(0, 1)
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
not _use_aiter or not _is_gfx95_supported
):
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
# TODO: multi-stream indexer
topk_indices = self.indexer(
hidden_states, q_lora, positions, forward_batch, self.layer_id
)
return (
q_pe,
k_pe,
q_nope_out,
k_nope,
topk_indices,
forward_batch,
zero_allocator,
positions,
)
def forward_npu_sparse_core(
self,
q_pe,
k_pe,
q_nope_out,
k_nope,
topk_indices,
forward_batch,
zero_allocator,
positions,
):
attn_output = self.attn_mqa(
q_nope_out.contiguous(),
k_nope.contiguous(),
k_nope.contiguous(),
forward_batch,
save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
q_rope=q_pe.contiguous(),
k_rope=k_pe.contiguous(),
topk_indices=topk_indices,
)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
attn_bmm_output = torch.empty(
(attn_output.shape[0], self.num_local_heads, self.v_head_dim),
dtype=attn_output.dtype,
device=attn_output.device,
)
if not forward_batch.forward_mode.is_decode():
attn_output = attn_output.transpose(0, 1)
torch.bmm(
attn_output,
self.w_vc,
out=attn_bmm_output.view(
-1, self.num_local_heads, self.v_head_dim
).transpose(0, 1),
)
else:
attn_output = attn_output.contiguous()
torch.ops.npu.batch_matmul_transpose(
attn_output, self.w_vc, attn_bmm_output
)
attn_bmm_output = attn_bmm_output.reshape(
-1, self.num_local_heads * self.v_head_dim
)
output, _ = self.o_proj(attn_bmm_output)
return output
def forward_absorb_fused_mla_rope_prepare(
self,
positions: torch.Tensor,
......@@ -2121,7 +2429,6 @@ class DeepseekV2DecoderLayer(nn.Module):
zero_allocator: BumpAllocator,
gemm_output_zero_allocator: BumpAllocator = None,
) -> torch.Tensor:
quant_format = (
"mxfp4"
if _is_gfx95_supported
......@@ -2704,7 +3011,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
)
self_attn.w_vc = bind_or_assign(
self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
self_attn.w_vc, w_vc.contiguous().transpose(1, 2).contiguous()
)
if (
hasattr(self_attn.kv_b_proj, "weight_scale")
......@@ -3086,6 +3393,7 @@ BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
BackendRegistry.register("fa4", handle_fa4)
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
BackendRegistry.register("aiter", handle_aiter)
BackendRegistry.register("nsa", handle_nsa)
BackendRegistry.register("triton", handle_triton)
......@@ -3093,4 +3401,8 @@ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
pass
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]
......@@ -20,6 +20,7 @@ import torch.nn.functional as F
from torch import nn
from transformers import (
ROPE_INIT_FUNCTIONS,
AutoModel,
Gemma3TextConfig,
PretrainedConfig,
PreTrainedModel,
......@@ -760,3 +761,4 @@ class Gemma3ForCausalLM(PreTrainedModel):
EntryClass = Gemma3ForCausalLM
AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment