Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
798cb729
Commit
798cb729
authored
Jul 18, 2023
by
shenggan
Committed by
binmakeswell
Jul 26, 2023
Browse files
[NFC] polish applications/Chat/coati/trainer/base.py code style (#4260)
parent
b2debdc0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
31 deletions
+22
-31
applications/Chat/coati/trainer/base.py
applications/Chat/coati/trainer/base.py
+22
-31
No files found.
applications/Chat/coati/trainer/base.py
View file @
798cb729
...
...
@@ -25,12 +25,13 @@ class SLTrainer(ABC):
optim (Optimizer): the optimizer to use for training
"""
def
__init__
(
self
,
strategy
:
Strategy
,
max_epochs
:
int
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
)
->
None
:
def
__init__
(
self
,
strategy
:
Strategy
,
max_epochs
:
int
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
)
->
None
:
super
().
__init__
()
self
.
strategy
=
strategy
self
.
max_epochs
=
max_epochs
...
...
@@ -50,10 +51,7 @@ class SLTrainer(ABC):
def
fit
(
self
,
*
args
,
**
kwargs
):
self
.
_before_fit
(
*
args
,
**
kwargs
)
for
epoch
in
tqdm
.
trange
(
self
.
max_epochs
,
desc
=
"Epochs"
,
disable
=
not
is_rank_0
()
or
self
.
no_epoch_bar
):
for
epoch
in
tqdm
.
trange
(
self
.
max_epochs
,
desc
=
"Epochs"
,
disable
=
not
is_rank_0
()
or
self
.
no_epoch_bar
):
self
.
_train
(
epoch
)
self
.
_eval
(
epoch
)
...
...
@@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC):
buffer
:
NaiveReplayBuffer
,
sample_buffer
:
bool
,
dataloader_pin_memory
:
bool
,
callbacks
:
List
[
Callback
]
=
[]
)
->
None
:
callbacks
:
List
[
Callback
]
=
[])
->
None
:
super
().
__init__
()
self
.
strategy
=
strategy
self
.
buffer
=
buffer
...
...
@@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC):
@
abstractmethod
def
_learn
(
self
,
update_step
:
int
):
"""
Implement this method to learn from experience, either
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise
NotImplementedError
()
...
...
@@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC):
self
.
_learn
(
update_step
)
self
.
_on_learn_epoch_end
(
update_step
)
def
fit
(
self
,
prompt_dataloader
:
DataLoader
,
pretrain_dataloader
:
DataLoader
,
num_episodes
:
int
,
num_collect_steps
:
int
,
num_update_steps
:
int
,
):
def
fit
(
self
,
prompt_dataloader
:
DataLoader
,
pretrain_dataloader
:
DataLoader
,
num_episodes
:
int
,
num_collect_steps
:
int
,
num_update_steps
:
int
,
):
"""
The main training loop of on-policy rl trainers.
...
...
@@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC):
self
.
pretrain_dataloader
=
CycledDataLoader
(
pretrain_dataloader
)
with
self
.
_fit_ctx
():
for
episode
in
tqdm
.
trange
(
num_episodes
,
desc
=
"Episodes"
,
disable
=
not
is_rank_0
()):
for
episode
in
tqdm
.
trange
(
num_episodes
,
desc
=
"Episodes"
,
disable
=
not
is_rank_0
()):
with
self
.
_episode_ctx
(
episode
):
for
collect_step
in
tqdm
.
trange
(
num_collect_steps
,
desc
=
"Collect steps"
,
disable
=
not
is_rank_0
()):
for
collect_step
in
tqdm
.
trange
(
num_collect_steps
,
desc
=
"Collect steps"
,
disable
=
not
is_rank_0
()):
self
.
_collect_phase
(
collect_step
)
if
not
self
.
sample_buffer
:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self
.
dataloader
=
self
.
strategy
.
setup_dataloader
(
self
.
buffer
,
self
.
dataloader_pin_memory
)
for
update_step
in
tqdm
.
trange
(
num_update_steps
,
desc
=
"Update steps"
,
disable
=
not
is_rank_0
()):
self
.
dataloader
=
self
.
strategy
.
setup_dataloader
(
self
.
buffer
,
self
.
dataloader_pin_memory
)
for
update_step
in
tqdm
.
trange
(
num_update_steps
,
desc
=
"Update steps"
,
disable
=
not
is_rank_0
()):
self
.
_update_phase
(
update_step
)
# NOTE: this is for on-policy algorithms
self
.
buffer
.
clear
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment