Unverified Commit a6305c7d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Lianmin/simplify memory pool (#7202)

parent a023856b
...@@ -34,7 +34,7 @@ import triton ...@@ -34,7 +34,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_compiler_backend, is_cuda from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -182,6 +182,9 @@ class TokenToKVPoolAllocator: ...@@ -182,6 +182,9 @@ class TokenToKVPoolAllocator:
def available_size(self): def available_size(self):
return len(self.free_slots) return len(self.free_slots)
def debug_print(self) -> str:
return ""
def get_kvcache(self): def get_kvcache(self):
return self._kvcache return self._kvcache
...@@ -314,17 +317,25 @@ class MHATokenToKVPool(KVCache): ...@@ -314,17 +317,25 @@ class MHATokenToKVPool(KVCache):
# layer_num x [seq_len, head_num, head_dim] # layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim] # layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [ kv_data_ptrs = [
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num) self.get_key_buffer(i).data_ptr()
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)] for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_data_lens = [ kv_data_lens = [
self.get_key_buffer(i).nbytes for i in range(self.layer_num) self.get_key_buffer(i).nbytes
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_item_lens = [ kv_item_lens = [
self.get_key_buffer(i)[0].nbytes * self.page_size self.get_key_buffer(i)[0].nbytes * self.page_size
for i in range(self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [ ] + [
self.get_value_buffer(i)[0].nbytes * self.page_size self.get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.layer_num) for i in range(self.start_layer, self.start_layer + self.layer_num)
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
...@@ -444,36 +455,6 @@ class MHATokenToKVPool(KVCache): ...@@ -444,36 +455,6 @@ class MHATokenToKVPool(KVCache):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v self.v_buffer[layer_id - self.start_layer][loc] = cache_v
@torch.compile
def fused_downcast(
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
dtype: torch.dtype,
store_dtype: torch.dtype,
max_fp8: float,
min_fp8: float,
):
cache_k = cache_k / k_scale
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
cache_v = cache_v / v_scale
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
cache_k = cache_k.to(dtype)
cache_v = cache_v.to(dtype)
cache_k = cache_k.view(store_dtype)
cache_v = cache_v.view(store_dtype)
return cache_k, cache_v
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True, backend=get_compiler_backend())
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_1[loc] = src_1.to(dtype).view(store_dtype)
dst_2[loc] = src_2.to(dtype).view(store_dtype)
@triton.jit @triton.jit
def set_mla_kv_buffer_kernel( def set_mla_kv_buffer_kernel(
kv_buffer_ptr, kv_buffer_ptr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment