Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e6a132a4
"tests/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "307894f74dd63d71f4b95272fe149ca607e2aafa"
Unverified
Commit
e6a132a4
authored
Apr 11, 2023
by
zhang-yi-chi
Committed by
GitHub
Apr 11, 2023
Browse files
[chat]: add vf_coef argument for PPOTrainer (#3318)
parent
89fd10a1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
1 deletion
+5
-1
applications/Chat/coati/models/loss.py
applications/Chat/coati/models/loss.py
+1
-1
applications/Chat/coati/trainer/ppo.py
applications/Chat/coati/trainer/ppo.py
+4
-0
No files found.
applications/Chat/coati/models/loss.py
View file @
e6a132a4
...
@@ -65,7 +65,7 @@ class ValueLoss(nn.Module):
...
@@ -65,7 +65,7 @@ class ValueLoss(nn.Module):
surr2
=
(
values
-
reward
)
**
2
surr2
=
(
values
-
reward
)
**
2
loss
=
torch
.
max
(
surr1
,
surr2
)
loss
=
torch
.
max
(
surr1
,
surr2
)
loss
=
loss
.
mean
()
loss
=
loss
.
mean
()
return
loss
return
0.5
*
loss
class
PPOPtxActorLoss
(
nn
.
Module
):
class
PPOPtxActorLoss
(
nn
.
Module
):
...
...
applications/Chat/coati/trainer/ppo.py
View file @
e6a132a4
...
@@ -32,6 +32,7 @@ class PPOTrainer(Trainer):
...
@@ -32,6 +32,7 @@ class PPOTrainer(Trainer):
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
max_epochs (int, defaults to 1): the number of epochs of training process
...
@@ -56,6 +57,7 @@ class PPOTrainer(Trainer):
...
@@ -56,6 +57,7 @@ class PPOTrainer(Trainer):
buffer_limit
:
int
=
0
,
buffer_limit
:
int
=
0
,
buffer_cpu_offload
:
bool
=
True
,
buffer_cpu_offload
:
bool
=
True
,
eps_clip
:
float
=
0.2
,
eps_clip
:
float
=
0.2
,
vf_coef
:
float
=
1.0
,
value_clip
:
float
=
0.4
,
value_clip
:
float
=
0.4
,
experience_batch_size
:
int
=
8
,
experience_batch_size
:
int
=
8
,
max_epochs
:
int
=
1
,
max_epochs
:
int
=
1
,
...
@@ -74,6 +76,7 @@ class PPOTrainer(Trainer):
...
@@ -74,6 +76,7 @@ class PPOTrainer(Trainer):
self
.
actor_loss_fn
=
PolicyLoss
(
eps_clip
)
self
.
actor_loss_fn
=
PolicyLoss
(
eps_clip
)
self
.
critic_loss_fn
=
ValueLoss
(
value_clip
)
self
.
critic_loss_fn
=
ValueLoss
(
value_clip
)
self
.
vf_coef
=
vf_coef
self
.
ptx_loss_fn
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
100
)
self
.
ptx_loss_fn
=
nn
.
CrossEntropyLoss
(
ignore_index
=-
100
)
self
.
ptx_coef
=
ptx_coef
self
.
ptx_coef
=
ptx_coef
self
.
actor_optim
=
actor_optim
self
.
actor_optim
=
actor_optim
...
@@ -112,6 +115,7 @@ class PPOTrainer(Trainer):
...
@@ -112,6 +115,7 @@ class PPOTrainer(Trainer):
experience
.
values
,
experience
.
values
,
experience
.
reward
,
experience
.
reward
,
action_mask
=
experience
.
action_mask
)
action_mask
=
experience
.
action_mask
)
critic_loss
=
critic_loss
*
self
.
vf_coef
self
.
strategy
.
backward
(
critic_loss
,
self
.
critic
,
self
.
critic_optim
)
self
.
strategy
.
backward
(
critic_loss
,
self
.
critic
,
self
.
critic_optim
)
self
.
strategy
.
optimizer_step
(
self
.
critic_optim
)
self
.
strategy
.
optimizer_step
(
self
.
critic_optim
)
self
.
critic_optim
.
zero_grad
()
self
.
critic_optim
.
zero_grad
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment