Unverified Commit 9bc9e68d authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Turn off default async amax reduction (#148)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent b921c0d1
...@@ -192,7 +192,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -192,7 +192,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
self.fp8_weight_shapes = [] self.fp8_weight_shapes = []
self.fp8_meta["autocast_id_fwd_stack"] = [] self.fp8_meta["autocast_id_fwd_stack"] = []
self.fp8_meta["async_amax_reduction"] = bool( self.fp8_meta["async_amax_reduction"] = bool(
int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "1")) int(os.getenv("NVTE_ASYNC_AMAX_REDUCTION", "0"))
) )
def set_meta_tensor(self, fwd: bool) -> None: def set_meta_tensor(self, fwd: bool) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment