Unverified Commit 2f6af1a3 authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

Enable bailing_moe to support TP=16 (#12369)

parent 50b6842b
...@@ -420,14 +420,21 @@ class BailingMoEAttention(nn.Module): ...@@ -420,14 +420,21 @@ class BailingMoEAttention(nn.Module):
attn_tp_size = get_attention_tp_size() attn_tp_size = get_attention_tp_size()
assert self.total_num_heads % attn_tp_size == 0 assert self.total_num_heads % attn_tp_size == 0
assert self.total_kv_heads % attn_tp_size == 0 if self.total_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert attn_tp_size % self.total_kv_heads == 0
assert self.total_num_heads >= self.total_kv_heads assert self.total_num_heads >= self.total_kv_heads
self.num_heads = self.total_num_heads // attn_tp_size self.num_heads = self.total_num_heads // attn_tp_size
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
self.q_size = self.head_dim * self.num_heads self.q_size = self.head_dim * self.num_heads
self.num_kv_heads = self.total_kv_heads // attn_tp_size self.num_kv_heads = max(1, self.total_kv_heads // attn_tp_size)
self.kv_size = max(1, self.num_kv_heads * self.head_dim) self.kv_size = max(1, self.num_kv_heads * self.head_dim)
self.scale = self.head_dim**-0.5 self.scale = self.head_dim**-0.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment