Unverified Commit eb64ec2a authored by Sangkug Lym's avatar Sangkug Lym Committed by GitHub
Browse files

Add an option for TP-only AMAX reduction (#431)


Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
parent a150d286
......@@ -326,9 +326,14 @@ class FP8GlobalStateManager:
return None
# Reduce AMAX in DP-domain at an interval.
# `NVTE_DP_AMAX_REDUCE_INTERVAL` should be set as an integer value larger than 0. If
# `NVTE_DP_AMAX_REDUCE_INTERVAL` is set to 0, AMAX is reduced only in TP domain.
if cls.dp_amax_reduce_interval is None:
cls.dp_amax_reduce_interval = int(os.getenv("NVTE_DP_AMAX_REDUCE_INTERVAL", "1"))
if cls.dp_amax_reduce_interval == 0:
tp_amax_reduce = True
else:
tp_amax_reduce = False
if forward:
if cls.dp_amax_reduce_forward_idx == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment