Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
773955ab
Unverified
Commit
773955ab
authored
Apr 04, 2023
by
Yuanchen
Committed by
GitHub
Apr 04, 2023
Browse files
fix save_model inin naive and ddp strategy (#3436)
Co-authored-by:
Yuanchen Xu
<
yuanchen.xu00@gmail.com
>
parent
1beb85cc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
12 deletions
+49
-12
applications/Chat/coati/trainer/strategies/ddp.py
applications/Chat/coati/trainer/strategies/ddp.py
+26
-8
applications/Chat/coati/trainer/strategies/naive.py
applications/Chat/coati/trainer/strategies/naive.py
+23
-4
No files found.
applications/Chat/coati/trainer/strategies/ddp.py
View file @
773955ab
from
typing
import
Optional
import
os
import
os
import
random
import
random
...
@@ -5,12 +7,13 @@ import numpy as np
...
@@ -5,12 +7,13 @@ import numpy as np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
torch.nn
as
nn
import
torch.nn
as
nn
from
coati.models.base
import
Actor
from
coati.models.base
import
LM
,
Actor
,
RewardModel
from
coati.models.lora
import
LoraLinear
from
coati.models.lora
import
LoraLinear
from
coati.replay_buffer
import
ReplayBuffer
from
coati.replay_buffer
import
ReplayBuffer
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.base
import
Strategy
from
.base
import
Strategy
from
.naive
import
NaiveStrategy
from
.naive
import
NaiveStrategy
...
@@ -72,17 +75,32 @@ class DDPStrategy(NaiveStrategy):
...
@@ -72,17 +75,32 @@ class DDPStrategy(NaiveStrategy):
model
:
DDP
=
Strategy
.
_unwrap_actor
(
actor
)
model
:
DDP
=
Strategy
.
_unwrap_actor
(
actor
)
return
model
.
module
return
model
.
module
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
None
for
module
in
model
.
modules
():
for
module
in
model
.
modules
():
if
isinstance
(
module
,
LoraLinear
):
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
merge_weights
=
True
module
.
eval
()
module
.
eval
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
if
isinstance
(
model
,
RewardModel
):
return
state_dict
=
model
.
state_dict
()
model
=
model
.
model
.
module
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
state_dict
=
model
.
state_dict
()
return
torch
.
save
(
state_dict
,
path
)
torch
.
save
(
state_dict
,
path
)
else
:
try
:
if
isinstance
(
model
,
LM
):
model
=
model
.
model
model
.
save_pretrained
(
path
)
if
tokenizer
is
not
None
:
tokenizer
.
save_pretrained
(
path
)
except
AttributeError
:
state_dict
=
model
.
state_dict
()
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
return
torch
.
save
(
state_dict
,
path
)
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_optimizer
(
self
,
optimizer
:
Optimizer
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
if
only_rank0
and
dist
.
get_rank
()
!=
0
:
...
...
applications/Chat/coati/trainer/strategies/naive.py
View file @
773955ab
from
typing
import
Any
from
typing
import
Any
,
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
coati.replay_buffer
import
ReplayBuffer
from
coati.replay_buffer
import
ReplayBuffer
from
coati.models.base
import
LM
,
RewardModel
from
coati.models.lora
import
LoraLinear
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
transformers.tokenization_utils_base
import
PreTrainedTokenizerBase
from
.base
import
Strategy
from
.base
import
Strategy
...
@@ -38,9 +41,25 @@ class NaiveStrategy(Strategy):
...
@@ -38,9 +41,25 @@ class NaiveStrategy(Strategy):
pin_memory
=
pin_memory
,
pin_memory
=
pin_memory
,
collate_fn
=
replay_buffer
.
collate_fn
)
collate_fn
=
replay_buffer
.
collate_fn
)
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
)
->
None
:
def
save_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
only_rank0
:
bool
=
False
,
tokenizer
:
Optional
[
PreTrainedTokenizerBase
]
=
None
)
->
None
:
unwrapped_model
=
self
.
_unwrap_model
(
model
)
for
module
in
model
.
modules
():
torch
.
save
(
unwrapped_model
.
state_dict
(),
path
)
if
isinstance
(
module
,
LoraLinear
):
module
.
merge_weights
=
True
module
.
eval
()
if
isinstance
(
model
,
RewardModel
):
state_dict
=
model
.
state_dict
()
torch
.
save
(
state_dict
,
path
)
else
:
try
:
if
isinstance
(
model
,
LM
):
model
=
model
.
model
model
.
save_pretrained
(
path
)
if
tokenizer
is
not
None
:
tokenizer
.
save_pretrained
(
path
)
except
AttributeError
:
state_dict
=
model
.
state_dict
()
torch
.
save
(
state_dict
,
path
)
def
load_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
map_location
:
Any
=
None
,
strict
:
bool
=
True
)
->
None
:
def
load_model
(
self
,
model
:
nn
.
Module
,
path
:
str
,
map_location
:
Any
=
None
,
strict
:
bool
=
True
)
->
None
:
unwrapped_model
=
self
.
_unwrap_model
(
model
)
unwrapped_model
=
self
.
_unwrap_model
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment