Commit 7ed649ed authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

renamed: full_ -> model_.

parent 82e6730c
...@@ -200,12 +200,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -200,12 +200,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_gbuf_map, param_gbuf_map,
opt_group_ranges): opt_group_ranges):
# Three groups of parameters: # Parameter groups:
# float16_groups: original float16 parameters # model_float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters # model_fp32_groups: original fp32 parameters
# fp32_groups: original fp32 parameters # shard_float16_groups: shards of original float16 parameters
full_float16_groups = [] # shard_fp32_groups: shards of original fp32 parameters
full_fp32_groups = [] # shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups = []
model_fp32_groups = []
shard_float16_groups = [] shard_float16_groups = []
shard_fp32_groups = [] shard_fp32_groups = []
shard_fp32_from_float16_groups = [] shard_fp32_from_float16_groups = []
...@@ -214,13 +216,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -214,13 +216,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for group_index, group_range in enumerate(opt_group_ranges): for group_index, group_range in enumerate(opt_group_ranges):
# Params of this group. # Params of this group.
full_float16_params_this_group = [] model_float16_params_this_group = []
full_fp32_params_this_group = [] model_fp32_params_this_group = []
shard_float16_params_this_group = [] shard_float16_params_this_group = []
shard_fp32_params_this_group = [] shard_fp32_params_this_group = []
shard_fp32_from_float16_params_this_group = [] shard_fp32_from_float16_params_this_group = []
full_float16_groups.append(full_float16_params_this_group) model_float16_groups.append(model_float16_params_this_group)
full_fp32_groups.append(full_fp32_params_this_group) model_fp32_groups.append(model_fp32_params_this_group)
shard_float16_groups.append(shard_float16_params_this_group) shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group) shard_fp32_groups.append(shard_fp32_params_this_group)
shard_fp32_from_float16_groups.append( shard_fp32_from_float16_groups.append(
...@@ -251,7 +253,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -251,7 +253,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_main_param.shared = model_param.shared shard_main_param.shared = model_param.shared
# Add to group. # Add to group.
full_float16_params_this_group.append(model_param) model_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param) shard_float16_params_this_group.append(shard_model_param)
shard_fp32_from_float16_params_this_group.append(shard_main_param) shard_fp32_from_float16_params_this_group.append(shard_main_param)
...@@ -259,7 +261,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -259,7 +261,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
elif model_param.type() == 'torch.cuda.FloatTensor': elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1) \ shard_model_param = model_param.view(-1) \
[param_range.start:param_range.end] [param_range.start:param_range.end]
full_fp32_params_this_group.append(model_param) model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param) shard_fp32_params_this_group.append(shard_model_param)
mpu.copy_tensor_model_parallel_attributes( mpu.copy_tensor_model_parallel_attributes(
shard_model_param, model_param) shard_model_param, model_param)
...@@ -280,8 +282,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -280,8 +282,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
] ]
return ( return (
full_float16_groups, model_float16_groups,
full_fp32_groups, model_fp32_groups,
shard_float16_groups, shard_float16_groups,
shard_fp32_groups, shard_fp32_groups,
shard_fp32_from_float16_groups, shard_fp32_from_float16_groups,
...@@ -315,8 +317,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -315,8 +317,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Allocate main param shards. # Allocate main param shards.
( (
self.full_float16_groups, self.model_float16_groups,
self.full_fp32_groups, self.model_fp32_groups,
self.shard_float16_groups, self.shard_float16_groups,
self.shard_fp32_groups, self.shard_fp32_groups,
self.shard_fp32_from_float16_groups, self.shard_fp32_from_float16_groups,
...@@ -333,6 +335,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -333,6 +335,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_model_param_range_map(self, param): def get_model_param_range_map(self, param):
'''
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
'''
model_index, dtype = self.model_param_gbuf_map[param] model_index, dtype = self.model_param_gbuf_map[param]
gbuf_range_map = self.model_gbuf_ranges[model_index][dtype] gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
param_range_map = gbuf_range_map["param_map"][param] param_range_map = gbuf_range_map["param_map"][param]
...@@ -390,8 +396,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -390,8 +396,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
fragmentation; in the case of set_to_none==True, the space fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point.""" used by this field can be safely deallocated at this point."""
for groups in ( for groups in (
self.full_float16_groups, self.model_float16_groups,
self.full_fp32_groups, self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here? self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups): self.shard_fp32_from_float16_groups):
...@@ -502,46 +508,46 @@ class DistributedOptimizer(MixedPrecisionOptimizer): ...@@ -502,46 +508,46 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def _copy_model_grads_to_main_grads(self): def _copy_model_grads_to_main_grads(self):
def copy_group_grads(full_model_groups, shard_main_groups): def copy_group_grads(model_groups, shard_main_groups):
for full_model_group, shard_main_group in zip(full_model_groups, for model_group, shard_main_group in zip(model_groups,
shard_main_groups): shard_main_groups):
for full_model_param, shard_main_param in zip(full_model_group, for model_param, shard_main_param in zip(model_group,
shard_main_group): shard_main_group):
param_range_map = self.get_model_param_range_map(full_model_param) param_range_map = self.get_model_param_range_map(model_param)
param_range = param_range_map["param"] param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement() assert param_range.size == shard_main_param.nelement()
full_model_grad = full_model_param.main_grad model_grad = model_param.main_grad
shard_model_grad = full_model_grad.view(-1) \ shard_model_grad = model_grad.view(-1) \
[param_range.start:param_range.end] [param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float() shard_main_param.grad = shard_model_grad.float()
copy_group_grads(self.full_float16_groups, copy_group_grads(self.model_float16_groups,
self.shard_fp32_from_float16_groups) self.shard_fp32_from_float16_groups)
copy_group_grads(self.full_fp32_groups, copy_group_grads(self.model_fp32_groups,
self.shard_fp32_groups) self.shard_fp32_groups)
def _copy_main_params_to_model_params(self): def _copy_main_params_to_model_params(self):
def copy_group_params(shard_main_groups, full_model_groups): def copy_group_params(shard_main_groups, model_groups):
for shard_main_group, full_model_group in zip(shard_main_groups, for shard_main_group, model_group in zip(shard_main_groups,
full_model_groups): model_groups):
for shard_main_param, full_model_param in zip(shard_main_group, for shard_main_param, model_param in zip(shard_main_group,
full_model_group): model_group):
param_range_map = self.get_model_param_range_map(full_model_param) param_range_map = self.get_model_param_range_map(model_param)
param_range = param_range_map["param"] param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement() assert param_range.size == shard_main_param.nelement()
full_model_grad = full_model_param.main_grad model_grad = model_param.main_grad
shard_model_grad = full_model_grad.view(-1) \ shard_model_grad = model_grad.view(-1) \
[param_range.start:param_range.end] [param_range.start:param_range.end]
shard_model_grad.data.copy_(shard_main_param) shard_model_grad.data.copy_(shard_main_param)
copy_group_params(self.shard_fp32_from_float16_groups, copy_group_params(self.shard_fp32_from_float16_groups,
self.full_float16_groups) self.model_float16_groups)
copy_group_params(self.shard_fp32_groups, copy_group_params(self.shard_fp32_groups,
self.full_fp32_groups) self.model_fp32_groups)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment