Unverified Commit 891d57d3 authored by eqy's avatar eqy Committed by GitHub
Browse files

[Pipeline-Parallelism][TF32] Disable TF32 for Pipeline-Parallel numerical checks (#1382)



* check in

* fancy context style
Co-authored-by: default avatarMasaki Kozuki <mkozuki@nvidia.com>
parent 3490b9e1
...@@ -6,6 +6,7 @@ import unittest ...@@ -6,6 +6,7 @@ import unittest
import torch import torch
from torch.testing._internal import common_utils from torch.testing._internal import common_utils
from torch.testing._internal import common_cuda
logging.getLogger("torch").setLevel(logging.WARNING) logging.getLogger("torch").setLevel(logging.WARNING)
...@@ -137,23 +138,24 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -137,23 +138,24 @@ class PipelineParallelForwardBackwardTestBase:
optimizer = torch.optim.Adam(_param_groups, lr=1e-3) optimizer = torch.optim.Adam(_param_groups, lr=1e-3)
pp_utils.update_num_microbatches(0) pp_utils.update_num_microbatches(0)
loss = fwd_bwd_func( with common_cuda.tf32_off():
testing_utils.fwd_step_func, loss = fwd_bwd_func(
batch, testing_utils.fwd_step_func,
model, batch,
forward_only=forward_only, model,
# `tensor_shape` is the shape of micro batch. forward_only=forward_only,
tensor_shape=( # `tensor_shape` is the shape of micro batch.
self.MICRO_BATCH_SIZE, tensor_shape=(
self.HIDDEN_SIZE, self.MICRO_BATCH_SIZE,
self.HIDDEN_SIZE, self.HIDDEN_SIZE,
), self.HIDDEN_SIZE,
dtype=dtype, ),
async_comm=async_comm, dtype=dtype,
grad_scaler=grad_scaler, async_comm=async_comm,
deallocate_pipeline_output=deallocate_pipeline_outputs, grad_scaler=grad_scaler,
) deallocate_pipeline_output=deallocate_pipeline_outputs,
)
if dtype == torch.float32: if dtype == torch.float32:
hidden_size = self.HIDDEN_SIZE hidden_size = self.HIDDEN_SIZE
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment