"docs/vscode:/vscode.git/clone" did not exist on "40d0e7411dbeb276befd33c4485115ac3d4d7f2a"
Unverified Commit f7e62e3d authored by bhargav-patel-29's avatar bhargav-patel-29 Committed by GitHub
Browse files

[Bugfix] Fix mismatch between global and local attention heads in...


[Bugfix] Fix mismatch between global and local attention heads in tensor-parallel mode for param2moe model (#39707)
Signed-off-by: default avatarbhargav-patel-29 <bhargav.patel@tihiitb.org>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 18b1c772
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# limitations under the License.
from __future__ import annotations from __future__ import annotations
from collections.abc import Iterable, Iterator from collections.abc import Iterable, Iterator
...@@ -202,10 +201,10 @@ class Param2MoEAttention(nn.Module): ...@@ -202,10 +201,10 @@ class Param2MoEAttention(nn.Module):
) )
self.attn = Attention( self.attn = Attention(
num_heads=self.num_heads, num_heads=self.num_local_heads,
head_size=self.head_dim, head_size=self.head_dim,
scale=self.scaling, scale=self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_local_kv_heads,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
...@@ -216,15 +215,15 @@ class Param2MoEAttention(nn.Module): ...@@ -216,15 +215,15 @@ class Param2MoEAttention(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# 1. Fused QKV projection → split into local Q / K / V
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split( q, k, v = qkv.split(
[self.q_size_local, self.kv_size_local, self.kv_size_local], [self.q_size_local, self.kv_size_local, self.kv_size_local],
dim=-1, dim=-1,
) )
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
# 2. Optional per-head QK norms
# Reshape to (T, num_local_heads, head_dim), norm, reshape back.
if self.use_qk_norm: if self.use_qk_norm:
T = q.shape[0] T = q.shape[0]
q = self.q_layernorm(q.view(T, self.num_local_heads, self.head_dim)).view( q = self.q_layernorm(q.view(T, self.num_local_heads, self.head_dim)).view(
...@@ -234,13 +233,8 @@ class Param2MoEAttention(nn.Module): ...@@ -234,13 +233,8 @@ class Param2MoEAttention(nn.Module):
k.view(T, self.num_local_kv_heads, self.head_dim) k.view(T, self.num_local_kv_heads, self.head_dim)
).view(T, self.kv_size_local) ).view(T, self.kv_size_local)
# 3. Rotary position embeddings
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
# 4. Paged attention
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
# 5. Output projection
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment