Unverified Commit 72bd7c69 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[amp] included dict for type casting of model output (#1102)

parent 5a9d8ef4
......@@ -149,4 +149,6 @@ class NaiveAMPModel(nn.Module):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
elif isinstance(out, dict):
out = {key: self._convert_to_fp32(val) for key, val in out.items()}
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment