Unverified Commit b202804a authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[doc] Minor additions to ShardedDDP docs (#299)

parent 11beea69
...@@ -23,8 +23,7 @@ from fairscale.optim.utils import Workhandle ...@@ -23,8 +23,7 @@ from fairscale.optim.utils import Workhandle
class ShardedDataParallel(nn.Module): class ShardedDataParallel(nn.Module):
""" """ Wrap the model, and reduce the gradients to the right rank during the backward pass.
Wrap the model, and reduce the gradients to the right rank during the backward pass.
- the partition is given by the sharded optimizer - the partition is given by the sharded optimizer
- wrap the base model with a model which knows where to reduce each gradient - wrap the base model with a model which knows where to reduce each gradient
...@@ -46,6 +45,21 @@ class ShardedDataParallel(nn.Module): ...@@ -46,6 +45,21 @@ class ShardedDataParallel(nn.Module):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed, Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state or the training restarts from a saved state
.. warning:
ShardedDDP implements gradient sharding, meaning that each rank only owns a unique shard of the model gradients
after the backward pass, in order to save memory and some communication bandwidth.
.. warning:
As a consequence of sharding, in case of gradient clipping, one has to use the `clip_grad_norm` exposed by
the `optimizer state sharding wrapper <fairscale.optim.OSS>`
.. warning:
As a consequence of sharding, after loss.backward() (or equivalent) each rank will have `None` in place of some param.grad
.. warning:
As a consequence of sharding, Pytorch and Apex AMP implementations will hang when used in conjunction with `ShardedDDP`.
One needs a `shard-aware grad scaler<ShardedGradScaler>`, which is proposed in `fairscale.optim.grad_scaler`, compatible with PytorchAMP.
""" """
def __init__( def __init__(
......
...@@ -554,10 +554,11 @@ class AdaScale(Optimizer): ...@@ -554,10 +554,11 @@ class AdaScale(Optimizer):
`set_scale` needs to be called to update the scale as well. `set_scale` needs to be called to update the scale as well.
TODO (min): need a way of determine how much to increase the step size? TODO (min): need a way of determine how much to increase the step size?
TODO (min): have both `set_scale` and `set_num_gradients_to_accumulate` TODO (min): have both `set_scale` and `set_num_gradients_to_accumulate`
is hard to use and easy to make mistake. I think it is better is hard to use and easy to make mistake. I think it is better
to specific a specify a `base_scale`. But more discussion is to specific a specify a `base_scale`. But more discussion is
needed here. needed here.
Args: Args:
num_gradients_to_accumulate (int): num_gradients_to_accumulate (int):
......
...@@ -239,7 +239,6 @@ class OSS(Optimizer): ...@@ -239,7 +239,6 @@ class OSS(Optimizer):
.. warning: This needs to be called on all ranks, since synchronization primitives will be used .. warning: This needs to be called on all ranks, since synchronization primitives will be used
.. warning: Model paralelism -groups other than world- are not yet supported
""" """
# Compute the max norm for this shards's worth of gradients # Compute the max norm for this shards's worth of gradients
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment