Commit 33626179 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

allocated main params/grads

parent 525a8351
...@@ -1109,27 +1109,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1109,27 +1109,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self.optimizer.param_groups, self.optimizer.param_groups,
self.model_gbuf_shards) self.model_gbuf_shards)
pax(0, {"opt_group_shards": self.opt_group_shards}) # pax(0, {"opt_group_shards": self.opt_group_shards})
# Allocate main param/grad shard. # Allocate main param/grad shard.
# Shard allocator.
# ** torch.nn.Parameter ?? # ** torch.nn.Parameter ??
# ** MemoryBuffer ?? # ** MemoryBuffer ??
# allocate_shard = lambda shard_size, dtype : torch.empty( allocate_shard = lambda shard_size, dtype : torch.empty(
# (shard_size,), (shard_size,),
# dtype = dtype, dtype = dtype,
# device = torch.cuda.current_device(), device = torch.cuda.current_device(),
# requires_grad = True) requires_grad = True)
self.main_param_shards = []
for group_index, group_shard in enumerate(self.opt_group_shards):
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# })
group_size = group_shard["size"]
# for dtype in model_main_dtypes ........
# >>> # Allocate shard.
param_size_map = self.get_param_size_map(self.model_gbuf_shards) main_param = allocate_shard(group_size, torch.float)
pax(0, { main_param.grad = allocate_shard(group_size, torch.float)
"model_gbuf_shards" : self.model_gbuf_shards, self.main_param_shards.append(main_param)
"param_size_map" : param_size_map,
}) # Update optimizer group.
# <<< self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# Leverage state_dict() and load_state_dict() to # Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors # recast preexisting per-param state tensors
...@@ -1137,7 +1144,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer): ...@@ -1137,7 +1144,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# >>> # >>>
pax(0, { pax(0, {
"world_shard_infos" : self.world_shard_infos, "model_gbuf_shards" : self.model_gbuf_shards,
"opt_group_shards" : self.opt_group_shards,
"main_param_shards" : self.main_param_shards,
}) })
# <<< # <<<
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment