Unverified Commit 9faad392 authored by Joshua Meier's avatar Joshua Meier Committed by GitHub
Browse files

[feat] Support model parallelism in OSS (#287)

* add additional unit test
* support model parallelism in oss
parent 53a912c3
...@@ -216,7 +216,12 @@ class OSS(Optimizer): ...@@ -216,7 +216,12 @@ class OSS(Optimizer):
return loss return loss
def clip_grad_norm(self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0) -> torch.Tensor: def clip_grad_norm(
self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2.0,
filter_params_fn: Callable[[Any], Any] = None,
) -> torch.Tensor:
""" """
Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were Clip all gradients at this point in time. The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place. concatenated into a single vector. Gradients are modified in-place.
...@@ -237,9 +242,6 @@ class OSS(Optimizer): ...@@ -237,9 +242,6 @@ class OSS(Optimizer):
.. warning: Model paralelism -groups other than world- are not yet supported .. warning: Model paralelism -groups other than world- are not yet supported
""" """
if self.group != dist.group.WORLD:
raise NotImplementedError("Clip norm not yet supported for model parallelism (coming soon!)")
# Compute the max norm for this shards's worth of gradients # Compute the max norm for this shards's worth of gradients
max_norm = float(max_norm) max_norm = float(max_norm)
norm_type = float(norm_type) norm_type = float(norm_type)
...@@ -252,11 +254,19 @@ class OSS(Optimizer): ...@@ -252,11 +254,19 @@ class OSS(Optimizer):
] ]
) )
# Option to filter parameters from the grad_norm calculation. This is useful for model parallelism.
# To avoid double counting, only consider parameters on rank zero + anything marked 'model_parallel'
# 'model_parallel' flag is set in Megatron-LM:
# https://github.com/NVIDIA/Megatron-LM/blob/19301985dd31c8b612095cbad15bd903e8ddd497/megatron/mpu/layers.py#L54
if filter_params_fn is not None:
local_params = filter_params_fn(local_params)
# Compute the norm on this grad set, # Compute the norm on this grad set,
# then sync all the norms from all ranks # then sync all the norms from all ranks
if norm_type == inf: if norm_type == inf:
total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore total_norm = max(p.grad.detach().abs().max().to(self._device) for p in local_params) # type: ignore
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=self.group) # all reduce over data parallel and model parallel workers
dist.all_reduce(total_norm, op=torch.distributed.ReduceOp.MAX, group=dist.group.WORLD)
else: else:
local_norm = torch.norm( local_norm = torch.norm(
input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._device) for p in local_params]), # type: ignore input=torch.stack([torch.norm(input=p.grad.detach(), p=norm_type, dtype=torch.float32).to(self._device) for p in local_params]), # type: ignore
...@@ -266,12 +276,12 @@ class OSS(Optimizer): ...@@ -266,12 +276,12 @@ class OSS(Optimizer):
# local norm result can be accumulated with the remote ones if put to the right power # local norm result can be accumulated with the remote ones if put to the right power
# n_i = sum_rank(a^p)^1/p # n_i = sum_rank(a^p)^1/p
# -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p # -> n_total = all_reduce(n_i^p)^(1/p) = sum_i(n_i^p)^1/p = sum_i(sum_rank(a^p))^1/p
# all reduce over data parallel and model parallel workers
total_norm = local_norm ** norm_type total_norm = local_norm ** norm_type
dist.all_reduce(total_norm, group=self.group) dist.all_reduce(total_norm)
total_norm = total_norm ** (1.0 / norm_type) total_norm = total_norm ** (1.0 / norm_type)
clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6) clip_coef = torch.tensor(max_norm, dtype=total_norm.dtype, device=total_norm.device) / (total_norm + 1e-6)
if clip_coef < 1: if clip_coef < 1:
for device, device_params in self.per_device_params.items(): for device, device_params in self.per_device_params.items():
for p in filter(lambda x: x.grad is not None, device_params[self.rank]): for p in filter(lambda x: x.grad is not None, device_params[self.rank]):
...@@ -354,7 +364,6 @@ class OSS(Optimizer): ...@@ -354,7 +364,6 @@ class OSS(Optimizer):
rank (int): rank to get local_state_dict for rank (int): rank to get local_state_dict for
state_dict (dict): global state_dict state_dict (dict): global state_dict
""" """
# Get this optimizer's param_groups shard
param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]] param_groups = state_dict["param_groups"][state_dict["partition"][rank][0] : state_dict["partition"][rank][1]]
return {"state": state_dict["state"][rank], "param_groups": param_groups} return {"state": state_dict["state"][rank], "param_groups": param_groups}
...@@ -402,7 +411,7 @@ class OSS(Optimizer): ...@@ -402,7 +411,7 @@ class OSS(Optimizer):
""" """
# Check whether we got a local or global dict # Check whether we got a local or global dict
if state_dict["local_state_dict"]: if "local_state_dict" in state_dict and state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict) self.load_local_state_dict(state_dict)
else: else:
# Dispatch this rank's state dictionary to the wrapped shard optimizer # Dispatch this rank's state dictionary to the wrapped shard optimizer
......
...@@ -28,4 +28,4 @@ use_parentheses = true ...@@ -28,4 +28,4 @@ use_parentheses = true
skip_glob = ["build/*", "stubs/*"] skip_glob = ["build/*", "stubs/*"]
# Don't split "import" and "from". # Don't split "import" and "from".
force_sort_within_sections = true force_sort_within_sections = true
known_third_party = ["benchmark_dataset", "dataclasses", "helpers", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"] known_third_party = ["benchmark_dataset", "dataclasses", "datasets", "helpers", "models", "numpy", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]
...@@ -604,6 +604,21 @@ def run_state_dict_distributed(rank, world_size, tempfile_name): ...@@ -604,6 +604,21 @@ def run_state_dict_distributed(rank, world_size, tempfile_name):
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
# save and reload without taking any steps
sharded_optimizer2.consolidate_state_dict()
state_dict2 = sharded_optimizer2.state_dict()
sharded_optimizer2 = optim.OSS(model_oss2.parameters(), lr=0.1, momentum=0.99)
sharded_optimizer2.load_state_dict(state_dict2)
# now take a step and check that parameters are equal
# take a step
run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2)
# check that model parameters are equal
for param1, param2 in zip(model_oss1.parameters(), model_oss2.parameters()):
assert torch.allclose(param1, param2), "parameters of the two identical models have diverged (before any steps)"
# take a step # take a step
run_grad_step(device, model_oss1, sharded_optimizer1) run_grad_step(device, model_oss1, sharded_optimizer1)
run_grad_step(device, model_oss2, sharded_optimizer2) run_grad_step(device, model_oss2, sharded_optimizer2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment