Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1ce997da
Commit
1ce997da
authored
Jul 18, 2023
by
Xu Kai
Committed by
binmakeswell
Jul 26, 2023
Browse files
[NFC] polish applications/Chat/examples/train_reward_model.py code style (#4271)
parent
a50d39a1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
6 deletions
+2
-6
applications/Chat/examples/train_reward_model.py
applications/Chat/examples/train_reward_model.py
+2
-6
No files found.
applications/Chat/examples/train_reward_model.py
View file @
1ce997da
...
@@ -150,9 +150,7 @@ def train(args):
...
@@ -150,9 +150,7 @@ def train(args):
pin_memory
=
True
)
pin_memory
=
True
)
lr_scheduler
=
CosineAnnealingLR
(
optim
,
train_dataloader
.
__len__
()
//
100
)
lr_scheduler
=
CosineAnnealingLR
(
optim
,
train_dataloader
.
__len__
()
//
100
)
strategy_dict
=
strategy
.
prepare
(
strategy_dict
=
strategy
.
prepare
(
dict
(
model
=
model
,
optimizer
=
optim
,
lr_scheduler
=
lr_scheduler
))
dict
(
model
=
model
,
optimizer
=
optim
,
lr_scheduler
=
lr_scheduler
)
)
model
=
strategy_dict
[
'model'
]
model
=
strategy_dict
[
'model'
]
optim
=
strategy_dict
[
'optimizer'
]
optim
=
strategy_dict
[
'optimizer'
]
lr_scheduler
=
strategy_dict
[
'lr_scheduler'
]
lr_scheduler
=
strategy_dict
[
'lr_scheduler'
]
...
@@ -163,9 +161,7 @@ def train(args):
...
@@ -163,9 +161,7 @@ def train(args):
loss_fn
=
loss_fn
,
loss_fn
=
loss_fn
,
max_epochs
=
args
.
max_epochs
)
max_epochs
=
args
.
max_epochs
)
trainer
.
fit
(
train_dataloader
=
train_dataloader
,
trainer
.
fit
(
train_dataloader
=
train_dataloader
,
valid_dataloader
=
valid_dataloader
,
eval_dataloader
=
eval_dataloader
)
valid_dataloader
=
valid_dataloader
,
eval_dataloader
=
eval_dataloader
)
# save model checkpoint after fitting on only rank0
# save model checkpoint after fitting on only rank0
strategy
.
save_model
(
model
,
args
.
save_path
,
only_rank0
=
True
)
strategy
.
save_model
(
model
,
args
.
save_path
,
only_rank0
=
True
)
# save optimizer checkpoint on all ranks
# save optimizer checkpoint on all ranks
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment