Unverified Commit 274478d0 authored by Vittorio Caggiano's avatar Vittorio Caggiano Committed by GitHub
Browse files

Update oss.rst (#107)

parent 53553474
...@@ -18,11 +18,12 @@ Let's suppose that your trainer looks like ...@@ -18,11 +18,12 @@ Let's suppose that your trainer looks like
dist_init(rank, world_size) dist_init(rank, world_size)
# Problem statement # Problem statement
model = myAwesomeModel() model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader() dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss() loss_ln = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc... # optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
params=model.parameters(), params=model.parameters(),
**base_optimizer_arguments) **base_optimizer_arguments)
...@@ -30,11 +31,12 @@ Let's suppose that your trainer looks like ...@@ -30,11 +31,12 @@ Let's suppose that your trainer looks like
# Any relevant training loop, nothing specific to OSS. For example: # Any relevant training loop, nothing specific to OSS. For example:
model.train() model.train()
for e in range(epochs): for e in range(epochs):
for batch in dataloader: for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train # Train
model.zero_grad() model.zero_grad()
outputs = model(batch["inputs"]) outputs = model(data)
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, target)
loss /= world_size loss /= world_size
loss.backward() loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
...@@ -58,11 +60,12 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer ...@@ -58,11 +60,12 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
dist_init(rank, world_size) dist_init(rank, world_size)
# Problem statement # Problem statement
model = myAwesomeModel() model = myAwesomeModel().to(rank)
dataloader = mySuperFastDataloader() dataloader = mySuperFastDataloader()
loss = myVeryRelevantLoss() loss_ln = myVeryRelevantLoss()
base_optimizer_arguments = {} # any optimizer specific arguments, LR, momentum, etc... # optimizer specific arguments e.g. LR, momentum, etc...
base_optimizer_arguments = { "lr": 1e-4}
# ** NEW ** Wrap a base optimizer into OSS # ** NEW ** Wrap a base optimizer into OSS
base_optimizer = torch.optim.SGD # any pytorch compliant optimizer base_optimizer = torch.optim.SGD # any pytorch compliant optimizer
...@@ -74,12 +77,28 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer ...@@ -74,12 +77,28 @@ Then sharding the optimizer state is merely a matter of wrapping your optimizer
# Any relevant training loop, nothing specific to OSS. For example: # Any relevant training loop, nothing specific to OSS. For example:
model.train() model.train()
for e in range(epochs): for e in range(epochs):
for batch in dataloader: for (data, target) in dataloader:
data, target = data.to(rank), target.to(rank)
# Train # Train
model.zero_grad() model.zero_grad()
outputs = model(batch["inputs"]) outputs = model(data)
loss = loss_fn(outputs, batch["label"]) loss = loss_fn(outputs, target)
loss /= world_size loss /= world_size
loss.backward() loss.backward()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
optimizer.step() optimizer.step()
The above `train` function will then need to be run via a `multiprocessing.spawn` function.
.. code-block:: python
mp.spawn(
train,
args=(WORLD_SIZE, EPOCHS),
nprocs=WORLD_SIZE,
join=True
)
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment