Unverified Commit d80c38f9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS doc (#101)

* Doc extensions to some APIs
* FIx the benchmark and tutorial
parent 63f7796a
......@@ -3,9 +3,8 @@
import argparse
import math
import os
import time
from typing import Any, List, cast
from typing import Any, List, Optional, cast
import torch
import torch.distributed as dist
......@@ -24,9 +23,9 @@ OPTIM = torch.optim.RMSprop
def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29501"
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size)
dist.init_process_group(
backend=BACKEND, init_method="tcp://localhost:29501", rank=rank, world_size=world_size, store=None
)
def get_problem(rank, data_size, batch_size):
......@@ -81,9 +80,11 @@ def train_oss_ddp(
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
if dist.get_rank() == 0:
print(f"Loss: {loss.item()}")
......@@ -146,6 +147,7 @@ def train(
model.train()
measurements = []
final_loss: Optional[float] = -1.0
for epoch in range(num_epochs):
epoch_start = time.monotonic()
......@@ -156,12 +158,14 @@ def train(
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size
loss.backward()
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
return loss
optimizer.step(closure)
final_loss = optimizer.step(closure)
epoch_end = time.monotonic()
......@@ -176,7 +180,7 @@ def train(
measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec")
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}")
torch.cuda.synchronize(rank)
training_stop = time.monotonic()
......
......@@ -2,9 +2,9 @@ Optimizer state sharding
========================
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks likemake html
Let's suppose that your trainer looks like
.. code-block:: default
.. code-block:: python
import torch
......@@ -23,7 +23,9 @@ Let's suppose that your trainer looks likemake html
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments)
optimizer = torch.optim.SGD(
params=model.parameters(),
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
......@@ -33,18 +35,17 @@ Let's suppose that your trainer looks likemake html
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows
.. code-block:: default
.. code-block:: python
:emphasize-lines: 49, 65, 66
import torch
from fairscale.optim.oss import OSS
......@@ -61,9 +62,14 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
......@@ -73,8 +79,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
model.zero_grad()
outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
......@@ -25,8 +25,7 @@ else:
class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_.
::
opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
:: opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
.. _ZeRO: https://arxiv.org/abs/1910.02054
......@@ -142,6 +141,14 @@ class OSS(Optimizer):
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note: Any extra parameter is passed to the base optimizer as-is"""
# Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups()
......@@ -162,13 +169,22 @@ class OSS(Optimizer):
return loss
def local_state_dict(self) -> dict:
""" Gets this rank's state_dict. """
"""Gets this rank's state_dict.
Returns:
The state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
return self.optim.state_dict()
def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
""" Update the consolidated state_dict list, one per rank.
"""Update the consolidated state_dict list, one per rank.
This needs to be called on all replicas """
.. warning: This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated
self._sync_param_groups()
......@@ -183,12 +199,13 @@ class OSS(Optimizer):
self._broadcast_state_dict()
def state_dict(self) -> Dict[str, Any]:
"""
Return the last known global optimizer state, which consist of a list of the shards.
"""Return the last known global optimizer state, which consist of a list of the shards.
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
NOTE:
- If the state has not been consolidated, this returns a shard's worth, not the global state.
- Returning the global state is limited to the replica which was responsible for the consolidation.
.. warning:
Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""
......@@ -218,7 +235,10 @@ class OSS(Optimizer):
}
def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """
"""Loads this rank's state_dict.
.. warning: This is not meant to load the global state dict.
"""
self.optim.load_state_dict(state_dict)
......@@ -242,7 +262,12 @@ class OSS(Optimizer):
global_group[k] = v
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """
"""Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""
# Check whether we got a local or global dict
if state_dict["local_state_dict"]:
......@@ -256,6 +281,18 @@ class OSS(Optimizer):
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups})
def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options
.. warning: This handles updating the shards on all partitions, but needs to be called on all ranks.
"""
super().add_param_group(param_group)
if not self.in_super_constructor:
self._partition_parameters.clear() # Force a re-partitioning
......@@ -273,9 +310,7 @@ class OSS(Optimizer):
local_group[k] = global_group[k]
def _collect_sharded_states(self) -> List[Dict[str, Any]]:
"""
Collect all the state shards, in CPU memory.
"""
"""Collect all the state shards, in CPU memory."""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = []
......@@ -304,9 +339,7 @@ class OSS(Optimizer):
return all_states
def _broadcast_state_dict(self) -> None:
"""
Broadcast this rank's state shard, discard others
"""
"""Broadcast this rank's state shard, discard others"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
for rank in range(dist.get_world_size(group=self.group)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment