Unverified Commit e6a132a4 authored by zhang-yi-chi's avatar zhang-yi-chi Committed by GitHub
Browse files

[chat]: add vf_coef argument for PPOTrainer (#3318)

parent 89fd10a1
...@@ -65,7 +65,7 @@ class ValueLoss(nn.Module): ...@@ -65,7 +65,7 @@ class ValueLoss(nn.Module):
surr2 = (values - reward)**2 surr2 = (values - reward)**2
loss = torch.max(surr1, surr2) loss = torch.max(surr1, surr2)
loss = loss.mean() loss = loss.mean()
return loss return 0.5 * loss
class PPOPtxActorLoss(nn.Module): class PPOPtxActorLoss(nn.Module):
......
...@@ -32,6 +32,7 @@ class PPOTrainer(Trainer): ...@@ -32,6 +32,7 @@ class PPOTrainer(Trainer):
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss value_clip (float, defaults to 0.4): the clip coefficient of value loss
experience_batch_size (int, defaults to 8): the batch size to use for experience generation experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process max_epochs (int, defaults to 1): the number of epochs of training process
...@@ -56,6 +57,7 @@ class PPOTrainer(Trainer): ...@@ -56,6 +57,7 @@ class PPOTrainer(Trainer):
buffer_limit: int = 0, buffer_limit: int = 0,
buffer_cpu_offload: bool = True, buffer_cpu_offload: bool = True,
eps_clip: float = 0.2, eps_clip: float = 0.2,
vf_coef: float = 1.0,
value_clip: float = 0.4, value_clip: float = 0.4,
experience_batch_size: int = 8, experience_batch_size: int = 8,
max_epochs: int = 1, max_epochs: int = 1,
...@@ -74,6 +76,7 @@ class PPOTrainer(Trainer): ...@@ -74,6 +76,7 @@ class PPOTrainer(Trainer):
self.actor_loss_fn = PolicyLoss(eps_clip) self.actor_loss_fn = PolicyLoss(eps_clip)
self.critic_loss_fn = ValueLoss(value_clip) self.critic_loss_fn = ValueLoss(value_clip)
self.vf_coef = vf_coef
self.ptx_loss_fn = nn.CrossEntropyLoss(ignore_index=-100) self.ptx_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
self.ptx_coef = ptx_coef self.ptx_coef = ptx_coef
self.actor_optim = actor_optim self.actor_optim = actor_optim
...@@ -112,6 +115,7 @@ class PPOTrainer(Trainer): ...@@ -112,6 +115,7 @@ class PPOTrainer(Trainer):
experience.values, experience.values,
experience.reward, experience.reward,
action_mask=experience.action_mask) action_mask=experience.action_mask)
critic_loss = critic_loss * self.vf_coef
self.strategy.backward(critic_loss, self.critic, self.critic_optim) self.strategy.backward(critic_loss, self.critic, self.critic_optim)
self.strategy.optimizer_step(self.critic_optim) self.strategy.optimizer_step(self.critic_optim)
self.critic_optim.zero_grad() self.critic_optim.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment