Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
798cb729
Commit
798cb729
authored
Jul 18, 2023
by
shenggan
Committed by
binmakeswell
Jul 26, 2023
Browse files
[NFC] polish applications/Chat/coati/trainer/base.py code style (#4260)
parent
b2debdc0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
31 deletions
+22
-31
applications/Chat/coati/trainer/base.py
applications/Chat/coati/trainer/base.py
+22
-31
No files found.
applications/Chat/coati/trainer/base.py
View file @
798cb729
...
@@ -25,12 +25,13 @@ class SLTrainer(ABC):
...
@@ -25,12 +25,13 @@ class SLTrainer(ABC):
optim (Optimizer): the optimizer to use for training
optim (Optimizer): the optimizer to use for training
"""
"""
def
__init__
(
self
,
def
__init__
(
strategy
:
Strategy
,
self
,
max_epochs
:
int
,
strategy
:
Strategy
,
model
:
nn
.
Module
,
max_epochs
:
int
,
optimizer
:
Optimizer
,
model
:
nn
.
Module
,
)
->
None
:
optimizer
:
Optimizer
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
max_epochs
=
max_epochs
self
.
max_epochs
=
max_epochs
...
@@ -50,10 +51,7 @@ class SLTrainer(ABC):
...
@@ -50,10 +51,7 @@ class SLTrainer(ABC):
def
fit
(
self
,
*
args
,
**
kwargs
):
def
fit
(
self
,
*
args
,
**
kwargs
):
self
.
_before_fit
(
*
args
,
**
kwargs
)
self
.
_before_fit
(
*
args
,
**
kwargs
)
for
epoch
in
tqdm
.
trange
(
self
.
max_epochs
,
for
epoch
in
tqdm
.
trange
(
self
.
max_epochs
,
desc
=
"Epochs"
,
disable
=
not
is_rank_0
()
or
self
.
no_epoch_bar
):
desc
=
"Epochs"
,
disable
=
not
is_rank_0
()
or
self
.
no_epoch_bar
):
self
.
_train
(
epoch
)
self
.
_train
(
epoch
)
self
.
_eval
(
epoch
)
self
.
_eval
(
epoch
)
...
@@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC):
...
@@ -75,8 +73,7 @@ class OnPolicyTrainer(ABC):
buffer
:
NaiveReplayBuffer
,
buffer
:
NaiveReplayBuffer
,
sample_buffer
:
bool
,
sample_buffer
:
bool
,
dataloader_pin_memory
:
bool
,
dataloader_pin_memory
:
bool
,
callbacks
:
List
[
Callback
]
=
[]
callbacks
:
List
[
Callback
]
=
[])
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
buffer
=
buffer
self
.
buffer
=
buffer
...
@@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC):
...
@@ -138,7 +135,7 @@ class OnPolicyTrainer(ABC):
@
abstractmethod
@
abstractmethod
def
_learn
(
self
,
update_step
:
int
):
def
_learn
(
self
,
update_step
:
int
):
"""
"""
Implement this method to learn from experience, either
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
sample from buffer or transform buffer into dataloader.
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC):
...
@@ -154,13 +151,14 @@ class OnPolicyTrainer(ABC):
self
.
_learn
(
update_step
)
self
.
_learn
(
update_step
)
self
.
_on_learn_epoch_end
(
update_step
)
self
.
_on_learn_epoch_end
(
update_step
)
def
fit
(
self
,
def
fit
(
prompt_dataloader
:
DataLoader
,
self
,
pretrain_dataloader
:
DataLoader
,
prompt_dataloader
:
DataLoader
,
num_episodes
:
int
,
pretrain_dataloader
:
DataLoader
,
num_collect_steps
:
int
,
num_episodes
:
int
,
num_update_steps
:
int
,
num_collect_steps
:
int
,
):
num_update_steps
:
int
,
):
"""
"""
The main training loop of on-policy rl trainers.
The main training loop of on-policy rl trainers.
...
@@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC):
...
@@ -175,23 +173,16 @@ class OnPolicyTrainer(ABC):
self
.
pretrain_dataloader
=
CycledDataLoader
(
pretrain_dataloader
)
self
.
pretrain_dataloader
=
CycledDataLoader
(
pretrain_dataloader
)
with
self
.
_fit_ctx
():
with
self
.
_fit_ctx
():
for
episode
in
tqdm
.
trange
(
num_episodes
,
for
episode
in
tqdm
.
trange
(
num_episodes
,
desc
=
"Episodes"
,
disable
=
not
is_rank_0
()):
desc
=
"Episodes"
,
disable
=
not
is_rank_0
()):
with
self
.
_episode_ctx
(
episode
):
with
self
.
_episode_ctx
(
episode
):
for
collect_step
in
tqdm
.
trange
(
num_collect_steps
,
for
collect_step
in
tqdm
.
trange
(
num_collect_steps
,
desc
=
"Collect steps"
,
disable
=
not
is_rank_0
()):
desc
=
"Collect steps"
,
disable
=
not
is_rank_0
()):
self
.
_collect_phase
(
collect_step
)
self
.
_collect_phase
(
collect_step
)
if
not
self
.
sample_buffer
:
if
not
self
.
sample_buffer
:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
# I only call strategy.setup_dataloader() to setup dataloader.
self
.
dataloader
=
self
.
strategy
.
setup_dataloader
(
self
.
buffer
,
self
.
dataloader
=
self
.
strategy
.
setup_dataloader
(
self
.
buffer
,
self
.
dataloader_pin_memory
)
self
.
dataloader_pin_memory
)
for
update_step
in
tqdm
.
trange
(
num_update_steps
,
desc
=
"Update steps"
,
disable
=
not
is_rank_0
()):
for
update_step
in
tqdm
.
trange
(
num_update_steps
,
desc
=
"Update steps"
,
disable
=
not
is_rank_0
()):
self
.
_update_phase
(
update_step
)
self
.
_update_phase
(
update_step
)
# NOTE: this is for on-policy algorithms
# NOTE: this is for on-policy algorithms
self
.
buffer
.
clear
()
self
.
buffer
.
clear
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment