Unverified Commit 5fbafbb8 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

fix MLATokenToKVPoolHost get_size_per_token bug (#5161)


Co-authored-by: default avatarAniZpZ <zhuangsen.zp@antgroup.com>
parent a9499885
......@@ -879,7 +879,12 @@ class MLATokenToKVPoolHost(HostKVCache):
self.qk_rope_head_dim = self.device_pool.qk_rope_head_dim
self.layer_num = self.device_pool.layer_num
return (self.kv_lora_rank + self.qk_rope_head_dim) * 1 * self.dtype.itemsize
return (
(self.kv_lora_rank + self.qk_rope_head_dim)
* 1
* self.dtype.itemsize
* self.layer_num
)
def init_kv_buffer(self):
return torch.empty(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment