"vscode:/vscode.git/clone" did not exist on "01ae83b5e57384a6cbeb1f7fb77fc568b6c8bd38"
Unverified Commit 891d57d3 authored by eqy's avatar eqy Committed by GitHub
Browse files

[Pipeline-Parallelism][TF32] Disable TF32 for Pipeline-Parallel numerical checks (#1382)



* check in

* fancy context style
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 3490b9e1
...@@ -6,6 +6,7 @@ import unittest ...@@ -6,6 +6,7 @@ import unittest
import torch import torch
from torch.testing._internal import common_utils from torch.testing._internal import common_utils
from torch.testing._internal import common_cuda
logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("torch").setLevel(logging.WARNING)
...@@ -138,6 +139,7 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -138,6 +139,7 @@ class PipelineParallelForwardBackwardTestBase:
pp_utils.update_num_microbatches(0) pp_utils.update_num_microbatches(0)
with common_cuda.tf32_off():
loss = fwd_bwd_func( loss = fwd_bwd_func(
testing_utils.fwd_step_func, testing_utils.fwd_step_func,
batch, batch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment