Unverified Commit 1e4a503c authored by msbaines's avatar msbaines Committed by GitHub
Browse files

[cleanup] clearly document how backward/forward are working (#700)

_SyncBatchNormFunction is a little complex in that it does the
full backward, including mean and var, but does not calculate
statistics in the forward path. Statistics are calculated outside
in the SyncBatchNorm nn.module.

This change does not impact functionality.
parent 25cebf85
...@@ -23,10 +23,9 @@ def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight: ...@@ -23,10 +23,9 @@ def _forward(input: Tensor, affine: bool, mean: Tensor, invstd: Tensor, weight:
def _track_running_stats( def _track_running_stats(
running_mean: Tensor, running_var: Tensor, momentum: float, mean: Tensor, var: Tensor, total_count: Tensor running_mean: Tensor, running_var: Tensor, momentum: float, mean: Tensor, var: Tensor, total_count: Tensor
) -> None: ) -> None:
with torch.no_grad(): unbiased_var = var * (total_count / (total_count - 1))
unbiased_var = var * (total_count / (total_count - 1)) running_mean += momentum * (mean.reshape(-1) - running_mean)
running_mean += momentum * (mean.reshape(-1) - running_mean) running_var += momentum * (unbiased_var.reshape(-1) - running_var)
running_var += momentum * (unbiased_var.reshape(-1) - running_var)
def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) -> Tuple[Tensor, Tensor, Tensor, Tensor]: def _calculate_stats(input: Tensor, eps: float, process_group: ProcessGroup) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
...@@ -52,6 +51,14 @@ if torch.__version__.split(".")[:2] >= ["1", "7"]: ...@@ -52,6 +51,14 @@ if torch.__version__.split(".")[:2] >= ["1", "7"]:
class _SyncBatchNormFunction(torch.autograd.Function): class _SyncBatchNormFunction(torch.autograd.Function):
"""
An autograd function used to avoid storing activations for intermediate results.
NOTE: Even though the mean and var are passed into this function, we do the entire
backward, including mean and var, here. We have to calculate statistics outside
this function in order to avoid multiple all_reduces when using checkpointing.
"""
@staticmethod @staticmethod
# type: ignore # type: ignore
def forward(ctx, input, weight, bias, affine, mean, invstd, total_count, process_group): def forward(ctx, input, weight, bias, affine, mean, invstd, total_count, process_group):
...@@ -134,9 +141,11 @@ class SyncBatchNorm(torch.nn.BatchNorm2d): ...@@ -134,9 +141,11 @@ class SyncBatchNorm(torch.nn.BatchNorm2d):
wrapped = is_checkpointing() or is_recomputing() wrapped = is_checkpointing() or is_recomputing()
if not wrapped or is_checkpointing(): if not wrapped or is_checkpointing():
mean, var, invstd, total_count = _calculate_stats(input, self.eps, self._process_group) # NOTE The full backward, including mean and var, is done by _SyncBatchNormFunction.
if self.track_running_stats: with torch.no_grad():
_track_running_stats(self.running_mean, self.running_var, self.momentum, mean, var, total_count) mean, var, invstd, total_count = _calculate_stats(input, self.eps, self._process_group)
if self.track_running_stats:
_track_running_stats(self.running_mean, self.running_var, self.momentum, mean, var, total_count)
if is_checkpointing(): if is_checkpointing():
self.saved_for_2nd_fwd.append((mean, invstd, total_count)) self.saved_for_2nd_fwd.append((mean, invstd, total_count))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment