Commit 33626179 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

allocated main params/grads

parent 525a8351
......@@ -1109,27 +1109,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.optimizer.param_groups,
self.model_gbuf_shards)
pax(0, {"opt_group_shards": self.opt_group_shards})
# pax(0, {"opt_group_shards": self.opt_group_shards})
# Allocate main param/grad shard.
# Shard allocator.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
# allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
allocate_shard = lambda shard_size, dtype : torch.empty(
(shard_size,),
dtype = dtype,
device = torch.cuda.current_device(),
requires_grad = True)
self.main_param_shards = []
for group_index, group_shard in enumerate(self.opt_group_shards):
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# })
group_size = group_shard["size"]
# for dtype in model_main_dtypes ........
# >>>
param_size_map = self.get_param_size_map(self.model_gbuf_shards)
pax(0, {
"model_gbuf_shards" : self.model_gbuf_shards,
"param_size_map" : param_size_map,
})
# <<<
# Allocate shard.
main_param = allocate_shard(group_size, torch.float)
main_param.grad = allocate_shard(group_size, torch.float)
self.main_param_shards.append(main_param)
# Update optimizer group.
self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
......@@ -1137,7 +1144,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# >>>
pax(0, {
"world_shard_infos" : self.world_shard_infos,
"model_gbuf_shards" : self.model_gbuf_shards,
"opt_group_shards" : self.opt_group_shards,
"main_param_shards" : self.main_param_shards,
})
# <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment