Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
52736205
Unverified
Commit
52736205
authored
Jul 06, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 06, 2022
Browse files
[checkpoint] make unitest faster (#1217)
parent
f38006ea
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
40 deletions
+34
-40
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+12
-23
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+22
-17
No files found.
colossalai/utils/checkpoint/module_checkpoint.py
View file @
52736205
...
@@ -5,7 +5,8 @@ import collections
...
@@ -5,7 +5,8 @@ import collections
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
as
_CosineAnnealingLR
from
torch.optim.lr_scheduler
import
CosineAnnealingLR
as
_CosineAnnealingLR
from
colossalai.utils.model.colo_init_context
import
colo_state_dict
from
colossalai.utils.model.colo_init_context
import
colo_state_dict
def
save_checkpoint
(
dire
,
def
save_checkpoint
(
dire
:
str
,
epoch
:
int
,
epoch
:
int
,
model
:
torch
.
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
optimizer
:
torch
.
optim
.
Optimizer
=
None
,
...
@@ -15,30 +16,21 @@ def save_checkpoint(dire,
...
@@ -15,30 +16,21 @@ def save_checkpoint(dire,
"""save_checkpoint
"""save_checkpoint
save a model, whose parameters are `ColoTensor`s.
save a model, whose parameters are `ColoTensor`s.
Args:
Args:
dire (
_type_): _description_
dire (
str): directory to save the checkpoint files.
epoch (int):
_description_
epoch (int):
the number of epoch
model (torch.nn.Module):
_description_
model (torch.nn.Module):
a torch module initialized by ColoInitContext
optimizer (torch.optim.Optimizer, optional):
_description_
. Defaults to None.
optimizer (torch.optim.Optimizer, optional):
optimizers
. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional):
_description_
. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional):
lr schedule
. Defaults to None.
"""
"""
model_state
=
{
model_state
=
{
'epoch'
:
epoch
,
'model'
:
colo_state_dict
(
model
,
state_dict_func
=
nn
.
Module
.
state_dict
)}
'epoch'
:
epoch
,
'model'
:
colo_state_dict
(
model
,
state_dict_func
=
nn
.
Module
.
state_dict
)
}
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
torch
.
save
(
model_state
,
dire
+
'/epoch_{}_model.pth'
.
format
(
epoch
))
lr_scheduler_dict
=
lr_scheduler
.
state_dict
()
lr_scheduler_dict
=
lr_scheduler
.
state_dict
()
lr_scheduler_dict
[
'after_scheduler'
]
=
lr_scheduler_dict
[
'after_scheduler'
].
state_dict
()
lr_scheduler_dict
[
'after_scheduler'
]
=
lr_scheduler_dict
[
'after_scheduler'
].
state_dict
()
optim_state
=
{
optim_state
=
{
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler_dict
}
'epoch'
:
epoch
,
'optimizer'
:
optimizer
.
state_dict
(),
'lr_scheduler'
:
lr_scheduler_dict
}
torch
.
save
(
optim_state
,
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
dist
.
get_rank
()))
torch
.
save
(
optim_state
,
dire
+
'/epoch_{}_optim_rank_{}.pth'
.
format
(
epoch
,
dist
.
get_rank
()))
def
load_checkpoint
(
dire
,
def
load_checkpoint
(
dire
,
epoch
:
int
,
epoch
:
int
,
rank
:
int
,
rank
:
int
,
...
@@ -64,10 +56,7 @@ def load_checkpoint(dire,
...
@@ -64,10 +56,7 @@ def load_checkpoint(dire,
optimizer
.
load_state_dict
(
optim_state
[
'optimizer'
])
optimizer
.
load_state_dict
(
optim_state
[
'optimizer'
])
lr_scheduler_dict
=
optim_state
[
'lr_scheduler'
]
lr_scheduler_dict
=
optim_state
[
'lr_scheduler'
]
after_scheduler_dict
=
lr_scheduler_dict
[
'after_scheduler'
]
after_scheduler_dict
=
lr_scheduler_dict
[
'after_scheduler'
]
lr_scheduler_dict
[
'after_scheduler'
]
=
_CosineAnnealingLR
(
lr_scheduler_dict
[
'after_scheduler'
]
=
_CosineAnnealingLR
(
optimizer
,
after_scheduler_dict
[
'T_max'
],
optimizer
,
after_scheduler_dict
[
'eta_min'
],
after_scheduler_dict
[
'T_max'
],
after_scheduler_dict
[
'last_epoch'
])
after_scheduler_dict
[
'eta_min'
],
after_scheduler_dict
[
'last_epoch'
]
)
lr_scheduler
.
load_state_dict
(
lr_scheduler_dict
)
lr_scheduler
.
load_state_dict
(
lr_scheduler_dict
)
tests/test_utils/test_colo_checkpoint.py
View file @
52736205
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
import
os
,
sys
,
shutil
import
os
,
shutil
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
pytest
import
pytest
import
copy
import
copy
import
operator
from
functools
import
partial
import
colossalai
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
,
ProcessGroup
,
ColoTensor
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
,
ProcessGroup
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.utils.checkpoint
import
save_checkpoint
,
load_checkpoint
from
colossalai.utils.checkpoint
import
save_checkpoint
,
load_checkpoint
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
...
@@ -46,15 +45,17 @@ class DummyDataGenerator(ABC):
...
@@ -46,15 +45,17 @@ class DummyDataGenerator(ABC):
class
DummyDataLoader
(
DummyDataGenerator
):
class
DummyDataLoader
(
DummyDataGenerator
):
batch_size
=
128
category
=
16
def
__init__
(
self
,
batch_size
,
category
,
feature_size
,
length
=
10
):
feature_size
=
256
super
().
__init__
(
length
)
self
.
batch_size
=
batch_size
self
.
category
=
category
self
.
feature_size
=
feature_size
def
generate
(
self
):
def
generate
(
self
):
image_dict
=
{}
image_dict
=
{}
image_dict
[
'pixel_values'
]
=
torch
.
rand
(
image_dict
[
'pixel_values'
]
=
torch
.
rand
(
self
.
batch_size
,
self
.
feature_size
,
device
=
get_current_device
())
*
2
-
1
DummyDataLoader
.
batch_size
,
DummyDataLoader
.
feature_size
,
device
=
get_current_device
())
*
2
-
1
image_dict
[
'label'
]
=
torch
.
randint
(
self
.
category
,
(
self
.
batch_size
,),
image_dict
[
'label'
]
=
torch
.
randint
(
DummyDataLoader
.
category
,
(
DummyDataLoader
.
batch_size
,),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
device
=
get_current_device
())
device
=
get_current_device
())
return
image_dict
return
image_dict
...
@@ -102,11 +103,15 @@ def remove(path):
...
@@ -102,11 +103,15 @@ def remove(path):
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
test_epoch
,
pg
):
def
run_checkpoint
(
init_spec_func
,
use_ddp
,
test_epoch
,
pg
):
train_dataloader
=
DummyDataLoader
(
length
=
16
)
batch
=
3
feature
=
32
category
=
16
train_dataloader
=
DummyDataLoader
(
batch
,
category
,
feature
,
length
=
16
)
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
MLP
(
256
,
16
,
64
)
model
=
MLP
(
feature
,
category
)
model_reload
=
MLP
(
256
,
16
,
64
)
model_reload
=
MLP
(
feature
,
category
)
model_ref
=
MLP
(
256
,
16
,
64
)
model_ref
=
MLP
(
feature
,
category
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
model_reload
=
model_reload
.
cuda
()
model_reload
=
model_reload
.
cuda
()
model_ref
=
model_ref
.
cuda
()
model_ref
=
model_ref
.
cuda
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment