Commit 7ed649ed authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

renamed: full_ -> model_.

parent 82e6730c
......@@ -200,12 +200,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param_gbuf_map,
opt_group_ranges):
# Three groups of parameters:
# float16_groups: original float16 parameters
# fp32_from_float16_groups: fp32 copy of float16 parameters
# fp32_groups: original fp32 parameters
full_float16_groups = []
full_fp32_groups = []
# Parameter groups:
# model_float16_groups: original float16 parameters
# model_fp32_groups: original fp32 parameters
# shard_float16_groups: shards of original float16 parameters
# shard_fp32_groups: shards of original fp32 parameters
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups = []
model_fp32_groups = []
shard_float16_groups = []
shard_fp32_groups = []
shard_fp32_from_float16_groups = []
......@@ -214,13 +216,13 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for group_index, group_range in enumerate(opt_group_ranges):
# Params of this group.
full_float16_params_this_group = []
full_fp32_params_this_group = []
model_float16_params_this_group = []
model_fp32_params_this_group = []
shard_float16_params_this_group = []
shard_fp32_params_this_group = []
shard_fp32_from_float16_params_this_group = []
full_float16_groups.append(full_float16_params_this_group)
full_fp32_groups.append(full_fp32_params_this_group)
model_float16_groups.append(model_float16_params_this_group)
model_fp32_groups.append(model_fp32_params_this_group)
shard_float16_groups.append(shard_float16_params_this_group)
shard_fp32_groups.append(shard_fp32_params_this_group)
shard_fp32_from_float16_groups.append(
......@@ -251,7 +253,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
shard_main_param.shared = model_param.shared
# Add to group.
full_float16_params_this_group.append(model_param)
model_float16_params_this_group.append(model_param)
shard_float16_params_this_group.append(shard_model_param)
shard_fp32_from_float16_params_this_group.append(shard_main_param)
......@@ -259,7 +261,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
elif model_param.type() == 'torch.cuda.FloatTensor':
shard_model_param = model_param.view(-1) \
[param_range.start:param_range.end]
full_fp32_params_this_group.append(model_param)
model_fp32_params_this_group.append(model_param)
shard_fp32_params_this_group.append(shard_model_param)
mpu.copy_tensor_model_parallel_attributes(
shard_model_param, model_param)
......@@ -280,8 +282,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
]
return (
full_float16_groups,
full_fp32_groups,
model_float16_groups,
model_fp32_groups,
shard_float16_groups,
shard_fp32_groups,
shard_fp32_from_float16_groups,
......@@ -315,8 +317,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Allocate main param shards.
(
self.full_float16_groups,
self.full_fp32_groups,
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups,
self.shard_fp32_groups,
self.shard_fp32_from_float16_groups,
......@@ -333,6 +335,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def get_model_param_range_map(self, param):
'''
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
'''
model_index, dtype = self.model_param_gbuf_map[param]
gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
param_range_map = gbuf_range_map["param_map"][param]
......@@ -390,8 +396,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point."""
for groups in (
self.full_float16_groups,
self.full_fp32_groups,
self.model_float16_groups,
self.model_fp32_groups,
self.shard_float16_groups, # grad empty/unused here?
self.shard_fp32_groups, # throws grad-access warning
self.shard_fp32_from_float16_groups):
......@@ -502,46 +508,46 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def _copy_model_grads_to_main_grads(self):
def copy_group_grads(full_model_groups, shard_main_groups):
for full_model_group, shard_main_group in zip(full_model_groups,
shard_main_groups):
for full_model_param, shard_main_param in zip(full_model_group,
shard_main_group):
def copy_group_grads(model_groups, shard_main_groups):
for model_group, shard_main_group in zip(model_groups,
shard_main_groups):
for model_param, shard_main_param in zip(model_group,
shard_main_group):
param_range_map = self.get_model_param_range_map(full_model_param)
param_range_map = self.get_model_param_range_map(model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
full_model_grad = full_model_param.main_grad
shard_model_grad = full_model_grad.view(-1) \
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1) \
[param_range.start:param_range.end]
shard_main_param.grad = shard_model_grad.float()
copy_group_grads(self.full_float16_groups,
copy_group_grads(self.model_float16_groups,
self.shard_fp32_from_float16_groups)
copy_group_grads(self.full_fp32_groups,
copy_group_grads(self.model_fp32_groups,
self.shard_fp32_groups)
def _copy_main_params_to_model_params(self):
def copy_group_params(shard_main_groups, full_model_groups):
for shard_main_group, full_model_group in zip(shard_main_groups,
full_model_groups):
for shard_main_param, full_model_param in zip(shard_main_group,
full_model_group):
def copy_group_params(shard_main_groups, model_groups):
for shard_main_group, model_group in zip(shard_main_groups,
model_groups):
for shard_main_param, model_param in zip(shard_main_group,
model_group):
param_range_map = self.get_model_param_range_map(full_model_param)
param_range_map = self.get_model_param_range_map(model_param)
param_range = param_range_map["param"]
assert param_range.size == shard_main_param.nelement()
full_model_grad = full_model_param.main_grad
shard_model_grad = full_model_grad.view(-1) \
model_grad = model_param.main_grad
shard_model_grad = model_grad.view(-1) \
[param_range.start:param_range.end]
shard_model_grad.data.copy_(shard_main_param)
copy_group_params(self.shard_fp32_from_float16_groups,
self.full_float16_groups)
self.model_float16_groups)
copy_group_params(self.shard_fp32_groups,
self.full_fp32_groups)
self.model_fp32_groups)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment