Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b6fd165f
Unverified
Commit
b6fd165f
authored
Jul 28, 2022
by
HELSON
Committed by
GitHub
Jul 28, 2022
Browse files
[checkpoint] add kwargs for load_state_dict (#1374)
parent
50dec605
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
14 deletions
+15
-14
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+15
-14
No files found.
colossalai/utils/checkpoint/module_checkpoint.py
View file @
b6fd165f
...
...
@@ -3,7 +3,7 @@ import torch.distributed as dist
from
colossalai.tensor
import
ColoTensor
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.checkpoint.utils
import
gather_tensor
,
scatter_tensor
from
typing
import
Optional
from
typing
import
Optional
,
Dict
def
save_checkpoint
(
path
:
str
,
...
...
@@ -71,22 +71,23 @@ def save_checkpoint(path: str,
dist
.
barrier
()
def
load_checkpoint
(
path
,
def
load_checkpoint
(
path
:
str
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optional
[
ColossalaiOptimizer
]
=
None
,
lr_scheduler
:
torch
.
optim
.
lr_scheduler
.
_LRScheduler
=
None
,
*
args
,
**
kwargs
):
torch_load_kwargs
:
Optional
[
Dict
]
=
None
,
load_state_dict_kwargs
:
Optional
[
Dict
]
=
None
):
"""load_checkpoint
load a model, whose parameters are `ColoTensor`s.
Args:
path (_type_): _description_
epoch (int): _description_
rank (int): _description_
model (torch.nn.Module): _description_
optimizer (ColossalaiOptimizer, optional): _description_. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None.
path (str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
torch_load_kwargs: (dict, optional): The kwargs of torch.load inside the function
load_state_dict_kwargs (dict, optional): The kwargs of load_state_dict inside the function
"""
rank
=
dist
.
get_rank
()
mapping
=
dict
()
...
...
@@ -96,8 +97,8 @@ def load_checkpoint(path,
gather_tensor
(
p
)
if
rank
==
0
:
load_state
=
torch
.
load
(
path
+
'/epoch_{}_model.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
model
.
load_state_dict
(
load_state
[
'model'
])
load_state
=
torch
.
load
(
path
+
'/epoch_{}_model.pth'
.
format
(
epoch
),
*
*
torch_load_
kwargs
)
model
.
load_state_dict
(
load_state
[
'model'
]
,
**
load_state_dict_kwargs
)
dist
.
barrier
()
# scatter loaded parameters
...
...
@@ -118,8 +119,8 @@ def load_checkpoint(path,
gather_tensor
(
t
)
if
rank
==
0
:
colo_checkpoint
=
torch
.
load
(
path
+
'/epoch_{}_optim.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
optimizer
.
load_state_dict
(
colo_checkpoint
[
'optim'
])
colo_checkpoint
=
torch
.
load
(
path
+
'/epoch_{}_optim.pth'
.
format
(
epoch
),
*
*
torch_load_
kwargs
)
optimizer
.
load_state_dict
(
colo_checkpoint
[
'optim'
]
,
**
load_state_dict_kwargs
)
dist
.
barrier
()
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment