Unverified Commit 5c3bae1a authored by ant-yy's avatar ant-yy Committed by GitHub
Browse files

[Fix] Remove divisibility requirement between num_kv_heads and tp_size in bailing_moe (#26876)


Signed-off-by: default avatarvito.yy <vito.yy@antgroup.com>
parent 5210dc39
...@@ -86,13 +86,12 @@ class BailingAttention(nn.Module): ...@@ -86,13 +86,12 @@ class BailingAttention(nn.Module):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
assert self.total_kv_heads % tp_size == 0
assert self.total_num_heads >= self.total_kv_heads assert self.total_num_heads >= self.total_kv_heads
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
self.q_size_per_rank = self.head_dim * self.num_heads self.q_size_per_rank = self.head_dim * self.num_heads
self.num_kv_heads = self.total_kv_heads // tp_size self.num_kv_heads = max(1, self.total_kv_heads // tp_size)
self.kv_size_per_rank = self.num_kv_heads * self.head_dim self.kv_size_per_rank = self.num_kv_heads * self.head_dim
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
self.use_qk_norm = getattr(config, "use_qk_norm", False) self.use_qk_norm = getattr(config, "use_qk_norm", False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment