Unverified Commit 88c319aa authored by Olatunji Ruwase's avatar Olatunji Ruwase Committed by GitHub
Browse files

Handle parameter groups smaller than DP (#273)

* Load non-DeepSpeed checkpoints into ZeRO optimizer

* Handle parameters smaller than DP

* Formatting fixes
parent 664fa30c
......@@ -561,7 +561,6 @@ class DeepSpeedLight(Module):
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
logger.info('Creating fp16 ZeRO Optimizer Stage 1')
optimizer = FP16_DeepSpeedZeroOptimizer_Stage1(
optimizer,
static_loss_scale=self.loss_scale(),
......@@ -593,7 +592,6 @@ class DeepSpeedLight(Module):
gradient_predivide_factor=self.gradient_predivide_factor())
else:
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))
logger.info('Creating fp16 zero stage {} optimizer'.format(zero_stage))
return optimizer
......
......@@ -1355,6 +1355,12 @@ class FP16_DeepSpeedZeroOptimizer(object):
return state_dict
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp32_partition.data.copy_(fp16_partitions[partition_id].data)
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
......
......@@ -353,6 +353,7 @@ class FP16_Optimizer(object):
state_dict['clip_grad'] = self.clip_grad
return state_dict
# Refresh fp32 master params from fp16 copies
def refresh_fp32_params(self):
for current, saved in zip(self.fp32_groups_flat, self.fp16_groups_flat):
current.data.copy_(saved.data)
......
......@@ -14,87 +14,48 @@ def flatten_dense_tensors_sub_partition_aligned(tensor_list,
dp,
max_elements_per_comm,
pg):
num_elements = 0
for tensor in tensor_list:
num_elements = num_elements + tensor.numel()
assert (max_elements_per_comm >= dp,
f"max_elements_per_comm {max_elements_per_comm} < dp {dp}")
num_elements = sum(t.numel() for t in tensor_list)
log_dist("Total number of elements in model: {}, max elements per com: {}".format(
num_elements,
max_elements_per_comm),
ranks=[0])
max_elements_per_comm = min(max_elements_per_comm, num_elements)
sub_partition_size = int(max_elements_per_comm // dp)
# Compute aligned partition size based on parameter count
aligned_param_partition_size = math.ceil(num_elements / dp)
alignment = sub_partition_size
# Compute aligned partition size based on communication size
aligned_comm_partition_size = int(max_elements_per_comm // dp)
# if alignment == 0:
# # number of elements not divisible by dp, outside range and small model must pad with zeroes
# pad_tensor = torch.zeros(max_elements_per_comm,
# device=tensor_list[0].device,
# dtype=tensor_list[0].dtype)
# return _flatten_dense_tensors(pad_tensor)
if aligned_param_partition_size <= aligned_comm_partition_size:
sub_partition_count = 1
sub_partition_size = aligned_param_partition_size
else:
sub_partition_count = math.ceil(aligned_param_partition_size /
aligned_comm_partition_size)
sub_partition_size = aligned_comm_partition_size
remaining = int(num_elements % alignment)
# Compute required padding for alignment to dp and max_elements_per_comm
padding = (sub_partition_count * sub_partition_size * dp) - num_elements
# ensure we have equal sized sub-partitions
elements_to_add = 0
if remaining:
elements_to_add = alignment - remaining
# adding padded tensor later after we check comm alignment
log_dist("adding pad tensor for alignment, {} + {}->{}".format(
num_elements,
elements_to_add,
num_elements + elements_to_add),
ranks=[0])
#num_elements = num_elements + elements_to_add
else:
padded_tensor_list = tensor_list
num_partitions = int((num_elements + elements_to_add) // sub_partition_size)
assert (num_elements + elements_to_add) % sub_partition_size == 0, "num elements should be " \
"aligned by sub partition " \
"size"
num_comm_intervals = int(num_partitions // dp)
partition_remaining = int(num_partitions % dp)
log_dist("num_comm_intervals={}, partition_remaining={}".format(
num_comm_intervals,
partition_remaining),
log_dist(
f"sub_partition_count: {sub_partition_count}, sub_partition_size: {sub_partition_size}, padding: {padding}",
ranks=[0])
if partition_remaining != 0:
log_dist("adding pad tensor and/or extra sub partition", ranks=[0])
# add pad tensor for alignment of comm interval, this overrules previous possibly sub-partition alignment
num_comm_intervals += 1
aligned_comm_elements = num_comm_intervals * sub_partition_size * dp
elements_to_add = aligned_comm_elements - num_elements
pad_tensor = torch.zeros(elements_to_add,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
log_dist("adding pad tensor and/or extra sub partition, {} + {}->{}".format(
num_elements,
elements_to_add,
num_elements + elements_to_add),
log_dist(
f"number of elements with padding: {num_elements} + {padding} = {num_elements + padding}",
ranks=[0])
num_elements += elements_to_add
elif elements_to_add > 0:
# add pad tensor for just alignment of sub-partition
pad_tensor = torch.zeros(elements_to_add,
if padding == 0:
aligned_tensor_list = tensor_list
else:
pad_tensor = torch.zeros(padding,
device=tensor_list[0].device,
dtype=tensor_list[0].dtype)
padded_tensor_list = tensor_list + [pad_tensor]
num_elements += elements_to_add
if pg is None or dist.get_rank(group=pg) == 0:
logger.info("Number of Elements (w. padding) is %s", num_elements)
aligned_tensor_list = tensor_list + [pad_tensor]
padded_num_elems = 0
for p in padded_tensor_list:
padded_num_elems += p.numel()
assert num_elements == padded_num_elems, "{} != {}, rank={}".format(num_elements, padded_num_elems, dist.get_rank())
return _flatten_dense_tensors(padded_tensor_list)
return _flatten_dense_tensors(aligned_tensor_list)
def _single_range_check(current_index, start_index, end_index, tensor_size):
......@@ -780,6 +741,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
'local_sub_partitions_of_fp32_groups'] = self.local_sub_partitions_of_fp32_groups
return state_dict
# Refresh the fp32 master params from the fp16 copies.
def refresh_fp32_params(self):
partition_id = dist.get_rank(group=self.dp_process_group)
for fp16_all_sub_partitions, fp32_local_sub_partitions in zip(self.parallel_sub_partitioned_fp16_groups, self.local_sub_partitions_of_fp32_groups):
for local_sub_partition_param_fp16, local_sub_partition_param_fp32 in zip(fp16_all_sub_partitions[partition_id], fp32_local_sub_partitions):
local_sub_partition_param_fp32.data.copy_(
local_sub_partition_param_fp16.data)
def load_state_dict(self, state_dict, load_optimizer_states=True):
"""
Loads a state_dict created by an earlier call to state_dict().
......
......@@ -353,34 +353,45 @@ def test_zero_allow_untested_optimizer(tmpdir, zero_stage):
_test_zero_allow_untested_optimizer(args)
# @pytest.mark.parametrize("zero_stage", [1])
# def test_zero_empty_partition(tmpdir, zero_stage):
# config_dict = {
# "train_batch_size": 3,
# "fp16": {
# "enabled": True
# },
# "optimizer": {
# "type": "Adam",
# "params": {
# "lr": 0.00015
# }
# },
# "zero_optimization": {
# "stage": zero_stage
# }
# }
# args = args_from_dict(tmpdir, config_dict)
# @distributed_test(world_size=[3])
# def _test_zero_empty_partition(args):
# hidden_dim = 1
# model = SimpleModel(hidden_dim)
# # Ensure model has 2 parameters, to cause empty partition with DP=3
# assert len(list(model.parameters())) == 2
# model, _, _, _ = deepspeed.initialize(args=args,
# model=model,
# model_parameters=model.parameters())
# model.step()
# _test_zero_empty_partition(args)
@pytest.mark.parametrize("zero_stage", [1])
def test_zero_empty_partition(tmpdir, zero_stage):
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"fp16": {
"enabled": True,
"initial_scale_power": 8
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": {
"stage": zero_stage
}
}
args = args_from_dict(tmpdir, config_dict)
@distributed_test(world_size=[3])
def _test_zero_empty_partition(args):
hidden_dim = 1
model = SimpleModel(hidden_dim)
# Ensure model has 2 parameters, to cause empty partition with DP=3
assert len(list(model.parameters())) == 2
model, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=model.parameters())
# Now make sure things work..
data_loader = random_dataloader(model=model,
total_samples=1,
hidden_dim=hidden_dim,
device=model.device)
for n, batch in enumerate(data_loader):
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
_test_zero_empty_partition(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment