Unverified Commit 99236852 authored by yjz's avatar yjz Committed by GitHub
Browse files

[KVTransfer] Fix TpKVTopology.is_kv_replicated equality case (#38179)


Signed-off-by: default avatarJianDan0212 <zhangyj0212@gmail.com>
Co-authored-by: default avatarNicolò Lucchesi <nlucches@redhat.com>
parent 58ee6142
...@@ -457,9 +457,11 @@ class TpKVTopology: ...@@ -457,9 +457,11 @@ class TpKVTopology:
""" """
Whether the KV cache is replicated across TP workers due to the Whether the KV cache is replicated across TP workers due to the
number of TP workers being greater than the number of KV heads. number of TP workers being greater than the number of KV heads.
When they are equal, each TP rank still owns one distinct KV head,
so this is not considered replication.
""" """
tp_size = self.remote_tp_size[engine_id] tp_size = self.remote_tp_size[engine_id]
return tp_size // self.total_num_kv_heads >= 1 return tp_size > self.total_num_kv_heads
def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool: def replicates_kv_cache(self, remote_engine_id: EngineId) -> bool:
# MLA is always replicated as the hidden dim can't be split. # MLA is always replicated as the hidden dim can't be split.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment