"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "c1c0d00b88320f97e00a3175fac235a232893da5"
Unverified Commit 2574a1ca authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

move cp_group setting to DotProductAttention (#468)



* rename set_context_parallel_running to set_context_parallel_group
Signed-off-by: default avatarxren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarxren <xren@nvidia.com>

---------
Signed-off-by: default avatarxren <xren@nvidia.com>
parent d7511ec4
......@@ -1914,6 +1914,17 @@ class DotProductAttention(torch.nn.Module):
return hidden_states
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group"""
self.cp_group = cp_group
self.cp_global_ranks = cp_global_ranks
self.cp_stream = cp_stream
def forward(
self,
query_layer: torch.Tensor,
......@@ -2549,16 +2560,19 @@ class MultiheadAttention(torch.nn.Module):
"""Set TP group"""
self.tp_group = tp_group
def set_context_parallel_running(
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
self.core_attention.cp_group = cp_group
self.core_attention.cp_global_ranks = cp_global_ranks
self.core_attention.cp_stream = cp_stream
"""Set CP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
def forward(
self,
......
......@@ -433,19 +433,19 @@ class TransformerLayer(torch.nn.Module):
if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group)
def set_context_parallel_running(
def set_context_parallel_group(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: List[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
"""Set CP group"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_running"):
child.set_context_parallel_running(cp_group, cp_global_ranks, cp_stream)
if hasattr(child, "set_context_parallel_group"):
child.set_context_parallel_group(cp_group, cp_global_ranks, cp_stream)
def forward(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment