Commit f48e1f29 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

studied float16 optimizer; more updates

parent 49cca4d9
...@@ -643,102 +643,235 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer): ...@@ -643,102 +643,235 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# >>> # >>>
class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params): import math
def __init__(self, *args): # from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
super().__init__(*args)
self.initialized = False
# >>>
self.initialize()
# <<<
def initialize(self): from megatron import get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# >>> # >>>
import math from lutil import pax, tp
# <<< # <<<
if self.initialized: # class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
raise Exception("initialization worked.") class Float16DistributedOptimizer(MegatronOptimizer):
return
self.initialized = True
data_parallel_rank = mpu.get_data_parallel_rank() # >>>
@classmethod
def test_reduce_scatter(cls):
torch.manual_seed(mpu.get_data_parallel_rank())
size = (20,)
dtype = torch.float
device = torch.cuda.current_device()
data_parallel_world_size = mpu.get_data_parallel_world_size() data_parallel_world_size = mpu.get_data_parallel_world_size()
total_param_size = sum( data_parallel_group = mpu.get_data_parallel_group()
p.numel()
for g in self.param_groups input_list = [
for p in g["params"] # torch.randn(size, dtype = dtype, device = device)
5 * torch.randint(low = 1, high = 3, size = size, dtype = dtype, device = device)
for _ in range(data_parallel_world_size)
]
output = torch.empty(size, dtype = dtype, device = device)
torch.distributed.reduce_scatter(
output,
input_list,
group = data_parallel_group,
) )
shard_size = int(math.ceil(total_param_size / data_parallel_world_size))
shard_start_index = data_parallel_rank * shard_size
shard_end_index = min(total_param_size, shard_start_index + shard_size)
self.shard_size = shard_end_index - shard_start_index
# allocate_shard = lambda dtype : torch.empty(
# [self.shard_size],
# dtype = dtype,
# device = torch.cuda.current_device())
allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype)
self.main_param_shard = allocate_shard(torch.float)
self.main_grad_shard = allocate_shard(torch.float)
self.adam_m_shard = allocate_shard(torch.float)
self.adam_v_shard = allocate_shard(torch.float)
def reduce_gradients(self, model): if torch.distributed.get_rank() == 0:
print(output)
pax(0, {
"data_parallel_world_size" : data_parallel_world_size,
"data_parallel_group" : data_parallel_group,
"input_list" : input_list,
"output" : tp(output),
})
# <<<
# def __init__(self, *_args):
# super().__init__(*_args)
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler):
super().__init__(
optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp)
# >>> # >>>
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP # self.test_reduce_scatter()
# <<<
from megatron import get_args # >>>
# from megatron import get_timers args = get_args()
# from megatron.model import DistributedDataParallel as LocalDDP # <<<
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
# Data parallel info.
self.data_parallel_group = mpu.get_data_parallel_group()
self.data_parallel_rank = mpu.get_data_parallel_rank()
self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Total trainable param count.
# self.total_param_size = sum(
# p.numel()
# for g in self.param_groups
# for p in g["params"]
# # if p .requires_grad ???
# )
# Model params: group sizes, group offset maps.
# self.model_params = []
# self.model_param_group_sizes = []
# self.model_param_group_offset_maps = []
self.model_param_groups = []
for param_group in self.optimizer.param_groups:
param_group_offset = 0
param_group_offset_map = {}
for param in param_group['params']:
if not param.requires_grad:
continue
# self.model_params.append(param)
param_group_offset_map[param] = {
"start" : param_group_offset,
"end" : param_group_offset + param.numel(),
}
param_group_offset += param.numel()
# self.model_param_group_sizes.append(param_group_offset)
# self.model_param_group_offset_maps.append(param_group_offset_map)
self.model_param_groups.append({
"size" : param_group_offset,
"offset_map" : param_group_offset_map,
})
# pax(0, {
# "model_params" : model_params,
# "model_param_group_sizes" : model_param_group_sizes,
# "model_param_group_offset_maps" : model_param_group_offset_maps,
# })
# Shard allocator.
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device())
# allocate_shard = lambda dtype : MemoryBuffer(self.shard_size, dtype)
# Collect DP world shard infos, per group.
model_main_dtypes = set([ args.params_dtype, torch.float ])
self.world_shard_info_groups = [] # world_group_shard_infos ?
self.main_param_shard_groups = []
for model_param_group_size in model_param_group_sizes:
max_world_shard_size = int(math.ceil(model_param_group_size /
self.data_parallel_world_size))
# Group shard infos.
shard_infos = []
for r in range(self.data_parallel_world_size):
shard_start_index = r * max_shard_size
shard_end_index = min(self.total_param_size,
shard_start_index + max_shard_size)
shard_infos.append({
"start" : shard_start_index,
"end" : shard_end_index,
"size" : shard_end_index - shard_start_index,
})
self.world_shard_info_groups.append(shard_infos)
# Allocate shards.
local_shard_size = \
self.world_shard_infos[self.data_parallel_rank]["size"]
# # self.main_param_shard = allocate_shard(torch.float)
# # self.main_grad_shard = allocate_shard(torch.float)
# self.param_shard_map = {ty:allocate_shard(ty) for ty in dtypes}
# self.grad_shard_map = {ty:allocate_shard(ty) for ty in dtypes}
# self.adam_m_shard = allocate_shard(torch.float)
# self.adam_v_shard = allocate_shard(torch.float)
self.main_param_shard_groups.append({ty:allocate_shard(ty)
for ty in model_main_dtypes})
# >>>
# pax(0, {
# "total_param_size" : self.total_param_size,
# "max_shard_size" : max_shard_size,
# "shard_infos" : self.shard_infos,
# "shard_size" : shard_size,
# "param_shard_map" : self.param_shard_map,
# })
# <<<
def get_loss_scale(self):
raise Exception("hi.")
def load_state_dict(self):
raise Exception("hi.")
def reload_model_params(self):
raise Exception("hi.")
def state_dict(self):
raise Exception("hi.")
def zero_grad(self):
raise Exception("hi.")
def reduce_gradients(self, model):
# >>>
args = get_args() args = get_args()
# timers = get_timers() # timers = get_timers()
# <<< # <<<
# >>> # >>> [ already checked in arguments.py ]
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
# <<< # <<<
# grad_buffers = [ m._grad_buffers for m in model ] # grad_buffers = [ m._grad_buffers for m in model ]
for virtual_model in model: for virtual_model in model:
grad_buffers = virtual_model._grad_buffers grad_buffer_map = virtual_model._grad_buffers
# >>>
assert len(grad_buffer_map) == 1, \
"multiple param types not currently supported."
assert args.params_dtype in grad_buffer_map
assert self.total_param_size == grad_buffer_map[args.params_dtype].numel
# <<<
# pax(0, {
# "total_param_size" : self.total_param_size,
# "grad_buffer" : tp(grad_buffer_map[args.params_dtype]),
# })
for dtype, grad_buffer in grad_buffers.items(): for dtype, grad_buffer in grad_buffer_map.items():
dp_grad_buffers = [ dp_grad_buffers = [
grad_buffer.get(self.shard_sizes[i], grad_buffer.get(torch.Size((self.shard_infos[i]["size"],)),
self.shard_start_indexes[i]) self.shard_infos[i]["start"])
for i in self.data_parallel_world_size] for i in range(self.data_parallel_world_size)]
grad_shard = self.grad_shard_map[dtype]
pax(0, {"dp_grad_buffers": dp_grad_buffers})
torch.distributed.reduce_scatter( torch.distributed.reduce_scatter(
self.main_grad_shard, grad_shard,
grad_buffer.data, dp_grad_buffers,
group = mpu.get_data_parallel_group(), group = self.data_parallel_group,
) )
# >>> # >>>
pax(0, { pax(0, {
"virtual_model" : virtual_model, "virtual_model" : virtual_model,
"grad_buffers" : grad_buffers, "grad_buffer_map" : grad_buffer_map,
"dtype" : dtype, "dtype" : dtype,
"grad_buffer / len" : grad_buffer.numel, "grad_shard" : tp(grad_shard),
"grad_buffer / data" : tp(grad_buffer.data), "dp_grad_buffers" : dp_grad_buffers,
# "optimizer" : self.optimizer,
"main_grad_shard" : tp(self.main_grad_shard),
}) })
# <<< # <<<
# >>> # >>>
from lutil import pax, tp
pax(0, { pax(0, {
"model" : model, "model" : model,
"grad_buffers" : grad_buffers, "grad_buffers" : grad_buffers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment