Unverified Commit 891d57d3 authored by eqy's avatar eqy Committed by GitHub
Browse files

[Pipeline-Parallelism][TF32] Disable TF32 for Pipeline-Parallel numerical checks (#1382)



* check in

* fancy context style
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 3490b9e1
......@@ -6,6 +6,7 @@ import unittest
import torch
from torch.testing._internal import common_utils
from torch.testing._internal import common_cuda
logging.getLogger("torch").setLevel(logging.WARNING)
......@@ -138,6 +139,7 @@ class PipelineParallelForwardBackwardTestBase:
pp_utils.update_num_microbatches(0)
with common_cuda.tf32_off():
loss = fwd_bwd_func(
testing_utils.fwd_step_func,
batch,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment