Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
274478d0
Unverified
Commit
274478d0
authored
Sep 24, 2020
by
Vittorio Caggiano
Committed by
GitHub
Sep 24, 2020
Browse files
Update oss.rst (#107)
parent
53553474
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
31 additions
and
12 deletions
+31
-12
docs/source/tutorials/oss.rst
docs/source/tutorials/oss.rst
+31
-12
No files found.
docs/source/tutorials/oss.rst
View file @
274478d0
...
@@ -18,11 +18,12 @@ Let's suppose that your trainer looks like
...
@@ -18,11 +18,12 @@ Let's suppose that your trainer looks like
dist_init(rank, world_size)
dist_init(rank, world_size)
# Problem statement
# Problem statement
model = myAwesomeModel()
model = myAwesomeModel()
.to(rank)
dataloader = mySuperFastDataloader()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
loss
_ln
= myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
optimizer = torch.optim.SGD(
optimizer = torch.optim.SGD(
params=model.parameters(),
params=model.parameters(),
**base_optimizer_arguments)
**base_optimizer_arguments)
...
@@ -30,11 +31,12 @@ Let's suppose that your trainer looks like
...
@@ -30,11 +31,12 @@ Let's suppose that your trainer looks like
# Any relevant training loop, nothing specific to OSS. For example:
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
model.train()
for e in range(epochs):
for e in range(epochs):
for batch in dataloader:
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
# Train
model.zero_grad()
model.zero_grad()
outputs = model(
b
at
ch["inputs"]
)
outputs = model(
d
at
a
)
loss = loss_fn(outputs,
batch["label"]
)
loss = loss_fn(outputs,
target
)
loss /= world_size
loss /= world_size
loss.backward()
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
...
@@ -58,11 +60,12 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
...
@@ -58,11 +60,12 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
dist_init(rank, world_size)
dist_init(rank, world_size)
# Problem statement
# Problem statement
model = myAwesomeModel()
model = myAwesomeModel()
.to(rank)
dataloader = mySuperFastDataloader()
dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss()
loss
_ln
= myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc...
# optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
# ** NEW ** Wrap a base optimizer into OSS
# ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
...
@@ -74,12 +77,28 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
...
@@ -74,12 +77,28 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
# Any relevant training loop, nothing specific to OSS. For example:
# Any relevant training loop, nothing specific to OSS. For example:
model.train()
model.train()
for e in range(epochs):
for e in range(epochs):
for batch in dataloader:
for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train
# Train
model.zero_grad()
model.zero_grad()
outputs = model(
b
at
ch["inputs"]
)
outputs = model(
d
at
a
)
loss = loss_fn(outputs,
batch["label"]
)
loss = loss_fn(outputs,
target
)
loss /= world_size
loss /= world_size
loss.backward()
loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step()
optimizer.step()
The above `train` function will then need to be run via a `multiprocessing.spawn` function.
.. code-block:: python
mp.spawn(
train,
args=(WORLD_SIZE, EPOCHS),
nprocs=WORLD_SIZE,
join=True
)
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment