Unverified Commit db95afeb authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Async amax reduction (#118)



* async amax reduction

add env knob to enable async amax reduction
Signed-off-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>

* Style fixes
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* remove is_last_model
Signed-off-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>

* fix naming
Signed-off-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>

* revert var name
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert var name
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarslym <slym@login-preos01.a51.clusters.nvidia.com>
parent 96ad903c
...@@ -27,6 +27,7 @@ _fp8_tensors_recompute_buffer = [] ...@@ -27,6 +27,7 @@ _fp8_tensors_recompute_buffer = []
_amax_forward_global_reduce_func = None _amax_forward_global_reduce_func = None
_buffer_delete_key_fwd = None _buffer_delete_key_fwd = None
_buffer_delete_key_bwd = None _buffer_delete_key_bwd = None
_amax_reduce_handle_fwd = None
_is_fp8_available = None _is_fp8_available = None
_reason_for_no_fp8 = "" _reason_for_no_fp8 = ""
...@@ -73,6 +74,12 @@ def get_autocast_key(forward: bool = True) -> str: ...@@ -73,6 +74,12 @@ def get_autocast_key(forward: bool = True) -> str:
return "autocast_id_bwd" return "autocast_id_bwd"
def get_amax_reduce_handle_fwd() -> Union[bool, None]:
"""Return AMAX reduction wait handle of forward prop."""
global _amax_reduce_handle_fwd
return _amax_reduce_handle_fwd
def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]: def get_global_fp8_buffer() -> Dict[str, List[torch.Tensor]]:
"""Returns global fp8 buffer.""" """Returns global fp8 buffer."""
return _global_fp8_buffer return _global_fp8_buffer
...@@ -264,6 +271,7 @@ def fp8_autocast( ...@@ -264,6 +271,7 @@ def fp8_autocast(
global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH global _FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP, _FP8_AUTOCAST_DEPTH
global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER global _IS_FIRST_FP8_MODULE, _FP8_AUTOCAST_COUNTER
global _global_fp8_buffer, _buffer_delete_key_fwd global _global_fp8_buffer, _buffer_delete_key_fwd
global _amax_reduce_handle_fwd
fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP) fp8_state = (_FP8_ENABLED, _FP8_CALIBRATION, _FP8_RECIPE, _FP8_DISTRIBUTED_GROUP)
try: try:
_FP8_ENABLED = enabled _FP8_ENABLED = enabled
...@@ -287,7 +295,7 @@ def fp8_autocast( ...@@ -287,7 +295,7 @@ def fp8_autocast(
if _FP8_AUTOCAST_DEPTH == 0: if _FP8_AUTOCAST_DEPTH == 0:
if callable(_amax_forward_global_reduce_func): if callable(_amax_forward_global_reduce_func):
_amax_forward_global_reduce_func() _amax_reduce_handle_fwd = _amax_forward_global_reduce_func()
delete_key_from_amax_buffer(forward=True) delete_key_from_amax_buffer(forward=True)
...@@ -521,16 +529,18 @@ def get_fp8_te_dtype( ...@@ -521,16 +529,18 @@ def get_fp8_te_dtype(
def reduce_tensor_across_group_op_max( def reduce_tensor_across_group_op_max(
tensor: torch.Tensor, group: dist_group_type tensor: torch.Tensor, group: dist_group_type, async_op: bool
) -> None: ) -> None:
"""Reduce tensor across given group.""" """Reduce tensor across given group."""
if torch.distributed.is_initialized(): if torch.distributed.is_initialized():
torch.distributed.all_reduce( wait_handle = torch.distributed.all_reduce(
tensor, tensor,
op=torch.distributed.ReduceOp.MAX, op=torch.distributed.ReduceOp.MAX,
group=group, group=group,
async_op=False, async_op=async_op,
) )
return wait_handle
return None
def global_amax_reduction( def global_amax_reduction(
...@@ -543,14 +553,19 @@ def global_amax_reduction( ...@@ -543,14 +553,19 @@ def global_amax_reduction(
# Key already deleted. # Key already deleted.
if amax_buffer_key not in _global_fp8_buffer: if amax_buffer_key not in _global_fp8_buffer:
return return None
chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]] chunk_sizes = [x.numel() for x in _global_fp8_buffer[amax_buffer_key]]
contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key]) contiguous_amax = torch.cat(_global_fp8_buffer[amax_buffer_key])
reduce_tensor_across_group_op_max(contiguous_amax, fp8_meta["fp8_group"]) wait_handle = reduce_tensor_across_group_op_max(
contiguous_amax,
fp8_meta["fp8_group"],
fp8_meta["async_amax_reduction"],
)
_global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes)) _global_fp8_buffer[amax_buffer_key] = list(contiguous_amax.split(chunk_sizes))
return wait_handle
def delete_key_from_amax_buffer(forward: bool = True) -> None: def delete_key_from_amax_buffer(forward: bool = True) -> None:
......
...@@ -41,6 +41,7 @@ from .fp8 import ( ...@@ -41,6 +41,7 @@ from .fp8 import (
copy_forward_fp8_meta_tensors_for_recompute, copy_forward_fp8_meta_tensors_for_recompute,
get_old_fp8_meta_tensors_for_recompute, get_old_fp8_meta_tensors_for_recompute,
restore_fp8_meta_tensors, restore_fp8_meta_tensors,
get_amax_reduce_handle_fwd,
) )
from .jit import ( from .jit import (
bias_gelu_fused, bias_gelu_fused,
...@@ -84,6 +85,7 @@ _2X_ACC_FPROP = False ...@@ -84,6 +85,7 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = True _2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True _2X_ACC_WGRAD = True
_cublas_workspace = None _cublas_workspace = None
_amax_reduce_handle_bwd = None
def get_cublas_workspace_size_bytes() -> None: def get_cublas_workspace_size_bytes() -> None:
...@@ -106,6 +108,11 @@ def get_workspace() -> torch.Tensor: ...@@ -106,6 +108,11 @@ def get_workspace() -> torch.Tensor:
def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> None: def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> None:
"""Checks and prep for BWD.""" """Checks and prep for BWD."""
if fp8: if fp8:
global _amax_reduce_handle_bwd
if _amax_reduce_handle_bwd is not None:
_amax_reduce_handle_bwd.wait()
_amax_reduce_handle_bwd = None
# Update amax and scale; Skip all setup for global amax reduction # Update amax and scale; Skip all setup for global amax reduction
if not fp8_meta["recipe"].reduce_amax: if not fp8_meta["recipe"].reduce_amax:
amax_and_scale_update(fp8_meta, False) amax_and_scale_update(fp8_meta, False)
...@@ -125,7 +132,7 @@ def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> N ...@@ -125,7 +132,7 @@ def _prepare_backward(fp8: bool, fp8_meta: Dict[str, Any], name: str = "") -> N
if fp8 and fp8_meta["recipe"].reduce_amax: if fp8 and fp8_meta["recipe"].reduce_amax:
if fp8_meta["first_module"]: if fp8_meta["first_module"]:
global_amax_reduction(fp8_meta, forward=False) _amax_reduce_handle_bwd = global_amax_reduction(fp8_meta, forward=False)
delete_key_from_amax_buffer(forward=False) delete_key_from_amax_buffer(forward=False)
...@@ -184,6 +191,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -184,6 +191,9 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.sequence_parallel = False self.sequence_parallel = False
self.fp8_weight_shapes = [] self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = [] self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "1"))
)
def set_meta_tensor(self, fwd: bool) -> None: def set_meta_tensor(self, fwd: bool) -> None:
"""Init scales and amaxes for fwd | bwd.""" """Init scales and amaxes for fwd | bwd."""
...@@ -497,6 +507,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -497,6 +507,10 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
if self.fp8_meta["recipe"].reduce_amax: if self.fp8_meta["recipe"].reduce_amax:
self.fp8_meta["first_module"] = is_first_fp8_module() self.fp8_meta["first_module"] = is_first_fp8_module()
if self.fp8_meta["first_module"]: if self.fp8_meta["first_module"]:
# Wait for the prior AMAX reduction to finish
amax_reduce_handle_fwd = get_amax_reduce_handle_fwd()
if amax_reduce_handle_fwd is not None:
amax_reduce_handle_fwd.wait()
self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id() self.fp8_meta["autocast_id_fwd"] = new_fp8_context_id()
set_fp8_context_id(self.fp8_meta["autocast_id_fwd"]) set_fp8_context_id(self.fp8_meta["autocast_id_fwd"])
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment