Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f5c425c8
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "aae496631c4acac03d8707216958184290a018fd"
Unverified
Commit
f5c425c8
authored
May 22, 2023
by
Frank Lee
Committed by
GitHub
May 22, 2023
Browse files
fixed the example docstring for booster (#3795)
parent
788e07db
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
21 deletions
+22
-21
colossalai/booster/booster.py
colossalai/booster/booster.py
+22
-21
No files found.
colossalai/booster/booster.py
View file @
f5c425c8
...
...
@@ -23,27 +23,28 @@ class Booster:
training with different precision, accelerator, and plugin.
Examples:
>>> colossalai.launch(...)
>>> plugin = GeminiPlugin(stage=3, ...)
>>> booster = Booster(precision='fp16', plugin=plugin)
>>>
>>> model = GPT2()
>>> optimizer = Adam(model.parameters())
>>> dataloader = Dataloader(Dataset)
>>> lr_scheduler = LinearWarmupScheduler()
>>> criterion = GPTLMLoss()
>>>
>>> model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
>>>
>>> for epoch in range(max_epochs):
>>> for input_ids, attention_mask in dataloader:
>>> outputs = model(input_ids, attention_mask)
>>> loss = criterion(outputs.logits, input_ids)
>>> booster.backward(loss, optimizer)
>>> optimizer.step()
>>> lr_scheduler.step()
>>> optimizer.zero_grad()
```python
colossalai.launch(...)
plugin = GeminiPlugin(stage=3, ...)
booster = Booster(precision='fp16', plugin=plugin)
model = GPT2()
optimizer = Adam(model.parameters())
dataloader = Dataloader(Dataset)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()
model, optimizer, lr_scheduler, dataloader = booster.boost(model, optimizer, lr_scheduler, dataloader)
for epoch in range(max_epochs):
for input_ids, attention_mask in dataloader:
outputs = model(input_ids, attention_mask)
loss = criterion(outputs.logits, input_ids)
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
```
Args:
device (str or torch.device): The device to run the training. Default: 'cuda'.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment