Unverified Commit beffb297 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[PyTorch] Get `skip_fp8_weight_update` only in CUDA Graph Capturing (#1854)



only get skip_fp8_weight_update in fp8_graph_capturing
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 05f3b573
......@@ -668,7 +668,10 @@ class GroupedLinear(TransformerEngineBaseModule):
), "GroupedLinear doesn't support input tensor in FP8."
assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs."
if FP8GlobalStateManager.fp8_graph_capturing():
skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor()
else:
skip_fp8_weight_update = None
if skip_fp8_weight_update is not None:
is_first_microbatch = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment