Unverified Commit ed1e85c4 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Add missed arguments of apply_rotary_pos_emb in MHA (#1296)



* add missed arguments of apply_rotary_pos_emb in MHA
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove an unnecessary f
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* add one more assert for cp_group len
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 8bdb54fe
...@@ -8495,6 +8495,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8495,6 +8495,8 @@ class MultiheadAttention(torch.nn.Module):
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_attention_heads = num_attention_heads self.num_attention_heads = num_attention_heads
self.return_bias = return_bias self.return_bias = return_bias
self.cp_size = 1
self.cp_rank = 0
kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads)
...@@ -8713,6 +8715,21 @@ class MultiheadAttention(torch.nn.Module): ...@@ -8713,6 +8715,21 @@ class MultiheadAttention(torch.nn.Module):
across each CP sub-group (e.g., via NVLink), then exchanging KV with across each CP sub-group (e.g., via NVLink), then exchanging KV with
p2p between sub-groups (e.g., via IBLink). p2p between sub-groups (e.g., via IBLink).
""" """
if isinstance(cp_group, dist_group_type):
self.cp_size = get_distributed_world_size(cp_group)
self.cp_rank = get_distributed_rank(cp_group)
elif isinstance(cp_group, list):
assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!"
assert (
cp_comm_type == "a2a+p2p"
), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!"
cp_size_a2a = get_distributed_world_size(cp_group[0])
cp_rank_a2a = get_distributed_rank(cp_group[0])
cp_size_p2p = get_distributed_world_size(cp_group[1])
cp_rank_p2p = get_distributed_rank(cp_group[1])
self.cp_size = cp_size_a2a * cp_size_p2p
self.cp_rank = cp_size_a2a * cp_rank_p2p + cp_rank_a2a
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
if index == 0: if index == 0:
...@@ -9047,8 +9064,24 @@ class MultiheadAttention(torch.nn.Module): ...@@ -9047,8 +9064,24 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True) query_layer = apply_rotary_pos_emb(
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True) query_layer,
q_pos_emb,
self.qkv_format,
fused=True,
cu_seqlens=cu_seqlens_q,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
)
key_layer = apply_rotary_pos_emb(
key_layer,
k_pos_emb,
self.qkv_format,
fused=True,
cu_seqlens=cu_seqlens_kv,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
)
# =========================== # ===========================
# Core attention computation # Core attention computation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment