Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1b347010
Unverified
Commit
1b347010
authored
Feb 14, 2023
by
ver217
Committed by
GitHub
Feb 14, 2023
Browse files
[app] add chatgpt application (#2698)
parent
c3abdd08
Changes
64
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1179 additions
and
0 deletions
+1179
-0
applications/ChatGPT/chatgpt/nn/generation_utils.py
applications/ChatGPT/chatgpt/nn/generation_utils.py
+92
-0
applications/ChatGPT/chatgpt/nn/gpt_actor.py
applications/ChatGPT/chatgpt/nn/gpt_actor.py
+31
-0
applications/ChatGPT/chatgpt/nn/gpt_critic.py
applications/ChatGPT/chatgpt/nn/gpt_critic.py
+33
-0
applications/ChatGPT/chatgpt/nn/gpt_rm.py
applications/ChatGPT/chatgpt/nn/gpt_rm.py
+33
-0
applications/ChatGPT/chatgpt/nn/lora.py
applications/ChatGPT/chatgpt/nn/lora.py
+127
-0
applications/ChatGPT/chatgpt/nn/loss.py
applications/ChatGPT/chatgpt/nn/loss.py
+105
-0
applications/ChatGPT/chatgpt/nn/opt_actor.py
applications/ChatGPT/chatgpt/nn/opt_actor.py
+35
-0
applications/ChatGPT/chatgpt/nn/opt_critic.py
applications/ChatGPT/chatgpt/nn/opt_critic.py
+37
-0
applications/ChatGPT/chatgpt/nn/opt_rm.py
applications/ChatGPT/chatgpt/nn/opt_rm.py
+33
-0
applications/ChatGPT/chatgpt/nn/reward_model.py
applications/ChatGPT/chatgpt/nn/reward_model.py
+41
-0
applications/ChatGPT/chatgpt/nn/utils.py
applications/ChatGPT/chatgpt/nn/utils.py
+92
-0
applications/ChatGPT/chatgpt/replay_buffer/__init__.py
applications/ChatGPT/chatgpt/replay_buffer/__init__.py
+4
-0
applications/ChatGPT/chatgpt/replay_buffer/base.py
applications/ChatGPT/chatgpt/replay_buffer/base.py
+43
-0
applications/ChatGPT/chatgpt/replay_buffer/naive.py
applications/ChatGPT/chatgpt/replay_buffer/naive.py
+57
-0
applications/ChatGPT/chatgpt/replay_buffer/utils.py
applications/ChatGPT/chatgpt/replay_buffer/utils.py
+73
-0
applications/ChatGPT/chatgpt/trainer/__init__.py
applications/ChatGPT/chatgpt/trainer/__init__.py
+5
-0
applications/ChatGPT/chatgpt/trainer/base.py
applications/ChatGPT/chatgpt/trainer/base.py
+162
-0
applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py
applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py
+4
-0
applications/ChatGPT/chatgpt/trainer/callbacks/base.py
applications/ChatGPT/chatgpt/trainer/callbacks/base.py
+39
-0
applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py
...hatGPT/chatgpt/trainer/callbacks/performance_evaluator.py
+133
-0
No files found.
applications/ChatGPT/chatgpt/nn/generation_utils.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
import
torch
def
gpt_prepare_inputs_fn
(
input_ids
:
torch
.
Tensor
,
past
:
Optional
[
torch
.
Tensor
]
=
None
,
**
kwargs
)
->
dict
:
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
if
past
:
input_ids
=
input_ids
[:,
-
1
].
unsqueeze
(
-
1
)
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
[:,
-
1
].
unsqueeze
(
-
1
)
attention_mask
=
kwargs
.
get
(
"attention_mask"
,
None
)
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
else
:
position_ids
=
None
return
{
"input_ids"
:
input_ids
,
"past_key_values"
:
past
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"position_ids"
:
position_ids
,
"attention_mask"
:
attention_mask
,
"token_type_ids"
:
token_type_ids
,
}
def
update_model_kwargs_fn
(
outputs
:
dict
,
**
model_kwargs
)
->
dict
:
if
"past_key_values"
in
outputs
:
model_kwargs
[
"past"
]
=
outputs
[
"past_key_values"
]
else
:
model_kwargs
[
"past"
]
=
None
# update token_type_ids with last value
if
"token_type_ids"
in
model_kwargs
:
token_type_ids
=
model_kwargs
[
"token_type_ids"
]
model_kwargs
[
"token_type_ids"
]
=
torch
.
cat
([
token_type_ids
,
token_type_ids
[:,
-
1
].
unsqueeze
(
-
1
)],
dim
=-
1
)
# update attention mask
if
"attention_mask"
in
model_kwargs
:
attention_mask
=
model_kwargs
[
"attention_mask"
]
model_kwargs
[
"attention_mask"
]
=
torch
.
cat
(
[
attention_mask
,
attention_mask
.
new_ones
((
attention_mask
.
shape
[
0
],
1
))],
dim
=-
1
)
return
model_kwargs
def
opt_prepare_inputs_fn
(
input_ids
:
torch
.
Tensor
,
past
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
**
kwargs
)
->
dict
:
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if
attention_mask
is
None
:
attention_mask
=
input_ids
.
new_ones
(
input_ids
.
shape
)
if
past
:
input_ids
=
input_ids
[:,
-
1
:]
# first step, decoder_cached_states are empty
return
{
"input_ids"
:
input_ids
,
# encoder_outputs is defined. input_ids not needed
"attention_mask"
:
attention_mask
,
"past_key_values"
:
past
,
"use_cache"
:
use_cache
,
}
def
bloom_prepare_inputs_fn
(
input_ids
:
torch
.
Tensor
,
past
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
**
kwargs
)
->
dict
:
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if
attention_mask
is
None
:
attention_mask
=
input_ids
.
new_ones
(
input_ids
.
shape
)
if
past
:
input_ids
=
input_ids
[:,
-
1
:]
# first step, decoder_cached_states are empty
return
{
"input_ids"
:
input_ids
,
# encoder_outputs is defined. input_ids not needed
"attention_mask"
:
attention_mask
,
"past_key_values"
:
past
,
"use_cache"
:
use_cache
,
}
applications/ChatGPT/chatgpt/nn/gpt_actor.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
from
transformers.models.gpt2.configuration_gpt2
import
GPT2Config
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
from
.actor
import
Actor
class
GPTActor
(
Actor
):
"""
GPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
checkpoint
:
bool
=
False
)
->
None
:
if
pretrained
is
not
None
:
model
=
GPT2LMHeadModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
GPT2LMHeadModel
(
config
)
else
:
model
=
GPT2LMHeadModel
(
GPT2Config
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
)
applications/ChatGPT/chatgpt/nn/gpt_critic.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers.models.gpt2.configuration_gpt2
import
GPT2Config
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
from
.critic
import
Critic
class
GPTCritic
(
Critic
):
"""
GPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
checkpoint
:
bool
=
False
)
->
None
:
if
pretrained
is
not
None
:
model
=
GPT2Model
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
GPT2Model
(
config
)
else
:
model
=
GPT2Model
(
GPT2Config
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
super
().
__init__
(
model
,
value_head
)
applications/ChatGPT/chatgpt/nn/gpt_rm.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers.models.gpt2.configuration_gpt2
import
GPT2Config
from
transformers.models.gpt2.modeling_gpt2
import
GPT2Model
from
.reward_model
import
RewardModel
class
GPTRM
(
RewardModel
):
"""
GPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (GPT2Config): Model config.
checkpoint (bool): Enable gradient checkpointing.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
GPT2Config
]
=
None
,
checkpoint
:
bool
=
False
)
->
None
:
if
pretrained
is
not
None
:
model
=
GPT2Model
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
GPT2Model
(
config
)
else
:
model
=
GPT2Model
(
GPT2Config
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
super
().
__init__
(
model
,
value_head
)
applications/ChatGPT/chatgpt/nn/lora.py
0 → 100644
View file @
1b347010
import
math
from
typing
import
Optional
import
loralib
as
lora
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
class
LoraLinear
(
lora
.
LoRALayer
,
nn
.
Module
):
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
"""
def
__init__
(
self
,
weight
:
nn
.
Parameter
,
bias
:
Optional
[
nn
.
Parameter
],
r
:
int
=
0
,
lora_alpha
:
int
=
1
,
lora_dropout
:
float
=
0.
,
fan_in_fan_out
:
bool
=
False
,
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
merge_weights
:
bool
=
True
,
):
nn
.
Module
.
__init__
(
self
)
lora
.
LoRALayer
.
__init__
(
self
,
r
=
r
,
lora_alpha
=
lora_alpha
,
lora_dropout
=
lora_dropout
,
merge_weights
=
merge_weights
)
self
.
weight
=
weight
self
.
bias
=
bias
out_features
,
in_features
=
weight
.
shape
self
.
in_features
=
in_features
self
.
out_features
=
out_features
self
.
fan_in_fan_out
=
fan_in_fan_out
# Actual trainable parameters
if
r
>
0
:
self
.
lora_A
=
nn
.
Parameter
(
self
.
weight
.
new_zeros
((
r
,
in_features
)))
self
.
lora_B
=
nn
.
Parameter
(
self
.
weight
.
new_zeros
((
out_features
,
r
)))
self
.
scaling
=
self
.
lora_alpha
/
self
.
r
# Freezing the pre-trained weight matrix
self
.
weight
.
requires_grad
=
False
self
.
reset_parameters
()
if
fan_in_fan_out
:
self
.
weight
.
data
=
self
.
weight
.
data
.
T
def
reset_parameters
(
self
):
if
hasattr
(
self
,
'lora_A'
):
# initialize A the same way as the default for nn.Linear and B to zero
nn
.
init
.
kaiming_uniform_
(
self
.
lora_A
,
a
=
math
.
sqrt
(
5
))
nn
.
init
.
zeros_
(
self
.
lora_B
)
def
train
(
self
,
mode
:
bool
=
True
):
def
T
(
w
):
return
w
.
T
if
self
.
fan_in_fan_out
else
w
nn
.
Module
.
train
(
self
,
mode
)
if
self
.
merge_weights
and
self
.
merged
:
# Make sure that the weights are not merged
if
self
.
r
>
0
:
self
.
weight
.
data
-=
T
(
self
.
lora_B
@
self
.
lora_A
)
*
self
.
scaling
self
.
merged
=
False
def
eval
(
self
):
def
T
(
w
):
return
w
.
T
if
self
.
fan_in_fan_out
else
w
nn
.
Module
.
eval
(
self
)
if
self
.
merge_weights
and
not
self
.
merged
:
# Merge the weights and mark it
if
self
.
r
>
0
:
self
.
weight
.
data
+=
T
(
self
.
lora_B
@
self
.
lora_A
)
*
self
.
scaling
self
.
merged
=
True
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
T
(
w
):
return
w
.
T
if
self
.
fan_in_fan_out
else
w
if
self
.
r
>
0
and
not
self
.
merged
:
result
=
F
.
linear
(
x
,
T
(
self
.
weight
),
bias
=
self
.
bias
)
if
self
.
r
>
0
:
result
=
result
+
(
self
.
lora_dropout
(
x
)
@
self
.
lora_A
.
t
()
@
self
.
lora_B
.
t
())
*
self
.
scaling
return
result
else
:
return
F
.
linear
(
x
,
T
(
self
.
weight
),
bias
=
self
.
bias
)
def
lora_linear_wrapper
(
linear
:
nn
.
Linear
,
lora_rank
:
int
)
->
LoraLinear
:
assert
lora_rank
<=
linear
.
in_features
,
f
'LoRA rank (
{
lora_rank
}
) must be less than or equal to in features (
{
linear
.
in_features
}
)'
lora_linear
=
LoraLinear
(
linear
.
weight
,
linear
.
bias
,
r
=
lora_rank
,
merge_weights
=
False
)
return
lora_linear
def
convert_to_lora_recursively
(
module
:
nn
.
Module
,
lora_rank
:
int
)
->
None
:
for
name
,
child
in
module
.
named_children
():
if
isinstance
(
child
,
nn
.
Linear
):
setattr
(
module
,
name
,
lora_linear_wrapper
(
child
,
lora_rank
))
else
:
convert_to_lora_recursively
(
child
,
lora_rank
)
class
LoRAModule
(
nn
.
Module
):
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
Args:
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
lora_train_bias (str, optional): Whether LoRA train biases.
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
Defaults to 'none'.
"""
def
__init__
(
self
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
super
().
__init__
()
self
.
lora_rank
=
lora_rank
self
.
lora_train_bias
=
lora_train_bias
def
convert_to_lora
(
self
)
->
None
:
if
self
.
lora_rank
<=
0
:
return
convert_to_lora_recursively
(
self
,
self
.
lora_rank
)
lora
.
mark_only_lora_as_trainable
(
self
,
self
.
lora_train_bias
)
applications/ChatGPT/chatgpt/nn/loss.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.utils
import
masked_mean
class
GPTLMLoss
(
nn
.
Module
):
"""
GPT Language Model Loss
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
loss
=
nn
.
CrossEntropyLoss
()
def
forward
(
self
,
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shift_logits
=
logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
return
self
.
loss
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
class
PolicyLoss
(
nn
.
Module
):
"""
Policy Loss for PPO
"""
def
__init__
(
self
,
clip_eps
:
float
=
0.2
)
->
None
:
super
().
__init__
()
self
.
clip_eps
=
clip_eps
def
forward
(
self
,
log_probs
:
torch
.
Tensor
,
old_log_probs
:
torch
.
Tensor
,
advantages
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
ratio
=
(
log_probs
-
old_log_probs
).
exp
()
surr1
=
ratio
*
advantages
surr2
=
ratio
.
clamp
(
1
-
self
.
clip_eps
,
1
+
self
.
clip_eps
)
*
advantages
loss
=
-
torch
.
min
(
surr1
,
surr2
)
if
action_mask
is
not
None
:
loss
=
masked_mean
(
loss
,
action_mask
)
loss
=
loss
.
mean
()
return
loss
class
ValueLoss
(
nn
.
Module
):
"""
Value Loss for PPO
"""
def
__init__
(
self
,
clip_eps
:
float
=
0.4
)
->
None
:
super
().
__init__
()
self
.
clip_eps
=
clip_eps
def
forward
(
self
,
values
:
torch
.
Tensor
,
old_values
:
torch
.
Tensor
,
reward
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
values_clipped
=
old_values
+
(
values
-
old_values
).
clamp
(
-
self
.
clip_eps
,
self
.
clip_eps
)
surr1
=
(
values_clipped
-
reward
)
**
2
surr2
=
(
values
-
reward
)
**
2
loss
=
torch
.
max
(
surr1
,
surr2
)
loss
=
loss
.
mean
()
return
loss
class
PPOPtxActorLoss
(
nn
.
Module
):
"""
To Do:
PPO-ptx Actor Loss
"""
def
__init__
(
self
,
policy_clip_eps
:
float
=
0.2
,
pretrain_coef
:
float
=
0.0
,
pretrain_loss_fn
=
GPTLMLoss
())
->
None
:
super
().
__init__
()
self
.
pretrain_coef
=
pretrain_coef
self
.
policy_loss_fn
=
PolicyLoss
(
clip_eps
=
policy_clip_eps
)
self
.
pretrain_loss_fn
=
pretrain_loss_fn
def
forward
(
self
,
log_probs
:
torch
.
Tensor
,
old_log_probs
:
torch
.
Tensor
,
advantages
:
torch
.
Tensor
,
lm_logits
:
torch
.
Tensor
,
lm_input_ids
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
policy_loss
=
self
.
policy_loss_fn
(
log_probs
,
old_log_probs
,
advantages
,
action_mask
=
action_mask
)
lm_loss
=
self
.
pretrain_loss_fn
(
lm_logits
,
lm_input_ids
)
return
policy_loss
+
self
.
pretrain_coef
*
lm_loss
class
PairWiseLoss
(
nn
.
Module
):
"""
Pairwise Loss for Reward Model
"""
def
forward
(
self
,
chosen_reward
:
torch
.
Tensor
,
reject_reward
:
torch
.
Tensor
)
->
torch
.
Tensor
:
probs
=
torch
.
sigmoid
(
chosen_reward
-
reject_reward
)
log_probs
=
torch
.
log
(
probs
)
loss
=
-
log_probs
.
mean
()
return
loss
applications/ChatGPT/chatgpt/nn/opt_actor.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
from
transformers.models.opt.configuration_opt
import
OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
.actor
import
Actor
class
OPTActor
(
Actor
):
"""
OPT Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
OPTForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
OPTForCausalLM
(
config
)
else
:
model
=
OPTForCausalLM
(
OPTConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/nn/opt_critic.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers.models.opt.configuration_opt
import
OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTModel
from
.critic
import
Critic
class
OPTCritic
(
Critic
):
"""
OPT Critic model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
OPTModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
OPTModel
(
config
)
else
:
model
=
OPTModel
(
OPTConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/nn/opt_rm.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
import
torch.nn
as
nn
from
transformers.models.opt.configuration_opt
import
OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTModel
from
.reward_model
import
RewardModel
class
OPTRM
(
RewardModel
):
"""
OPT Reward model.
Args:
pretrained (str): Pretrained model name or path.
config (OPTConfig): Model config.
lora_rank (int): Rank of the low-rank approximation.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
OPTConfig
]
=
None
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
OPTModel
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
OPTModel
(
config
)
else
:
model
=
OPTModel
(
OPTConfig
())
value_head
=
nn
.
Linear
(
model
.
config
.
hidden_size
,
1
)
super
().
__init__
(
model
,
value_head
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/nn/reward_model.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
.lora
import
LoRAModule
class
RewardModel
(
LoRAModule
):
"""
Reward model base class.
Args:
model (nn.Module): Reward model.
value_head (nn.Module): Value head to get reward score.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
model
:
nn
.
Module
,
value_head
:
Optional
[
nn
.
Module
]
=
None
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
super
().
__init__
(
lora_rank
=
lora_rank
,
lora_train_bias
=
lora_train_bias
)
self
.
model
=
model
if
value_head
is
not
None
:
if
value_head
.
out_features
!=
1
:
raise
ValueError
(
"The value head of reward model's output dim should be 1!"
)
self
.
value_head
=
value_head
else
:
self
.
value_head
=
nn
.
Linear
(
model
.
config
.
n_embd
,
1
)
self
.
convert_to_lora
()
def
forward
(
self
,
sequences
:
torch
.
LongTensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
outputs
=
self
.
model
(
sequences
,
attention_mask
=
attention_mask
)
last_hidden_states
=
outputs
[
'last_hidden_state'
]
values
=
self
.
value_head
(
last_hidden_states
)[:,
:
-
1
]
value
=
values
.
mean
(
dim
=
1
).
squeeze
(
1
)
# ensure shape is (B)
return
value
applications/ChatGPT/chatgpt/nn/utils.py
0 → 100644
View file @
1b347010
from
typing
import
Optional
,
Union
import
loralib
as
lora
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
def
compute_approx_kl
(
log_probs
:
torch
.
Tensor
,
log_probs_base
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
Args:
log_probs: Log probabilities of the new distribution.
log_probs_base: Log probabilities of the base distribution.
action_mask: Mask for actions.
"""
log_ratio
=
log_probs
-
log_probs_base
approx_kl
=
(
log_ratio
.
exp
()
-
1
)
-
log_ratio
if
action_mask
is
not
None
:
approx_kl
=
masked_mean
(
approx_kl
,
action_mask
,
dim
=
1
)
return
approx_kl
approx_kl
=
approx_kl
.
mean
(
dim
=
1
)
return
approx_kl
def
compute_reward
(
r
:
Union
[
torch
.
Tensor
,
float
],
kl_coef
:
float
,
log_probs
:
torch
.
Tensor
,
log_probs_base
:
torch
.
Tensor
,
action_mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
kl_coef
<=
0.0
:
return
r
kl
=
compute_approx_kl
(
log_probs
,
log_probs_base
,
action_mask
=
action_mask
)
reward
=
r
-
kl_coef
*
kl
return
reward
def
log_probs_from_logits
(
logits
:
torch
.
Tensor
,
labels
:
torch
.
Tensor
)
->
torch
.
Tensor
:
log_probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
)
log_probs_labels
=
log_probs
.
gather
(
dim
=-
1
,
index
=
labels
.
unsqueeze
(
-
1
))
return
log_probs_labels
.
squeeze
(
-
1
)
def
masked_mean
(
tensor
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
dim
:
int
=
1
)
->
torch
.
Tensor
:
tensor
=
tensor
*
mask
tensor
=
tensor
.
sum
(
dim
=
dim
)
mask_sum
=
mask
.
sum
(
dim
=
dim
)
mean
=
tensor
/
(
mask_sum
+
1e-8
)
return
mean
def
masked_normalize
(
tensor
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
,
dim
:
int
=
1
,
eps
:
float
=
1e-8
)
->
torch
.
Tensor
:
tensor
=
tensor
*
mask
mean
=
masked_mean
(
tensor
,
mask
,
dim
=
dim
)
mean_centered
=
tensor
-
mean
var
=
masked_mean
(
mean_centered
**
2
,
mask
,
dim
=
dim
)
return
mean_centered
*
var
.
clamp
(
min
=
eps
).
rsqrt
()
def
normalize
(
tensor
:
torch
.
Tensor
,
dim
:
int
=
0
,
eps
:
float
=
1e-8
)
->
torch
.
Tensor
:
mean
=
tensor
.
mean
(
dim
)
mean_centered
=
tensor
-
mean
var
=
(
mean_centered
**
2
).
mean
(
dim
)
norm
=
mean_centered
*
var
.
clamp
(
min
=
eps
).
rsqrt
()
return
norm
def
convert_to_lora
(
model
:
nn
.
Module
,
input_size
:
int
,
output_size
:
int
,
lora_rank
:
int
=
16
,
lora_alpha
:
int
=
1
,
lora_dropout
:
float
=
0.
,
fan_in_fan_out
:
bool
=
False
,
merge_weights
:
bool
=
True
):
if
lora_rank
>
min
(
input_size
,
output_size
):
raise
ValueError
(
f
"LoRA rank
{
lora_rank
}
must be less or equal than
{
min
(
input_size
,
output_size
)
}
"
)
for
name
,
module
in
model
.
named_modules
():
if
isinstance
(
module
,
nn
.
Linear
):
module
.
_modules
[
name
]
=
lora
.
Linear
(
input_size
,
output_size
,
r
=
lora_rank
,
lora_alpha
=
lora_alpha
,
lora_dropout
=
lora_dropout
,
fan_in_fan_out
=
fan_in_fan_out
,
merge_weights
=
merge_weights
)
applications/ChatGPT/chatgpt/replay_buffer/__init__.py
0 → 100644
View file @
1b347010
from
.base
import
ReplayBuffer
from
.naive
import
NaiveReplayBuffer
__all__
=
[
'ReplayBuffer'
,
'NaiveReplayBuffer'
]
applications/ChatGPT/chatgpt/replay_buffer/base.py
0 → 100644
View file @
1b347010
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
from
chatgpt.experience_maker.base
import
Experience
class
ReplayBuffer
(
ABC
):
"""Replay buffer base class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
"""
def
__init__
(
self
,
sample_batch_size
:
int
,
limit
:
int
=
0
)
->
None
:
super
().
__init__
()
self
.
sample_batch_size
=
sample_batch_size
# limit <= 0 means unlimited
self
.
limit
=
limit
@
abstractmethod
def
append
(
self
,
experience
:
Experience
)
->
None
:
pass
@
abstractmethod
def
clear
(
self
)
->
None
:
pass
@
abstractmethod
def
sample
(
self
)
->
Experience
:
pass
@
abstractmethod
def
__len__
(
self
)
->
int
:
pass
@
abstractmethod
def
__getitem__
(
self
,
idx
:
int
)
->
Any
:
pass
@
abstractmethod
def
collate_fn
(
self
,
batch
:
Any
)
->
Experience
:
pass
applications/ChatGPT/chatgpt/replay_buffer/naive.py
0 → 100644
View file @
1b347010
import
random
from
typing
import
List
import
torch
from
chatgpt.experience_maker.base
import
Experience
from
.base
import
ReplayBuffer
from
.utils
import
BufferItem
,
make_experience_batch
,
split_experience_batch
class
NaiveReplayBuffer
(
ReplayBuffer
):
"""Naive replay buffer class. It stores experience.
Args:
sample_batch_size (int): Batch size when sampling.
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
"""
def
__init__
(
self
,
sample_batch_size
:
int
,
limit
:
int
=
0
,
cpu_offload
:
bool
=
True
)
->
None
:
super
().
__init__
(
sample_batch_size
,
limit
)
self
.
cpu_offload
=
cpu_offload
self
.
target_device
=
torch
.
device
(
f
'cuda:
{
torch
.
cuda
.
current_device
()
}
'
)
# TODO(ver217): add prefetch
self
.
items
:
List
[
BufferItem
]
=
[]
@
torch
.
no_grad
()
def
append
(
self
,
experience
:
Experience
)
->
None
:
if
self
.
cpu_offload
:
experience
.
to_device
(
torch
.
device
(
'cpu'
))
items
=
split_experience_batch
(
experience
)
self
.
items
.
extend
(
items
)
if
self
.
limit
>
0
:
samples_to_remove
=
len
(
self
.
items
)
-
self
.
limit
if
samples_to_remove
>
0
:
self
.
items
=
self
.
items
[
samples_to_remove
:]
def
clear
(
self
)
->
None
:
self
.
items
.
clear
()
@
torch
.
no_grad
()
def
sample
(
self
)
->
Experience
:
items
=
random
.
sample
(
self
.
items
,
self
.
sample_batch_size
)
experience
=
make_experience_batch
(
items
)
if
self
.
cpu_offload
:
experience
.
to_device
(
self
.
target_device
)
return
experience
def
__len__
(
self
)
->
int
:
return
len
(
self
.
items
)
def
__getitem__
(
self
,
idx
:
int
)
->
BufferItem
:
return
self
.
items
[
idx
]
def
collate_fn
(
self
,
batch
)
->
Experience
:
experience
=
make_experience_batch
(
batch
)
return
experience
applications/ChatGPT/chatgpt/replay_buffer/utils.py
0 → 100644
View file @
1b347010
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
import
torch
import
torch.nn.functional
as
F
from
chatgpt.experience_maker.base
import
Experience
@
dataclass
class
BufferItem
:
"""BufferItem is an item of experience data.
Shapes of each tensor:
sequences: (S)
action_log_probs: (A)
values: (1)
reward: (1)
advatanges: (1)
attention_mask: (S)
action_mask: (A)
"A" is the number of actions.
"""
sequences
:
torch
.
Tensor
action_log_probs
:
torch
.
Tensor
values
:
torch
.
Tensor
reward
:
torch
.
Tensor
advantages
:
torch
.
Tensor
attention_mask
:
Optional
[
torch
.
LongTensor
]
action_mask
:
Optional
[
torch
.
BoolTensor
]
def
split_experience_batch
(
experience
:
Experience
)
->
List
[
BufferItem
]:
batch_size
=
experience
.
sequences
.
size
(
0
)
batch_kwargs
=
[{}
for
_
in
range
(
batch_size
)]
keys
=
(
'sequences'
,
'action_log_probs'
,
'values'
,
'reward'
,
'advantages'
,
'attention_mask'
,
'action_mask'
)
for
key
in
keys
:
value
=
getattr
(
experience
,
key
)
if
isinstance
(
value
,
torch
.
Tensor
):
vals
=
torch
.
unbind
(
value
)
else
:
# None
vals
=
[
value
for
_
in
range
(
batch_size
)]
assert
batch_size
==
len
(
vals
)
for
i
,
v
in
enumerate
(
vals
):
batch_kwargs
[
i
][
key
]
=
v
items
=
[
BufferItem
(
**
kwargs
)
for
kwargs
in
batch_kwargs
]
return
items
def
zero_pad_sequences
(
sequences
:
List
[
torch
.
Tensor
],
side
:
str
=
'left'
)
->
torch
.
Tensor
:
assert
side
in
(
'left'
,
'right'
)
max_len
=
max
(
seq
.
size
(
0
)
for
seq
in
sequences
)
padded_sequences
=
[]
for
seq
in
sequences
:
pad_len
=
max_len
-
seq
.
size
(
0
)
padding
=
(
pad_len
,
0
)
if
side
==
'left'
else
(
0
,
pad_len
)
padded_sequences
.
append
(
F
.
pad
(
seq
,
padding
))
return
torch
.
stack
(
padded_sequences
,
dim
=
0
)
def
make_experience_batch
(
items
:
List
[
BufferItem
])
->
Experience
:
kwargs
=
{}
to_pad_keys
=
set
((
'action_log_probs'
,
'action_mask'
))
keys
=
(
'sequences'
,
'action_log_probs'
,
'values'
,
'reward'
,
'advantages'
,
'attention_mask'
,
'action_mask'
)
for
key
in
keys
:
vals
=
[
getattr
(
item
,
key
)
for
item
in
items
]
if
key
in
to_pad_keys
:
batch_data
=
zero_pad_sequences
(
vals
)
else
:
batch_data
=
torch
.
stack
(
vals
,
dim
=
0
)
kwargs
[
key
]
=
batch_data
return
Experience
(
**
kwargs
)
applications/ChatGPT/chatgpt/trainer/__init__.py
0 → 100644
View file @
1b347010
from
.base
import
Trainer
from
.ppo
import
PPOTrainer
from
.rm
import
RewardModelTrainer
__all__
=
[
'Trainer'
,
'PPOTrainer'
,
'RewardModelTrainer'
]
applications/ChatGPT/chatgpt/trainer/base.py
0 → 100644
View file @
1b347010
import
random
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
import
torch
from
chatgpt.experience_maker
import
Experience
,
ExperienceMaker
from
chatgpt.replay_buffer
import
ReplayBuffer
from
torch
import
Tensor
from
torch.utils.data
import
DistributedSampler
from
tqdm
import
tqdm
from
.callbacks
import
Callback
from
.strategies
import
Strategy
from
.utils
import
is_rank_0
class
Trainer
(
ABC
):
"""
Base class for rlhf trainers.
Args:
strategy (Strategy):the strategy to use for training
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
replay_buffer (ReplayBuffer): the replay buffer to use for training
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
max_epochs (int, defaults to 1): the number of epochs of training process
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def
__init__
(
self
,
strategy
:
Strategy
,
experience_maker
:
ExperienceMaker
,
replay_buffer
:
ReplayBuffer
,
experience_batch_size
:
int
=
8
,
max_epochs
:
int
=
1
,
tokenizer
:
Optional
[
Callable
[[
Any
],
dict
]]
=
None
,
sample_replay_buffer
:
bool
=
False
,
dataloader_pin_memory
:
bool
=
True
,
callbacks
:
List
[
Callback
]
=
[],
**
generate_kwargs
)
->
None
:
super
().
__init__
()
self
.
strategy
=
strategy
self
.
experience_maker
=
experience_maker
self
.
replay_buffer
=
replay_buffer
self
.
experience_batch_size
=
experience_batch_size
self
.
max_epochs
=
max_epochs
self
.
tokenizer
=
tokenizer
self
.
generate_kwargs
=
generate_kwargs
self
.
sample_replay_buffer
=
sample_replay_buffer
self
.
dataloader_pin_memory
=
dataloader_pin_memory
self
.
callbacks
=
callbacks
@
abstractmethod
def
training_step
(
self
,
experience
:
Experience
)
->
Dict
[
str
,
Any
]:
pass
def
_make_experience
(
self
,
inputs
:
Union
[
Tensor
,
Dict
[
str
,
Tensor
]])
->
Experience
:
if
isinstance
(
inputs
,
Tensor
):
return
self
.
experience_maker
.
make_experience
(
inputs
,
**
self
.
generate_kwargs
)
elif
isinstance
(
inputs
,
dict
):
return
self
.
experience_maker
.
make_experience
(
**
inputs
,
**
self
.
generate_kwargs
)
else
:
raise
ValueError
(
f
'Unsupported input type "
{
type
(
inputs
)
}
"'
)
def
_sample_prompts
(
self
,
prompts
)
->
list
:
indices
=
list
(
range
(
len
(
prompts
)))
sampled_indices
=
random
.
sample
(
indices
,
self
.
experience_batch_size
)
return
[
prompts
[
i
]
for
i
in
sampled_indices
]
def
_learn
(
self
):
# replay buffer may be empty at first, we should rebuild at each training
if
not
self
.
sample_replay_buffer
:
dataloader
=
self
.
strategy
.
setup_dataloader
(
self
.
replay_buffer
,
self
.
dataloader_pin_memory
)
device
=
torch
.
cuda
.
current_device
()
if
self
.
sample_replay_buffer
:
pbar
=
tqdm
(
range
(
self
.
max_epochs
),
desc
=
'Train epoch'
,
disable
=
not
is_rank_0
())
for
_
in
pbar
:
experience
=
self
.
replay_buffer
.
sample
()
metrics
=
self
.
training_step
(
experience
)
pbar
.
set_postfix
(
metrics
)
else
:
for
epoch
in
range
(
self
.
max_epochs
):
self
.
_on_learn_epoch_start
(
epoch
)
if
isinstance
(
dataloader
.
sampler
,
DistributedSampler
):
dataloader
.
sampler
.
set_epoch
(
epoch
)
pbar
=
tqdm
(
dataloader
,
desc
=
f
'Train epoch [
{
epoch
+
1
}
/
{
self
.
max_epochs
}
]'
,
disable
=
not
is_rank_0
())
for
experience
in
pbar
:
self
.
_on_learn_batch_start
()
experience
.
to_device
(
device
)
metrics
=
self
.
training_step
(
experience
)
self
.
_on_learn_batch_end
(
metrics
,
experience
)
pbar
.
set_postfix
(
metrics
)
self
.
_on_learn_epoch_end
(
epoch
)
def
fit
(
self
,
prompts
,
num_episodes
:
int
=
50000
,
max_timesteps
:
int
=
500
,
update_timesteps
:
int
=
5000
)
->
None
:
time
=
0
self
.
_on_fit_start
()
for
episode
in
range
(
num_episodes
):
self
.
_on_episode_start
(
episode
)
for
timestep
in
tqdm
(
range
(
max_timesteps
),
desc
=
f
'Episode [
{
episode
+
1
}
/
{
num_episodes
}
]'
,
disable
=
not
is_rank_0
()):
time
+=
1
rand_prompts
=
self
.
_sample_prompts
(
prompts
)
if
self
.
tokenizer
is
not
None
:
inputs
=
self
.
tokenizer
(
rand_prompts
)
else
:
inputs
=
rand_prompts
self
.
_on_make_experience_start
()
experience
=
self
.
_make_experience
(
inputs
)
self
.
_on_make_experience_end
(
experience
)
self
.
replay_buffer
.
append
(
experience
)
if
time
%
update_timesteps
==
0
:
self
.
_learn
()
self
.
replay_buffer
.
clear
()
self
.
_on_episode_end
(
episode
)
self
.
_on_fit_end
()
# TODO(ver217): maybe simplify these code using context
def
_on_fit_start
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_fit_start
()
def
_on_fit_end
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_fit_end
()
def
_on_episode_start
(
self
,
episode
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_episode_start
(
episode
)
def
_on_episode_end
(
self
,
episode
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_episode_end
(
episode
)
def
_on_make_experience_start
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_make_experience_start
()
def
_on_make_experience_end
(
self
,
experience
:
Experience
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_make_experience_end
(
experience
)
def
_on_learn_epoch_start
(
self
,
epoch
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_epoch_start
(
epoch
)
def
_on_learn_epoch_end
(
self
,
epoch
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_epoch_end
(
epoch
)
def
_on_learn_batch_start
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_batch_start
()
def
_on_learn_batch_end
(
self
,
metrics
:
dict
,
experience
:
Experience
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_batch_end
(
metrics
,
experience
)
applications/ChatGPT/chatgpt/trainer/callbacks/__init__.py
0 → 100644
View file @
1b347010
from
.base
import
Callback
from
.performance_evaluator
import
PerformanceEvaluator
__all__
=
[
'Callback'
,
'PerformanceEvaluator'
]
applications/ChatGPT/chatgpt/trainer/callbacks/base.py
0 → 100644
View file @
1b347010
from
abc
import
ABC
from
chatgpt.experience_maker
import
Experience
class
Callback
(
ABC
):
"""
Base callback class. It defines the interface for callbacks.
"""
def
on_fit_start
(
self
)
->
None
:
pass
def
on_fit_end
(
self
)
->
None
:
pass
def
on_episode_start
(
self
,
episode
:
int
)
->
None
:
pass
def
on_episode_end
(
self
,
episode
:
int
)
->
None
:
pass
def
on_make_experience_start
(
self
)
->
None
:
pass
def
on_make_experience_end
(
self
,
experience
:
Experience
)
->
None
:
pass
def
on_learn_epoch_start
(
self
,
epoch
:
int
)
->
None
:
pass
def
on_learn_epoch_end
(
self
,
epoch
:
int
)
->
None
:
pass
def
on_learn_batch_start
(
self
)
->
None
:
pass
def
on_learn_batch_end
(
self
,
metrics
:
dict
,
experience
:
Experience
)
->
None
:
pass
applications/ChatGPT/chatgpt/trainer/callbacks/performance_evaluator.py
0 → 100644
View file @
1b347010
from
time
import
time
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
from
chatgpt.experience_maker
import
Experience
from
.base
import
Callback
def
get_world_size
()
->
int
:
if
dist
.
is_initialized
():
return
dist
.
get_world_size
()
return
1
def
print_rank_0
(
*
args
,
**
kwargs
)
->
None
:
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
print
(
*
args
,
**
kwargs
)
@
torch
.
no_grad
()
def
all_reduce_mean
(
x
:
float
,
world_size
:
int
)
->
float
:
if
world_size
==
1
:
return
x
tensor
=
torch
.
tensor
([
x
],
device
=
torch
.
cuda
.
current_device
())
dist
.
all_reduce
(
tensor
)
tensor
=
tensor
/
world_size
return
tensor
.
item
()
class
PerformanceEvaluator
(
Callback
):
"""
Callback for valuate the performance of the model.
Args:
actor_num_params: The number of parameters of the actor model.
critic_num_params: The number of parameters of the critic model.
initial_model_num_params: The number of parameters of the initial model.
reward_model_num_params: The number of parameters of the reward model.
enable_grad_checkpoint: Whether to enable gradient checkpointing.
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def
__init__
(
self
,
actor_num_params
:
int
,
critic_num_params
:
int
,
initial_model_num_params
:
int
,
reward_model_num_params
:
int
,
enable_grad_checkpoint
:
bool
=
False
,
ignore_episodes
:
int
=
0
)
->
None
:
super
().
__init__
()
self
.
world_size
=
get_world_size
()
self
.
actor_num_params
=
actor_num_params
self
.
critic_num_params
=
critic_num_params
self
.
initial_model_num_params
=
initial_model_num_params
self
.
reward_model_num_params
=
reward_model_num_params
self
.
enable_grad_checkpoint
=
enable_grad_checkpoint
self
.
ignore_episodes
=
ignore_episodes
self
.
disable
:
bool
=
False
self
.
make_experience_duration
:
float
=
0.
self
.
make_experience_start_time
:
Optional
[
float
]
=
None
self
.
make_experience_num_samples
:
int
=
0
self
.
make_experience_flop
:
int
=
0
self
.
learn_duration
:
float
=
0.
self
.
learn_start_time
:
Optional
[
float
]
=
None
self
.
learn_num_samples
:
int
=
0
self
.
learn_flop
:
int
=
0
def
on_episode_start
(
self
,
episode
:
int
)
->
None
:
self
.
disable
=
self
.
ignore_episodes
>
0
and
episode
<
self
.
ignore_episodes
def
on_make_experience_start
(
self
)
->
None
:
if
self
.
disable
:
return
self
.
make_experience_start_time
=
time
()
def
on_make_experience_end
(
self
,
experience
:
Experience
)
->
None
:
if
self
.
disable
:
return
self
.
make_experience_duration
+=
time
()
-
self
.
make_experience_start_time
batch_size
,
seq_len
=
experience
.
sequences
.
shape
self
.
make_experience_num_samples
+=
batch_size
# actor generate
num_actions
=
experience
.
action_mask
.
size
(
1
)
input_len
=
seq_len
-
num_actions
total_seq_len
=
(
input_len
+
seq_len
-
1
)
*
num_actions
/
2
self
.
make_experience_flop
+=
self
.
actor_num_params
*
batch_size
*
total_seq_len
*
2
# actor forward
self
.
make_experience_flop
+=
self
.
actor_num_params
*
batch_size
*
seq_len
*
2
# critic forward
self
.
make_experience_flop
+=
self
.
critic_num_params
*
batch_size
*
seq_len
*
2
# initial model forward
self
.
make_experience_flop
+=
self
.
initial_model_num_params
*
batch_size
*
seq_len
*
2
# reward model forward
self
.
make_experience_flop
+=
self
.
reward_model_num_params
*
batch_size
*
seq_len
*
2
def
on_learn_batch_start
(
self
)
->
None
:
if
self
.
disable
:
return
self
.
learn_start_time
=
time
()
def
on_learn_batch_end
(
self
,
metrics
:
dict
,
experience
:
Experience
)
->
None
:
if
self
.
disable
:
return
self
.
learn_duration
+=
time
()
-
self
.
learn_start_time
batch_size
,
seq_len
=
experience
.
sequences
.
shape
self
.
learn_num_samples
+=
batch_size
# actor forward-backward, 3 means forward(1) + backward(2)
self
.
learn_flop
+=
self
.
actor_num_params
*
batch_size
*
seq_len
*
2
*
(
3
+
int
(
self
.
enable_grad_checkpoint
))
# critic foward-backward
self
.
learn_flop
+=
self
.
critic_num_params
*
batch_size
*
seq_len
*
2
*
(
3
+
int
(
self
.
enable_grad_checkpoint
))
def
on_fit_end
(
self
)
->
None
:
avg_make_experience_duration
=
all_reduce_mean
(
self
.
make_experience_duration
,
self
.
world_size
)
avg_learn_duration
=
all_reduce_mean
(
self
.
learn_duration
,
self
.
world_size
)
avg_make_experience_throughput
=
self
.
make_experience_num_samples
/
(
avg_make_experience_duration
+
1e-12
)
avg_make_experience_tflops
=
self
.
make_experience_flop
/
1e12
/
(
avg_make_experience_duration
+
1e-12
)
avg_learn_throughput
=
self
.
learn_num_samples
/
(
avg_learn_duration
+
1e-12
)
avg_learn_tflops
=
self
.
learn_flop
/
1e12
/
(
avg_learn_duration
+
1e-12
)
print_rank_0
(
f
'Making experience throughput:
{
avg_make_experience_throughput
:.
3
f
}
samples/sec, TFLOPS:
{
avg_make_experience_tflops
:.
3
f
}
'
)
print_rank_0
(
f
'Learning throughput:
{
avg_learn_throughput
:.
3
f
}
samples/sec, TFLOPS:
{
avg_learn_tflops
:.
3
f
}
'
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment