Unverified Commit 9faad392 authored by Joshua Meier's avatar Joshua Meier Committed by GitHub
Browse files

[feat] Support model parallelism in OSS (#287)

* add additional unit test
* support model parallelism in oss
parent 53a912c3
......@@ -216,7 +216,12 @@ class OSS(Optimizer):
return loss
def clip_grad_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) -> torch.Tensor:
def clip_grad_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
filter_params_fn: Callable[[Any], Any] = None,
) -> torch.Tensor:
"""
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
......@@ -237,9 +242,6 @@ class OSS(Optimizer):
.. warning: Model paralelism -groups other than world- are not yet supported
"""
if self.group != dist.group.WORLD:
raise NotImplementedError("Clip norm not yet supported for model parallelism (coming soon!)")
# Compute the max norm for this shards's worth of gradients
max_norm = float(max_norm)
norm_type = float(norm_type)
......@@ -252,11 +254,19 @@ class OSS(Optimizer):
]
)
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if filter_params_fn is not None:
local_params = filter_params_fn(local_params)
# Compute the norm on this grad set,
# then sync all the norms from all ranks
if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.group)
# all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else:
local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._device) for p in local_params]), # type: ignore
......@@ -266,12 +276,12 @@ class OSS(Optimizer):
# local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# all reduce over data parallel and model parallel workers
total_norm = local_norm ** norm_type
dist.all_reduce(total_norm, group=self.group)
dist.all_reduce(total_norm)
total_norm = total_norm ** (1.0 / norm_type)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1:
for device, device_params in self.per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
......@@ -354,7 +364,6 @@ class OSS(Optimizer):
rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict
"""
# Get this optimizer's param_groups shard
param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]]
return {"state": state_dict["state"][rank], "param_groups": param_groups}
......@@ -402,7 +411,7 @@ class OSS(Optimizer):
"""
# Check whether we got a local or global dict
if state_dict["local_state_dict"]:
if "local_state_dict" in state_dict and state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict)
else:
# Dispatch this rank's state dictionary to the wrapped shard optimizer
......
......@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "dataclasses", "helpers", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
known_third_party = ["benchmark_dataset", "dataclasses", "datasets", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
......@@ -604,6 +604,21 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
optimizer.step()
optimizer.zero_grad()
# save and reload without taking any steps
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer2.load_state_dict(state_dict2)
# now take a step and check that parameters are equal
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
# check that model parameters are equal
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before any steps)"
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment