Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
02ae80bf
Unverified
Commit
02ae80bf
authored
Mar 10, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Mar 10, 2023
Browse files
[chatgpt]add flag of action mask in critic(#3086)
parent
95a36eae
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
21 additions
and
14 deletions
+21
-14
applications/ChatGPT/chatgpt/models/base/actor.py
applications/ChatGPT/chatgpt/models/base/actor.py
+1
-1
applications/ChatGPT/chatgpt/models/base/critic.py
applications/ChatGPT/chatgpt/models/base/critic.py
+11
-7
applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py
applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py
+3
-2
applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py
applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py
+3
-2
applications/ChatGPT/chatgpt/models/opt/opt_critic.py
applications/ChatGPT/chatgpt/models/opt/opt_critic.py
+3
-2
No files found.
applications/ChatGPT/chatgpt/models/base/actor.py
View file @
02ae80bf
...
...
@@ -37,7 +37,7 @@ class Actor(LoRAModule):
if
pad_token_id
is
not
None
:
attention_mask
=
sequences
.
not_equal
(
pad_token_id
).
to
(
dtype
=
torch
.
long
,
device
=
sequences
.
device
)
if
not
return_action_mask
:
return
sequences
,
attention_mask
return
sequences
,
attention_mask
,
None
input_len
=
input_ids
.
size
(
1
)
eos_token_id
=
kwargs
.
get
(
'eos_token_id'
,
None
)
if
eos_token_id
is
None
:
...
...
applications/ChatGPT/chatgpt/models/base/critic.py
View file @
02ae80bf
...
...
@@ -18,15 +18,19 @@ class Critic(LoRAModule):
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
value_head
:
nn
.
Module
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
def
__init__
(
self
,
model
:
nn
.
Module
,
value_head
:
nn
.
Module
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
,
use_action_mask
:
bool
=
False
,
)
->
None
:
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
self
.
model
=
model
self
.
value_head
=
value_head
self
.
use_action_mask
=
use_action_mask
self
.
convert_to_lora
()
def
forward
(
self
,
...
...
@@ -38,7 +42,7 @@ class Critic(LoRAModule):
values
=
self
.
value_head
(
last_hidden_states
).
squeeze
(
-
1
)
if
action_mask
is
not
None
:
if
action_mask
is
not
None
and
self
.
use_action_mask
:
num_actions
=
action_mask
.
size
(
1
)
prompt_mask
=
attention_mask
[:,
:
-
num_actions
]
values
=
values
[:,
:
-
num_actions
]
...
...
@@ -46,5 +50,5 @@ class Critic(LoRAModule):
return
value
values
=
values
[:,
:
-
1
]
value
=
values
.
mean
(
dim
=
1
)
.
squeeze
(
1
)
value
=
values
.
mean
(
dim
=
1
)
return
value
applications/ChatGPT/chatgpt/models/bloom/bloom_critic.py
View file @
02ae80bf
...
...
@@ -24,7 +24,8 @@ class BLOOMCritic(Critic):
config
:
Optional
[
BloomConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
lora_train_bias
:
str
=
'none'
,
**
kwargs
)
->
None
:
if
pretrained
is
not
None
:
model
=
BloomModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
...
...
@@ -34,4 +35,4 @@ class BLOOMCritic(Critic):
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
,
**
kwargs
)
applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py
View file @
02ae80bf
...
...
@@ -20,7 +20,8 @@ class GPTCritic(Critic):
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
checkpoint
:
bool
=
False
)
->
None
:
checkpoint
:
bool
=
False
,
**
kwargs
)
->
None
:
if
pretrained
is
not
None
:
model
=
GPT2Model
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
...
...
@@ -30,4 +31,4 @@ class GPTCritic(Critic):
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
super
().
__init__
(
model
,
value_head
)
super
().
__init__
(
model
,
value_head
,
**
kwargs
)
applications/ChatGPT/chatgpt/models/opt/opt_critic.py
View file @
02ae80bf
...
...
@@ -24,7 +24,8 @@ class OPTCritic(Critic):
config
:
Optional
[
OPTConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
lora_train_bias
:
str
=
'none'
,
**
kargs
)
->
None
:
if
pretrained
is
not
None
:
model
=
OPTModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
...
...
@@ -34,4 +35,4 @@ class OPTCritic(Critic):
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment