Unverified Commit db95afeb authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Async amax reduction (#118)



* async amax reduction

add env knob to enable async amax reduction
Signed-off-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>

* Style fixes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* remove is_last_model
Signed-off-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>

* fix naming
Signed-off-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>

* revert var name
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert var name
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>
parent 96ad903c
......@@ -27,6 +27,7 @@ _fp8_tensors_recompute_buffer = []
_amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None
_amax_reduce_handle_fwd = None
_is_fp8_available = None
_reason_for_no_fp8 = ""
......@@ -73,6 +74,12 @@ def get_autocast_key(forward: bool = True) -> str:
return "autocast_id_bwd"
def get_amax_reduce_handle_fwd() -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop."""
global _amax_reduce_handle_fwd
return _amax_reduce_handle_fwd
def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 buffer."""
return _global_fp8_buffer
......@@ -264,6 +271,7 @@ def fp8_autocast(
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd
global _amax_reduce_handle_fwd
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
try:
_FP8_ENABLED = enabled
......@@ -287,7 +295,7 @@ def fp8_autocast(
if _FP8_AUTOCAST_DEPTH == 0:
if callable(_amax_forward_global_reduce_func):
_amax_forward_global_reduce_func()
_amax_reduce_handle_fwd = _amax_forward_global_reduce_func()
delete_key_from_amax_buffer(forward=True)
......@@ -521,16 +529,18 @@ def get_fp8_te_dtype(
def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type
tensor: torch.Tensor, group: dist_group_type, async_op: bool
) -> None:
"""Reduce tensor across given group."""
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
wait_handle = torch.distributed.all_reduce(
tensor,
op=torch.distributed.ReduceOp.MAX,
group=group,
async_op=False,
async_op=async_op,
)
return wait_handle
return None
def global_amax_reduction(
......@@ -543,14 +553,19 @@ def global_amax_reduction(
# Key already deleted.
if amax_buffer_key not in _global_fp8_buffer:
return
return None
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"])
wait_handle = reduce_tensor_across_group_op_max(
contiguous_amax,
fp8_meta["fp8_group"],
fp8_meta["async_amax_reduction"],
)
_global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
return wait_handle
def delete_key_from_amax_buffer(forward: bool = True) -> None:
......
......@@ -41,6 +41,7 @@ from .fp8 import (
copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors,
get_amax_reduce_handle_fwd,
)
from .jit import (
bias_gelu_fused,
......@@ -84,6 +85,7 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
_cublas_workspace = None
_amax_reduce_handle_bwd = None
def get_cublas_workspace_size_bytes() -> None:
......@@ -106,6 +108,11 @@ def get_workspace() -> torch.Tensor:
def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> None:
"""Checks and prep for BWD."""
if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False)
......@@ -125,7 +132,7 @@ def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> N
if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]:
global_amax_reduction(fp8_meta, forward=False)
_amax_reduce_handle_bwd = global_amax_reduction(fp8_meta, forward=False)
delete_key_from_amax_buffer(forward=False)
......@@ -184,6 +191,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.sequence_parallel = False
self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "1"))
)
def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd."""
......@@ -497,6 +507,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment