Unverified Commit 4e2ce516 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

[PyTorch] Use dummy amax for Float8Tensor cast (#693)



* Avoid updating real during param cast
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 0bd84ed9
...@@ -790,7 +790,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -790,7 +790,8 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
param = Float8Tensor.to_float8( param = Float8Tensor.to_float8(
param, param,
fp8_meta=self.fp8_meta, fp8_meta=self.fp8_meta,
fp8_meta_index=fp8_meta_index fp8_meta_index=fp8_meta_index,
amax=torch.empty(1, device="cuda"), # Dummy amax to avoid overwriting history.
) )
# Redo parameter wrap in case we broke it above # Redo parameter wrap in case we broke it above
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment