Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9e768b59
Commit
9e768b59
authored
Oct 10, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
parents
7bc5a8e3
8aed02b9
Changes
436
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1016 additions
and
839 deletions
+1016
-839
applications/Chat/coati/ray/utils.py
applications/Chat/coati/ray/utils.py
+140
-0
applications/Chat/coati/replay_buffer/__init__.py
applications/Chat/coati/replay_buffer/__init__.py
+0
-4
applications/Chat/coati/trainer/__init__.py
applications/Chat/coati/trainer/__init__.py
+2
-2
applications/Chat/coati/trainer/base.py
applications/Chat/coati/trainer/base.py
+140
-27
applications/Chat/coati/trainer/callbacks/__init__.py
applications/Chat/coati/trainer/callbacks/__init__.py
+1
-1
applications/Chat/coati/trainer/callbacks/base.py
applications/Chat/coati/trainer/callbacks/base.py
+2
-2
applications/Chat/coati/trainer/callbacks/performance_evaluator.py
...ons/Chat/coati/trainer/callbacks/performance_evaluator.py
+25
-25
applications/Chat/coati/trainer/callbacks/save_checkpoint.py
applications/Chat/coati/trainer/callbacks/save_checkpoint.py
+17
-16
applications/Chat/coati/trainer/ppo.py
applications/Chat/coati/trainer/ppo.py
+138
-149
applications/Chat/coati/trainer/rm.py
applications/Chat/coati/trainer/rm.py
+96
-96
applications/Chat/coati/trainer/sft.py
applications/Chat/coati/trainer/sft.py
+100
-105
applications/Chat/coati/trainer/strategies/__init__.py
applications/Chat/coati/trainer/strategies/__init__.py
+2
-3
applications/Chat/coati/trainer/strategies/base.py
applications/Chat/coati/trainer/strategies/base.py
+67
-62
applications/Chat/coati/trainer/strategies/colossalai.py
applications/Chat/coati/trainer/strategies/colossalai.py
+150
-138
applications/Chat/coati/trainer/strategies/ddp.py
applications/Chat/coati/trainer/strategies/ddp.py
+101
-58
applications/Chat/coati/trainer/strategies/naive.py
applications/Chat/coati/trainer/strategies/naive.py
+0
-70
applications/Chat/coati/trainer/strategies/sampler.py
applications/Chat/coati/trainer/strategies/sampler.py
+3
-4
applications/Chat/coati/trainer/utils.py
applications/Chat/coati/trainer/utils.py
+32
-1
applications/Chat/coati/utils/__init__.py
applications/Chat/coati/utils/__init__.py
+0
-3
applications/Chat/coati/utils/tokenizer_utils.py
applications/Chat/coati/utils/tokenizer_utils.py
+0
-73
No files found.
Too many changes to show.
To preserve performance only
436 of 436+
files are displayed.
Plain diff
Email patch
applications/Chat/coati/ray/utils.py
0 → 100644
View file @
9e768b59
import
os
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
coati.models.bloom
import
BLOOMRM
,
BLOOMActor
,
BLOOMCritic
from
coati.models.gpt
import
GPTRM
,
GPTActor
,
GPTCritic
from
coati.models.llama
import
LlamaActor
,
LlamaCritic
,
LlamaRM
from
coati.models.opt
import
OPTRM
,
OPTActor
,
OPTCritic
from
coati.trainer.strategies
import
DDPStrategy
,
GeminiStrategy
,
LowLevelZeroStrategy
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
GPT2Tokenizer
def
is_rank_0
()
->
bool
:
return
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
def
get_rank
()
->
int
:
return
dist
.
get_rank
()
if
dist
.
is_initialized
()
else
0
def
get_world_size
()
->
int
:
return
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
1
def
get_actor_from_args
(
model
:
str
,
pretrained
:
str
=
None
,
config
=
None
,
lora_rank
=
0
):
if
model
==
"gpt2"
:
actor
=
GPTActor
(
pretrained
=
pretrained
,
config
=
config
,
lora_rank
=
lora_rank
)
elif
model
==
"bloom"
:
actor
=
BLOOMActor
(
pretrained
=
pretrained
,
config
=
config
,
lora_rank
=
lora_rank
)
elif
model
==
"opt"
:
actor
=
OPTActor
(
pretrained
=
pretrained
,
config
=
config
,
lora_rank
=
lora_rank
)
elif
model
==
"llama"
:
actor
=
LlamaActor
(
pretrained
=
pretrained
,
config
=
config
,
lora_rank
=
lora_rank
)
else
:
raise
ValueError
(
f
'Unsupported actor model "
{
model
}
"'
)
return
actor
def
get_critic_from_args
(
model
:
str
,
pretrained
:
str
=
None
,
config
=
None
,
lora_rank
=
0
):
if
model
==
"gpt2"
:
critic
=
GPTCritic
(
pretrained
=
pretrained
,
lora_rank
=
lora_rank
,
config
=
config
)
elif
model
==
"bloom"
:
critic
=
BLOOMCritic
(
pretrained
=
pretrained
,
lora_rank
=
lora_rank
,
config
=
config
)
elif
model
==
"opt"
:
critic
=
OPTCritic
(
pretrained
=
pretrained
,
lora_rank
=
lora_rank
,
config
=
config
)
elif
model
==
"llama"
:
critic
=
LlamaCritic
(
pretrained
=
pretrained
,
lora_rank
=
lora_rank
,
config
=
config
)
else
:
raise
ValueError
(
f
'Unsupported reward model "
{
model
}
"'
)
return
critic
def
get_reward_model_from_args
(
model
:
str
,
pretrained
:
str
=
None
,
config
=
None
):
if
model
==
"gpt2"
:
reward_model
=
GPTRM
(
pretrained
=
pretrained
,
config
=
config
)
elif
model
==
"bloom"
:
reward_model
=
BLOOMRM
(
pretrained
=
pretrained
,
config
=
config
)
elif
model
==
"opt"
:
reward_model
=
OPTRM
(
pretrained
=
pretrained
,
config
=
config
)
elif
model
==
"llama"
:
reward_model
=
LlamaRM
(
pretrained
=
pretrained
,
config
=
config
)
else
:
raise
ValueError
(
f
'Unsupported reward model "
{
model
}
"'
)
return
reward_model
def
get_strategy_from_args
(
strategy
:
str
):
if
strategy
==
"ddp"
:
strategy_
=
DDPStrategy
()
elif
strategy
==
"colossalai_gemini"
:
strategy_
=
GeminiStrategy
(
placement_policy
=
"static"
,
initial_scale
=
2
**
5
)
elif
strategy
==
"colossalai_zero2"
:
strategy_
=
LowLevelZeroStrategy
(
stage
=
2
,
placement_policy
=
"cuda"
)
elif
strategy
==
"colossalai_gemini_cpu"
:
strategy_
=
GeminiStrategy
(
placement_policy
=
"static"
,
offload_optim_frac
=
1.0
,
offload_param_frac
=
1.0
,
initial_scale
=
2
**
5
)
elif
strategy
==
"colossalai_zero2_cpu"
:
strategy_
=
LowLevelZeroStrategy
(
stage
=
2
,
placement_policy
=
"cpu"
)
else
:
raise
ValueError
(
f
'Unsupported strategy "
{
strategy
}
"'
)
return
strategy_
def
get_tokenizer_from_args
(
model
:
str
,
**
kwargs
):
if
model
==
"gpt2"
:
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
"gpt2"
)
elif
model
==
"bloom"
:
tokenizer
=
BloomTokenizerFast
.
from_pretrained
(
"bigscience/bloom-560m"
)
elif
model
==
"opt"
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/opt-350m"
)
elif
model
==
"llama"
:
pretrain_path
=
kwargs
[
"pretrain"
]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
pretrain_path
)
else
:
raise
ValueError
(
f
'Unsupported model "
{
model
}
"'
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
return
tokenizer
def
set_dist_env
(
env_info
:
Dict
[
str
,
str
]):
os
.
environ
[
"RANK"
]
=
env_info
[
"rank"
]
os
.
environ
[
"LOCAL_RANK"
]
=
env_info
[
"local_rank"
]
os
.
environ
[
"WORLD_SIZE"
]
=
env_info
[
"world_size"
]
os
.
environ
[
"MASTER_PORT"
]
=
env_info
[
"master_port"
]
os
.
environ
[
"MASTER_ADDR"
]
=
env_info
[
"master_addr"
]
def
get_model_numel
(
model
:
nn
.
Module
)
->
int
:
numel
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
return
numel
def
get_receivers_per_sender
(
sender_idx
:
int
,
num_senders
:
int
,
num_receivers
:
int
,
allow_idle_sender
:
bool
)
->
list
:
target_receivers
=
[]
if
num_senders
<=
num_receivers
or
allow_idle_sender
:
# a sender will send data to one or more receivers
# a receiver only has one sender
for
i
in
range
(
num_receivers
):
if
i
%
num_senders
==
sender_idx
:
target_receivers
.
append
(
i
)
else
:
# a sender will send data to one receiver
# a receiver may have more than one sender
target_receivers
.
append
(
sender_idx
%
num_receivers
)
return
target_receivers
def
state_dict_to
(
state_dict
:
Dict
[
str
,
Any
],
dtype
:
torch
.
dtype
=
torch
.
float16
,
device
:
torch
.
device
=
torch
.
device
(
"cpu"
)
):
"""
keep state_dict intact
"""
new_state_dict
=
OrderedDict
()
for
k
,
v
in
state_dict
.
items
():
new_state_dict
[
k
]
=
v
.
to
(
dtype
=
dtype
,
device
=
device
)
return
new_state_dict
applications/Chat/coati/replay_buffer/__init__.py
deleted
100644 → 0
View file @
7bc5a8e3
from
.base
import
ReplayBuffer
from
.naive
import
NaiveReplayBuffer
__all__
=
[
'ReplayBuffer'
,
'NaiveReplayBuffer'
]
applications/Chat/coati/trainer/__init__.py
View file @
9e768b59
from
.base
import
Trainer
from
.base
import
OnPolicyTrainer
,
SL
Trainer
from
.ppo
import
PPOTrainer
from
.rm
import
RewardModelTrainer
from
.sft
import
SFTTrainer
__all__
=
[
'
Trainer
'
,
'PPO
Trainer
'
,
'
RewardModelTrainer
'
,
'
SFTTrainer
'
]
__all__
=
[
"SL
Trainer
"
,
"OnPolicy
Trainer
"
,
"
RewardModelTrainer
"
,
"
SFTTrainer
"
,
"PPOTrainer"
]
applications/Chat/coati/trainer/base.py
View file @
9e768b59
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
contextlib
import
contextmanager
from
typing
import
List
import
torch
import
torch.nn
as
nn
import
tqdm
from
coati.experience_buffer
import
NaiveExperienceBuffer
from
coati.experience_maker
import
Experience
from
torch.optim
import
Optimizer
from
.callbacks
import
Callback
from
.strategies
import
Strategy
from
.utils
import
is_rank_0
class
Trainer
(
ABC
):
class
SL
Trainer
(
ABC
):
"""
Base class for
rlhf
trainers.
Base class for
supervised learning
trainers.
Args:
strategy (Strategy):the strategy to use for training
max_epochs (int, defaults to 1): the number of epochs of training process
model (nn.Module): the model to train
optim (Optimizer): the optimizer to use for training
"""
def
__init__
(
self
,
strategy
:
Strategy
,
max_epochs
:
int
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
)
->
None
:
super
().
__init__
()
self
.
strategy
=
strategy
self
.
max_epochs
=
max_epochs
self
.
model
=
model
self
.
optimizer
=
optimizer
@
abstractmethod
def
_train
(
self
,
epoch
):
raise
NotImplementedError
()
@
abstractmethod
def
_eval
(
self
,
epoch
):
raise
NotImplementedError
()
def
_before_fit
(
self
):
raise
NotImplementedError
()
def
fit
(
self
,
*
args
,
**
kwargs
):
self
.
_before_fit
(
*
args
,
**
kwargs
)
for
epoch
in
tqdm
.
trange
(
self
.
max_epochs
,
desc
=
"Epochs"
,
disable
=
not
is_rank_0
()):
self
.
_train
(
epoch
)
self
.
_eval
(
epoch
)
class
OnPolicyTrainer
(
ABC
):
"""
Base class for on-policy rl trainers, e.g. PPO.
Args:
strategy (Strategy):the strategy to use for training
data_buffer (NaiveExperienceBuffer): the buffer to collect experiences
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def
__init__
(
self
,
strategy
:
Strategy
,
max_epochs
:
int
=
1
,
dataloader_pin_memory
:
bool
=
True
,
callbacks
:
List
[
Callback
]
=
[],
**
generate_kwargs
)
->
None
:
def
__init__
(
self
,
strategy
:
Strategy
,
data_buffer
:
NaiveExperienceBuffer
,
sample_buffer
:
bool
,
dataloader_pin_memory
:
bool
,
callbacks
:
List
[
Callback
]
=
[],
)
->
None
:
super
().
__init__
()
self
.
strategy
=
strategy
self
.
max_epochs
=
max_epochs
self
.
generate_kwargs
=
generate_kwargs
self
.
data_buffer
=
data_buffer
self
.
sample_buffer
=
sample_buffer
self
.
dataloader_pin_memory
=
dataloader_pin_memory
self
.
callbacks
=
callbacks
# TODO(ver217): maybe simplify these code using context
def
_on
_fit_
start
(
self
)
->
None
:
@
contextmanager
def
_fit_
ctx
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_fit_start
()
def
_on_fit_end
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_fit_end
()
def
_on_episode_start
(
self
,
episode
:
int
)
->
None
:
try
:
yield
finally
:
for
callback
in
self
.
callbacks
:
callback
.
on_fit_end
()
@
contextmanager
def
_episode_ctx
(
self
,
episode
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_episode_start
(
episode
)
def
_on_episode_end
(
self
,
episode
:
int
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_episode_end
(
episode
)
try
:
yield
finally
:
for
callback
in
self
.
callbacks
:
callback
.
on_episode_end
(
episode
)
def
_on_make_experience_start
(
self
)
->
None
:
for
callback
in
self
.
callbacks
:
...
...
@@ -70,6 +122,67 @@ class Trainer(ABC):
for
callback
in
self
.
callbacks
:
callback
.
on_learn_batch_start
()
def
_on_learn_batch_end
(
self
,
metrics
:
dict
,
experience
:
Experience
)
->
None
:
def
_on_learn_batch_end
(
self
,
experience
:
Experience
)
->
None
:
for
callback
in
self
.
callbacks
:
callback
.
on_learn_batch_end
(
metrics
,
experience
)
callback
.
on_learn_batch_end
(
experience
)
@
abstractmethod
def
_make_experience
(
self
,
collect_step
:
int
):
"""
Implement this method to make experience.
"""
raise
NotImplementedError
()
@
abstractmethod
def
_learn
(
self
,
update_step
:
int
):
"""
Implement this method to learn from experience, either
sample from buffer or transform buffer into dataloader.
"""
raise
NotImplementedError
()
def
_collect_phase
(
self
,
collect_step
:
int
):
self
.
_on_make_experience_start
()
experience
=
self
.
_make_experience
(
collect_step
)
self
.
_on_make_experience_end
(
experience
)
self
.
data_buffer
.
append
(
experience
)
def
_update_phase
(
self
,
update_step
:
int
):
self
.
_on_learn_epoch_start
(
update_step
)
self
.
_learn
(
update_step
)
self
.
_on_learn_epoch_end
(
update_step
)
def
_before_fit
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
()
def
fit
(
self
,
num_episodes
:
int
,
num_collect_steps
:
int
,
num_update_steps
:
int
,
*
args
,
**
kwargs
,
):
"""
The main training loop of on-policy rl trainers.
Args:
num_episodes (int): the number of episodes to train
num_collect_steps (int): the number of collect steps per episode
num_update_steps (int): the number of update steps per episode
"""
self
.
_before_fit
(
*
args
,
**
kwargs
)
with
self
.
_fit_ctx
():
for
episode
in
tqdm
.
trange
(
num_episodes
,
desc
=
"Episodes"
,
disable
=
not
is_rank_0
()):
with
self
.
_episode_ctx
(
episode
):
for
collect_step
in
tqdm
.
trange
(
num_collect_steps
,
desc
=
"Collect steps"
,
disable
=
not
is_rank_0
()):
self
.
_collect_phase
(
collect_step
)
if
not
self
.
sample_buffer
:
# HACK(cwher): according to the design of boost API, dataloader should also be boosted,
# but it is impractical to adapt this pattern in RL training. Thus, I left dataloader unboosted.
# I only call strategy.setup_dataloader() to setup dataloader.
self
.
dataloader
=
self
.
strategy
.
setup_dataloader
(
self
.
data_buffer
,
self
.
dataloader_pin_memory
)
for
update_step
in
tqdm
.
trange
(
num_update_steps
,
desc
=
"Update steps"
,
disable
=
not
is_rank_0
()):
self
.
_update_phase
(
update_step
)
# NOTE: this is for on-policy algorithms
self
.
data_buffer
.
clear
()
applications/Chat/coati/trainer/callbacks/__init__.py
View file @
9e768b59
...
...
@@ -2,4 +2,4 @@ from .base import Callback
from
.performance_evaluator
import
PerformanceEvaluator
from
.save_checkpoint
import
SaveCheckpoint
__all__
=
[
'
Callback
'
,
'
PerformanceEvaluator
'
,
'
SaveCheckpoint
'
]
__all__
=
[
"
Callback
"
,
"
PerformanceEvaluator
"
,
"
SaveCheckpoint
"
]
applications/Chat/coati/trainer/callbacks/base.py
View file @
9e768b59
...
...
@@ -5,7 +5,7 @@ from coati.experience_maker import Experience
class
Callback
(
ABC
):
"""
Base callback class. It defines the interface for callbacks.
Base callback class. It defines the interface for callbacks.
"""
def
on_fit_start
(
self
)
->
None
:
...
...
@@ -35,5 +35,5 @@ class Callback(ABC):
def
on_learn_batch_start
(
self
)
->
None
:
pass
def
on_learn_batch_end
(
self
,
metrics
:
dict
,
experience
:
Experience
)
->
None
:
def
on_learn_batch_end
(
self
,
experience
:
Experience
)
->
None
:
pass
applications/Chat/coati/trainer/callbacks/performance_evaluator.py
View file @
9e768b59
...
...
@@ -21,9 +21,9 @@ def print_rank_0(*args, **kwargs) -> None:
def
divide
(
x
:
float
,
y
:
float
)
->
float
:
if
y
==
0
:
return
float
(
'
inf
'
)
elif
y
==
float
(
'
inf
'
):
return
float
(
'
nan
'
)
return
float
(
"
inf
"
)
elif
y
==
float
(
"
inf
"
):
return
float
(
"
nan
"
)
return
x
/
y
...
...
@@ -38,10 +38,9 @@ def all_reduce_mean(x: float, world_size: int) -> float:
class
Timer
:
def
__init__
(
self
)
->
None
:
self
.
start_time
:
Optional
[
float
]
=
None
self
.
duration
:
float
=
0.
self
.
duration
:
float
=
0.
0
def
start
(
self
)
->
None
:
self
.
start_time
=
time
()
...
...
@@ -52,7 +51,7 @@ class Timer:
self
.
start_time
=
None
def
reset
(
self
)
->
None
:
self
.
duration
=
0.
self
.
duration
=
0.
0
class
PerformanceEvaluator
(
Callback
):
...
...
@@ -67,13 +66,15 @@ class PerformanceEvaluator(Callback):
ignore_episodes: The number of episodes to ignore when calculating the performance.
"""
def
__init__
(
self
,
actor_num_params
:
int
,
critic_num_params
:
int
,
initial_model_num_params
:
int
,
reward_model_num_params
:
int
,
enable_grad_checkpoint
:
bool
=
False
,
ignore_episodes
:
int
=
0
)
->
None
:
def
__init__
(
self
,
actor_num_params
:
int
,
critic_num_params
:
int
,
initial_model_num_params
:
int
,
reward_model_num_params
:
int
,
enable_grad_checkpoint
:
bool
=
False
,
ignore_episodes
:
int
=
0
,
)
->
None
:
super
().
__init__
()
self
.
world_size
=
get_world_size
()
self
.
actor_num_params
=
actor_num_params
...
...
@@ -136,7 +137,7 @@ class PerformanceEvaluator(Callback):
return
self
.
learn_timer
.
start
()
def
on_learn_batch_end
(
self
,
metrics
:
dict
,
experience
:
Experience
)
->
None
:
def
on_learn_batch_end
(
self
,
experience
:
Experience
)
->
None
:
if
self
.
disable
:
return
self
.
learn_timer
.
end
()
...
...
@@ -155,8 +156,9 @@ class PerformanceEvaluator(Callback):
avg_learn_duration
=
all_reduce_mean
(
self
.
learn_timer
.
duration
,
self
.
world_size
)
avg_overall_duration
=
all_reduce_mean
(
self
.
overall_timer
.
duration
,
self
.
world_size
)
avg_make_experience_throughput
=
self
.
make_experience_num_samples
*
\
self
.
world_size
/
(
avg_make_experience_duration
+
1e-12
)
avg_make_experience_throughput
=
(
self
.
make_experience_num_samples
*
self
.
world_size
/
(
avg_make_experience_duration
+
1e-12
)
)
avg_make_experience_tflops
=
self
.
make_experience_flop
/
1e12
/
(
avg_make_experience_duration
+
1e-12
)
avg_learn_throughput
=
self
.
learn_num_samples
*
self
.
world_size
/
(
avg_learn_duration
+
1e-12
)
...
...
@@ -171,13 +173,11 @@ class PerformanceEvaluator(Callback):
learn_time_per_sample
=
divide
(
avg_learn_duration
,
num_effective_samples
)
print_rank_0
(
f
'Performance summary:
\n
'
+
f
'Generate
{
self
.
make_experience_num_samples
*
self
.
world_size
}
samples, throughput:
{
avg_make_experience_throughput
:.
2
f
}
samples/s, TFLOPS per GPU:
{
avg_make_experience_tflops
:.
2
f
}
\n
'
+
f
'Train
{
self
.
learn_num_samples
*
self
.
world_size
}
samples, throughput:
{
avg_learn_throughput
:.
2
f
}
samples/s, TFLOPS per GPU:
{
avg_learn_tflops
:.
2
f
}
\n
'
+
f
'Overall throughput:
{
avg_overall_throughput
:.
2
f
}
samples/s
\n
'
+
f
'Overall time per sample:
{
overall_time_per_sample
:.
2
f
}
s
\n
'
+
f
'Make experience time per sample:
{
make_experience_time_per_sample
:.
2
f
}
s,
{
make_experience_time_per_sample
/
overall_time_per_sample
*
100
:.
2
f
}
%
\n
'
+
f
'Learn time per sample:
{
learn_time_per_sample
:.
2
f
}
s,
{
learn_time_per_sample
/
overall_time_per_sample
*
100
:.
2
f
}
%'
f
"Performance summary:
\n
"
+
f
"Generate
{
self
.
make_experience_num_samples
*
self
.
world_size
}
samples, throughput:
{
avg_make_experience_throughput
:.
2
f
}
samples/s, TFLOPS per GPU:
{
avg_make_experience_tflops
:.
2
f
}
\n
"
+
f
"Train
{
self
.
learn_num_samples
*
self
.
world_size
}
samples, throughput:
{
avg_learn_throughput
:.
2
f
}
samples/s, TFLOPS per GPU:
{
avg_learn_tflops
:.
2
f
}
\n
"
+
f
"Overall throughput:
{
avg_overall_throughput
:.
2
f
}
samples/s
\n
"
+
f
"Overall time per sample:
{
overall_time_per_sample
:.
2
f
}
s
\n
"
+
f
"Make experience time per sample:
{
make_experience_time_per_sample
:.
2
f
}
s,
{
make_experience_time_per_sample
/
overall_time_per_sample
*
100
:.
2
f
}
%
\n
"
+
f
"Learn time per sample:
{
learn_time_per_sample
:.
2
f
}
s,
{
learn_time_per_sample
/
overall_time_per_sample
*
100
:.
2
f
}
%"
)
applications/Chat/coati/trainer/callbacks/save_checkpoint.py
View file @
9e768b59
import
os
import
torch.distributed
as
dist
from
coati.trainer.strategies
import
ColossalAI
Strategy
,
Strategy
from
coati.trainer.strategies
import
GeminiStrategy
,
LowLevelZero
Strategy
,
Strategy
from
coati.trainer.utils
import
is_rank_0
from
torch
import
nn
from
torch.optim
import
Optimizer
...
...
@@ -36,40 +36,41 @@ class SaveCheckpoint(Callback):
"""
def
__init__
(
self
,
path
:
str
,
interval
:
int
,
strategy
:
Strategy
,
actor
:
nn
.
Module
=
None
,
critic
:
nn
.
Module
=
None
,
actor_optim
:
Optimizer
=
None
,
critic_optim
:
Optimizer
=
None
)
->
None
:
def
__init__
(
self
,
path
:
str
,
interval
:
int
,
strategy
:
Strategy
,
actor
:
nn
.
Module
=
None
,
critic
:
nn
.
Module
=
None
,
actor_optim
:
Optimizer
=
None
,
critic_optim
:
Optimizer
=
None
,
)
->
None
:
super
().
__init__
()
self
.
path
=
os
.
path
.
join
(
path
,
'
checkpoint
'
)
self
.
path
=
os
.
path
.
join
(
path
,
"
checkpoint
"
)
self
.
interval
=
interval
self
.
strategy
=
strategy
self
.
model_dict
=
{
'
actor
'
:
[
actor
,
actor_optim
],
'
critic
'
:
[
critic
,
critic_optim
]}
self
.
model_dict
=
{
"
actor
"
:
[
actor
,
actor_optim
],
"
critic
"
:
[
critic
,
critic_optim
]}
def
on_episode_end
(
self
,
episode
:
int
)
->
None
:
if
(
episode
+
1
)
%
self
.
interval
!=
0
:
return
base_path
=
os
.
path
.
join
(
self
.
path
,
f
'
episode_
{
episode
}
'
)
base_path
=
os
.
path
.
join
(
self
.
path
,
f
"
episode_
{
episode
}
"
)
if
not
os
.
path
.
exists
(
base_path
):
os
.
makedirs
(
base_path
)
for
model
in
self
.
model_dict
.
keys
():
# save model
if
self
.
model_dict
[
model
][
0
]
is
None
:
# saving only optimizer states is meaningless, so it would be skipped
continue
model_path
=
os
.
path
.
join
(
base_path
,
f
'
{
model
}
.pt
'
)
model_path
=
os
.
path
.
join
(
base_path
,
f
"
{
model
}
.pt
"
)
self
.
strategy
.
save_model
(
model
=
self
.
model_dict
[
model
][
0
],
path
=
model_path
,
only_rank0
=
True
)
# save optimizer
if
self
.
model_dict
[
model
][
1
]
is
None
:
continue
only_rank0
=
not
isinstance
(
self
.
strategy
,
ColossalAI
Strategy
)
only_rank0
=
not
isinstance
(
self
.
strategy
,
(
LowLevelZeroStrategy
,
Gemini
Strategy
)
)
rank
=
0
if
is_rank_0
()
else
dist
.
get_rank
()
optim_path
=
os
.
path
.
join
(
base_path
,
f
'
{
model
}
-optim-rank-
{
rank
}
.pt
'
)
optim_path
=
os
.
path
.
join
(
base_path
,
f
"
{
model
}
-optim-rank-
{
rank
}
.pt
"
)
self
.
strategy
.
save_optimizer
(
optimizer
=
self
.
model_dict
[
model
][
1
],
path
=
optim_path
,
only_rank0
=
only_rank0
)
applications/Chat/coati/trainer/ppo.py
View file @
9e768b59
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch.nn
as
nn
from
coati.experience_buffer
import
NaiveExperienceBuffer
from
coati.experience_maker
import
Experience
,
NaiveExperienceMaker
from
coati.models.base
import
Actor
,
Critic
from
coati.models.base
import
Actor
,
Critic
,
RewardModel
,
get_base_model
from
coati.models.loss
import
GPTLMLoss
,
PolicyLoss
,
ValueLoss
from
coati.replay_buffer
import
NaiveReplayBuffer
from
torch
import
Tensor
from
coati.models.utils
import
calc_action_log_probs
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DistributedSampler
from
torch.utils.data
import
DataLoader
,
DistributedSampler
from
tqdm
import
tqdm
from
transformers
.tokenization_utils_base
import
PreTrainedTokenizerBase
from
transformers
import
PreTrainedTokenizerBase
from
colossalai.utils
import
get_current_device
from
.base
import
Trainer
from
.base
import
OnPolicy
Trainer
from
.callbacks
import
Callback
from
.strategies
import
Strategy
from
.utils
import
is_rank_0
,
to_device
from
.strategies
import
GeminiStrategy
,
Strategy
from
.utils
import
CycledDataLoader
,
is_rank_0
,
to_device
class
PPOTrainer
(
Trainer
):
def
_set_default_generate_kwargs
(
strategy
:
Strategy
,
generate_kwargs
:
dict
,
actor
:
Actor
)
->
Dict
:
unwrapped_model
=
strategy
.
unwrap_model
(
actor
)
hf_model
=
get_base_model
(
unwrapped_model
)
new_kwargs
=
{
**
generate_kwargs
}
# use huggingface models method directly
if
"prepare_inputs_fn"
not
in
generate_kwargs
and
hasattr
(
hf_model
,
"prepare_inputs_for_generation"
):
new_kwargs
[
"prepare_inputs_fn"
]
=
hf_model
.
prepare_inputs_for_generation
if
"update_model_kwargs_fn"
not
in
generate_kwargs
and
hasattr
(
hf_model
,
"_update_model_kwargs_for_generation"
):
new_kwargs
[
"update_model_kwargs_fn"
]
=
hf_model
.
_update_model_kwargs_for_generation
return
new_kwargs
class
PPOTrainer
(
OnPolicyTrainer
):
"""
Trainer for PPO algorithm.
...
...
@@ -28,60 +40,61 @@ class PPOTrainer(Trainer):
strategy (Strategy): the strategy to use for training
actor (Actor): the actor model in ppo algorithm
critic (Critic): the critic model in ppo algorithm
reward_model (
nn.Module
): the reward model in rlhf algorithm to make reward of sentences
reward_model (
RewardModel
): the reward model in rlhf algorithm to make reward of sentences
initial_model (Actor): the initial model in rlhf algorithm to generate reference logics to limit the update of actor
actor_optim (Optimizer): the optimizer to use for actor model
critic_optim (Optimizer): the optimizer to use for critic model
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
train_batch_size (int, defaults to 8): the batch size to use for training
buffer_limit (int, defaults to 0): the max_size limitation of
replay
buffer
buffer_cpu_offload (bool, defaults to True): whether to offload
replay
buffer to cpu
buffer_limit (int, defaults to 0): the max_size limitation of buffer
buffer_cpu_offload (bool, defaults to True): whether to offload buffer to cpu
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
vf_coef (float, defaults to 1.0): the coefficient of value loss
ptx_coef (float, defaults to 0.9): the coefficient of ptx loss
value_clip (float, defaults to 0.4): the clip coefficient of value loss
max_epochs (int, defaults to 1): the number of epochs of training process
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
sample_buffer (bool, defaults to False): whether to sample from buffer
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
offload_inference_models (bool, defaults to True): whether to offload inference models to cpu during training process
callbacks (List[Callback], defaults to []): the callbacks to call during training process
generate_kwargs (dict, optional): the kwargs to use while model generating
"""
def
__init__
(
self
,
strategy
:
Strategy
,
actor
:
Actor
,
critic
:
Critic
,
reward_model
:
nn
.
Module
,
initial_model
:
Actor
,
actor_optim
:
Optimizer
,
critic_optim
:
Optimizer
,
kl_coef
:
float
=
0.1
,
ptx_coef
:
float
=
0.9
,
train_batch_size
:
int
=
8
,
buffer_limit
:
int
=
0
,
buffer_cpu_offload
:
bool
=
True
,
eps_clip
:
float
=
0.2
,
vf_coef
:
float
=
1.0
,
value_clip
:
float
=
0.4
,
max_epochs
:
int
=
1
,
sample_replay_buffer
:
bool
=
False
,
dataloader_pin_memory
:
bool
=
True
,
offload_inference_models
:
bool
=
True
,
callbacks
:
List
[
Callback
]
=
[],
**
generate_kwargs
)
->
None
:
experience_maker
=
NaiveExperienceMaker
(
actor
,
critic
,
reward_model
,
initial_model
,
kl_coef
)
replay_buffer
=
NaiveReplayBuffer
(
train_batch_size
,
buffer_limit
,
buffer_cpu_offload
)
generate_kwargs
=
_set_default_generate_kwargs
(
strategy
,
generate_kwargs
,
actor
)
super
().
__init__
(
strategy
,
max_epochs
,
dataloader_pin_memory
,
callbacks
,
**
generate_kwargs
)
self
.
experience_maker
=
experience_maker
self
.
replay_buffer
=
replay_buffer
self
.
sample_replay_buffer
=
sample_replay_buffer
self
.
offload_inference_models
=
offload_inference_models
def
__init__
(
self
,
strategy
:
Strategy
,
actor
:
Actor
,
critic
:
Critic
,
reward_model
:
RewardModel
,
initial_model
:
Actor
,
actor_optim
:
Optimizer
,
critic_optim
:
Optimizer
,
tokenizer
:
PreTrainedTokenizerBase
,
kl_coef
:
float
=
0.1
,
ptx_coef
:
float
=
0.9
,
train_batch_size
:
int
=
8
,
buffer_limit
:
int
=
0
,
buffer_cpu_offload
:
bool
=
True
,
eps_clip
:
float
=
0.2
,
vf_coef
:
float
=
1.0
,
value_clip
:
float
=
0.4
,
sample_buffer
:
bool
=
False
,
dataloader_pin_memory
:
bool
=
True
,
offload_inference_models
:
bool
=
True
,
callbacks
:
List
[
Callback
]
=
[],
**
generate_kwargs
,
)
->
None
:
if
isinstance
(
strategy
,
GeminiStrategy
):
assert
not
offload_inference_models
,
"GeminiPlugin is not compatible with manual model.to('cpu')"
data_buffer
=
NaiveExperienceBuffer
(
train_batch_size
,
buffer_limit
,
buffer_cpu_offload
)
super
().
__init__
(
strategy
,
data_buffer
,
sample_buffer
,
dataloader_pin_memory
,
callbacks
)
self
.
generate_kwargs
=
_set_default_generate_kwargs
(
strategy
,
generate_kwargs
,
actor
)
self
.
experience_maker
=
NaiveExperienceMaker
(
actor
,
critic
,
reward_model
,
initial_model
,
tokenizer
,
kl_coef
)
self
.
actor
=
actor
self
.
critic
=
critic
self
.
tokenizer
=
tokenizer
self
.
actor_loss_fn
=
PolicyLoss
(
eps_clip
)
self
.
critic_loss_fn
=
ValueLoss
(
value_clip
)
...
...
@@ -91,123 +104,99 @@ class PPOTrainer(Trainer):
self
.
actor_optim
=
actor_optim
self
.
critic_optim
=
critic_optim
self
.
offload_inference_models
=
offload_inference_models
self
.
device
=
get_current_device
()
def
_make_experience
(
self
,
inputs
:
Union
[
Tensor
,
Dict
[
str
,
Tensor
]])
->
Experience
:
if
isinstance
(
inputs
,
Tensor
):
return
self
.
experience_maker
.
make_experience
(
inputs
,
**
self
.
generate_kwargs
)
elif
isinstance
(
inputs
,
dict
):
return
self
.
experience_maker
.
make_experience
(
**
inputs
,
**
self
.
generate_kwargs
)
else
:
raise
ValueError
(
f
'Unsupported input type "
{
type
(
inputs
)
}
"'
)
def
_learn
(
self
):
# replay buffer may be empty at first, we should rebuild at each training
if
not
self
.
sample_replay_buffer
:
dataloader
=
self
.
strategy
.
setup_dataloader
(
self
.
replay_buffer
,
self
.
dataloader_pin_memory
)
if
self
.
sample_replay_buffer
:
pbar
=
tqdm
(
range
(
self
.
max_epochs
),
desc
=
'Train epoch'
,
disable
=
not
is_rank_0
())
for
_
in
pbar
:
experience
=
self
.
replay_buffer
.
sample
()
experience
.
to_device
(
self
.
device
)
metrics
=
self
.
training_step
(
experience
)
pbar
.
set_postfix
(
metrics
)
else
:
for
epoch
in
range
(
self
.
max_epochs
):
self
.
_on_learn_epoch_start
(
epoch
)
if
isinstance
(
dataloader
.
sampler
,
DistributedSampler
):
dataloader
.
sampler
.
set_epoch
(
epoch
)
pbar
=
tqdm
(
dataloader
,
desc
=
f
'Train epoch [
{
epoch
+
1
}
/
{
self
.
max_epochs
}
]'
,
disable
=
not
is_rank_0
())
for
experience
in
pbar
:
self
.
_on_learn_batch_start
()
experience
.
to_device
(
self
.
device
)
metrics
=
self
.
training_step
(
experience
)
self
.
_on_learn_batch_end
(
metrics
,
experience
)
pbar
.
set_postfix
(
metrics
)
self
.
_on_learn_epoch_end
(
epoch
)
def
fit
(
self
,
prompt_dataloader
,
pretrain_dataloader
,
num_episodes
:
int
=
50000
,
max_timesteps
:
int
=
500
,
update_timesteps
:
int
=
5000
)
->
None
:
time
=
0
self
.
pretrain_dataloader
=
pretrain_dataloader
self
.
prompt_dataloader
=
prompt_dataloader
self
.
_on_fit_start
()
for
episode
in
range
(
num_episodes
):
self
.
_on_episode_start
(
episode
)
for
timestep
in
tqdm
(
range
(
max_timesteps
),
desc
=
f
'Episode [
{
episode
+
1
}
/
{
num_episodes
}
]'
,
disable
=
not
is_rank_0
()):
time
+=
1
prompts
=
next
(
iter
(
self
.
prompt_dataloader
))
self
.
_on_make_experience_start
()
if
self
.
offload_inference_models
:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self
.
experience_maker
.
initial_model
.
to
(
self
.
device
)
self
.
experience_maker
.
reward_model
.
to
(
self
.
device
)
experience
=
self
.
_make_experience
(
prompts
)
self
.
_on_make_experience_end
(
experience
)
self
.
replay_buffer
.
append
(
experience
)
if
time
%
update_timesteps
==
0
:
if
self
.
offload_inference_models
:
self
.
experience_maker
.
initial_model
.
to
(
'cpu'
)
self
.
experience_maker
.
reward_model
.
to
(
'cpu'
)
self
.
_learn
()
self
.
replay_buffer
.
clear
()
self
.
_on_episode_end
(
episode
)
self
.
_on_fit_end
()
def
training_step
(
self
,
experience
:
Experience
)
->
Dict
[
str
,
float
]:
def
_before_fit
(
self
,
prompt_dataloader
:
DataLoader
,
pretrain_dataloader
:
DataLoader
,
log_dir
:
Optional
[
str
]
=
None
,
use_wandb
:
bool
=
False
,
):
"""
Args:
prompt_dataloader (DataLoader): the dataloader to use for prompt data
pretrain_dataloader (DataLoader): the dataloader to use for pretrain data
"""
self
.
prompt_dataloader
=
CycledDataLoader
(
prompt_dataloader
)
self
.
pretrain_dataloader
=
CycledDataLoader
(
pretrain_dataloader
)
self
.
writer
=
None
if
use_wandb
and
is_rank_0
():
assert
log_dir
is
not
None
,
"log_dir must be provided when use_wandb is True"
import
wandb
wandb
.
init
(
project
=
"Coati-ppo"
,
sync_tensorboard
=
True
)
if
log_dir
is
not
None
and
is_rank_0
():
import
os
import
time
from
torch.utils.tensorboard
import
SummaryWriter
log_dir
=
os
.
path
.
join
(
log_dir
,
"ppo"
)
log_dir
=
os
.
path
.
join
(
log_dir
,
time
.
strftime
(
"%Y-%m-%d_%H:%M:%S"
,
time
.
localtime
()))
self
.
writer
=
SummaryWriter
(
log_dir
=
log_dir
)
def
_make_experience
(
self
,
collect_step
:
int
)
->
Experience
:
prompts
=
self
.
prompt_dataloader
.
next
()
if
self
.
offload_inference_models
:
# TODO(ver217): this may be controlled by strategy if they are prepared by strategy
self
.
experience_maker
.
initial_model
.
to
(
self
.
device
)
self
.
experience_maker
.
reward_model
.
to
(
self
.
device
)
assert
isinstance
(
prompts
,
dict
),
f
'Unsupported input type "
{
type
(
prompts
)
}
"'
return
self
.
experience_maker
.
make_experience
(
**
prompts
,
**
self
.
generate_kwargs
)
def
_training_step
(
self
,
experience
:
Experience
):
self
.
actor
.
train
()
self
.
critic
.
train
()
# policy loss
num_actions
=
experience
.
action_mask
.
size
(
1
)
action_log_probs
=
self
.
actor
(
experience
.
sequences
,
num_actions
,
attention_mask
=
experience
.
attention_mask
)
actor_loss
=
self
.
actor_loss_fn
(
action_log_probs
,
experience
.
action_log_probs
,
experience
.
advantages
,
action_mask
=
experience
.
action_mask
)
num_actions
=
experience
.
action_log_probs
.
size
(
1
)
actor_logits
=
self
.
actor
(
experience
.
sequences
,
experience
.
attention_mask
)[
"logits"
]
action_log_probs
=
calc_action_log_probs
(
actor_logits
,
experience
.
sequences
,
num_actions
)
actor_loss
=
self
.
actor_loss_fn
(
action_log_probs
,
experience
.
action_log_probs
,
experience
.
advantages
,
action_mask
=
experience
.
action_mask
)
actor_loss
=
(
1
-
self
.
ptx_coef
)
*
actor_loss
self
.
strategy
.
backward
(
actor_loss
,
self
.
actor
,
self
.
actor_optim
)
# ptx loss
if
self
.
ptx_coef
!=
0
:
batch
=
next
(
iter
(
self
.
pretrain_dataloader
)
)
batch
=
self
.
pretrain_dataloader
.
next
(
)
batch
=
to_device
(
batch
,
self
.
device
)
ptx_log_probs
=
self
.
actor
.
get_base_model
()(
batch
[
'input_ids'
],
attention_mask
=
batch
[
'attention_mask'
])[
'logits'
]
ptx_loss
=
self
.
ptx_loss_fn
(
ptx_log_probs
,
batch
[
'labels'
])
actor_loss
=
ptx_loss
*
self
.
ptx_coef
+
actor_loss
*
(
1
-
self
.
ptx_coef
)
ptx_log_probs
=
self
.
actor
(
batch
[
"input_ids"
],
batch
[
"attention_mask"
])[
"logits"
]
ptx_loss
=
self
.
ptx_coef
*
self
.
ptx_loss_fn
(
ptx_log_probs
,
batch
[
"labels"
])
self
.
strategy
.
backward
(
ptx_loss
,
self
.
actor
,
self
.
actor_optim
)
self
.
strategy
.
backward
(
actor_loss
,
self
.
actor
,
self
.
actor_optim
)
self
.
strategy
.
optimizer_step
(
self
.
actor_optim
)
self
.
actor_optim
.
zero_grad
()
# value loss
values
=
self
.
critic
(
experience
.
sequences
,
action_mask
=
experience
.
action_mask
,
attention_mask
=
experience
.
attention_mask
)
critic_loss
=
self
.
critic_loss_fn
(
values
,
experience
.
values
,
experience
.
reward
,
action_mask
=
experience
.
action_mask
)
values
=
self
.
critic
(
experience
.
sequences
,
attention_mask
=
experience
.
attention_mask
)
critic_loss
=
self
.
critic_loss_fn
(
values
,
experience
.
values
,
experience
.
reward
)
critic_loss
=
critic_loss
*
self
.
vf_coef
self
.
strategy
.
backward
(
critic_loss
,
self
.
critic
,
self
.
critic_optim
)
self
.
strategy
.
optimizer_step
(
self
.
critic_optim
)
self
.
critic_optim
.
zero_grad
()
return
{
'reward'
:
experience
.
reward
.
mean
().
item
()}
def
_set_default_generate_kwargs
(
strategy
:
Strategy
,
generate_kwargs
:
dict
,
actor
:
Actor
)
->
None
:
origin_model
=
strategy
.
unwrap_model
(
actor
)
new_kwargs
=
{
**
generate_kwargs
}
# use huggingface models method directly
if
'prepare_inputs_fn'
not
in
generate_kwargs
and
hasattr
(
origin_model
,
'prepare_inputs_for_generation'
):
new_kwargs
[
'prepare_inputs_fn'
]
=
origin_model
.
prepare_inputs_for_generation
if
'update_model_kwargs_fn'
not
in
generate_kwargs
and
hasattr
(
origin_model
,
'_update_model_kwargs_for_generation'
):
new_kwargs
[
'update_model_kwargs_fn'
]
=
origin_model
.
_update_model_kwargs_for_generation
return
new_kwargs
def
_learn
(
self
,
update_step
:
int
):
if
self
.
offload_inference_models
:
self
.
experience_maker
.
initial_model
.
to
(
"cpu"
)
self
.
experience_maker
.
reward_model
.
to
(
"cpu"
)
# buffer may be empty at first, we should rebuild at each training
if
self
.
sample_buffer
:
experience
=
self
.
data_buffer
.
sample
()
self
.
_on_learn_batch_start
()
experience
.
to_device
(
self
.
device
)
self
.
_training_step
(
experience
)
self
.
_on_learn_batch_end
(
experience
)
else
:
if
isinstance
(
self
.
dataloader
.
sampler
,
DistributedSampler
):
self
.
dataloader
.
sampler
.
set_epoch
(
update_step
)
pbar
=
tqdm
(
self
.
dataloader
,
desc
=
f
"Train epoch [
{
update_step
+
1
}
]"
,
disable
=
not
is_rank_0
())
for
experience
in
pbar
:
self
.
_on_learn_batch_start
()
experience
.
to_device
(
self
.
device
)
self
.
_training_step
(
experience
)
self
.
_on_learn_batch_end
(
experience
)
applications/Chat/coati/trainer/rm.py
View file @
9e768b59
from
datetime
import
datetime
from
typing
import
List
,
Optional
from
typing
import
Callable
,
Optional
import
pandas
as
pd
import
torch
import
torch.distributed
as
dist
from
torch.optim
import
Optimizer
,
lr_scheduler
from
torch.utils.data
import
DataLoader
,
Dataset
,
DistributedSampler
from
tqdm
import
tqdm
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.base
import
Trainer
from
.callbacks
import
Callback
import
tqdm
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.utils.data
import
DataLoader
from
.base
import
SLTrainer
from
.strategies
import
Strategy
from
.utils
import
is_rank_0
class
RewardModelTrainer
(
Trainer
):
class
RewardModelTrainer
(
SL
Trainer
):
"""
Trainer to use while training reward model.
Args:
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
optim (Optimizer): the optimizer to use for training
lr_scheduler (_LRScheduler): the lr scheduler to use for training
loss_fn (callable): the loss function to use for training
train_dataloader (DataLoader): the dataloader to use for training
valid_dataloader (DataLoader): the dataloader to use for validation
eval_dataloader (DataLoader): the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
"""
def
__init__
(
...
...
@@ -37,87 +29,95 @@ class RewardModelTrainer(Trainer):
model
,
strategy
:
Strategy
,
optim
:
Optimizer
,
loss_fn
,
train_dataloader
:
DataLoader
,
valid_dataloader
:
DataLoader
,
eval_dataloader
:
DataLoader
,
lr_scheduler
:
_LRScheduler
,
loss_fn
:
Callable
,
max_epochs
:
int
=
1
,
callbacks
:
List
[
Callback
]
=
[],
)
->
None
:
super
().
__init__
(
strategy
,
max_epochs
,
callbacks
=
callbacks
)
super
().
__init__
(
strategy
,
max_epochs
,
model
,
optim
)
self
.
loss_fn
=
loss_fn
self
.
scheduler
=
lr_scheduler
self
.
num_train_step
=
0
def
_eval
(
self
,
epoch
):
if
self
.
eval_dataloader
is
not
None
:
self
.
model
.
eval
()
dist
,
num_correct
,
num_samples
=
0
,
0
,
0
with
torch
.
no_grad
():
for
chosen_ids
,
c_mask
,
reject_ids
,
r_mask
in
self
.
eval_dataloader
:
chosen_ids
=
chosen_ids
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
c_mask
=
c_mask
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
reject_ids
=
reject_ids
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
r_mask
=
r_mask
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
chosen_reward
=
self
.
model
(
chosen_ids
,
attention_mask
=
c_mask
)
reject_reward
=
self
.
model
(
reject_ids
,
attention_mask
=
r_mask
)
num_samples
+=
chosen_ids
.
size
(
0
)
num_correct
+=
(
chosen_reward
>
reject_reward
).
sum
().
item
()
dist
+=
(
chosen_reward
-
reject_reward
).
mean
().
item
()
self
.
dist
=
dist
/
len
(
self
.
eval_dataloader
)
self
.
acc
=
num_correct
/
num_samples
if
self
.
writer
:
self
.
writer
.
add_scalar
(
"eval/dist"
,
self
.
dist
,
epoch
)
self
.
writer
.
add_scalar
(
"eval/acc"
,
self
.
acc
,
epoch
)
def
_train
(
self
,
epoch
):
self
.
model
.
train
()
step_bar
=
tqdm
.
trange
(
len
(
self
.
train_dataloader
),
desc
=
f
"Epoch
{
epoch
+
1
}
/
{
self
.
max_epochs
}
"
,
disable
=
not
is_rank_0
()
)
for
chosen_ids
,
c_mask
,
reject_ids
,
r_mask
in
self
.
train_dataloader
:
chosen_ids
=
chosen_ids
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
c_mask
=
c_mask
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
reject_ids
=
reject_ids
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
r_mask
=
r_mask
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
chosen_reward
=
self
.
model
(
chosen_ids
,
attention_mask
=
c_mask
)
reject_reward
=
self
.
model
(
reject_ids
,
attention_mask
=
r_mask
)
loss
=
self
.
loss_fn
(
chosen_reward
,
reject_reward
)
self
.
strategy
.
backward
(
loss
,
self
.
model
,
self
.
optimizer
)
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
optimizer
.
zero_grad
()
if
self
.
writer
:
self
.
writer
.
add_scalar
(
"train/loss"
,
loss
.
item
(),
self
.
num_train_step
)
self
.
writer
.
add_scalar
(
"train/lr"
,
self
.
optimizer
.
param_groups
[
0
][
"lr"
],
self
.
num_train_step
)
self
.
writer
.
add_scalar
(
"train/dist"
,
(
chosen_reward
-
reject_reward
).
mean
().
item
(),
self
.
num_train_step
)
self
.
writer
.
add_scalar
(
"train/acc"
,
(
chosen_reward
>
reject_reward
).
float
().
mean
().
item
(),
self
.
num_train_step
)
self
.
num_train_step
+=
1
if
self
.
num_train_step
%
100
==
0
:
self
.
scheduler
.
step
()
step_bar
.
update
()
step_bar
.
close
()
def
_before_fit
(
self
,
train_dataloader
:
DataLoader
,
eval_dataloader
:
DataLoader
,
log_dir
:
Optional
[
str
]
=
None
,
use_wandb
:
bool
=
False
,
):
"""
Args:
train_dataloader (DataLoader): the dataloader to use for training
eval_dataloader (DataLoader): the dataloader to use for evaluation
"""
self
.
train_dataloader
=
train_dataloader
self
.
valid_dataloader
=
valid_dataloader
self
.
eval_dataloader
=
eval_dataloader
self
.
model
=
model
self
.
loss_fn
=
loss_fn
self
.
optimizer
=
optim
self
.
scheduler
=
lr_scheduler
.
CosineAnnealingLR
(
self
.
optimizer
,
self
.
train_dataloader
.
__len__
()
//
100
)
def
eval_acc
(
self
,
dataloader
):
dist
=
0
on
=
0
cnt
=
0
self
.
model
.
eval
()
with
torch
.
no_grad
():
for
chosen_ids
,
c_mask
,
reject_ids
,
r_mask
in
dataloader
:
chosen_ids
=
chosen_ids
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
c_mask
=
c_mask
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
reject_ids
=
reject_ids
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
r_mask
=
r_mask
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
chosen_reward
=
self
.
model
(
chosen_ids
,
attention_mask
=
c_mask
)
reject_reward
=
self
.
model
(
reject_ids
,
attention_mask
=
r_mask
)
for
i
in
range
(
len
(
chosen_reward
)):
cnt
+=
1
if
chosen_reward
[
i
]
>
reject_reward
[
i
]:
on
+=
1
dist
+=
(
chosen_reward
-
reject_reward
).
mean
().
item
()
dist_mean
=
dist
/
len
(
dataloader
)
acc
=
on
/
cnt
self
.
model
.
train
()
return
dist_mean
,
acc
def
fit
(
self
):
time
=
datetime
.
now
()
epoch_bar
=
tqdm
(
range
(
self
.
max_epochs
),
desc
=
'Train epoch'
,
disable
=
not
is_rank_0
())
for
epoch
in
range
(
self
.
max_epochs
):
step_bar
=
tqdm
(
range
(
self
.
train_dataloader
.
__len__
()),
desc
=
'Train step of epoch %d'
%
epoch
,
disable
=
not
is_rank_0
())
# train
self
.
model
.
train
()
cnt
=
0
acc
=
0
dist
=
0
for
chosen_ids
,
c_mask
,
reject_ids
,
r_mask
in
self
.
train_dataloader
:
chosen_ids
=
chosen_ids
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
c_mask
=
c_mask
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
reject_ids
=
reject_ids
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
r_mask
=
r_mask
.
squeeze
(
1
).
to
(
torch
.
cuda
.
current_device
())
chosen_reward
=
self
.
model
(
chosen_ids
,
attention_mask
=
c_mask
)
reject_reward
=
self
.
model
(
reject_ids
,
attention_mask
=
r_mask
)
loss
=
self
.
loss_fn
(
chosen_reward
,
reject_reward
)
self
.
strategy
.
backward
(
loss
,
self
.
model
,
self
.
optimizer
)
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
optimizer
.
zero_grad
()
cnt
+=
1
if
cnt
==
100
:
self
.
scheduler
.
step
()
dist
,
acc
=
self
.
eval_acc
(
self
.
valid_dataloader
)
cnt
=
0
if
is_rank_0
():
log
=
pd
.
DataFrame
([[
step_bar
.
n
,
loss
.
item
(),
dist
,
acc
]],
columns
=
[
'step'
,
'loss'
,
'dist'
,
'acc'
])
log
.
to_csv
(
'log_%s.csv'
%
time
,
mode
=
'a'
,
header
=
False
,
index
=
False
)
step_bar
.
update
()
step_bar
.
set_postfix
({
'dist'
:
dist
,
'acc'
:
acc
})
# eval
dist
,
acc
=
self
.
eval_acc
(
self
.
eval_dataloader
)
if
is_rank_0
():
log
=
pd
.
DataFrame
([[
step_bar
.
n
,
loss
.
item
(),
dist
,
acc
]],
columns
=
[
'step'
,
'loss'
,
'dist'
,
'acc'
])
log
.
to_csv
(
'log.csv'
,
mode
=
'a'
,
header
=
False
,
index
=
False
)
epoch_bar
.
update
()
step_bar
.
set_postfix
({
'dist'
:
dist
,
'acc'
:
acc
})
step_bar
.
close
()
self
.
writer
=
None
if
use_wandb
and
is_rank_0
():
assert
log_dir
is
not
None
,
"log_dir must be provided when use_wandb is True"
import
wandb
wandb
.
init
(
project
=
"Coati-rm"
,
sync_tensorboard
=
True
)
if
log_dir
is
not
None
and
is_rank_0
():
import
os
import
time
from
torch.utils.tensorboard
import
SummaryWriter
log_dir
=
os
.
path
.
join
(
log_dir
,
"rm"
)
log_dir
=
os
.
path
.
join
(
log_dir
,
time
.
strftime
(
"%Y-%m-%d_%H:%M:%S"
,
time
.
localtime
()))
self
.
writer
=
SummaryWriter
(
log_dir
=
log_dir
)
applications/Chat/coati/trainer/sft.py
View file @
9e768b59
import
math
import
time
from
typing
import
List
,
Optional
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
import
wandb
import
tqdm
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
transformers.trainer
import
get_scheduler
from
.base
import
Trainer
from
.callbacks
import
Callback
from
.strategies
import
ColossalAIStrategy
,
Strategy
from
colossalai.logging
import
DistributedLogger
from
.base
import
SLTrainer
from
.strategies
import
GeminiStrategy
,
Strategy
from
.utils
import
is_rank_0
,
to_device
class
SFTTrainer
(
Trainer
):
class
SFTTrainer
(
SL
Trainer
):
"""
Trainer to use while training reward model.
...
...
@@ -25,12 +22,9 @@ class SFTTrainer(Trainer):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
batch_size (int, defaults to 1): the batch size while training
lr_scheduler(_LRScheduler): the lr scheduler to use for training
max_epochs (int, defaults to 2): the number of epochs to train
callbacks (List[Callback], defaults to []): the callbacks to call during training process
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
accumulation_steps (int, defaults to 8): the number of steps to accumulate gradients
"""
def
__init__
(
...
...
@@ -38,98 +32,99 @@ class SFTTrainer(Trainer):
model
,
strategy
:
Strategy
,
optim
:
Optimizer
,
train_dataloader
:
DataLoader
,
eval_dataloader
:
DataLoader
=
None
,
lr_scheduler
:
_LRScheduler
,
max_epochs
:
int
=
2
,
accumulation_steps
:
int
=
8
,
callbacks
:
List
[
Callback
]
=
[],
)
->
None
:
if
accumulation_steps
>
1
and
isinstance
(
strategy
,
ColossalAIStrategy
)
and
strategy
.
stage
==
3
:
raise
ValueError
(
"Accumulation steps are not supported in stage 3 of ColossalAI"
)
super
().
__init__
(
strategy
,
max_epochs
,
callbacks
=
callbacks
)
if
accumulation_steps
>
1
:
assert
not
isinstance
(
strategy
,
GeminiStrategy
),
"Accumulation steps are not supported in stage 3 of ColossalAI"
super
().
__init__
(
strategy
,
max_epochs
,
model
,
optim
)
self
.
accumulation_steps
=
accumulation_steps
self
.
scheduler
=
lr_scheduler
self
.
num_train_step
=
0
self
.
num_eval_step
=
0
def
_train
(
self
,
epoch
:
int
):
self
.
model
.
train
()
step_bar
=
tqdm
.
trange
(
len
(
self
.
train_dataloader
)
//
self
.
accumulation_steps
,
desc
=
f
"Epoch
{
epoch
+
1
}
/
{
self
.
max_epochs
}
"
,
disable
=
not
is_rank_0
(),
)
for
i
,
batch
in
enumerate
(
self
.
train_dataloader
):
batch
=
to_device
(
batch
,
torch
.
cuda
.
current_device
())
outputs
=
self
.
model
(
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
],
labels
=
batch
[
"labels"
])
loss
=
outputs
.
loss
/
self
.
accumulation_steps
self
.
total_loss
+=
loss
.
item
()
self
.
strategy
.
backward
(
loss
,
self
.
model
,
self
.
optimizer
)
# gradient accumulation
if
(
i
+
1
)
%
self
.
accumulation_steps
==
0
:
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
optimizer
.
zero_grad
()
self
.
scheduler
.
step
()
if
self
.
writer
:
self
.
writer
.
add_scalar
(
"train/loss"
,
self
.
total_loss
,
self
.
num_train_step
)
self
.
writer
.
add_scalar
(
"train/lr"
,
self
.
scheduler
.
get_last_lr
()[
0
],
self
.
num_train_step
)
self
.
num_train_step
+=
1
self
.
total_loss
=
0
step_bar
.
update
()
step_bar
.
close
()
def
_eval
(
self
,
epoch
:
int
):
if
self
.
eval_dataloader
is
not
None
:
self
.
model
.
eval
()
with
torch
.
no_grad
():
loss_sum
,
num_seen
=
0
,
0
for
batch
in
self
.
eval_dataloader
:
batch
=
to_device
(
batch
,
torch
.
cuda
.
current_device
())
outputs
=
self
.
model
(
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
],
labels
=
batch
[
"labels"
]
)
loss_sum
+=
outputs
.
loss
.
item
()
num_seen
+=
batch
[
"input_ids"
].
size
(
0
)
loss_mean
=
loss_sum
/
num_seen
if
dist
.
get_rank
()
==
0
:
self
.
logger
.
info
(
f
"Eval Epoch
{
epoch
}
/
{
self
.
max_epochs
}
loss
{
loss_mean
}
"
)
if
self
.
writer
:
self
.
writer
.
add_scalar
(
"eval/loss"
,
loss_mean
,
self
.
num_eval_step
)
self
.
num_eval_step
+=
1
def
_before_fit
(
self
,
train_dataloader
:
DataLoader
,
eval_dataloader
:
Optional
[
DataLoader
]
=
None
,
logger
:
Optional
[
DistributedLogger
]
=
None
,
log_dir
:
Optional
[
str
]
=
None
,
use_wandb
:
bool
=
False
,
):
"""
Args:
train_dataloader: the dataloader to use for training
eval_dataloader: the dataloader to use for evaluation
"""
self
.
train_dataloader
=
train_dataloader
self
.
eval_dataloader
=
eval_dataloader
self
.
model
=
model
self
.
optimizer
=
optim
self
.
accumulation_steps
=
accumulation_steps
num_update_steps_per_epoch
=
len
(
train_dataloader
)
//
self
.
accumulation_steps
max_steps
=
math
.
ceil
(
self
.
max_epochs
*
num_update_steps_per_epoch
)
self
.
scheduler
=
get_scheduler
(
"cosine"
,
self
.
optimizer
,
num_warmup_steps
=
math
.
ceil
(
max_steps
*
0.03
),
num_training_steps
=
max_steps
)
def
fit
(
self
,
logger
,
use_wandb
:
bool
=
False
):
if
use_wandb
:
wandb
.
init
(
project
=
"Coati"
,
name
=
time
.
strftime
(
"%Y-%m-%d %H:%M:%S"
,
time
.
localtime
()))
wandb
.
watch
(
self
.
model
)
total_loss
=
0
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
step_bar
=
tqdm
(
range
(
len
(
self
.
train_dataloader
)
//
self
.
accumulation_steps
*
self
.
max_epochs
),
desc
=
f
'steps'
,
disable
=
not
is_rank_0
())
for
epoch
in
range
(
self
.
max_epochs
):
# process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
# train
self
.
model
.
train
()
for
batch_id
,
batch
in
enumerate
(
self
.
train_dataloader
):
batch
=
to_device
(
batch
,
torch
.
cuda
.
current_device
())
outputs
=
self
.
model
(
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
],
labels
=
batch
[
"labels"
])
loss
=
outputs
.
loss
if
loss
>=
2.5
and
is_rank_0
():
logger
.
warning
(
f
"batch_id:
{
batch_id
}
, abnormal loss:
{
loss
}
"
)
loss
=
loss
/
self
.
accumulation_steps
self
.
strategy
.
backward
(
loss
,
self
.
model
,
self
.
optimizer
)
total_loss
+=
loss
.
item
()
# gradient accumulation
if
(
batch_id
+
1
)
%
self
.
accumulation_steps
==
0
:
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
optimizer
.
zero_grad
()
self
.
scheduler
.
step
()
if
is_rank_0
()
and
use_wandb
:
wandb
.
log
({
"loss"
:
total_loss
/
self
.
accumulation_steps
,
"lr"
:
self
.
scheduler
.
get_last_lr
()[
0
],
"epoch"
:
epoch
,
"batch_id"
:
batch_id
})
total_loss
=
0
step_bar
.
update
()
# if batch_id % log_interval == 0:
# logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
# wandb.log({"loss": loss.item()})
# process_bar.update()
# eval
if
self
.
eval_dataloader
is
not
None
:
self
.
model
.
eval
()
with
torch
.
no_grad
():
loss_sum
=
0
num_seen
=
0
for
batch
in
self
.
eval_dataloader
:
batch
=
to_device
(
batch
,
torch
.
cuda
.
current_device
())
outputs
=
self
.
model
(
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
],
labels
=
batch
[
"labels"
])
loss
=
outputs
.
loss
loss_sum
+=
loss
.
item
()
num_seen
+=
batch
[
"input_ids"
].
size
(
0
)
loss_mean
=
loss_sum
/
num_seen
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
'Eval Epoch
{
epoch
}
/
{
self
.
max_epochs
}
loss
{
loss_mean
}
'
)
# epoch_bar.update()
self
.
logger
=
logger
self
.
writer
=
None
if
use_wandb
and
is_rank_0
():
assert
log_dir
is
not
None
,
"log_dir must be provided when use_wandb is True"
import
wandb
wandb
.
init
(
project
=
"Coati-sft"
,
sync_tensorboard
=
True
)
if
log_dir
is
not
None
and
is_rank_0
():
import
os
import
time
from
torch.utils.tensorboard
import
SummaryWriter
log_dir
=
os
.
path
.
join
(
log_dir
,
"sft"
)
log_dir
=
os
.
path
.
join
(
log_dir
,
time
.
strftime
(
"%Y-%m-%d_%H:%M:%S"
,
time
.
localtime
()))
self
.
writer
=
SummaryWriter
(
log_dir
=
log_dir
)
self
.
total_loss
=
0
applications/Chat/coati/trainer/strategies/__init__.py
View file @
9e768b59
from
.base
import
Strategy
from
.colossalai
import
ColossalAI
Strategy
from
.colossalai
import
GeminiStrategy
,
LowLevelZero
Strategy
from
.ddp
import
DDPStrategy
from
.naive
import
NaiveStrategy
__all__
=
[
'
Strategy
'
,
'Naive
Strategy
'
,
'DDP
Strategy
'
,
'ColossalAI
Strategy
'
]
__all__
=
[
"
Strategy
"
,
"DDP
Strategy
"
,
"LowLevelZero
Strategy
"
,
"Gemini
Strategy
"
]
applications/Chat/coati/trainer/strategies/base.py
View file @
9e768b59
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
nullcontext
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
coati.models.base
import
Actor
,
get_base_model
from
coati.replay_buffer
import
ReplayBuffer
from
coati.experience_buffer
import
ExperienceBuffer
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
Plugin
from
.sampler
import
DistributedSampler
ModelOptimPair
=
Tuple
[
nn
.
Module
,
Optimizer
]
ModelOrModelOptimPair
=
Union
[
nn
.
Module
,
ModelOptimPair
]
_BoostArgSpec
=
Union
[
nn
.
Module
,
Tuple
[
nn
.
Module
,
Optimizer
],
Dict
]
class
Strategy
(
ABC
):
"""
Base class for training strategies.
Base class for training strategies.
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
plugin_initializer
:
Callable
[...,
Optional
[
Plugin
]]
=
lambda
:
None
)
->
None
:
super
().
__init__
()
# NOTE: dist must be initialized before Booster
self
.
setup_distributed
()
self
.
plugin
=
plugin_initializer
()
self
.
booster
=
Booster
(
plugin
=
self
.
plugin
)
self
.
_post_init
()
@
abstractmethod
def
backward
(
self
,
loss
:
torch
.
Tensor
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
**
kwargs
)
->
None
:
def
_post_init
(
self
)
->
None
:
pass
@
abstractmethod
def
backward
(
self
,
loss
:
torch
.
Tensor
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
**
kwargs
)
->
None
:
self
.
booster
.
backward
(
loss
,
optimizer
)
def
optimizer_step
(
self
,
optimizer
:
Optimizer
,
**
kwargs
)
->
None
:
pass
optimizer
.
step
()
@
abstractmethod
def
setup_distributed
(
self
)
->
None
:
pass
@
abstractmethod
def
setup_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
pass
@
abstractmethod
def
setup_optimizer
(
self
,
optimizer
:
Optimizer
,
model
:
nn
.
Module
)
->
Optimizer
:
pass
@
abstractmethod
def
setup_dataloader
(
self
,
replay_buffer
:
ReplayBuffer
,
pin_memory
:
bool
=
False
)
->
DataLoader
:
def
setup_dataloader
(
self
,
data_buffer
:
ExperienceBuffer
,
pin_memory
:
bool
=
False
)
->
DataLoader
:
pass
def
model_init_context
(
self
):
return
nullcontext
()
def
prepare
(
self
,
*
models_or_model_optim_pairs
:
ModelOrModelOptimPair
)
->
Union
[
List
[
ModelOrModelOptimPair
],
ModelOrModelOptimPair
]:
"""Prepare models or model-optimizer-pairs based on each strategy.
def
prepare
(
self
,
*
boost_args
:
_BoostArgSpec
)
->
Union
[
List
[
_BoostArgSpec
],
_BoostArgSpec
]:
"""Prepare [model | (model, optimizer) | Dict] based on each strategy.
NOTE: the keys of Dict must be a subset of `self.booster.boost`'s arguments.
Example::
>>> # e.g., include lr_scheduler
>>> result_dict = strategy.prepare(dict(model=model, lr_scheduler=lr_scheduler))
>>> # when fine-tuning actor and critic
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
>>> # or when training reward model
...
...
@@ -66,67 +66,72 @@ class Strategy(ABC):
>>> actor, critic = strategy.prepare(actor, critic)
Returns:
Union[List[
ModelOrModelOptimPair], ModelOrModelOptimPair
]:
M
odel
s or
model
-
optimizer
-pairs
in the original order.
Union[List[
_BoostArgSpec], _BoostArgSpec
]:
[m
odel
| (
model
,
optimizer
) | Dict]
in the original order.
"""
def
prepare_model
(
model
:
nn
.
Module
):
if
isinstance
(
model
,
Actor
):
return
Actor
(
self
.
setup_model
(
model
.
get_base_model
()))
return
self
.
setup_model
(
model
)
rets
=
[]
for
arg
in
models_or_model_optim_pairs
:
if
isinstance
(
arg
,
tuple
):
assert
len
(
arg
)
==
2
,
f
'Expect (model, optimizer) pair, got a tuple with size "
{
len
(
arg
)
}
"'
model
,
optimizer
=
arg
model
=
prepare_model
(
model
)
optimizer
=
self
.
setup_optimizer
(
optimizer
,
get_base_model
(
model
))
for
arg
in
boost_args
:
if
isinstance
(
arg
,
nn
.
Module
):
model
,
*
_
=
self
.
booster
.
boost
(
arg
)
rets
.
append
(
model
)
elif
isinstance
(
arg
,
tuple
):
try
:
model
,
optimizer
=
arg
except
ValueError
:
raise
RuntimeError
(
f
'Expect (model, optimizer) pair, got a tuple with size "
{
len
(
arg
)
}
"'
)
model
,
optimizer
,
*
_
=
self
.
booster
.
boost
(
model
=
model
,
optimizer
=
optimizer
)
rets
.
append
((
model
,
optimizer
))
elif
isinstance
(
arg
,
nn
.
Module
):
rets
.
append
(
prepare_model
(
arg
))
elif
isinstance
(
arg
,
Dict
):
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
=
self
.
booster
.
boost
(
**
arg
)
boost_result
=
dict
(
model
=
model
,
optimizer
=
optimizer
,
criterion
=
criterion
,
dataloader
=
dataloader
,
lr_scheduler
=
lr_scheduler
,
)
# remove None values
boost_result
=
{
key
:
value
for
key
,
value
in
boost_result
.
items
()
if
value
is
not
None
}
rets
.
append
(
boost_result
)
else
:
raise
RuntimeError
(
f
'Expect model or (model, optimizer) pair, got
{
type
(
arg
)
}
'
)
raise
RuntimeError
(
f
"Type
{
type
(
arg
)
}
is not supported"
)
if
len
(
rets
)
==
1
:
return
rets
[
0
]
return
rets
return
rets
[
0
]
if
len
(
rets
)
==
1
else
rets
@
staticmethod
def
unwrap_model
(
model
:
nn
.
Module
)
->
nn
.
Module
:
"""Get the unwrapped model from a wrapped model. Useful for getting original huggingface model.
For Actor, it will unwrap `actor.model`.
"""Get the unwrapped model from a wrapped model made by Strategy.prepare.
Args:
model (nn.Module): the model to unwrap
Returns:
nn.Module: the original model
(usually a huggingface model)
nn.Module: the original model
"""
return
get_base_model
(
model
)
return
model
@
abstractmethod
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
)
->
None
:
pass
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
shard
:
bool
=
False
,
**
kwargs
)
->
None
:
self
.
booster
.
save_model
(
model
,
path
,
shard
=
shard
,
**
kwargs
)
@
abstractmethod
def
load_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
map_location
:
Any
=
None
,
strict
:
bool
=
True
)
->
None
:
pass
def
load_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
strict
:
bool
=
True
)
->
None
:
self
.
booster
.
load_model
(
model
,
path
,
strict
)
@
abstractmethod
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
pass
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
,
**
kwargs
)
->
None
:
self
.
booster
.
save_optimizer
(
optimizer
,
path
,
shard
=
not
only_rank0
,
**
kwargs
)
@
abstractmethod
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
map_location
:
Any
=
None
)
->
None
:
pass
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
)
->
None
:
self
.
booster
.
load_optimizer
(
optimizer
,
path
)
def
setup_sampler
(
self
,
dataset
)
->
DistributedSampler
:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return
DistributedSampler
(
dataset
,
1
,
0
)
@
abstractmethod
def
save_pretrained
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
def
save_pretrained
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
pass
@
abstractmethod
def
get_model_state_dict_shard
(
self
,
model
:
nn
.
Module
,
**
config
):
pass
applications/Chat/coati/trainer/strategies/colossalai.py
View file @
9e768b59
import
warnings
from
typing
import
Optional
,
Union
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.optim
as
optim
from
coati.models.base
import
get_base_model
from
torch.optim
import
Optimizer
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
import
colossalai
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
CPUAdam
,
HybridAdam
from
colossalai.tensor
import
ProcessGroup
,
ShardSpec
from
colossalai.booster.plugin
import
GeminiPlugin
,
LowLevelZeroPlugin
from
colossalai.booster.plugin.low_level_zero_plugin
import
LowLevelZeroModel
from
colossalai.utils
import
get_current_device
from
colossalai.zero
import
ColoInitContext
,
ZeroDDP
,
zero_model_wrapper
,
zero_optim_wrapper
from
colossalai.zero
.gemini.gemini_ddp
import
GeminiDDP
from
.ddp
import
DDPStrategy
logger
=
get_dist_logger
(
__name__
)
class
LowLevelZeroStrategy
(
DDPStrategy
):
"""
The strategy for training with ColossalAI.
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2)
precision(str): The precision to use. Choose in ('fp32', 'fp16').
seed(int): The seed for the random number generator.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
reduce_bucket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
growth_interval(int): The growth interval for the optimizer.
hysteresis(int): The hysteresis for the optimizer.
min_scale(float): The minimum scale for the optimizer.
max_scale(float): The maximum scale for the optimizer.
max_norm(float): The maximum norm for the optimizer.
norm_type(float): The norm type for the optimizer.
"""
def
__init__
(
self
,
stage
:
int
=
2
,
precision
:
str
=
"fp16"
,
seed
:
int
=
42
,
placement_policy
:
str
=
"cuda"
,
reduce_bucket_size
:
int
=
12
*
1024
**
2
,
# only for stage 1&2
overlap_communication
:
bool
=
True
,
# only for stage 1&2
initial_scale
:
float
=
2
**
16
,
growth_factor
:
float
=
2
,
backoff_factor
:
float
=
0.5
,
growth_interval
:
int
=
1000
,
hysteresis
:
int
=
2
,
min_scale
:
float
=
1
,
max_scale
:
float
=
2
**
32
,
max_norm
:
float
=
0.0
,
norm_type
:
float
=
2.0
,
)
->
None
:
assert
stage
in
(
1
,
2
),
f
'Unsupported stage "
{
stage
}
"'
assert
placement_policy
in
(
"cpu"
,
"cuda"
),
f
'Unsupported placement policy "
{
placement_policy
}
"'
assert
precision
in
(
"fp32"
,
"fp16"
),
f
'Unsupported precision "
{
precision
}
"'
plugin_initializer
=
lambda
:
LowLevelZeroPlugin
(
stage
=
stage
,
precision
=
precision
,
reduce_bucket_size_in_m
=
reduce_bucket_size
,
overlap_communication
=
overlap_communication
,
cpu_offload
=
(
placement_policy
==
"cpu"
),
initial_scale
=
initial_scale
,
growth_factor
=
growth_factor
,
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
min_scale
=
min_scale
,
max_scale
=
max_scale
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
)
super
().
__init__
(
seed
,
plugin_initializer
)
def
_post_init
(
self
)
->
None
:
assert
isinstance
(
self
.
plugin
,
LowLevelZeroPlugin
),
f
"
{
type
(
self
).
__name__
}
's plugin is not initialized properly."
def
setup_distributed
(
self
)
->
None
:
colossalai
.
launch_from_torch
({},
seed
=
self
.
seed
)
def
unwrap_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
assert
isinstance
(
model
,
LowLevelZeroModel
)
return
model
.
module
def
get_model_state_dict_shard
(
self
,
model
:
nn
.
Module
,
**
config
):
assert
isinstance
(
model
,
LowLevelZeroModel
)
yield
from
model
.
state_dict_shard
(
max_shard_size
=
1024
,
only_rank_0
=
False
)
class
ColossalAIStrategy
(
DDPStrategy
):
class
GeminiStrategy
(
DDPStrategy
):
"""
The strategy for training with ColossalAI.
Args:
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compati
v
le with `from_pretrained()`. We temporarily disable this and will support it in the future.
This is not compati
b
le with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
search_range_m
b
(int): The search range
in MB
for the chunk size. Only for ZeRO-3.
search_range_m(int): The
number of
search range for the chunk size
, divided by 2^20
. Only for ZeRO-3.
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
min_chunk_size_m
b
(float): The minimum chunk size
in MB
. Only for ZeRO-3.
min_chunk_size_m(float): The minimum chunk size
divided by 2^20
. Only for ZeRO-3.
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
initial_scale(float): The initial scale for the optimizer.
growth_factor(float): The growth factor for the optimizer.
backoff_factor(float): The backoff factor for the optimizer.
...
...
@@ -55,134 +125,76 @@ class ColossalAIStrategy(DDPStrategy):
"""
def
__init__
(
self
,
stage
:
int
=
3
,
precision
:
str
=
'fp16'
,
seed
:
int
=
42
,
shard_init
:
bool
=
False
,
# only for stage 3
placement_policy
:
str
=
'cuda'
,
pin_memory
:
bool
=
True
,
# only for stage 3
force_outputs_fp32
:
bool
=
False
,
# only for stage 3
scatter_after_inference
:
bool
=
False
,
# only for stage 3
search_range_mb
:
int
=
32
,
# only for stage 3
hidden_dim
:
Optional
[
int
]
=
None
,
# only for stage 3
min_chunk_size_mb
:
float
=
32
,
# only for stage 3
gpu_margin_mem_ratio
:
float
=
0.0
,
# only for stage 3
reduce_bucket_size
:
int
=
12
*
1024
**
2
,
# only for stage 1&2
overlap_communication
:
bool
=
True
,
# only for stage 1&2
initial_scale
:
float
=
2
**
16
,
growth_factor
:
float
=
2
,
backoff_factor
:
float
=
0.5
,
growth_interval
:
int
=
1000
,
hysteresis
:
int
=
2
,
min_scale
:
float
=
1
,
max_scale
:
float
=
2
**
32
,
max_norm
:
float
=
0.0
,
norm_type
:
float
=
2.0
)
->
None
:
super
().
__init__
(
seed
)
assert
placement_policy
in
(
'cpu'
,
'cuda'
),
f
'Unsupported placement policy "
{
placement_policy
}
"'
assert
precision
in
(
'fp32'
,
'fp16'
),
f
'Unsupported precision "
{
precision
}
"'
self
.
stage
=
stage
self
,
seed
:
int
=
42
,
shard_init
:
bool
=
False
,
# only for stage 3
placement_policy
:
str
=
"auto"
,
shard_param_frac
:
float
=
1.0
,
# only for static placement
offload_optim_frac
:
float
=
0.0
,
# only for static placement
offload_param_frac
:
float
=
0.0
,
# only for static placement
pin_memory
:
bool
=
True
,
# only for stage 3
force_outputs_fp32
:
bool
=
False
,
# only for stage 3
search_range_m
:
int
=
32
,
# only for stage 3
hidden_dim
:
Optional
[
int
]
=
None
,
# only for stage 3
min_chunk_size_m
:
float
=
32
,
# only for stage 3
gpu_margin_mem_ratio
:
float
=
0.0
,
# only for stage 3
initial_scale
:
float
=
2
**
16
,
growth_factor
:
float
=
2
,
backoff_factor
:
float
=
0.5
,
growth_interval
:
int
=
1000
,
hysteresis
:
int
=
2
,
min_scale
:
float
=
1
,
max_scale
:
float
=
2
**
32
,
max_norm
:
float
=
0.0
,
norm_type
:
float
=
2.0
,
)
->
None
:
# TODO(ver217): support shard_init when using from_pretrained()
if
shard_init
:
warnings
.
warn
(
f
'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
f
"Shard init is not supported model.from_pretrained() yet. "
"Please load weights after strategy.prepare()"
)
if
stage
==
3
and
precision
==
'fp32'
:
warnings
.
warn
(
f
'Stage 3 only supports fp16. Precision is set to fp16.'
)
precision
=
'fp16'
self
.
precision
=
precision
self
.
shard_init
=
shard_init
self
.
gemini_config
=
dict
(
device
=
get_current_device
(),
placement_policy
=
placement_policy
,
pin_memory
=
pin_memory
,
force_outputs_fp32
=
force_outputs_fp32
,
strict_ddp_mode
=
shard_init
,
search_range_mb
=
search_range_mb
,
hidden_dim
=
hidden_dim
,
min_chunk_size_mb
=
min_chunk_size_mb
,
scatter_after_inference
=
scatter_after_inference
)
if
stage
==
3
:
self
.
zero_optim_config
=
dict
(
gpu_margin_mem_ratio
=
gpu_margin_mem_ratio
)
else
:
self
.
zero_optim_config
=
dict
(
reduce_bucket_size
=
reduce_bucket_size
,
overlap_communication
=
overlap_communication
,
cpu_offload
=
(
placement_policy
==
'cpu'
))
self
.
optim_kwargs
=
dict
(
initial_scale
=
initial_scale
,
growth_factor
=
growth_factor
,
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
min_scale
=
min_scale
,
max_scale
=
max_scale
,
max_norm
=
max_norm
,
norm_type
=
norm_type
)
warnings
.
warn
(
f
"Stage 3 only supports fp16. Precision is set to fp16."
)
# NOTE: dist should be initialized before calling get_current_device()
plugin_initializer
=
lambda
:
GeminiPlugin
(
chunk_init_device
=
get_current_device
(),
placement_policy
=
placement_policy
,
shard_param_frac
=
shard_param_frac
,
offload_optim_frac
=
offload_optim_frac
,
offload_param_frac
=
offload_param_frac
,
precision
=
"fp16"
,
pin_memory
=
pin_memory
,
force_outputs_fp32
=
force_outputs_fp32
,
strict_ddp_mode
=
shard_init
,
search_range_m
=
search_range_m
,
hidden_dim
=
hidden_dim
,
min_chunk_size_m
=
min_chunk_size_m
,
gpu_margin_mem_ratio
=
gpu_margin_mem_ratio
,
initial_scale
=
initial_scale
,
growth_factor
=
growth_factor
,
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
min_scale
=
min_scale
,
max_scale
=
max_scale
,
max_norm
=
max_norm
,
norm_type
=
norm_type
,
)
super
().
__init__
(
seed
,
plugin_initializer
)
def
_post_init
(
self
)
->
None
:
assert
isinstance
(
self
.
plugin
,
GeminiPlugin
),
f
"
{
type
(
self
).
__name__
}
's plugin is not initialized properly."
def
setup_distributed
(
self
)
->
None
:
colossalai
.
launch_from_torch
({},
seed
=
self
.
seed
)
def
model_init_context
(
self
):
if
self
.
stage
==
3
:
world_size
=
dist
.
get_world_size
()
shard_pg
=
ProcessGroup
(
tp_degree
=
world_size
)
if
self
.
shard_init
else
None
default_dist_spec
=
ShardSpec
([
-
1
],
[
world_size
])
if
self
.
shard_init
else
None
return
ColoInitContext
(
device
=
get_current_device
(),
dtype
=
torch
.
half
,
default_pg
=
shard_pg
,
default_dist_spec
=
default_dist_spec
)
return
super
().
model_init_context
()
def
setup_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
model
=
zero_model_wrapper
(
model
,
zero_stage
=
self
.
stage
,
gemini_config
=
self
.
gemini_config
)
if
self
.
stage
!=
3
and
self
.
precision
==
'fp16'
:
model
=
model
.
half
().
cuda
()
return
model
def
setup_optimizer
(
self
,
optimizer
:
optim
.
Optimizer
,
model
:
nn
.
Module
)
->
optim
.
Optimizer
:
assert
isinstance
(
optimizer
,
(
CPUAdam
,
HybridAdam
)),
f
'Unsupported optimizer
{
type
(
optimizer
)
}
'
return
zero_optim_wrapper
(
model
,
optimizer
,
optim_config
=
self
.
zero_optim_config
,
**
self
.
optim_kwargs
)
def
backward
(
self
,
loss
:
torch
.
Tensor
,
model
:
nn
.
Module
,
optimizer
:
optim
.
Optimizer
,
**
kwargs
)
->
None
:
optimizer
.
backward
(
loss
)
def
optimizer_step
(
self
,
optimizer
:
optim
.
Optimizer
,
**
kwargs
)
->
None
:
optimizer
.
step
()
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
and
self
.
stage
!=
3
:
return
base_model
=
get_base_model
(
model
)
if
self
.
stage
==
3
:
assert
isinstance
(
base_model
,
ZeroDDP
)
# for stage 3, state_dict() method should be called on every rank
state_dict
=
base_model
.
state_dict
(
only_rank_0
=
only_rank0
)
else
:
# only_rank0 is false or rank == 0
state_dict
=
base_model
.
state_dict
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
torch
.
save
(
state_dict
,
path
)
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
if
only_rank0
:
raise
RuntimeError
(
f
'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.'
)
torch
.
save
(
optimizer
.
state_dict
(),
path
)
def
unwrap_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
base_model
:
Union
[
nn
.
Module
,
ZeroDDP
]
=
get_base_model
(
model
)
if
self
.
stage
==
3
:
assert
isinstance
(
base_model
,
ZeroDDP
)
return
base_model
.
module
return
base_model
def
save_pretrained
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
if
self
.
stage
==
3
:
raise
RuntimeError
(
'ColossalAI strategy with stage-3 does not support save_pretrained() now'
)
super
().
save_pretrained
(
model
,
path
,
only_rank0
,
tokenizer
)
assert
isinstance
(
model
,
GeminiDDP
)
return
model
.
module
applications/Chat/coati/trainer/strategies/ddp.py
View file @
9e768b59
import
os
import
random
from
typing
import
Optional
from
collections
import
OrderedDict
from
typing
import
Callable
,
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
coati.replay_buffer
import
ReplayBuffer
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim
import
Optimizer
from
coati.experience_buffer
import
ExperienceBuffer
from
coati.models
import
Actor
,
Critic
,
RewardModel
from
torch.utils.data
import
DataLoader
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.naive
import
NaiveStrategy
from
colossalai.booster.plugin
import
TorchDDPPlugin
from
colossalai.booster.plugin.torch_ddp_plugin
import
TorchDDPModel
from
.base
import
Strategy
from
.sampler
import
DistributedSampler
class
DDPStrategy
(
NaiveStrategy
):
# TODO Move this to a util.py (Moving to ray.util introduces ringed import)
def
get_grad_required_state_dict
(
model
:
nn
.
Module
):
state_dict
=
OrderedDict
()
for
name
,
parameter
in
model
.
named_parameters
():
if
parameter
.
requires_grad
:
state_dict
[
name
]
=
parameter
.
detach
()
return
state_dict
class
DDPStrategy
(
Strategy
):
"""
Strategy for distributed training using torch.distributed.
Strategy for distributed training using torch.distributed.
"""
def
__init__
(
self
,
seed
:
int
=
42
)
->
None
:
def
__init__
(
self
,
seed
:
int
=
42
,
plugin_initializer
:
Callable
=
TorchDDPPlugin
)
->
None
:
self
.
seed
=
seed
super
().
__init__
()
super
().
__init__
(
plugin_initializer
)
def
setup_distributed
(
self
)
->
None
:
def
_try_init_dist
(
self
,
force
:
bool
=
False
)
->
None
:
try
:
rank
=
int
(
os
.
environ
[
'RANK'
])
local_rank
=
int
(
os
.
environ
[
'LOCAL_RANK'
])
world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
host
=
os
.
environ
[
'MASTER_ADDR'
]
port
=
int
(
os
.
environ
[
'MASTER_PORT'
])
rank
=
int
(
os
.
environ
[
"RANK"
])
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
host
=
os
.
environ
[
"MASTER_ADDR"
]
port
=
int
(
os
.
environ
[
"MASTER_PORT"
])
dist
.
init_process_group
(
"nccl"
,
init_method
=
f
"tcp://[
{
host
}
]:
{
port
}
"
,
world_size
=
world_size
,
rank
=
rank
)
torch
.
cuda
.
set_device
(
local_rank
)
except
KeyError
as
e
:
raise
RuntimeError
(
f
"Could not find
{
e
}
in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
dist
.
init_process_group
(
'nccl'
,
init_method
=
f
'tcp://[
{
host
}
]:
{
port
}
'
,
world_size
=
world_size
,
rank
=
rank
)
if
force
:
raise
RuntimeError
(
f
"Could not find
{
e
}
in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
)
except
Exception
as
e
:
if
force
:
raise
e
def
_post_init
(
self
)
->
None
:
assert
isinstance
(
self
.
plugin
,
TorchDDPPlugin
),
f
"
{
type
(
self
).
__name__
}
's plugin is not initialized properly."
def
setup_distributed
(
self
)
->
None
:
self
.
_try_init_dist
(
force
=
True
)
self
.
set_seed
(
self
.
seed
)
torch
.
cuda
.
set_device
(
local_rank
)
def
set_seed
(
self
,
seed
:
int
)
->
None
:
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
def
setup_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
device
=
torch
.
cuda
.
current_device
()
return
DDP
(
model
,
device_ids
=
[
device
])
def
setup_dataloader
(
self
,
replay_buffer
:
ReplayBuffer
,
pin_memory
:
bool
=
False
)
->
DataLoader
:
# DDP only mode, replay buffers on each rank are different.
# sampler = DistributedSampler(replay_buffer,
# num_replicas=dist.get_world_size(),
# rank=dist.get_rank(),
# shuffle=True,
# seed=self.seed,
# drop_last=True)
return
DataLoader
(
replay_buffer
,
batch_size
=
replay_buffer
.
sample_batch_size
,
# sampler=sampler,
def
setup_dataloader
(
self
,
data_buffer
:
ExperienceBuffer
,
pin_memory
:
bool
=
False
)
->
DataLoader
:
return
self
.
plugin
.
prepare_dataloader
(
data_buffer
,
batch_size
=
data_buffer
.
sample_batch_size
,
shuffle
=
True
,
drop_last
=
True
,
pin_memory
=
pin_memory
,
collate_fn
=
replay_buffer
.
collate_fn
)
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
super
().
save_model
(
model
,
path
,
only_rank0
)
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
super
().
save_optimizer
(
optimizer
,
path
,
only_rank0
)
collate_fn
=
data_buffer
.
collate_fn
,
)
def
setup_sampler
(
self
,
dataset
)
->
DistributedSampler
:
# FIXME(cwher): this is only invoked in train_on_ray, not tested after adapt Boost API.
return
DistributedSampler
(
dataset
,
dist
.
get_world_size
(),
dist
.
get_rank
())
def
unwrap_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
base_model
:
DDP
=
super
().
unwrap_model
(
model
)
return
base_model
.
module
def
save_pretrained
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
super
().
save_pretrained
(
model
,
path
,
only_rank0
,
tokenizer
)
assert
isinstance
(
model
,
TorchDDPModel
),
"model is not wrapped by TorchDDPModel."
return
model
.
unwrap
()
def
save_pretrained
(
self
,
model
:
nn
.
Module
,
path
:
str
,
shard
:
bool
=
False
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
if
dist
.
get_rank
()
==
0
:
unwrapped_model
=
self
.
unwrap_model
(
model
)
assert
isinstance
(
unwrapped_model
,
(
Actor
,
Critic
,
RewardModel
))
pretrained_model
=
unwrapped_model
.
model
assert
isinstance
(
pretrained_model
,
PreTrainedModel
)
# HACK: only use hf save_pretrained to save config
pretrained_model
.
save_pretrained
(
path
,
save_function
=
lambda
*
args
,
**
kwargs
:
None
)
if
tokenizer
is
not
None
:
tokenizer
.
save_pretrained
(
path
)
model_path
=
os
.
path
.
join
(
path
,
"pytorch_model.bin"
)
self
.
save_model
(
model
,
model_path
,
shard
=
shard
)
def
_replace_keys
(
model_path
:
str
,
replace_fn
:
Callable
):
state_dict
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)
state_dict
=
{
replace_fn
(
k
):
v
for
k
,
v
in
state_dict
.
items
()}
torch
.
save
(
state_dict
,
model_path
)
# FIXME: save_model would add "model." prefix to keys of pytorch_model.bin
# HACK: rename keys of pytorch_model.bin
if
dist
.
get_rank
()
==
0
:
_replace_keys
(
model_path
,
lambda
k
:
k
.
replace
(
"model."
,
""
,
1
))
def
get_model_state_dict_shard
(
self
,
model
:
nn
.
Module
,
**
config
):
# TODO: implement sharding on naive strategy
model
=
self
.
unwrap_model
(
model
)
if
"requires_grad_only"
in
config
and
config
[
"requires_grad_only"
]
==
True
:
state_dict
=
get_grad_required_state_dict
(
model
)
else
:
state_dict
=
model
.
state_dict
()
if
"shard_size"
in
config
:
shard_size
=
config
[
"shard_size"
]
accumulate_size
=
0
state_dict_shard
=
OrderedDict
()
for
name
,
param
in
state_dict
.
items
():
state_dict_shard
[
name
]
=
param
accumulate_size
+=
param
.
numel
()
*
param
.
element_size
()
if
accumulate_size
>=
shard_size
:
accumulate_size
=
0
yield
state_dict_shard
state_dict_shard
=
OrderedDict
()
if
accumulate_size
>
0
:
yield
state_dict_shard
else
:
yield
state_dict
applications/Chat/coati/trainer/strategies/naive.py
deleted
100644 → 0
View file @
7bc5a8e3
from
typing
import
Any
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
coati.models.base
import
get_base_model
from
coati.replay_buffer
import
ReplayBuffer
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.base
import
Strategy
class
NaiveStrategy
(
Strategy
):
"""
Strategy for single GPU. No parallelism is used.
"""
def
backward
(
self
,
loss
:
torch
.
Tensor
,
model
:
nn
.
Module
,
optimizer
:
optim
.
Optimizer
,
**
kwargs
)
->
None
:
loss
.
backward
()
def
optimizer_step
(
self
,
optimizer
:
optim
.
Optimizer
,
**
kwargs
)
->
None
:
optimizer
.
step
()
def
setup_distributed
(
self
)
->
None
:
pass
def
setup_model
(
self
,
model
:
nn
.
Module
)
->
nn
.
Module
:
return
model
def
setup_optimizer
(
self
,
optimizer
:
optim
.
Optimizer
,
model
:
nn
.
Module
)
->
optim
.
Optimizer
:
return
optimizer
def
setup_dataloader
(
self
,
replay_buffer
:
ReplayBuffer
,
pin_memory
:
bool
=
False
)
->
DataLoader
:
return
DataLoader
(
replay_buffer
,
batch_size
=
replay_buffer
.
sample_batch_size
,
shuffle
=
True
,
drop_last
=
True
,
pin_memory
=
pin_memory
,
collate_fn
=
replay_buffer
.
collate_fn
)
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
)
->
None
:
base_model
=
get_base_model
(
model
)
state_dict
=
base_model
.
state_dict
()
torch
.
save
(
state_dict
,
path
)
def
load_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
map_location
:
Any
=
None
,
strict
:
bool
=
True
)
->
None
:
base_model
=
get_base_model
(
model
)
state_dict
=
torch
.
load
(
path
,
map_location
=
map_location
)
base_model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
torch
.
save
(
optimizer
.
state_dict
(),
path
)
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
map_location
:
Any
=
None
)
->
None
:
state_dict
=
torch
.
load
(
path
,
map_location
=
map_location
)
optimizer
.
load_state_dict
(
state_dict
)
def
save_pretrained
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
True
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
unwrapped_model
=
self
.
unwrap_model
(
model
)
assert
isinstance
(
unwrapped_model
,
PreTrainedModel
)
unwrapped_model
.
save_pretrained
(
path
)
if
tokenizer
is
not
None
:
tokenizer
.
save_pretrained
(
path
)
applications/Chat/coati/trainer/strategies/sampler.py
View file @
9e768b59
...
...
@@ -4,7 +4,6 @@ import numpy as np
class
DistributedSampler
:
def
__init__
(
self
,
dataset
,
num_replicas
:
int
,
rank
:
int
)
->
None
:
self
.
dataset
=
dataset
self
.
num_replicas
=
num_replicas
...
...
@@ -12,7 +11,7 @@ class DistributedSampler:
if
len
(
self
.
dataset
)
%
self
.
num_replicas
!=
0
:
self
.
num_samples
=
math
.
ceil
(
(
len
(
self
.
dataset
)
-
self
.
num_replicas
)
/
self
.
num_replicas
# type: ignore[arg-type]
(
len
(
self
.
dataset
)
-
self
.
num_replicas
)
/
self
.
num_replicas
# type: ignore[arg-type]
)
else
:
self
.
num_samples
=
math
.
ceil
(
len
(
self
.
dataset
)
/
self
.
num_replicas
)
...
...
@@ -20,10 +19,10 @@ class DistributedSampler:
self
.
total_size
=
self
.
num_samples
*
self
.
num_replicas
indices
=
list
(
range
(
len
(
self
.
dataset
)))
indices
=
indices
[:
self
.
total_size
]
indices
=
indices
[:
self
.
total_size
]
assert
len
(
indices
)
==
self
.
total_size
# subsample
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
indices
=
indices
[
self
.
rank
:
self
.
total_size
:
self
.
num_replicas
]
assert
len
(
indices
)
==
self
.
num_samples
self
.
indices
=
indices
...
...
applications/Chat/coati/trainer/utils.py
View file @
9e768b59
...
...
@@ -3,6 +3,38 @@ from typing import Any
import
torch
import
torch.distributed
as
dist
from
torch.utils._pytree
import
tree_map
from
torch.utils.data
import
DataLoader
class
CycledDataLoader
:
"""
Why do we need this class?
In version 4da324cd60, "prompts = next(iter(self.prompt_dataloader))" is used to sample a batch of prompts/pretrain.
However, this may be inefficient due to frequent re-initialization of the dataloader. (re-initialize workers...)
NOTE: next(iter(dataloader)) is not equivalent to for batch in dataloader: break, it causes slightly different behavior.
"""
def
__init__
(
self
,
dataloader
:
DataLoader
,
)
->
None
:
self
.
dataloader
=
dataloader
self
.
count
=
0
self
.
dataloader_iter
=
None
def
next
(
self
):
# defer initialization
if
self
.
dataloader_iter
is
None
:
self
.
dataloader_iter
=
iter
(
self
.
dataloader
)
self
.
count
+=
1
try
:
return
next
(
self
.
dataloader_iter
)
except
StopIteration
:
self
.
count
=
0
self
.
dataloader_iter
=
iter
(
self
.
dataloader
)
return
next
(
self
.
dataloader_iter
)
def
is_rank_0
()
->
bool
:
...
...
@@ -10,7 +42,6 @@ def is_rank_0() -> bool:
def
to_device
(
x
:
Any
,
device
:
torch
.
device
)
->
Any
:
def
_to
(
t
:
Any
):
if
isinstance
(
t
,
torch
.
Tensor
):
return
t
.
to
(
device
)
...
...
applications/Chat/coati/utils/__init__.py
deleted
100644 → 0
View file @
7bc5a8e3
from
.tokenizer_utils
import
prepare_llama_tokenizer_and_embedding
,
smart_tokenizer_and_embedding_resize
__all__
=
[
'smart_tokenizer_and_embedding_resize'
,
'prepare_llama_tokenizer_and_embedding'
]
\ No newline at end of file
applications/Chat/coati/utils/tokenizer_utils.py
deleted
100644 → 0
View file @
7bc5a8e3
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
import
transformers
DEFAULT_PAD_TOKEN
=
"[PAD]"
DEFAULT_EOS_TOKEN
=
"</s>"
DEFAULT_BOS_TOKEN
=
"</s>"
DEFAULT_UNK_TOKEN
=
"</s>"
def
prepare_llama_tokenizer_and_embedding
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
model
:
transformers
.
PreTrainedModel
,
special_tokens_dict
:
Dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
):
"""prepare llama tokenizer and embedding.
"""
if
tokenizer
.
pad_token
is
None
:
smart_tokenizer_and_embedding_resize
(
special_tokens_dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
tokenizer
=
tokenizer
,
model
=
model
,
)
tokenizer
.
add_special_tokens
({
"eos_token"
:
DEFAULT_EOS_TOKEN
,
"bos_token"
:
DEFAULT_BOS_TOKEN
,
"unk_token"
:
DEFAULT_UNK_TOKEN
,
})
return
tokenizer
def
smart_tokenizer_and_embedding_resize
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
model
:
transformers
.
PreTrainedModel
,
special_tokens_dict
:
Dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
if
tokenizer
.
pad_token
is
None
:
num_new_tokens
=
tokenizer
.
add_special_tokens
(
special_tokens_dict
)
model
.
resize_token_embeddings
(
len
(
tokenizer
))
if
num_new_tokens
>
0
:
input_embeddings
=
model
.
get_input_embeddings
().
weight
.
data
output_embeddings
=
model
.
get_output_embeddings
().
weight
.
data
input_embeddings_avg
=
input_embeddings
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
output_embeddings_avg
=
output_embeddings
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
input_embeddings
[
-
num_new_tokens
:]
=
input_embeddings_avg
output_embeddings
[
-
num_new_tokens
:]
=
output_embeddings_avg
Prev
1
…
3
4
5
6
7
8
9
10
11
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment