"vscode:/vscode.git/clone" did not exist on "2e308484d9693f8251748c295f6ed7ed25d767eb"
Unverified Commit eda8f461 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Fix gradients when using AMP (#70)



retain grad related attrs while casting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 04490337
......@@ -169,4 +169,5 @@ def safely_set_viewless_tensor_data(
def cast_if_needed(tensor: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
"""Cast tensor to dtype"""
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
with torch.enable_grad():
return tensor if tensor is None or tensor.dtype == dtype else tensor.to(dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment