Commit a4f41882 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

copy model params to main params at very start

parent ae6cd9b5
......@@ -696,37 +696,37 @@ class Shard:
class Float16DistributedOptimizer(BaseFloat16Optimizer):
# >>>
@classmethod
def test_reduce_scatter(cls):
torch.manual_seed(mpu.get_data_parallel_rank())
size = (20,)
dtype = torch.float
device = torch.cuda.current_device()
data_parallel_world_size = mpu.get_data_parallel_world_size()
data_parallel_group = mpu.get_data_parallel_group()
input_list = [
# torch.randn(size, dtype = dtype, device = device)
5 * torch.randint(low = 1, high = 3, size = size, dtype = dtype, device = device)
for _ in range(data_parallel_world_size)
]
output = torch.empty(size, dtype = dtype, device = device)
torch.distributed.reduce_scatter(
output,
input_list,
group = data_parallel_group,
)
if torch.distributed.get_rank() == 0:
print(output)
pax(0, {
"data_parallel_world_size" : data_parallel_world_size,
"data_parallel_group" : data_parallel_group,
"input_list" : input_list,
"output" : tp(output),
})
# @classmethod
# def test_reduce_scatter(cls):
# torch.manual_seed(mpu.get_data_parallel_rank())
# size = (20,)
# dtype = torch.float
# device = torch.cuda.current_device()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# input_list = [
# # torch.randn(size, dtype = dtype, device = device)
# 5 * torch.randint(low = 1, high = 3, size = size, dtype = dtype, device = device)
# for _ in range(data_parallel_world_size)
# ]
# output = torch.empty(size, dtype = dtype, device = device)
# torch.distributed.reduce_scatter(
# output,
# input_list,
# group = data_parallel_group,
# )
# if torch.distributed.get_rank() == 0:
# print(output)
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "data_parallel_group" : data_parallel_group,
# "input_list" : input_list,
# "output" : tp(output),
# })
# <<<
@classmethod
......@@ -750,10 +750,17 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
if param_local_end > param_local_start:
param_local_shard = Shard(param_local_start, param_local_end)
param_world_shard = param_local_shard.normalize(param_world_start)
sub_param_start = max(0, gbuf_world_shard.start-param_world_start)
sub_param_shard = param_local_shard.normalize(sub_param_start)
param_shard_map[param] = {
"local" : param_local_shard,
"world" : param_world_shard,
"gbuf_world" : param_world_shard,
"gbuf_local" : param_local_shard,
"param" : sub_param_shard,
}
# >>>
if param_world_start < gbuf_world_shard.start:
raise Exception("hi.")
# <<<
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
......@@ -798,17 +805,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
@classmethod
def get_model_gbuf_shard_map(cls, model):
# shard_index_map = {
shard_map = {
return {
dtype : cls.get_model_gbuf_shard(model, dtype)
for dtype in model._grad_buffers
}
# pax(0, {"shard_map": shard_map})
return shard_map
@classmethod
def get_param_gbuf_map(cls, model_gbuf_shards):
......@@ -855,7 +856,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group_index = param_group_map[param]
group_shard = group_shards[group_index]
param_size = gbuf_shard_map["param_map"][param]["local"].size
param_size = gbuf_shard_map["param_map"][param]["param"].size
param_group_start = group_shard["size"]
param_group_end = param_group_start + param_size
......@@ -1055,10 +1056,42 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def _collect_main_grad_data_for_unscaling(self):
return [ p.grad.data for p in self.main_param_shards ]
def _copy_model_params_to_main_params(self):
for group_index, group_shard in enumerate(self.opt_group_shards):
main_param = self.main_param_shards[group_index]
for model_param, main_shard in group_shard["param_map"].items():
# Model shard.
model_index, dtype = self.param_gbuf_map[model_param]
model_shard = self.model_gbuf_shards \
[model_index][dtype]["param_map"][model_param]["param"]
assert main_shard.size == model_shard.size
# Copy shard data.
main_view = main_param[main_shard.start:main_shard.end]
model_view = model_param[model_shard.start:model_shard.end].view(-1)
main_view.detach().copy_(model_view)
# pax(0, {
# "main_param" : tp(main_param),
# "model_param" : tp(model_param),
# "main_view" : tp(main_view),
# "model_view" : tp(model_view),
# "main_shard" : str(main_shard),
# "model_shard" : str(model_shard),
# })
pax(0, {
"opt_group_shards" : self.opt_group_shards,
"main_param_shards" : self.main_param_shards,
})
def _copy_model_grads_to_main_grads(self):
for group_index, group_shard in enumerate(self.opt_group_shards):
for param, main_shard in group_shard["param_map"].items():
for model_param, main_shard in group_shard["param_map"].items():
model_index, gbuf_dtype = self.param_gbuf_map[param]
model_shard = self.model_gbuf_shards \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment