Unverified Commit 18da4e88 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

TP communication overlap with userbuffers (#147)



* Port initial changes
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* readd FA include for PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Re-enable sm_70 + cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* LICENSE, cleanup header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* 5k -> 173 errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license and fixes in userbuffers-host
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* next round fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* final cpp cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* pylinting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix from linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Turn off default async amax reduction (#148)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code path
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup Macros
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix conflict resolution bug
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix gencode flags in setup (#145)

* Fix gencode flags based on cuda version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert append_nvcc_threads change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change overlap config dict error message
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* simplify ub initialization
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix sanity imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cpplint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TensorFlow build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TE macros in public header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* compiles with and w/o MPI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes for python side annotations for conditional compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* link gdrAPI only when MPI found
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix comments for dummy var
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix linking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* load MPI before TE
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add Py side argument checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code and catch silent failures
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix find_lib path for tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 7bb2af35
...@@ -121,7 +121,8 @@ at::Tensor te_gemm_ts(at::Tensor A, ...@@ -121,7 +121,8 @@ at::Tensor te_gemm_ts(at::Tensor A,
workspace, workspace,
workspaceSize_arg, workspaceSize_arg,
accumulate_arg, accumulate_arg,
use_split_accumulator_arg); use_split_accumulator_arg,
0);
return D; return D;
} }
......
This diff is collapsed.
...@@ -15,6 +15,7 @@ import torch ...@@ -15,6 +15,7 @@ import torch
from flash_attn.flash_attn_interface import flash_attn_unpadded_func from flash_attn.flash_attn_interface import flash_attn_unpadded_func
import transformer_engine_extensions as tex
from transformer_engine.pytorch.module import LayerNormLinear, Linear, LayerNormMLP, LayerNorm from transformer_engine.pytorch.module import LayerNormLinear, Linear, LayerNormMLP, LayerNorm
from transformer_engine.pytorch.jit import ( from transformer_engine.pytorch.jit import (
set_jit_fusion_options, set_jit_fusion_options,
...@@ -495,6 +496,10 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -495,6 +496,10 @@ class MultiHeadAttention(torch.nn.Module):
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_bulk_wgrad: bool = False,
ub_bulk_dgrad: bool = False,
ub_split_rs: bool = False,
ub_split_ag: bool = False,
bias: bool = True, bias: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -547,6 +552,9 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -547,6 +552,9 @@ class MultiHeadAttention(torch.nn.Module):
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None, parameters_split=("query_", "key_", "value_") if not fuse_qkv_params else None,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -572,6 +580,9 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -572,6 +580,9 @@ class MultiHeadAttention(torch.nn.Module):
parallel_mode=qkv_parallel_mode, parallel_mode=qkv_parallel_mode,
return_layernorm_output=return_layernorm_output, return_layernorm_output=return_layernorm_output,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
else: else:
...@@ -616,6 +627,8 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -616,6 +627,8 @@ class MultiHeadAttention(torch.nn.Module):
bias=bias, bias=bias,
return_bias=True, return_bias=True,
parallel_mode="row" if set_parallel_mode else None, parallel_mode="row" if set_parallel_mode else None,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
**common_gemm_kwargs, **common_gemm_kwargs,
) )
...@@ -911,6 +924,12 @@ class TransformerLayer(torch.nn.Module): ...@@ -911,6 +924,12 @@ class TransformerLayer(torch.nn.Module):
`set_tensor_parallel_group(tp_group)` method on the initialized module before the `set_tensor_parallel_group(tp_group)` method on the initialized module before the
forward pass to supply the tensor parallel group needed for tensor and sequence forward pass to supply the tensor parallel group needed for tensor and sequence
parallel collectives. parallel collectives.
ub_bulk_wgrad: bool, default = False
Bulk overlap UserBuffer ReduceScatter | WGRAD GEMM
ub_bulk_dgrad: bool, default = False
Bulk overlap UserBuffer AllGather | DGRAD GEMM
ub_split_ag: bool, default = False
Split pipelined overlap UserBuffer AllGather -> GEMM
Optimization parameters Optimization parameters
----------------------- -----------------------
...@@ -970,6 +989,7 @@ class TransformerLayer(torch.nn.Module): ...@@ -970,6 +989,7 @@ class TransformerLayer(torch.nn.Module):
fuse_qkv_params: bool = False, fuse_qkv_params: bool = False,
zero_centered_gamma: bool = False, zero_centered_gamma: bool = False,
qkv_weight_interleaved: bool = True, qkv_weight_interleaved: bool = True,
ub_tp_comm_overlap: bool = False,
bias: bool = True, bias: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -980,6 +1000,16 @@ class TransformerLayer(torch.nn.Module): ...@@ -980,6 +1000,16 @@ class TransformerLayer(torch.nn.Module):
category=DeprecationWarning, category=DeprecationWarning,
) )
if ub_tp_comm_overlap:
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
ub_tp_comm_overlap = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_OVERLAP", "1")))
ub_bulk_wgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_WGRAD", "1")))
ub_bulk_dgrad = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_BULK_DGRAD", "1")))
ub_split_ag = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_AG", "1")))
ub_split_rs = ub_tp_comm_overlap and bool(int(os.getenv("NVTE_UB_SPLIT_RS", "1")))
bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1")))
self.layer_number = layer_number self.layer_number = layer_number
self.output_layernorm = output_layernorm self.output_layernorm = output_layernorm
...@@ -1037,6 +1067,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -1037,6 +1067,10 @@ class TransformerLayer(torch.nn.Module):
"fuse_qkv_params": fuse_qkv_params, "fuse_qkv_params": fuse_qkv_params,
"zero_centered_gamma": zero_centered_gamma, "zero_centered_gamma": zero_centered_gamma,
"qkv_weight_interleaved" : qkv_weight_interleaved, "qkv_weight_interleaved" : qkv_weight_interleaved,
"ub_bulk_wgrad" : ub_bulk_wgrad,
"ub_bulk_dgrad" : ub_bulk_dgrad,
"ub_split_ag" : ub_split_ag,
"ub_split_rs" : ub_split_rs,
} }
self.self_attention = MultiHeadAttention( self.self_attention = MultiHeadAttention(
...@@ -1080,6 +1114,10 @@ class TransformerLayer(torch.nn.Module): ...@@ -1080,6 +1114,10 @@ class TransformerLayer(torch.nn.Module):
micro_batch_size=micro_batch_size, micro_batch_size=micro_batch_size,
set_parallel_mode=set_parallel_mode, set_parallel_mode=set_parallel_mode,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
ub_bulk_wgrad=ub_bulk_wgrad,
ub_bulk_dgrad=ub_bulk_dgrad,
ub_split_rs=ub_split_rs,
ub_split_ag=ub_split_ag,
) )
self.hidden_dropout = hidden_dropout self.hidden_dropout = hidden_dropout
......
...@@ -568,7 +568,7 @@ py::object TFE_Py_TeGemm_wrapper( ...@@ -568,7 +568,7 @@ py::object TFE_Py_TeGemm_wrapper(
nvte_cublas_gemm(a_tensor.data(), b_tensor.data(), d_tensor.data(), nvte_cublas_gemm(a_tensor.data(), b_tensor.data(), d_tensor.data(),
bias_tensor.data(), gelu_input_tensor.data(), transa, bias_tensor.data(), gelu_input_tensor.data(), transa,
transb, grad, workspace_tensor.data(), accumulate, transb, grad, workspace_tensor.data(), accumulate,
use_split_accumulate, stream); use_split_accumulate, 0, stream);
auto d_eager = CreateTensor(d_ptr, d_shape, otype); auto d_eager = CreateTensor(d_ptr, d_shape, otype);
if (use_gelu && !grad) { if (use_gelu && !grad) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment