Unverified Commit 7e759174 authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

[pyTorch] Enable the model to change precision between iterations (#414)



* Enable the model to be change precision between iterations
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Add test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix for the test
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e7eff4a3
...@@ -788,3 +788,16 @@ def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_ ...@@ -788,3 +788,16 @@ def test_gpt_cuda_graph(dtype, bs, fp8_recipe, model, skip_wgrad, zero_centered_
) )
_test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad) _test_sanity_e2e_cuda_graph(block, bs, dtype, config, fp8_recipe, skip_wgrad)
def test_model_multiple_cast():
a = torch.zeros((16,16)).cuda()
m = Linear(16,32)
y = m(a)
assert y.dtype == torch.float32
m.half()
a = a.half()
y2 = m(a)
assert y2.dtype == torch.float16
...@@ -445,8 +445,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -445,8 +445,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
return return
# All checks after this have already been performed once, thus skip # All checks after this have already been performed once, thus skip
# We assume that user doesn't change input types across iterations if hasattr(self, "activation_dtype") and self.activation_dtype == inp.dtype:
if hasattr(self, "activation_dtype"):
return return
dtype = inp.dtype dtype = inp.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment