"include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "b0d1c08ff2dc98a1f2ed395ccfed5585ff048d2b"
Unverified Commit 0828aa86 authored by Xiaowei Ren's avatar Xiaowei Ren Committed by GitHub
Browse files

Remove CPU overheads of torch.cuda.get_device_properties() by caching it (#1722)



* build pybind of sm_arch in TE-Pytorch
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* check sm_arch for batch_p2p_comm in CP+P2P
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix device compute capability of pytorch tests
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* bug fix
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert "fix device compute capability of pytorch tests"

This reverts commit 85886eb35dcf57a37ddc98a13d283f7a6d8f8e32.

* revert changes and resolve conflict
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* Revert "bug fix"

This reverts commit dd75c64c62e882ee5e3b54591b86f89c349ad3b0.

* manually revert changes
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* cache torch.cuda.get_device_properties
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>

---------
Signed-off-by: default avatarXiaowei Ren <xren@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 02096f61
...@@ -13,6 +13,7 @@ from transformer_engine.pytorch.utils import ( ...@@ -13,6 +13,7 @@ from transformer_engine.pytorch.utils import (
get_cudnn_version, get_cudnn_version,
nvtx_range_pop, nvtx_range_pop,
nvtx_range_push, nvtx_range_push,
get_device_compute_capability,
) )
from transformer_engine.pytorch.cpp_extensions.fused_attn import ( from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd, fused_attn_fwd,
...@@ -476,7 +477,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -476,7 +477,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank = get_distributed_rank(cp_group) rank = get_distributed_rank(cp_group)
send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) device_compute_capability = get_device_compute_capability()
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (
device_compute_capability < (10, 0) and cp_size == 2
)
causal = "causal" in attn_mask_type causal = "causal" in attn_mask_type
padding = "padding" in attn_mask_type padding = "padding" in attn_mask_type
...@@ -1436,7 +1440,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1436,7 +1440,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
rank = get_distributed_rank(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group)
send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a]
recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] recv_src = ctx.cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a]
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) device_compute_capability = get_device_compute_capability()
batch_p2p_comm = int(os.getenv("NVTE_BATCH_MHA_P2P_COMM", "0")) or (
device_compute_capability < (10, 0) and cp_size == 2
)
q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = (
restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) restore_from_saved(ctx.tensor_objects, ctx.saved_tensors)
......
...@@ -41,10 +41,15 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: ...@@ -41,10 +41,15 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
del t del t
@functools.lru_cache
def _get_device_compute_capability(device: torch.device) -> Tuple[int, int]:
props = torch.cuda.get_device_properties(device)
return (props.major, props.minor)
def get_device_compute_capability() -> Tuple[int, int]: def get_device_compute_capability() -> Tuple[int, int]:
"""CUDA compute capability of current GPU""" """CUDA compute capability of current GPU"""
props = torch.cuda.get_device_properties(torch.cuda.current_device()) return _get_device_compute_capability(torch.cuda.current_device())
return (props.major, props.minor)
def attention_mask_func( def attention_mask_func(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment