Unverified Commit 2574a1ca authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

move cp_group setting to DotProductAttention (#468)



* rename set_context_parallel_running to set_context_parallel_group
Signed-off-by: default avatarxren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarxren <xren@nvidia.com>

---------
Signed-off-by: default avatarxren <xren@nvidia.com>
parent d7511ec4
...@@ -1914,6 +1914,17 @@ class DotProductAttention(torch.nn.Module): ...@@ -1914,6 +1914,17 @@ class DotProductAttention(torch.nn.Module):
return hidden_states return hidden_states
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
def forward( def forward(
self, self,
query_layer: torch.Tensor, query_layer: torch.Tensor,
...@@ -2549,16 +2560,19 @@ class MultiheadAttention(torch.nn.Module): ...@@ -2549,16 +2560,19 @@ class MultiheadAttention(torch.nn.Module):
"""Set TP group""" """Set TP group"""
self.tp_group = tp_group self.tp_group = tp_group
def set_context_parallel_running( def set_context_parallel_group(
self, self,
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
) -> None: ) -> None:
"""Set CP group and CP dual-stream running""" """Set CP group"""
self.core_attention.cp_group = cp_group # Deep iterate but skip self to avoid infinite recursion.
self.core_attention.cp_global_ranks = cp_global_ranks for index, child in enumerate(self.modules()):
self.core_attention.cp_stream = cp_stream if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
def forward( def forward(
self, self,
......
...@@ -433,19 +433,19 @@ class TransformerLayer(torch.nn.Module): ...@@ -433,19 +433,19 @@ class TransformerLayer(torch.nn.Module):
if hasattr(child, "set_tensor_parallel_group"): if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group) child.set_tensor_parallel_group(tp_group)
def set_context_parallel_running( def set_context_parallel_group(
self, self,
cp_group: Union[dist_group_type, None], cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int], cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream, cp_stream: torch.cuda.Stream,
) -> None: ) -> None:
"""Set CP group and CP dual-stream running""" """Set CP group"""
# Deep iterate but skip self to avoid infinite recursion. # Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()): for index, child in enumerate(self.modules()):
if index == 0: if index == 0:
continue continue
if hasattr(child, "set_context_parallel_running"): if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_running(cp_group, cp_global_ranks, cp_stream) child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
def forward( def forward(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment