"vscode:/vscode.git/clone" did not exist on "7396d8f6818bbbbadedc0b034e60726dddee0b4e"
Commit d1f5776a authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

small changes

parent ff4c4273
...@@ -1302,8 +1302,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1302,8 +1302,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# timers = get_timers() # timers = get_timers()
# <<< # <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
# ... todo ...
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params.
# ... todo ...
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter. # Reduce-scatter.
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] ** # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert args.use_contiguous_buffers_in_local_ddp assert args.use_contiguous_buffers_in_local_ddp
...@@ -1334,64 +1345,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1334,64 +1345,10 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# }) # })
# >>> # >>>
torch.distributed.barrier() # torch.distributed.barrier()
raise Exception("hi.") # raise Exception("hi.")
# <<< # <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# grad_buffers = [ m._grad_buffers for m in model ]
for virtual_model in model:
grad_buffer_map = virtual_model._grad_buffers
# >>>
assert len(grad_buffer_map) == 1, \
"multiple param types not currently supported."
assert args.params_dtype in grad_buffer_map
assert self.total_param_size == grad_buffer_map[args.params_dtype].numel
# <<<
# pax(0, {
# "total_param_size" : self.total_param_size,
# "grad_buffer" : tp(grad_buffer_map[args.params_dtype]),
# })
for dtype, grad_buffer in grad_buffer_map.items():
dp_grad_buffers = [
grad_buffer.get(torch.Size((self.shard_infos[i]["size"],)),
self.shard_infos[i]["start"])
for i in range(self.data_parallel_world_size)]
grad_shard = self.grad_shard_map[dtype]
torch.distributed.reduce_scatter(
grad_shard,
dp_grad_buffers,
group = self.data_parallel_group,
)
# >>>
pax(0, {
"virtual_model" : virtual_model,
"grad_buffer_map" : grad_buffer_map,
"dtype" : dtype,
"grad_shard" : tp(grad_shard),
"dp_grad_buffers" : dp_grad_buffers,
})
# <<<
# >>>
pax(0, {
"model" : model,
"grad_buffers" : grad_buffers,
"grad_buffers / 0" : grad_buffers[0],
"grad_buffers / 0 / data" :tp(list(grad_buffers[0].values())[0].data),
})
# <<<
def step(self): def step(self):
raise Exception("step.") raise Exception("step.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment