Unverified Commit b51bfec3 authored by wenjunyang's avatar wenjunyang Committed by GitHub
Browse files

[chatgpt] change critic input as state (#3042)



* fix Critic

* fix Critic

* fix Critic

* fix neglect of attention mask

* fix neglect of attention mask

* fix neglect of attention mask

* add return

---------
Co-authored-by: default avataryangwenjun <yangwenjun@soyoung.com>
Co-authored-by: default avataryangwjd <yangwjd@chanjet.com>
parent 2ef855c7
...@@ -36,12 +36,15 @@ class Critic(LoRAModule): ...@@ -36,12 +36,15 @@ class Critic(LoRAModule):
outputs = self.model(sequences, attention_mask=attention_mask) outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state'] last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states).squeeze(-1)[:, :-1] values = self.value_head(last_hidden_states).squeeze(-1)
if action_mask is not None: if action_mask is not None:
num_actions = action_mask.size(1) num_actions = action_mask.size(1)
values = values[:, -num_actions:] prompt_mask = attention_mask[:, :-num_actions]
value = masked_mean(values, action_mask, dim=1) values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value return value
values = values[:, :-1]
value = values.mean(dim=1).squeeze(1) value = values.mean(dim=1).squeeze(1)
return value return value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment