Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b0ce5a10
Unverified
Commit
b0ce5a10
authored
Mar 28, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Mar 28, 2023
Browse files
[Coati] first commit (#3283)
parent
fd6add57
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1012 additions
and
0 deletions
+1012
-0
applications/Chat/coati/models/llama/__init__.py
applications/Chat/coati/models/llama/__init__.py
+6
-0
applications/Chat/coati/models/llama/llama_actor.py
applications/Chat/coati/models/llama/llama_actor.py
+38
-0
applications/Chat/coati/models/llama/llama_critic.py
applications/Chat/coati/models/llama/llama_critic.py
+42
-0
applications/Chat/coati/models/llama/llama_lm.py
applications/Chat/coati/models/llama/llama_lm.py
+40
-0
applications/Chat/coati/models/llama/llama_rm.py
applications/Chat/coati/models/llama/llama_rm.py
+40
-0
applications/Chat/coati/models/lora.py
applications/Chat/coati/models/lora.py
+129
-0
applications/Chat/coati/models/loss.py
applications/Chat/coati/models/loss.py
+117
-0
applications/Chat/coati/models/opt/__init__.py
applications/Chat/coati/models/opt/__init__.py
+6
-0
applications/Chat/coati/models/opt/opt_actor.py
applications/Chat/coati/models/opt/opt_actor.py
+35
-0
applications/Chat/coati/models/opt/opt_critic.py
applications/Chat/coati/models/opt/opt_critic.py
+38
-0
applications/Chat/coati/models/opt/opt_lm.py
applications/Chat/coati/models/opt/opt_lm.py
+35
-0
applications/Chat/coati/models/opt/opt_rm.py
applications/Chat/coati/models/opt/opt_rm.py
+38
-0
applications/Chat/coati/models/utils.py
applications/Chat/coati/models/utils.py
+92
-0
applications/Chat/coati/replay_buffer/__init__.py
applications/Chat/coati/replay_buffer/__init__.py
+4
-0
applications/Chat/coati/replay_buffer/base.py
applications/Chat/coati/replay_buffer/base.py
+43
-0
applications/Chat/coati/replay_buffer/naive.py
applications/Chat/coati/replay_buffer/naive.py
+57
-0
applications/Chat/coati/replay_buffer/utils.py
applications/Chat/coati/replay_buffer/utils.py
+73
-0
applications/Chat/coati/trainer/__init__.py
applications/Chat/coati/trainer/__init__.py
+6
-0
applications/Chat/coati/trainer/base.py
applications/Chat/coati/trainer/base.py
+168
-0
applications/Chat/coati/trainer/callbacks/__init__.py
applications/Chat/coati/trainer/callbacks/__init__.py
+5
-0
No files found.
applications/Chat/coati/models/llama/__init__.py
0 → 100644
View file @
b0ce5a10
from
.llama_actor
import
LlamaActor
from
.llama_critic
import
LlamaCritic
from
.llama_lm
import
LlamaLM
from
.llama_rm
import
LlamaRM
__all__
=
[
'LlamaActor'
,
'LlamaCritic'
,
'LlamaRM'
,
'LlamaLM'
]
applications/Chat/coati/models/llama/llama_actor.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
import
torch
from
transformers
import
AutoModelForCausalLM
,
LlamaConfig
,
LlamaForCausalLM
from
..base
import
Actor
class
LlamaActor
(
Actor
):
"""
Llama Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaForCausalLM
(
config
)
else
:
model
=
LlamaForCausalLM
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/Chat/coati/models/llama/llama_critic.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
AutoModelForCausalLM
,
LlamaConfig
,
LlamaForCausalLM
from
..base
import
Critic
class
LlamaCritic
(
Critic
):
"""
Llama Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
,
**
kwargs
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaForCausalLM
(
config
)
else
:
model
=
LlamaForCausalLM
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
,
**
kwargs
)
applications/Chat/coati/models/llama/llama_lm.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
from
..base
import
LM
class
LlamaLM
(
LM
):
"""
Llama language model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaForCausalLM
(
config
)
else
:
model
=
LlamaForCausalLM
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
labels
=
None
,
**
kwargs
):
return
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
,
**
kwargs
)
applications/Chat/coati/models/llama/llama_rm.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
,
LlamaModel
from
..base
import
RewardModel
class
LlamaRM
(
RewardModel
):
"""
Llama Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaModel
(
config
)
else
:
model
=
LlamaModel
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
value_head
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
1
/
(
model
.
config
.
hidden_size
+
1
))
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
applications/Chat/coati/models/lora.py
0 → 100644
View file @
b0ce5a10
import
math
from
typing
import
Optional
import
loralib
as
lora
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
LoraLinear
(
lora
.
LoRALayer
,
nn
.
Module
):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
def
__init__
(
self
,
weight
:
nn
.
Parameter
,
bias
:
Optional
[
nn
.
Parameter
],
r
:
int
=
0
,
lora_alpha
:
int
=
1
,
lora_dropout
:
float
=
0.
,
fan_in_fan_out
:
bool
=
False
,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights
:
bool
=
True
,
):
nn
.
Module
.
__init__
(
self
)
lora
.
LoRALayer
.
__init__
(
self
,
r
=
r
,
lora_alpha
=
lora_alpha
,
lora_dropout
=
lora_dropout
,
merge_weights
=
merge_weights
)
self
.
weight
=
weight
self
.
bias
=
bias
out_features
,
in_features
=
weight
.
shape
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
fan_in_fan_out
=
fan_in_fan_out
# Actual trainable parameters
if
r
>
0
:
self
.
lora_A
=
nn
.
Parameter
(
self
.
weight
.
new_zeros
((
r
,
in_features
)))
self
.
lora_B
=
nn
.
Parameter
(
self
.
weight
.
new_zeros
((
out_features
,
r
)))
self
.
scaling
=
self
.
lora_alpha
/
self
.
r
# Freezing the pre-trained weight matrix
self
.
weight
.
requires_grad
=
False
self
.
reset_parameters
()
if
fan_in_fan_out
:
self
.
weight
.
data
=
self
.
weight
.
data
.
T
def
reset_parameters
(
self
):
if
hasattr
(
self
,
'lora_A'
):
# initialize A the same way as the default for nn.Linear and B to zero
nn
.
init
.
kaiming_uniform_
(
self
.
lora_A
,
a
=
math
.
sqrt
(
5
))
nn
.
init
.
zeros_
(
self
.
lora_B
)
def
train
(
self
,
mode
:
bool
=
True
):
def
T
(
w
):
return
w
.
T
if
self
.
fan_in_fan_out
else
w
nn
.
Module
.
train
(
self
,
mode
)
if
self
.
merge_weights
and
self
.
merged
:
# Make sure that the weights are not merged
if
self
.
r
>
0
:
self
.
weight
.
data
-=
T
(
self
.
lora_B
@
self
.
lora_A
)
*
self
.
scaling
self
.
merged
=
False
def
eval
(
self
):
def
T
(
w
):
return
w
.
T
if
self
.
fan_in_fan_out
else
w
nn
.
Module
.
eval
(
self
)
if
self
.
merge_weights
and
not
self
.
merged
:
# Merge the weights and mark it
if
self
.
r
>
0
:
self
.
weight
.
data
+=
T
(
self
.
lora_B
@
self
.
lora_A
)
*
self
.
scaling
delattr
(
self
,
'lora_A'
)
delattr
(
self
,
'lora_B'
)
self
.
merged
=
True
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
T
(
w
):
return
w
.
T
if
self
.
fan_in_fan_out
else
w
if
self
.
r
>
0
and
not
self
.
merged
:
result
=
F
.
linear
(
x
,
T
(
self
.
weight
),
bias
=
self
.
bias
)
if
self
.
r
>
0
:
result
=
result
+
(
self
.
lora_dropout
(
x
)
@
self
.
lora_A
.
t
()
@
self
.
lora_B
.
t
())
*
self
.
scaling
return
result
else
:
return
F
.
linear
(
x
,
T
(
self
.
weight
),
bias
=
self
.
bias
)
def
lora_linear_wrapper
(
linear
:
nn
.
Linear
,
lora_rank
:
int
)
->
LoraLinear
:
assert
lora_rank
<=
linear
.
in_features
,
f
'LoRA rank (
{
lora_rank
}
) must be less than or equal to in features (
{
linear
.
in_features
}
)'
lora_linear
=
LoraLinear
(
linear
.
weight
,
linear
.
bias
,
r
=
lora_rank
,
merge_weights
=
False
)
return
lora_linear
def
convert_to_lora_recursively
(
module
:
nn
.
Module
,
lora_rank
:
int
)
->
None
:
for
name
,
child
in
module
.
named_children
():
if
isinstance
(
child
,
nn
.
Linear
):
setattr
(
module
,
name
,
lora_linear_wrapper
(
child
,
lora_rank
))
else
:
convert_to_lora_recursively
(
child
,
lora_rank
)
class
LoRAModule
(
nn
.
Module
):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
Args:
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
lora_train_bias (str, optional): Whether LoRA train biases.
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
Defaults to 'none'.
"""
def
__init__
(
self
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
super
().
__init__
()
self
.
lora_rank
=
lora_rank
self
.
lora_train_bias
=
lora_train_bias
def
convert_to_lora
(
self
)
->
None
:
if
self
.
lora_rank
<=
0
:
return
convert_to_lora_recursively
(
self
,
self
.
lora_rank
)
lora
.
mark_only_lora_as_trainable
(
self
,
self
.
lora_train_bias
)
applications/Chat/coati/models/loss.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.utils
import
masked_mean
class
GPTLMLoss
(
nn
.
Module
):
"""
GPT Language Model Loss
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
loss
=
nn
.
CrossEntropyLoss
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shift_logits
=
logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
return
self
.
loss
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
class
PolicyLoss
(
nn
.
Module
):
"""
Policy Loss for PPO
"""
def
__init__
(
self
,
clip_eps
:
float
=
0.2
)
->
None
:
super
().
__init__
()
self
.
clip_eps
=
clip_eps
def
forward
(
self
,
log_probs
:
torch
.
Tensor
,
old_log_probs
:
torch
.
Tensor
,
advantages
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
ratio
=
(
log_probs
-
old_log_probs
).
exp
()
surr1
=
ratio
*
advantages
surr2
=
ratio
.
clamp
(
1
-
self
.
clip_eps
,
1
+
self
.
clip_eps
)
*
advantages
loss
=
-
torch
.
min
(
surr1
,
surr2
)
if
action_mask
is
not
None
:
loss
=
masked_mean
(
loss
,
action_mask
)
loss
=
loss
.
mean
()
return
loss
class
ValueLoss
(
nn
.
Module
):
"""
Value Loss for PPO
"""
def
__init__
(
self
,
clip_eps
:
float
=
0.4
)
->
None
:
super
().
__init__
()
self
.
clip_eps
=
clip_eps
def
forward
(
self
,
values
:
torch
.
Tensor
,
old_values
:
torch
.
Tensor
,
reward
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
values_clipped
=
old_values
+
(
values
-
old_values
).
clamp
(
-
self
.
clip_eps
,
self
.
clip_eps
)
surr1
=
(
values_clipped
-
reward
)
**
2
surr2
=
(
values
-
reward
)
**
2
loss
=
torch
.
max
(
surr1
,
surr2
)
loss
=
loss
.
mean
()
return
loss
class
PPOPtxActorLoss
(
nn
.
Module
):
"""
To Do:
PPO-ptx Actor Loss
"""
def
__init__
(
self
,
policy_clip_eps
:
float
=
0.2
,
pretrain_coef
:
float
=
0.0
,
pretrain_loss_fn
=
GPTLMLoss
())
->
None
:
super
().
__init__
()
self
.
pretrain_coef
=
pretrain_coef
self
.
policy_loss_fn
=
PolicyLoss
(
clip_eps
=
policy_clip_eps
)
self
.
pretrain_loss_fn
=
pretrain_loss_fn
def
forward
(
self
,
log_probs
:
torch
.
Tensor
,
old_log_probs
:
torch
.
Tensor
,
advantages
:
torch
.
Tensor
,
lm_logits
:
torch
.
Tensor
,
lm_input_ids
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
policy_loss
=
self
.
policy_loss_fn
(
log_probs
,
old_log_probs
,
advantages
,
action_mask
=
action_mask
)
lm_loss
=
self
.
pretrain_loss_fn
(
lm_logits
,
lm_input_ids
)
return
policy_loss
+
self
.
pretrain_coef
*
lm_loss
class
LogSigLoss
(
nn
.
Module
):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2203.02155
"""
def
forward
(
self
,
chosen_reward
:
torch
.
Tensor
,
reject_reward
:
torch
.
Tensor
)
->
torch
.
Tensor
:
probs
=
torch
.
sigmoid
(
chosen_reward
-
reject_reward
)
log_probs
=
torch
.
log
(
probs
)
loss
=
-
log_probs
.
mean
()
return
loss
class
LogExpLoss
(
nn
.
Module
):
"""
Pairwise Loss for Reward Model
Details: https://arxiv.org/abs/2204.05862
"""
def
forward
(
self
,
chosen_reward
:
torch
.
Tensor
,
reject_reward
:
torch
.
Tensor
)
->
torch
.
Tensor
:
loss
=
torch
.
log
(
1
+
torch
.
exp
(
reject_reward
-
chosen_reward
)).
mean
()
return
loss
applications/Chat/coati/models/opt/__init__.py
0 → 100644
View file @
b0ce5a10
from
.opt_actor
import
OPTActor
from
.opt_critic
import
OPTCritic
from
.opt_lm
import
OPTLM
from
.opt_rm
import
OPTRM
__all__
=
[
'OPTActor'
,
'OPTCritic'
,
'OPTRM'
,
'OPTLM'
]
applications/Chat/coati/models/opt/opt_actor.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
from
transformers.models.opt.configuration_opt
import
OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
..base
import
Actor
class
OPTActor
(
Actor
):
"""
OPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
OPTForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
OPTForCausalLM
(
config
)
else
:
model
=
OPTForCausalLM
(
OPTConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/Chat/coati/models/opt/opt_critic.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers.models.opt.configuration_opt
import
OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTModel
from
..base
import
Critic
class
OPTCritic
(
Critic
):
"""
OPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
,
**
kwargs
)
->
None
:
if
pretrained
is
not
None
:
model
=
OPTModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
OPTModel
(
config
)
else
:
model
=
OPTModel
(
OPTConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
word_embed_proj_dim
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
,
**
kwargs
)
applications/Chat/coati/models/opt/opt_lm.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
from
transformers.models.opt.configuration_opt
import
OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
..base
import
LM
class
OPTLM
(
LM
):
"""
OPT language model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
OPTForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
OPTForCausalLM
(
config
)
else
:
model
=
OPTForCausalLM
(
OPTConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/Chat/coati/models/opt/opt_rm.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers
import
OPTConfig
,
OPTModel
from
..base
import
RewardModel
class
OPTRM
(
RewardModel
):
"""
OPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
OPTModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
OPTModel
(
config
)
else
:
model
=
OPTModel
(
OPTConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
word_embed_proj_dim
,
1
)
value_head
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
1
/
(
model
.
config
.
word_embed_proj_dim
+
1
))
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
applications/Chat/coati/models/utils.py
0 → 100644
View file @
b0ce5a10
from
typing
import
Optional
,
Union
import
loralib
as
lora
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
def
compute_approx_kl
(
log_probs
:
torch
.
Tensor
,
log_probs_base
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
Args:
log_probs: Log probabilities of the new distribution.
log_probs_base: Log probabilities of the base distribution.
action_mask: Mask for actions.
"""
log_ratio
=
log_probs
-
log_probs_base
approx_kl
=
(
log_ratio
.
exp
()
-
1
)
-
log_ratio
if
action_mask
is
not
None
:
approx_kl
=
masked_mean
(
approx_kl
,
action_mask
,
dim
=
1
)
return
approx_kl
approx_kl
=
approx_kl
.
mean
(
dim
=
1
)
return
approx_kl
def
compute_reward
(
r
:
Union
[
torch
.
Tensor
,
float
],
kl_coef
:
float
,
log_probs
:
torch
.
Tensor
,
log_probs_base
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
kl_coef
<=
0.0
:
return
r
kl
=
compute_approx_kl
(
log_probs
,
log_probs_base
,
action_mask
=
action_mask
)
reward
=
r
-
kl_coef
*
kl
return
reward
def
log_probs_from_logits
(
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
)
->
torch
.
Tensor
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
log_probs_labels
=
log_probs
.
gather
(
dim
=-
1
,
index
=
labels
.
unsqueeze
(
-
1
))
return
log_probs_labels
.
squeeze
(
-
1
)
def
masked_mean
(
tensor
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
dim
:
int
=
1
)
->
torch
.
Tensor
:
tensor
=
tensor
*
mask
tensor
=
tensor
.
sum
(
dim
=
dim
)
mask_sum
=
mask
.
sum
(
dim
=
dim
)
mean
=
tensor
/
(
mask_sum
+
1e-8
)
return
mean
def
masked_normalize
(
tensor
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
dim
:
int
=
1
,
eps
:
float
=
1e-8
)
->
torch
.
Tensor
:
tensor
=
tensor
*
mask
mean
=
masked_mean
(
tensor
,
mask
,
dim
=
dim
)
mean_centered
=
tensor
-
mean
var
=
masked_mean
(
mean_centered
**
2
,
mask
,
dim
=
dim
)
return
mean_centered
*
var
.
clamp
(
min
=
eps
).
rsqrt
()
def
normalize
(
tensor
:
torch
.
Tensor
,
dim
:
int
=
0
,
eps
:
float
=
1e-8
)
->
torch
.
Tensor
:
mean
=
tensor
.
mean
(
dim
)
mean_centered
=
tensor
-
mean
var
=
(
mean_centered
**
2
).
mean
(
dim
)
norm
=
mean_centered
*
var
.
clamp
(
min
=
eps
).
rsqrt
()
return
norm
def
convert_to_lora
(
model
:
nn
.
Module
,
input_size
:
int
,
output_size
:
int
,
lora_rank
:
int
=
16
,
lora_alpha
:
int
=
1
,
lora_dropout
:
float
=
0.
,
fan_in_fan_out
:
bool
=
False
,
merge_weights
:
bool
=
True
):
if
lora_rank
>
min
(
input_size
,
output_size
):
raise
ValueError
(
f
"LoRA rank
{
lora_rank
}
must be less or equal than
{
min
(
input_size
,
output_size
)
}
"
)
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
nn
.
Linear
):
module
.
_modules
[
name
]
=
lora
.
Linear
(
input_size
,
output_size
,
r
=
lora_rank
,
lora_alpha
=
lora_alpha
,
lora_dropout
=
lora_dropout
,
fan_in_fan_out
=
fan_in_fan_out
,
merge_weights
=
merge_weights
)
applications/Chat/coati/replay_buffer/__init__.py
0 → 100644
View file @
b0ce5a10
from
.base
import
ReplayBuffer
from
.naive
import
NaiveReplayBuffer
__all__
=
[
'ReplayBuffer'
,
'NaiveReplayBuffer'
]
applications/Chat/coati/replay_buffer/base.py
0 → 100644
View file @
b0ce5a10
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
from
coati.experience_maker.base
import
Experience
class
ReplayBuffer
(
ABC
):
"""Replay buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
"""
def
__init__
(
self
,
sample_batch_size
:
int
,
limit
:
int
=
0
)
->
None
:
super
().
__init__
()
self
.
sample_batch_size
=
sample_batch_size
# limit <= 0 means unlimited
self
.
limit
=
limit
@
abstractmethod
def
append
(
self
,
experience
:
Experience
)
->
None
:
pass
@
abstractmethod
def
clear
(
self
)
->
None
:
pass
@
abstractmethod
def
sample
(
self
)
->
Experience
:
pass
@
abstractmethod
def
__len__
(
self
)
->
int
:
pass
@
abstractmethod
def
__getitem__
(
self
,
idx
:
int
)
->
Any
:
pass
@
abstractmethod
def
collate_fn
(
self
,
batch
:
Any
)
->
Experience
:
pass
applications/Chat/coati/replay_buffer/naive.py
0 → 100644
View file @
b0ce5a10
import
random
from
typing
import
List
import
torch
from
coati.experience_maker.base
import
Experience
from
.base
import
ReplayBuffer
from
.utils
import
BufferItem
,
make_experience_batch
,
split_experience_batch
class
NaiveReplayBuffer
(
ReplayBuffer
):
"""Naive replay buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
"""
def
__init__
(
self
,
sample_batch_size
:
int
,
limit
:
int
=
0
,
cpu_offload
:
bool
=
True
)
->
None
:
super
().
__init__
(
sample_batch_size
,
limit
)
self
.
cpu_offload
=
cpu_offload
self
.
target_device
=
torch
.
device
(
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
)
# TODO(ver217): add prefetch
self
.
items
:
List
[
BufferItem
]
=
[]
@
torch
.
no_grad
()
def
append
(
self
,
experience
:
Experience
)
->
None
:
if
self
.
cpu_offload
:
experience
.
to_device
(
torch
.
device
(
'cpu'
))
items
=
split_experience_batch
(
experience
)
self
.
items
.
extend
(
items
)
if
self
.
limit
>
0
:
samples_to_remove
=
len
(
self
.
items
)
-
self
.
limit
if
samples_to_remove
>
0
:
self
.
items
=
self
.
items
[
samples_to_remove
:]
def
clear
(
self
)
->
None
:
self
.
items
.
clear
()
@
torch
.
no_grad
()
def
sample
(
self
)
->
Experience
:
items
=
random
.
sample
(
self
.
items
,
self
.
sample_batch_size
)
experience
=
make_experience_batch
(
items
)
if
self
.
cpu_offload
:
experience
.
to_device
(
self
.
target_device
)
return
experience
def
__len__
(
self
)
->
int
:
return
len
(
self
.
items
)
def
__getitem__
(
self
,
idx
:
int
)
->
BufferItem
:
return
self
.
items
[
idx
]
def
collate_fn
(
self
,
batch
)
->
Experience
:
experience
=
make_experience_batch
(
batch
)
return
experience
applications/Chat/coati/replay_buffer/utils.py
0 → 100644
View file @
b0ce5a10
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
coati.experience_maker.base
import
Experience
@
dataclass
class
BufferItem
:
"""BufferItem is an item of experience data.
Shapes of each tensor:
sequences: (S)
action_log_probs: (A)
values: (1)
reward: (1)
advatanges: (1)
attention_mask: (S)
action_mask: (A)
"A" is the number of actions.
"""
sequences
:
torch
.
Tensor
action_log_probs
:
torch
.
Tensor
values
:
torch
.
Tensor
reward
:
torch
.
Tensor
advantages
:
torch
.
Tensor
attention_mask
:
Optional
[
torch
.
LongTensor
]
action_mask
:
Optional
[
torch
.
BoolTensor
]
def
split_experience_batch
(
experience
:
Experience
)
->
List
[
BufferItem
]:
batch_size
=
experience
.
sequences
.
size
(
0
)
batch_kwargs
=
[{}
for
_
in
range
(
batch_size
)]
keys
=
(
'sequences'
,
'action_log_probs'
,
'values'
,
'reward'
,
'advantages'
,
'attention_mask'
,
'action_mask'
)
for
key
in
keys
:
value
=
getattr
(
experience
,
key
)
if
isinstance
(
value
,
torch
.
Tensor
):
vals
=
torch
.
unbind
(
value
)
else
:
# None
vals
=
[
value
for
_
in
range
(
batch_size
)]
assert
batch_size
==
len
(
vals
)
for
i
,
v
in
enumerate
(
vals
):
batch_kwargs
[
i
][
key
]
=
v
items
=
[
BufferItem
(
**
kwargs
)
for
kwargs
in
batch_kwargs
]
return
items
def
zero_pad_sequences
(
sequences
:
List
[
torch
.
Tensor
],
side
:
str
=
'left'
)
->
torch
.
Tensor
:
assert
side
in
(
'left'
,
'right'
)
max_len
=
max
(
seq
.
size
(
0
)
for
seq
in
sequences
)
padded_sequences
=
[]
for
seq
in
sequences
:
pad_len
=
max_len
-
seq
.
size
(
0
)
padding
=
(
pad_len
,
0
)
if
side
==
'left'
else
(
0
,
pad_len
)
padded_sequences
.
append
(
F
.
pad
(
seq
,
padding
))
return
torch
.
stack
(
padded_sequences
,
dim
=
0
)
def
make_experience_batch
(
items
:
List
[
BufferItem
])
->
Experience
:
kwargs
=
{}
to_pad_keys
=
set
((
'action_log_probs'
,
'action_mask'
))
keys
=
(
'sequences'
,
'action_log_probs'
,
'values'
,
'reward'
,
'advantages'
,
'attention_mask'
,
'action_mask'
)
for
key
in
keys
:
vals
=
[
getattr
(
item
,
key
)
for
item
in
items
]
if
key
in
to_pad_keys
:
batch_data
=
zero_pad_sequences
(
vals
)
else
:
batch_data
=
torch
.
stack
(
vals
,
dim
=
0
)
kwargs
[
key
]
=
batch_data
return
Experience
(
**
kwargs
)
applications/Chat/coati/trainer/__init__.py
0 → 100644
View file @
b0ce5a10
from
.base
import
Trainer
from
.ppo
import
PPOTrainer
from
.rm
import
RewardModelTrainer
from
.sft
import
SFTTrainer
__all__
=
[
'Trainer'
,
'PPOTrainer'
,
'RewardModelTrainer'
,
'SFTTrainer'
]
applications/Chat/coati/trainer/base.py
0 → 100644
View file @
b0ce5a10
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
coati.experience_maker
import
Experience
,
ExperienceMaker
from
coati.replay_buffer
import
ReplayBuffer
from
torch
import
Tensor
from
torch.utils.data
import
DistributedSampler
from
tqdm
import
tqdm
from
.callbacks
import
Callback
from
.strategies
import
Strategy
from
.utils
import
is_rank_0
class
Trainer
(
ABC
):
"""
Base class for rlhf trainers.
Args:
strategy (Strategy):the strategy to use for training
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
replay_buffer (ReplayBuffer): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def
__init__
(
self
,
strategy
:
Strategy
,
experience_maker
:
ExperienceMaker
,
replay_buffer
:
ReplayBuffer
,
experience_batch_size
:
int
=
8
,
max_epochs
:
int
=
1
,
tokenizer
:
Optional
[
Callable
[[
Any
],
dict
]]
=
None
,
sample_replay_buffer
:
bool
=
False
,
dataloader_pin_memory
:
bool
=
True
,
callbacks
:
List
[
Callback
]
=
[],
**
generate_kwargs
)
->
None
:
super
().
__init__
()
self
.
strategy
=
strategy
self
.
experience_maker
=
experience_maker
self
.
replay_buffer
=
replay_buffer
self
.
experience_batch_size
=
experience_batch_size
self
.
max_epochs
=
max_epochs
self
.
tokenizer
=
tokenizer
self
.
generate_kwargs
=
generate_kwargs
self
.
sample_replay_buffer
=
sample_replay_buffer
self
.
dataloader_pin_memory
=
dataloader_pin_memory
self
.
callbacks
=
callbacks
@
abstractmethod
def
training_step
(
self
,
experience
:
Experience
)
->
Dict
[
str
,
Any
]:
pass
def
_make_experience
(
self
,
inputs
:
Union
[
Tensor
,
Dict
[
str
,
Tensor
]])
->
Experience
:
if
isinstance
(
inputs
,
Tensor
):
return
self
.
experience_maker
.
make_experience
(
inputs
,
**
self
.
generate_kwargs
)
elif
isinstance
(
inputs
,
dict
):
return
self
.
experience_maker
.
make_experience
(
**
inputs
,
**
self
.
generate_kwargs
)
else
:
raise
ValueError
(
f
'Unsupported input type "
{
type
(
inputs
)
}
"'
)
def
_sample_prompts
(
self
,
prompts
)
->
list
:
indices
=
list
(
range
(
len
(
prompts
)))
sampled_indices
=
self
.
strategy
.
experience_sampler
.
choice
(
indices
,
self
.
experience_batch_size
,
replace
=
False
)
return
[
prompts
[
i
]
for
i
in
sampled_indices
]
def
_learn
(
self
):
# replay buffer may be empty at first, we should rebuild at each training
if
not
self
.
sample_replay_buffer
:
dataloader
=
self
.
strategy
.
setup_dataloader
(
self
.
replay_buffer
,
self
.
dataloader_pin_memory
)
device
=
torch
.
cuda
.
current_device
()
if
self
.
sample_replay_buffer
:
pbar
=
tqdm
(
range
(
self
.
max_epochs
),
desc
=
'Train epoch'
,
disable
=
not
is_rank_0
())
for
_
in
pbar
:
experience
=
self
.
replay_buffer
.
sample
()
metrics
=
self
.
training_step
(
experience
)
pbar
.
set_postfix
(
metrics
)
else
:
for
epoch
in
range
(
self
.
max_epochs
):
self
.
_on_learn_epoch_start
(
epoch
)
if
isinstance
(
dataloader
.
sampler
,
DistributedSampler
):
dataloader
.
sampler
.
set_epoch
(
epoch
)
pbar
=
tqdm
(
dataloader
,
desc
=
f
'Train epoch [
{
epoch
+
1
}
/
{
self
.
max_epochs
}
]'
,
disable
=
not
is_rank_0
())
for
experience
in
pbar
:
self
.
_on_learn_batch_start
()
experience
.
to_device
(
device
)
metrics
=
self
.
training_step
(
experience
)
self
.
_on_learn_batch_end
(
metrics
,
experience
)
pbar
.
set_postfix
(
metrics
)
self
.
_on_learn_epoch_end
(
epoch
)
def
fit
(
self
,
prompt_dataloader
,
pretrain_dataloader
,
num_episodes
:
int
=
50000
,
max_timesteps
:
int
=
500
,
update_timesteps
:
int
=
5000
)
->
None
:
time
=
0
self
.
pretrain_dataloader
=
pretrain_dataloader
self
.
prompt_dataloader
=
prompt_dataloader
self
.
_on_fit_start
()
for
episode
in
range
(
num_episodes
):
self
.
_on_episode_start
(
episode
)
for
timestep
in
tqdm
(
range
(
max_timesteps
),
desc
=
f
'Episode [
{
episode
+
1
}
/
{
num_episodes
}
]'
,
disable
=
not
is_rank_0
()):
time
+=
1
prompts
=
next
(
iter
(
self
.
prompt_dataloader
))
self
.
_on_make_experience_start
()
self
.
experience_maker
.
initial_model
.
to
(
torch
.
cuda
.
current_device
())
self
.
experience_maker
.
reward_model
.
to
(
torch
.
cuda
.
current_device
())
experience
=
self
.
_make_experience
(
prompts
)
self
.
_on_make_experience_end
(
experience
)
self
.
replay_buffer
.
append
(
experience
)
if
time
%
update_timesteps
==
0
:
self
.
experience_maker
.
initial_model
.
to
(
'cpu'
)
self
.
experience_maker
.
reward_model
.
to
(
'cpu'
)
self
.
_learn
()
self
.
replay_buffer
.
clear
()
self
.
_on_episode_end
(
episode
)
self
.
_on_fit_end
()
# TODO(ver217): maybe simplify these code using context
def
_on_fit_start
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_fit_start
()
def
_on_fit_end
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_fit_end
()
def
_on_episode_start
(
self
,
episode
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_episode_start
(
episode
)
def
_on_episode_end
(
self
,
episode
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_episode_end
(
episode
)
def
_on_make_experience_start
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_make_experience_start
()
def
_on_make_experience_end
(
self
,
experience
:
Experience
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_make_experience_end
(
experience
)
def
_on_learn_epoch_start
(
self
,
epoch
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_epoch_start
(
epoch
)
def
_on_learn_epoch_end
(
self
,
epoch
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_epoch_end
(
epoch
)
def
_on_learn_batch_start
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_batch_start
()
def
_on_learn_batch_end
(
self
,
metrics
:
dict
,
experience
:
Experience
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_batch_end
(
metrics
,
experience
)
applications/Chat/coati/trainer/callbacks/__init__.py
0 → 100644
View file @
b0ce5a10
from
.base
import
Callback
from
.performance_evaluator
import
PerformanceEvaluator
from
.save_checkpoint
import
SaveCheckpoint
__all__
=
[
'Callback'
,
'PerformanceEvaluator'
,
'SaveCheckpoint'
]
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment