Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1e1b9d2f
Unverified
Commit
1e1b9d2f
authored
Mar 22, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Mar 22, 2023
Browse files
[chatgpt]support llama (#3070)
parent
e3ad88fb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
126 additions
and
0 deletions
+126
-0
applications/ChatGPT/chatgpt/models/llama/__init__.py
applications/ChatGPT/chatgpt/models/llama/__init__.py
+5
-0
applications/ChatGPT/chatgpt/models/llama/llama_actor.py
applications/ChatGPT/chatgpt/models/llama/llama_actor.py
+38
-0
applications/ChatGPT/chatgpt/models/llama/llama_critic.py
applications/ChatGPT/chatgpt/models/llama/llama_critic.py
+42
-0
applications/ChatGPT/chatgpt/models/llama/llama_rm.py
applications/ChatGPT/chatgpt/models/llama/llama_rm.py
+41
-0
No files found.
applications/ChatGPT/chatgpt/models/llama/__init__.py
0 → 100644
View file @
1e1b9d2f
from
.llama_actor
import
LlamaActor
from
.llama_critic
import
LlamaCritic
from
.llama_rm
import
LlamaRM
__all__
=
[
'LlamaActor'
,
'LlamaCritic'
,
'LlamaRM'
]
applications/ChatGPT/chatgpt/models/llama/llama_actor.py
0 → 100644
View file @
1e1b9d2f
from
typing
import
Optional
import
torch
from
transformers
import
AutoModelForCausalLM
,
LlamaConfig
,
LlamaForCausalLM
from
..base
import
Actor
class
LlamaActor
(
Actor
):
"""
Llama Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaForCausalLM
(
config
)
else
:
model
=
LlamaForCausalLM
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/models/llama/llama_critic.py
0 → 100644
View file @
1e1b9d2f
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
AutoModelForCausalLM
,
LlamaConfig
,
LlamaForCausalLM
from
..base
import
Critic
class
LlamaCritic
(
Critic
):
"""
Llama Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
,
**
kwargs
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaForCausalLM
(
config
)
else
:
model
=
LlamaForCausalLM
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
,
**
kwargs
)
applications/ChatGPT/chatgpt/models/llama/llama_rm.py
0 → 100644
View file @
1e1b9d2f
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
from
..base
import
RewardModel
class
LlamaRM
(
RewardModel
):
"""
Llama Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaForCausalLM
(
config
)
else
:
model
=
LlamaForCausalLM
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
value_head
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
1
/
(
model
.
config
.
hidden_size
+
1
))
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment