Commit 14e60427 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

cleaned distributed.py.

parent 2c1660e7
...@@ -15,9 +15,7 @@ ...@@ -15,9 +15,7 @@
from abc import ABC from abc import ABC
from abc import abstractmethod from abc import abstractmethod
# >>>
import math import math
# <<<
import torch import torch
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
...@@ -27,10 +25,8 @@ from megatron import mpu ...@@ -27,10 +25,8 @@ from megatron import mpu
from .module import MegatronModule from .module import MegatronModule
class MemoryBuffer: class MemoryBuffer:
# >>>
def __init__(self, numel, numel_padded, dtype): def __init__(self, numel, numel_padded, dtype):
self.numel = numel self.numel = numel
self.numel_padded = numel_padded self.numel_padded = numel_padded
...@@ -39,7 +35,6 @@ class MemoryBuffer: ...@@ -39,7 +35,6 @@ class MemoryBuffer:
dtype=self.dtype, dtype=self.dtype,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
requires_grad=False) requires_grad=False)
# <<<
def zero(self): def zero(self):
"""Reset the buffer to zero.""" """Reset the buffer to zero."""
...@@ -148,30 +143,17 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -148,30 +143,17 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer. # Allocate the buffer.
for dtype, num_elements in type_num_elements.items(): for dtype, num_elements in type_num_elements.items():
# >>>
# If using distributed optimizer, pad memory buffer to be # If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done # multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires # due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.) # all tensors have equal size. See: optimizer.py.)
num_elements_padded = data_parallel_world_size * \ num_elements_padded = data_parallel_world_size * \
int(math.ceil(num_elements / data_parallel_world_size)) int(math.ceil(num_elements / data_parallel_world_size))
# <<<
# Allocate grad buffer. # Allocate grad buffer.
self._grad_buffers[dtype] = MemoryBuffer(num_elements, self._grad_buffers[dtype] = MemoryBuffer(num_elements,
num_elements_padded, num_elements_padded,
dtype) dtype)
# >>>
# from lutil import pax
# if True or num_elements % data_parallel_world_size != 0:
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "num_elements" : num_elements,
# "num_elements_padded" : num_elements_padded,
# "modulo" : num_elements % data_parallel_world_size,
# "grad buffer" : self._grad_buffers[dtype],
# })
# <<<
# Assume the back prop order is reverse the params order, # Assume the back prop order is reverse the params order,
# store the start index for the gradients. # store the start index for the gradients.
...@@ -181,20 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase): ...@@ -181,20 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements[dtype] -= param.data.nelement() type_num_elements[dtype] -= param.data.nelement()
param.main_grad = self._grad_buffers[dtype].get( param.main_grad = self._grad_buffers[dtype].get(
param.data.shape, type_num_elements[dtype]) param.data.shape, type_num_elements[dtype])
# >>>
# self._grad_buffer_param_offsets[dtype][param] = \
# type_num_elements[dtype]
if dtype not in self._grad_buffer_param_index_map: if dtype not in self._grad_buffer_param_index_map:
self._grad_buffer_param_index_map[dtype] = {} self._grad_buffer_param_index_map[dtype] = {}
# self._grad_buffer_param_index_map[dtype][param] = {
# "start" : type_num_elements[dtype],
# "end" : type_num_elements[dtype] + param.data.nelement(),
# }
self._grad_buffer_param_index_map[dtype][param] = ( self._grad_buffer_param_index_map[dtype][param] = (
type_num_elements[dtype], type_num_elements[dtype],
type_num_elements[dtype] + param.data.nelement(), type_num_elements[dtype] + param.data.nelement(),
) )
# <<<
# Backward hook. # Backward hook.
# Accumalation function for the gradients. We need # Accumalation function for the gradients. We need
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment