Unverified Commit f7e62e3d authored by bhargav-patel-29's avatar bhargav-patel-29 Committed by GitHub
Browse files

[Bugfix] Fix mismatch between global and local attention heads in...


[Bugfix] Fix mismatch between global and local attention heads in tensor-parallel mode for param2moe model (#39707)
Signed-off-by: default avatarbhargav-patel-29 <bhargav.patel@tihiitb.org>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 18b1c772
......@@ -16,7 +16,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterable, Iterator
......@@ -202,10 +201,10 @@ class Param2MoEAttention(nn.Module):
)
self.attn = Attention(
num_heads=self.num_heads,
num_heads=self.num_local_heads,
head_size=self.head_dim,
scale=self.scaling,
num_kv_heads=self.num_kv_heads,
num_kv_heads=self.num_local_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
......@@ -216,15 +215,15 @@ class Param2MoEAttention(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# 1. Fused QKV projection → split into local Q / K / V
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split(
[self.q_size_local, self.kv_size_local, self.kv_size_local],
dim=-1,
)
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
# 2. Optional per-head QK norms
# Reshape to (T, num_local_heads, head_dim), norm, reshape back.
if self.use_qk_norm:
T = q.shape[0]
q = self.q_layernorm(q.view(T, self.num_local_heads, self.head_dim)).view(
......@@ -234,13 +233,8 @@ class Param2MoEAttention(nn.Module):
k.view(T, self.num_local_kv_heads, self.head_dim)
).view(T, self.kv_size_local)
# 3. Rotary position embeddings
q, k = self.rotary_emb(positions, q, k)
# 4. Paged attention
attn_output = self.attn(q, k, v)
# 5. Output projection
output, _ = self.o_proj(attn_output)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment