Commit 525a8351 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

created optimizer group shards

parent f498a6fe
......@@ -1016,6 +1016,68 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return shard_map
# @classmethod
# def get_param_size_map(cls, model_gbuf_shards):
# param_size_map = {}
# for model_gbuf_shard_map in model_gbuf_shards:
# for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
# for param, param_shard_map in gbuf_shard_map["param_map"].items():
# assert param not in param_size_map
# param_size_map[param] = param_shard_map["local"].size
# # pax(0, {
# # "dtype" : dtype,
# # "gbuf_shard_map" : gbuf_shard_map,
# # "param" : tp(param),
# # "param_shard_map" : param_shard_map,
# # })
# pax(0, {
# "model_gbuf_shards" : model_gbuf_shards,
# "param_size_map" : [ (str(p.shape), s) for p, s in param_size_map.items() ],
# })
# return param_size_map
@classmethod
def get_optimizer_group_shards(cls, param_groups, model_gbuf_shards):
num_groups = len(param_groups)
# Param group map.
param_group_map = {}
for group_index, group in enumerate(param_groups):
for param in group["params"]:
assert param.requires_grad
param_group_map[param] = group_index
# Optimizer group shards.
group_shards = [ {"size": 0, "param_map": {}} for _ in param_groups ]
for model_gbuf_shard_map in model_gbuf_shards:
for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
for param in gbuf_shard_map["param_map"]:
group_index = param_group_map[param]
group_shard = group_shards[group_index]
param_size = gbuf_shard_map["param_map"][param]["local"].size
param_group_start = group_shard["size"]
param_group_end = param_group_start + param_size
param_group_shard = Shard(param_group_start, param_group_end)
group_shard["size"] += param_size
group_shard["param_map"][param] = param_group_shard
# raise Exception("hi.")
# pax(0, {"param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
# ]})
# pax(0, {"group_shards": group_shards})
return group_shards
def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
params_have_main_grad, use_contiguous_buffers_in_local_ddp,
bf16, grad_scaler, models):
......@@ -1037,17 +1099,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Param group map.
self.param_group_map = {}
for group_index, group in enumerate(self.optimizer.param_groups):
for param in group["params"]:
assert param.requires_grad
self.param_group_map[param] = group_index
# Model grad buffer shards.
self.model_gbuf_shards = []
for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
# pax(0, {"param_group_map": [
# (g, str(p.shape))
# for p, g in self.param_group_map.items()
# ]})
# Optimizer shards.
self.opt_group_shards = self.get_optimizer_group_shards(
self.optimizer.param_groups,
self.model_gbuf_shards)
pax(0, {"opt_group_shards": self.opt_group_shards})
# Allocate main param/grad shard.
# Shard allocator.
# ** torch.nn.Parameter ??
......@@ -1058,18 +1122,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# device = torch.cuda.current_device(),
# requires_grad = True)
# Model grad buffer shards.
self.model_gbuf_shards = []
for model_index, model in enumerate(self.models):
self.model_gbuf_shards.append(self.get_model_gbuf_shard_map(model))
# Allocate main param/grad shard.
param_shard_map = self.get_param_shard_map(self.model_gbuf_shards)
# >>>
param_size_map = self.get_param_size_map(self.model_gbuf_shards)
pax(0, {
"model_gbuf_shards" : self.model_gbuf_shards,
"param_shard_map" : param_shard_map,
"param_size_map" : param_size_map,
})
# <<<
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment