Unverified Commit d80c38f9 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[chore] OSS doc (#101)

* Doc extensions to some APIs
* FIx the benchmark and tutorial
parent 63f7796a
...@@ -3,9 +3,8 @@ ...@@ -3,9 +3,8 @@
import argparse import argparse
import math import math
import os
import time import time
from typing import Any, List, cast from typing import Any, List, Optional, cast
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -24,9 +23,9 @@ OPTIM = torch.optim.RMSprop ...@@ -24,9 +23,9 @@ OPTIM = torch.optim.RMSprop
def dist_init(rank, world_size): def dist_init(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost" dist.init_process_group(
os.environ["MASTER_PORT"] = "29501" backend=BACKEND, init_method="tcp://localhost:29501", rank=rank, world_size=world_size, store=None
dist.init_process_group(backend=BACKEND, rank=rank, world_size=world_size) )
def get_problem(rank, data_size, batch_size): def get_problem(rank, data_size, batch_size):
...@@ -81,9 +80,11 @@ def train_oss_ddp( ...@@ -81,9 +80,11 @@ def train_oss_ddp(
model.zero_grad() model.zero_grad()
outputs = model(batch["inputs"]) outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size loss /= world_size
loss.backward() loss.backward()
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
if dist.get_rank() == 0: if dist.get_rank() == 0:
print(f"Loss: {loss.item()}") print(f"Loss: {loss.item()}")
...@@ -146,6 +147,7 @@ def train( ...@@ -146,6 +147,7 @@ def train(
model.train() model.train()
measurements = [] measurements = []
final_loss: Optional[float] = -1.0
for epoch in range(num_epochs): for epoch in range(num_epochs):
epoch_start = time.monotonic() epoch_start = time.monotonic()
...@@ -156,12 +158,14 @@ def train( ...@@ -156,12 +158,14 @@ def train(
model.zero_grad() model.zero_grad()
outputs = model(batch["inputs"]) outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, batch["label"])
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
loss /= world_size loss /= world_size
loss.backward() loss.backward()
dist.all_reduce(loss, op=dist.ReduceOp.SUM)
return loss return loss
optimizer.step(closure) final_loss = optimizer.step(closure)
epoch_end = time.monotonic() epoch_end = time.monotonic()
...@@ -176,7 +180,7 @@ def train( ...@@ -176,7 +180,7 @@ def train(
measurements.append(data_size / (epoch_end - epoch_start)) measurements.append(data_size / (epoch_end - epoch_start))
if dist.get_rank() == 0: if dist.get_rank() == 0:
print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec") print(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}")
torch.cuda.synchronize(rank) torch.cuda.synchronize(rank)
training_stop = time.monotonic() training_stop = time.monotonic()
......
...@@ -2,9 +2,9 @@ Optimizer state sharding ...@@ -2,9 +2,9 @@ Optimizer state sharding
======================== ========================
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code. Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Let's suppose that your trainer looks likemake html Let's suppose that your trainer looks like
.. code-block:: default .. code-block:: python
import torch import torch
...@@ -23,7 +23,9 @@ Let's suppose that your trainer looks likemake html ...@@ -23,7 +23,9 @@ Let's suppose that your trainer looks likemake html
loss = myVeryRelevantLoss() loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc... base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
optimizer = torch.optim.SGD(params=model.parameters(), **base_optimizer_arguments) optimizer = torch.optim.SGD(
params=model.parameters(),
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example: # Any relevant training loop, nothing specific to OSS. For example:
model.train() model.train()
...@@ -33,18 +35,17 @@ Let's suppose that your trainer looks likemake html ...@@ -33,18 +35,17 @@ Let's suppose that your trainer looks likemake html
model.zero_grad() model.zero_grad()
outputs = model(batch["inputs"]) outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size loss /= world_size
loss.backward() loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step() optimizer.step()
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows
.. code-block:: default .. code-block:: python
:emphasize-lines: 49, 65, 66
import torch import torch
from fairscale.optim.oss import OSS from fairscale.optim.oss import OSS
...@@ -61,9 +62,14 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer ...@@ -61,9 +62,14 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
dataloader = mySuperFastDataloader() dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss() loss = myVeryRelevantLoss()
base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments) optimizer = OSS(
params=model.parameters(),
optim=base_optimizer,
**base_optimizer_arguments)
# Any relevant training loop, nothing specific to OSS. For example: # Any relevant training loop, nothing specific to OSS. For example:
model.train() model.train()
...@@ -73,8 +79,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer ...@@ -73,8 +79,7 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
model.zero_grad() model.zero_grad()
outputs = model(batch["inputs"]) outputs = model(batch["inputs"])
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, batch["label"])
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss /= world_size loss /= world_size
loss.backward() loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step() optimizer.step()
...@@ -25,8 +25,7 @@ else: ...@@ -25,8 +25,7 @@ else:
class OSS(Optimizer): class OSS(Optimizer):
"""Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>` """Wraps an arbitrary :class:`optim.Optimizer <torch.optim.Optimizer>`
optimizer and shards its state as described by ZeRO_. optimizer and shards its state as described by ZeRO_.
:: :: opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
opt = OSS(params, optim=torch.optim.Adam, lr=0.01)
.. _ZeRO: https://arxiv.org/abs/1910.02054 .. _ZeRO: https://arxiv.org/abs/1910.02054
...@@ -142,6 +141,14 @@ class OSS(Optimizer): ...@@ -142,6 +141,14 @@ class OSS(Optimizer):
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs. # NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs. # For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
"""Performs a single optimization step (parameter update).
Arguments:
closure (callable): A closure that reevaluates the model and
returns the loss. Optional for most optimizers.
.. note: Any extra parameter is passed to the base optimizer as-is"""
# Sync oss param_groups attributes in case they've been updated by a scheduler. # Sync oss param_groups attributes in case they've been updated by a scheduler.
self._sync_param_groups() self._sync_param_groups()
...@@ -162,13 +169,22 @@ class OSS(Optimizer): ...@@ -162,13 +169,22 @@ class OSS(Optimizer):
return loss return loss
def local_state_dict(self) -> dict: def local_state_dict(self) -> dict:
""" Gets this rank's state_dict. """ """Gets this rank's state_dict.
Returns:
The state of the optimizer as a :class:`dict`.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a dict containing all parameter groups
"""
return self.optim.state_dict() return self.optim.state_dict()
def consolidate_state_dict(self, recipient_rank: int = 0) -> None: def consolidate_state_dict(self, recipient_rank: int = 0) -> None:
""" Update the consolidated state_dict list, one per rank. """Update the consolidated state_dict list, one per rank.
This needs to be called on all replicas """ .. warning: This needs to be called on all replicas"""
# Sync lr and other attributes in case its been updated # Sync lr and other attributes in case its been updated
self._sync_param_groups() self._sync_param_groups()
...@@ -183,12 +199,13 @@ class OSS(Optimizer): ...@@ -183,12 +199,13 @@ class OSS(Optimizer):
self._broadcast_state_dict() self._broadcast_state_dict()
def state_dict(self) -> Dict[str, Any]: def state_dict(self) -> Dict[str, Any]:
""" """Return the last known global optimizer state, which consist of a list of the shards.
Return the last known global optimizer state, which consist of a list of the shards.
.. warning:
If the state has not been consolidated, this returns a shard's worth, not the global state.
NOTE: .. warning:
- If the state has not been consolidated, this returns a shard's worth, not the global state. Returning the global state is limited to the replica which was responsible for the consolidation.
- Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called. The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
""" """
...@@ -218,7 +235,10 @@ class OSS(Optimizer): ...@@ -218,7 +235,10 @@ class OSS(Optimizer):
} }
def load_local_state_dict(self, state_dict: dict) -> None: def load_local_state_dict(self, state_dict: dict) -> None:
""" Loads this rank's state_dict. """ """Loads this rank's state_dict.
.. warning: This is not meant to load the global state dict.
"""
self.optim.load_state_dict(state_dict) self.optim.load_state_dict(state_dict)
...@@ -242,7 +262,12 @@ class OSS(Optimizer): ...@@ -242,7 +262,12 @@ class OSS(Optimizer):
global_group[k] = v global_group[k] = v
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """ """Restore the global parameter groups as well as the shard.
Arguments:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`
"""
# Check whether we got a local or global dict # Check whether we got a local or global dict
if state_dict["local_state_dict"]: if state_dict["local_state_dict"]:
...@@ -256,6 +281,18 @@ class OSS(Optimizer): ...@@ -256,6 +281,18 @@ class OSS(Optimizer):
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups}) self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups})
def add_param_group(self, param_group: dict) -> None: def add_param_group(self, param_group: dict) -> None:
"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen layers can be made
trainable and added to the :class:`Optimizer` as training progresses.
Arguments:
param_group (dict): Specifies what Tensors should be optimized along with group
specific optimization options
.. warning: This handles updating the shards on all partitions, but needs to be called on all ranks.
"""
super().add_param_group(param_group) super().add_param_group(param_group)
if not self.in_super_constructor: if not self.in_super_constructor:
self._partition_parameters.clear() # Force a re-partitioning self._partition_parameters.clear() # Force a re-partitioning
...@@ -273,9 +310,7 @@ class OSS(Optimizer): ...@@ -273,9 +310,7 @@ class OSS(Optimizer):
local_group[k] = global_group[k] local_group[k] = global_group[k]
def _collect_sharded_states(self) -> List[Dict[str, Any]]: def _collect_sharded_states(self) -> List[Dict[str, Any]]:
""" """Collect all the state shards, in CPU memory."""
Collect all the state shards, in CPU memory.
"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device) empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
all_states: List[Dict[str, Any]] = [] all_states: List[Dict[str, Any]] = []
...@@ -304,9 +339,7 @@ class OSS(Optimizer): ...@@ -304,9 +339,7 @@ class OSS(Optimizer):
return all_states return all_states
def _broadcast_state_dict(self) -> None: def _broadcast_state_dict(self) -> None:
""" """Broadcast this rank's state shard, discard others"""
Broadcast this rank's state shard, discard others
"""
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device) empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
for rank in range(dist.get_world_size(group=self.group)): for rank in range(dist.get_world_size(group=self.group)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment