Unverified Commit f5db6ce7 authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

Fix code format for Accelerate doc (#15335)

* 🖍 fix code syntax to external libraries and replace image

* 🔄revert code formatting, replace image with code block

* 🖍 apply feedback
parent 0b072304
......@@ -22,7 +22,7 @@ Get started by installing 🤗 Accelerate:
pip install accelerate
```
Then import and create an [`Accelerator`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator) object. [`Accelerator`] will automatically detect your type of distributed setup and initialize all the necessary components for training. You don't need to explicitly place your model on a device.
Then import and create an [`Accelerator`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator) object. `Accelerator` will automatically detect your type of distributed setup and initialize all the necessary components for training. You don't need to explicitly place your model on a device.
```py
>>> from accelerate import Accelerator
......@@ -32,7 +32,7 @@ Then import and create an [`Accelerator`](https://huggingface.co/docs/accelerate
## Prepare to accelerate
The next step is to pass all the relevant training objects to [`prepare`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator.prepare). This includes your training and evaluation DataLoaders, a model and an optimizer:
The next step is to pass all the relevant training objects to the [`prepare`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator.prepare) method. This includes your training and evaluation DataLoaders, a model and an optimizer:
```py
>>> train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(
......@@ -42,7 +42,7 @@ The next step is to pass all the relevant training objects to [`prepare`](https:
## Backward
The last addition is to replace the typical `loss.backward()` in your training loop with 🤗 Accelerate's [`backward`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator.backward):
The last addition is to replace the typical `loss.backward()` in your training loop with 🤗 Accelerate's [`backward`](https://huggingface.co/docs/accelerate/accelerator.html#accelerate.Accelerator.backward) method:
```py
>>> for epoch in range(num_epochs):
......@@ -57,9 +57,49 @@ The last addition is to replace the typical `loss.backward()` in your training l
... progress_bar.update(1)
```
As you can see in the following image, you only need to add four additional lines of code to your training loop to enable distributed training!
As you can see in the following code, you only need to add four additional lines of code to your training loop to enable distributed training!
![accelerate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/accelerate.png)
```diff
+ from accelerate import Accelerator
from transformers import AdamW, AutoModelForSequenceClassification, get_scheduler
+ accelerator = Accelerator()
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
optimizer = AdamW(model.parameters(), lr=3e-5)
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
- model.to(device)
+ train_dataloader, eval_dataloader, model, optimizer = accelerator.prepare(
+ train_dataloader, eval_dataloader, model, optimizer
+ )
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training_steps
)
progress_bar = tqdm(range(num_training_steps))
model.train()
for epoch in range(num_epochs):
for batch in train_dataloader:
- batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
- loss.backward()
+ accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
```
## Train
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment