Unverified Commit a6305c7d authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Lianmin/simplify memory pool (#7202)

parent a023856b
......@@ -34,7 +34,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import debug_timing, get_compiler_backend, is_cuda
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
logger = logging.getLogger(__name__)
......@@ -182,6 +182,9 @@ class TokenToKVPoolAllocator:
def available_size(self):
return len(self.free_slots)
def debug_print(self) -> str:
return ""
def get_kvcache(self):
return self._kvcache
......@@ -314,17 +317,25 @@ class MHATokenToKVPool(KVCache):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
self.get_key_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_data_lens = [
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
self.get_key_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_item_lens = [
self.get_key_buffer(i)[0].nbytes * self.page_size
for i in range(self.layer_num)
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.layer_num)
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
......@@ -444,36 +455,6 @@ class MHATokenToKVPool(KVCache):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
@torch.compile
def fused_downcast(
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
dtype: torch.dtype,
store_dtype: torch.dtype,
max_fp8: float,
min_fp8: float,
):
cache_k = cache_k / k_scale
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
cache_v = cache_v / v_scale
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
cache_k = cache_k.to(dtype)
cache_v = cache_v.to(dtype)
cache_k = cache_k.view(store_dtype)
cache_v = cache_v.view(store_dtype)
return cache_k, cache_v
# This compiled version is slower in the unit test
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
@torch.compile(dynamic=True, backend=get_compiler_backend())
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
dst_1[loc] = src_1.to(dtype).view(store_dtype)
dst_2[loc] = src_2.to(dtype).view(store_dtype)
@triton.jit
def set_mla_kv_buffer_kernel(
kv_buffer_ptr,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment