Unverified Commit 479dbb73 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

add flash implementation with context parallelism (#362)



* add flash implementation with context parallelism
Signed-off-by: default avatarxren <xren@nvidia.com>

* next more comments
Signed-off-by: default avatarxren <xren@nvidia.com>

* code comment fix
Signed-off-by: default avatarxren <xren@nvidia.com>

* comment fix
Signed-off-by: default avatarxren <xren@nvidia.com>

* add missing space
Signed-off-by: default avatarxren <xren@nvidia.com>

* fix docstrings
Signed-off-by: default avatarxren <xren@nvidia.com>

* try to add fa v2 api
Signed-off-by: default avatarxren <xren@nvidia.com>

* fix a comment
Signed-off-by: default avatarxren <xren@nvidia.com>

* fix padded kv return
Signed-off-by: default avatarxren <xren@nvidia.com>

* add docstrings of context parallelism
Signed-off-by: default avatarxren <xren@nvidia.com>

* minor fix
Signed-off-by: default avatarxren <xren@nvidia.com>

* minor docstring fix
Signed-off-by: default avatarxren <xren@nvidia.com>

* fix positional arguments
Signed-off-by: default avatarxren <xren@nvidia.com>

* make docstring line shorter
Signed-off-by: default avatarxren <xren@nvidia.com>

* add fa v2 backward api for flash_attn_with_cp
Signed-off-by: default avatarxren <xren@nvidia.com>

* remove redundant code
Signed-off-by: default avatarxren <xren@nvidia.com>

* make sure hidden size per attn head is multiple of 8 for FA2
Signed-off-by: default avatarxren <xren@nvidia.com>

* remove an unnecessary assert check for FA2
Signed-off-by: default avatarxren <xren@nvidia.com>

* indention fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* Update FA version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarxren <xren@nvidia.com>
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b95c1818
...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -290,7 +290,7 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
# Framework-specific requirements # Framework-specific requirements
if "pytorch" in frameworks(): if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.2.1"]) add_unique(install_reqs, ["torch", "flash-attn>=1.0.6, <=2.2.2"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks(): if "jax" in frameworks():
if not found_pybind11(): if not found_pybind11():
......
This diff is collapsed.
...@@ -427,6 +427,20 @@ class TransformerLayer(torch.nn.Module): ...@@ -427,6 +427,20 @@ class TransformerLayer(torch.nn.Module):
if hasattr(child, "set_tensor_parallel_group"): if hasattr(child, "set_tensor_parallel_group"):
child.set_tensor_parallel_group(tp_group) child.set_tensor_parallel_group(tp_group)
def set_context_parallel_running(
self,
cp_group: Union[dist_group_type, None],
cp_global_ranks: Union[int],
cp_stream: torch.cuda.Stream,
) -> None:
"""Set CP group and CP dual-stream running"""
# Deep iterate but skip self to avoid infinite recursion.
for index, child in enumerate(self.modules()):
if index == 0:
continue
if hasattr(child, "set_context_parallel_running"):
child.set_context_parallel_running(cp_group, cp_global_ranks, cp_stream)
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment