Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
773955ab
Unverified
Commit
773955ab
authored
Apr 04, 2023
by
Yuanchen
Committed by
GitHub
Apr 04, 2023
Browse files
fix save_model inin naive and ddp strategy (#3436)
Co-authored-by:
Yuanchen Xu
<
yuanchen.xu00@gmail.com
>
parent
1beb85cc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
12 deletions
+49
-12
applications/Chat/coati/trainer/strategies/ddp.py
applications/Chat/coati/trainer/strategies/ddp.py
+26
-8
applications/Chat/coati/trainer/strategies/naive.py
applications/Chat/coati/trainer/strategies/naive.py
+23
-4
No files found.
applications/Chat/coati/trainer/strategies/ddp.py
View file @
773955ab
from
typing
import
Optional
import
os
import
random
...
...
@@ -5,12 +7,13 @@ import numpy as np
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
coati.models.base
import
Actor
from
coati.models.base
import
LM
,
Actor
,
RewardModel
from
coati.models.lora
import
LoraLinear
from
coati.replay_buffer
import
ReplayBuffer
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.base
import
Strategy
from
.naive
import
NaiveStrategy
...
...
@@ -72,16 +75,31 @@ class DDPStrategy(NaiveStrategy):
model
:
DDP
=
Strategy
.
_unwrap_actor
(
actor
)
return
model
.
module
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
None
for
module
in
model
.
modules
():
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
eval
()
if
isinstance
(
model
,
RewardModel
):
state_dict
=
model
.
state_dict
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
model
=
model
.
model
.
module
torch
.
save
(
state_dict
,
path
)
else
:
try
:
if
isinstance
(
model
,
LM
):
model
=
model
.
model
model
.
save_pretrained
(
path
)
if
tokenizer
is
not
None
:
tokenizer
.
save_pretrained
(
path
)
except
AttributeError
:
state_dict
=
model
.
state_dict
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
torch
.
save
(
state_dict
,
path
)
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
...
...
applications/Chat/coati/trainer/strategies/naive.py
View file @
773955ab
from
typing
import
Any
from
typing
import
Any
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
coati.replay_buffer
import
ReplayBuffer
from
coati.models.base
import
LM
,
RewardModel
from
coati.models.lora
import
LoraLinear
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.base
import
Strategy
...
...
@@ -38,9 +41,25 @@ class NaiveStrategy(Strategy):
pin_memory
=
pin_memory
,
collate_fn
=
replay_buffer
.
collate_fn
)
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
unwrapped_model
=
self
.
_unwrap_model
(
model
)
torch
.
save
(
unwrapped_model
.
state_dict
(),
path
)
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
for
module
in
model
.
modules
():
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
eval
()
if
isinstance
(
model
,
RewardModel
):
state_dict
=
model
.
state_dict
()
torch
.
save
(
state_dict
,
path
)
else
:
try
:
if
isinstance
(
model
,
LM
):
model
=
model
.
model
model
.
save_pretrained
(
path
)
if
tokenizer
is
not
None
:
tokenizer
.
save_pretrained
(
path
)
except
AttributeError
:
state_dict
=
model
.
state_dict
()
torch
.
save
(
state_dict
,
path
)
def
load_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
map_location
:
Any
=
None
,
strict
:
bool
=
True
)
->
None
:
unwrapped_model
=
self
.
_unwrap_model
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment