Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0672b5af
Unverified
Commit
0672b5af
authored
Mar 13, 2023
by
BlueRum
Committed by
GitHub
Mar 13, 2023
Browse files
[chatgpt] fix lora support for gpt (#3113)
* fix gpt-actor * fix gpt-critic * fix opt-critic
parent
0aa92c04
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
5 deletions
+12
-5
applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py
applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py
+6
-2
applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py
applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py
+5
-2
applications/ChatGPT/chatgpt/models/opt/opt_critic.py
applications/ChatGPT/chatgpt/models/opt/opt_critic.py
+1
-1
No files found.
applications/ChatGPT/chatgpt/models/gpt/gpt_actor.py
View file @
0672b5af
...
@@ -14,12 +14,16 @@ class GPTActor(Actor):
...
@@ -14,12 +14,16 @@ class GPTActor(Actor):
pretrained (str): Pretrained model name or path.
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LoRa layer.
lora_train_bias (str): Bias training strategy for the LoRa layer.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
checkpoint
:
bool
=
False
)
->
None
:
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
if
pretrained
is
not
None
:
model
=
GPT2LMHeadModel
.
from_pretrained
(
pretrained
)
model
=
GPT2LMHeadModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
elif
config
is
not
None
:
...
@@ -28,4 +32,4 @@ class GPTActor(Actor):
...
@@ -28,4 +32,4 @@ class GPTActor(Actor):
model
=
GPT2LMHeadModel
(
GPT2Config
())
model
=
GPT2LMHeadModel
(
GPT2Config
())
if
checkpoint
:
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
)
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/models/gpt/gpt_critic.py
View file @
0672b5af
...
@@ -15,13 +15,16 @@ class GPTCritic(Critic):
...
@@ -15,13 +15,16 @@ class GPTCritic(Critic):
pretrained (str): Pretrained model name or path.
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the LO-RA decomposition.
lora_train_bias (str): LoRA bias training mode.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
checkpoint
:
bool
=
False
,
checkpoint
:
bool
=
False
,
**
kwargs
)
->
None
:
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
if
pretrained
is
not
None
:
model
=
GPT2Model
.
from_pretrained
(
pretrained
)
model
=
GPT2Model
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
elif
config
is
not
None
:
...
@@ -31,4 +34,4 @@ class GPTCritic(Critic):
...
@@ -31,4 +34,4 @@ class GPTCritic(Critic):
if
checkpoint
:
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
super
().
__init__
(
model
,
value_head
,
**
kwarg
s
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bia
s
)
applications/ChatGPT/chatgpt/models/opt/opt_critic.py
View file @
0672b5af
...
@@ -34,5 +34,5 @@ class OPTCritic(Critic):
...
@@ -34,5 +34,5 @@ class OPTCritic(Critic):
model
=
OPTModel
(
OPTConfig
())
model
=
OPTModel
(
OPTConfig
())
if
checkpoint
:
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
value_head
=
nn
.
Linear
(
model
.
config
.
word_embed_proj_dim
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
,
**
kwargs
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment