Unverified Commit 69d19188 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Decouple kv (#679)

parent 4b4a67f8
...@@ -99,7 +99,7 @@ class RadixAttention(nn.Module): ...@@ -99,7 +99,7 @@ class RadixAttention(nn.Module):
else: else:
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse( o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id], input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
causal=False, causal=False,
sm_scale=self.scaling, sm_scale=self.scaling,
logits_soft_cap=self.logit_cap, logits_soft_cap=self.logit_cap,
...@@ -119,7 +119,7 @@ class RadixAttention(nn.Module): ...@@ -119,7 +119,7 @@ class RadixAttention(nn.Module):
o = input_metadata.flashinfer_decode_wrapper.forward( o = input_metadata.flashinfer_decode_wrapper.forward(
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
input_metadata.token_to_kv_pool.kv_data[self.layer_id], input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
sm_scale=self.scaling, sm_scale=self.scaling,
logits_soft_cap=self.logit_cap, logits_soft_cap=self.logit_cap,
) )
...@@ -136,33 +136,7 @@ class RadixAttention(nn.Module): ...@@ -136,33 +136,7 @@ class RadixAttention(nn.Module):
return self.decode_forward(q, k, v, input_metadata) return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id] k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc) v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
k_cache[input_metadata.out_cache_loc] = cache_k
v_cache[input_metadata.out_cache_loc] = cache_v
try:
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
def _store_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
cache_loc: torch.Tensor,
) -> None:
kv_cache[cache_loc, 0] = k
kv_cache[cache_loc, 1] = v
@_store_kv_cache.register_fake
def _(k, v, kv_cache, cache_loc):
pass
except:
def _store_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
cache_loc: torch.Tensor,
) -> None:
kv_cache[cache_loc, 0] = k
kv_cache[cache_loc, 1] = v
...@@ -57,9 +57,13 @@ class TokenToKVPool: ...@@ -57,9 +57,13 @@ class TokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens. # We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
# [size, key/value, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
self.kv_data = [ self.k_buffer = [
torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda") torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num) for _ in range(layer_num)
] ]
...@@ -71,10 +75,13 @@ class TokenToKVPool: ...@@ -71,10 +75,13 @@ class TokenToKVPool:
self.clear() self.clear()
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
return self.kv_data[layer_id][:, 0] return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int): def get_value_buffer(self, layer_id: int):
return self.kv_data[layer_id][:, 1] return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
def available_size(self): def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer) return self.can_use_mem_size + len(self.prefetch_buffer)
......
...@@ -182,7 +182,7 @@ def launch_server( ...@@ -182,7 +182,7 @@ def launch_server(
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version( assert_pkg_version(
"flashinfer", "flashinfer",
"0.1.0", "0.1.1",
"Please uninstall the old version and " "Please uninstall the old version and "
"reinstall the latest version by following the instructions " "reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment