Commit 14e60427 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleaned distributed.py.

parent 2c1660e7
......@@ -15,9 +15,7 @@
from abc import ABC
from abc import abstractmethod
# >>>
import math
# <<<
import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
......@@ -27,10 +25,8 @@ from megatron import mpu
from .module import MegatronModule
class MemoryBuffer:
# >>>
def __init__(self, numel, numel_padded, dtype):
self.numel = numel
self.numel_padded = numel_padded
......@@ -39,7 +35,6 @@ class MemoryBuffer:
dtype=self.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# <<<
def zero(self):
"""Reset the buffer to zero."""
......@@ -148,30 +143,17 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer.
for dtype, num_elements in type_num_elements.items():
# >>>
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size))
# <<<
# Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded,
dtype)
# >>>
# from lutil import pax
# if True or num_elements % data_parallel_world_size != 0:
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "num_elements" : num_elements,
# "num_elements_padded" : num_elements_padded,
# "modulo" : num_elements % data_parallel_world_size,
# "grad buffer" : self._grad_buffers[dtype],
# })
# <<<
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
......@@ -181,20 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype])
# >>>
# self._grad_buffer_param_offsets[dtype][param] = \
# type_num_elements[dtype]
if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {}
# self._grad_buffer_param_index_map[dtype][param] = {
# "start" : type_num_elements[dtype],
# "end" : type_num_elements[dtype] + param.data.nelement(),
# }
self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(),
)
# <<<
# Backward hook.
# Accumalation function for the gradients. We need
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment