Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0c1a16ea
Unverified
Commit
0c1a16ea
authored
Jul 28, 2022
by
Frank Lee
Committed by
GitHub
Jul 28, 2022
Browse files
[util] standard checkpoint function naming (#1377)
parent
52bc2dc2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
8 deletions
+8
-8
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+8
-8
No files found.
colossalai/utils/checkpoint/module_checkpoint.py
View file @
0c1a16ea
...
...
@@ -6,7 +6,7 @@ from colossalai.utils.checkpoint.utils import gather_tensor, scatter_tensor
from
typing
import
Optional
def
save_checkpoint
(
dire
:
str
,
def
save_checkpoint
(
path
:
str
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optional
[
ColossalaiOptimizer
]
=
None
,
...
...
@@ -16,7 +16,7 @@ def save_checkpoint(dire: str,
"""save_checkpoint
save a model, whose parameters are `ColoTensor`s.
Args:
dire
(str): directory to save the checkpoint files.
path
(str): directory to save the checkpoint files.
epoch (int): the number of epoch
model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (ColossalaiOptimizer, optional): optimizers. Defaults to None.
...
...
@@ -39,7 +39,7 @@ def save_checkpoint(dire: str,
delattr
(
v
,
'save_ready'
)
# model saving
save_state
=
{
'epoch'
:
epoch
,
'model'
:
model_state
}
torch
.
save
(
save_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
torch
.
save
(
save_state
,
path
+
'/epoch_{}_model.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
# delete old dicts
del
model_state
...
...
@@ -57,7 +57,7 @@ def save_checkpoint(dire: str,
if
rank
==
0
:
save_state
=
{
'epoch'
:
epoch
,
'optim'
:
optim_state
}
torch
.
save
(
save_state
,
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
torch
.
save
(
save_state
,
path
+
'/epoch_{}_optim.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
# recover colo tensors in rank0
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
...
...
@@ -71,7 +71,7 @@ def save_checkpoint(dire: str,
dist
.
barrier
()
def
load_checkpoint
(
dire
,
def
load_checkpoint
(
path
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
optimizer
:
Optional
[
ColossalaiOptimizer
]
=
None
,
...
...
@@ -81,7 +81,7 @@ def load_checkpoint(dire,
"""load_checkpoint
load a model, whose parameters are `ColoTensor`s.
Args:
dire
(_type_): _description_
path
(_type_): _description_
epoch (int): _description_
rank (int): _description_
model (torch.nn.Module): _description_
...
...
@@ -96,7 +96,7 @@ def load_checkpoint(dire,
gather_tensor
(
p
)
if
rank
==
0
:
load_state
=
torch
.
load
(
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
load_state
=
torch
.
load
(
path
+
'/epoch_{}_model.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
model
.
load_state_dict
(
load_state
[
'model'
])
dist
.
barrier
()
...
...
@@ -118,7 +118,7 @@ def load_checkpoint(dire,
gather_tensor
(
t
)
if
rank
==
0
:
colo_checkpoint
=
torch
.
load
(
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
colo_checkpoint
=
torch
.
load
(
path
+
'/epoch_{}_optim.pth'
.
format
(
epoch
),
*
args
,
**
kwargs
)
optimizer
.
load_state_dict
(
colo_checkpoint
[
'optim'
])
dist
.
barrier
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment