Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6b305a99
Unverified
Commit
6b305a99
authored
May 23, 2023
by
wukong1992
Committed by
GitHub
May 23, 2023
Browse files
[booster] torch fsdp fix ckpt (#3788)
parent
9265f2d4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
223 additions
and
179 deletions
+223
-179
colossalai/booster/booster.py
colossalai/booster/booster.py
+1
-1
colossalai/booster/plugin/torch_fsdp_plugin.py
colossalai/booster/plugin/torch_fsdp_plugin.py
+69
-138
colossalai/checkpoint_io/general_checkpoint_io.py
colossalai/checkpoint_io/general_checkpoint_io.py
+39
-36
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+1
-4
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
+113
-0
No files found.
colossalai/booster/booster.py
View file @
6b305a99
...
...
@@ -196,7 +196,7 @@ class Booster:
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
self
.
checkpoint_io
.
save_model
(
model
,
checkpoint
,
prefix
,
shard
,
size_per_shard
)
self
.
checkpoint_io
.
save_model
(
model
,
checkpoint
=
checkpoint
,
shard
=
shard
,
size_per_shard
=
size_per_shard
)
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
"""Load optimizer from checkpoint.
...
...
colossalai/booster/plugin/torch_fsdp_plugin.py
View file @
6b305a99
from
pathlib
import
Path
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
torch
...
...
@@ -5,30 +6,18 @@ import torch.nn as nn
from
packaging
import
version
from
torch.distributed
import
ProcessGroup
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'
2.0
.0'
):
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'
1.12
.0'
):
from
torch.distributed.fsdp
import
FullStateDictConfig
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
StateDictType
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
(
BackwardPrefetch
,
CPUOffload
,
MixedPrecision
,
ShardingStrategy
,
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp._init_utils
import
ProcessGroupType
from
torch.distributed.fsdp.api
import
(
BackwardPrefetch
,
CPUOffload
,
FullOptimStateDictConfig
,
FullStateDictConfig
,
MixedPrecision
,
ShardingStrategy
,
StateDictType
,
)
from
torch.distributed.fsdp.wrap
import
_FSDPPolicy
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
...
...
@@ -36,7 +25,7 @@ from torch.optim import Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.utils.data
import
DataLoader
from
colossalai.checkpoint_io
import
CheckpointIO
,
GeneralCheckpointIO
from
colossalai.checkpoint_io
import
CheckpointIO
,
GeneralCheckpointIO
,
utils
from
colossalai.cluster
import
DistCoordinator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
...
...
@@ -51,102 +40,71 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
super
().
__init__
()
self
.
coordinator
=
DistCoordinator
()
def
__set_model_optim_state
(
self
,
model
,
state_dict_type
,
state_dict_config
,
optim_state_dict_config
,
):
return
FSDP
.
set_state_dict_type
(
model
,
state_dict_type
,
state_dict_config
,
optim_state_dict_config
)
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
load_sharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
load_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
strict
:
bool
):
checkpoint
=
utils
.
load_state_dict
(
checkpoint
)
model
.
load_state_dict
(
checkpoint
)
def
save_sharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
def
load_unsharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
Path
):
checkpoint
=
utils
.
load_state_dict
(
checkpoint
)
fsdp_model
=
optimizer
.
unwrap_model
()
sharded_osd
=
FSDP
.
scatter_full_optim_state_dict
(
checkpoint
,
fsdp_model
)
optimizer
.
load_state_dict
(
sharded_osd
)
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
load_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
def
save_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
use_safetensors
:
bool
):
"""
Load
model
from
checkpoint
with automatic unwrapping
.
Save
model
to
checkpoint
but only on master process
.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
cfg
=
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
)
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
cfg
):
full_model_state
=
model
.
state_dict
()
utils
.
save_state_dict
(
full_model_state
,
checkpoint_file_path
=
checkpoint
,
use_safetensors
=
use_safetensors
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
rank0_only
=
True
))
full_state_dict
=
model
.
state_dict
()
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
model
.
load_state_dict
(
full_state_dict
)
def
load_unsharded_optimizer
(
self
,
model
:
nn
.
Module
,
optim
:
Optimizer
,
checkpoint
:
str
):
def
save_unsharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
"""
Load O
ptimizer
from
checkpoint
with automatic unwrapping
.
Save o
ptimizer
to
checkpoint
but only on master process
.
"""
assert
isinstance
(
optimizer
,
FSDPOptimizerWrapper
)
fsdp_model
=
optimizer
.
unwrap_model
()
full_optimizer_state
=
FSDP
.
full_optim_state_dict
(
fsdp_model
,
optim
=
optimizer
,
rank0_only
=
True
)
utils
.
save_state_dict
(
full_optimizer_state
,
checkpoint_file_path
=
checkpoint
,
use_safetensors
=
False
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
optim_full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
optim_full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
FSDP
.
full_optim_state_dict_to_load
(
optim_full_state_dict
,
model
,
optim
)
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
optim
.
load_state_dict
(
optim_full_state_dict
)
def
save_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
variant
:
Optional
[
str
],
size_per_shard
:
int
,
use_safetensors
:
bool
):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_index_file
:
Path
,
strict
:
bool
=
False
,
use_safetensors
:
bool
=
False
,
load_sub_module
:
bool
=
True
):
"""
Load model to checkpoint but only on master process.
"""
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
cfg
=
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
)
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
cfg
):
model_state_dict
=
model
.
state_dict
()
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
rank0_only
=
True
))
model_state_dict
=
model
.
state_dict
()
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
self
.
save_checkpoint
(
model_state_dict
,
checkpoint
)
def
save_unsharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
def
save_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
"""
Save optimizer to checkpoint but only on master process.
"""
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
optim_state_dict
=
FSDP
.
full_optim_state_dict
(
model
=
model
,
optim
=
optimizer
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullOptimStateDictConfig
(
rank0_only
=
True
))
optim_state_dict
=
FSDP
.
optim_state_dict
(
model
,
optimizer
)
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
self
.
save_checkpoint
(
optim_state_dict
,
checkpoint
)
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
prefix
:
str
,
size_per_shard
:
int
):
"""
Load optimizer to checkpoint but only on master process.
"""
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
def
save_lr_scheduler
(
self
,
lr_scheduler
:
LRScheduler
,
checkpoint
:
str
):
"""
Save model to checkpoint but only on master process.
"""
if
self
.
coordinator
.
is_master
():
super
().
save_lr_scheduler
(
lr_scheduler
,
checkpoint
)
class
TorchFSDPModel
(
ModelWrapper
):
...
...
@@ -156,7 +114,17 @@ class TorchFSDPModel(ModelWrapper):
self
.
module
=
FSDP
(
module
,
*
args
,
**
kwargs
)
def
unwrap
(
self
):
return
self
.
module
.
module
return
self
.
module
class
FSDPOptimizerWrapper
(
OptimizerWrapper
):
def
__init__
(
self
,
optimizer
:
Optimizer
,
model
:
nn
.
Module
):
self
.
model
=
model
super
().
__init__
(
optimizer
)
def
unwrap_model
(
self
)
->
nn
.
Module
:
return
self
.
model
class
TorchFSDPPlugin
(
DPPluginBase
):
...
...
@@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
):
def
__init__
(
self
,
...
...
@@ -191,7 +158,6 @@ class TorchFSDPPlugin(DPPluginBase):
mixed_precision
:
Optional
[
MixedPrecision
]
=
None
,
ignored_modules
:
Optional
[
Iterable
[
torch
.
nn
.
Module
]]
=
None
,
param_init_fn
:
Optional
[
Callable
[[
nn
.
Module
],
None
]]
=
None
,
device_id
:
Optional
[
Union
[
int
,
torch
.
device
]]
=
None
,
sync_module_states
:
bool
=
False
,
):
super
().
__init__
()
...
...
@@ -203,42 +169,7 @@ class TorchFSDPPlugin(DPPluginBase):
mixed_precision
=
mixed_precision
,
ignored_modules
=
ignored_modules
,
param_init_fn
=
param_init_fn
,
device_id
=
device_id
,
sync_module_states
=
sync_module_states
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
def
__init__
(
self
,
process_group
:
ProcessGroupType
=
None
,
sharding_strategy
:
Optional
[
ShardingStrategy
]
=
None
,
cpu_offload
:
Optional
[
CPUOffload
]
=
None
,
auto_wrap_policy
:
Optional
[
Union
[
Callable
,
_FSDPPolicy
]]
=
None
,
backward_prefetch
:
Optional
[
BackwardPrefetch
]
=
BackwardPrefetch
.
BACKWARD_PRE
,
mixed_precision
:
Optional
[
MixedPrecision
]
=
None
,
ignored_modules
:
Optional
[
Iterable
[
torch
.
nn
.
Module
]]
=
None
,
param_init_fn
:
Optional
[
Callable
[[
nn
.
Module
],
None
]]
=
None
,
device_id
:
Optional
[
Union
[
int
,
torch
.
device
]]
=
None
,
sync_module_states
:
bool
=
False
,
forward_prefetch
:
bool
=
False
,
limit_all_gathers
:
bool
=
False
,
use_orig_params
:
bool
=
False
,
ignored_parameters
:
Optional
[
Iterable
[
torch
.
nn
.
Parameter
]]
=
None
,
):
super
().
__init__
()
self
.
fsdp_kwargs
=
dict
(
process_group
=
process_group
,
sharding_strategy
=
sharding_strategy
,
cpu_offload
=
cpu_offload
,
auto_wrap_policy
=
auto_wrap_policy
,
backward_prefetch
=
backward_prefetch
,
mixed_precision
=
mixed_precision
,
ignored_modules
=
ignored_modules
,
param_init_fn
=
param_init_fn
,
device_id
=
device_id
,
sync_module_states
=
sync_module_states
,
forward_prefetch
=
forward_prefetch
,
limit_all_gathers
=
limit_all_gathers
,
use_orig_params
=
use_orig_params
,
ignored_parameters
=
ignored_parameters
)
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
...
...
@@ -269,14 +200,14 @@ class TorchFSDPPlugin(DPPluginBase):
lr_scheduler
:
LRScheduler
=
None
,
)
->
Tuple
[
Union
[
nn
.
Module
,
OptimizerWrapper
,
LRScheduler
,
DataLoader
]]:
model
=
model
.
cuda
()
# wrap the model with PyTorch FSDP
model
=
TorchFSDPModel
(
model
,
**
self
.
fsdp_kwargs
)
fsdp_model
=
TorchFSDPModel
(
model
,
device_id
=
torch
.
cuda
.
current_device
(),
**
self
.
fsdp_kwargs
)
optimizer
.
__init__
(
fsdp_model
.
parameters
(),
**
optimizer
.
defaults
)
if
not
isinstance
(
optimizer
,
OptimizerWrapper
):
optimizer
=
OptimizerWrapper
(
optimizer
)
if
not
isinstance
(
optimizer
,
FSDP
OptimizerWrapper
):
optimizer
=
FSDP
OptimizerWrapper
(
optimizer
,
fsdp_model
)
return
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
return
fsdp_
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
def
control_checkpoint_io
(
self
)
->
bool
:
return
True
...
...
colossalai/checkpoint_io/general_checkpoint_io.py
View file @
6b305a99
from
pathlib
import
Path
import
gc
import
logging
import
os
from
functools
import
reduce
from
pathlib
import
Path
from
typing
import
Iterator
,
Optional
,
OrderedDict
,
Tuple
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
import
logging
import
os
import
gc
from
typing
import
Optional
,
Iterator
,
OrderedDict
,
Tuple
from
.checkpoint_io_base
import
CheckpointIO
from
.index_file
import
CheckpointIndexFile
from
.utils
import
(
has_index_file
,
load_state_dict
,
save_state_dict
,
get_base_filenames
,
get_shard_filename
,
has_index_file
,
is_safetensors_available
,
shard_checkpoint
,
load_shard_state_dict
,
load_state_dict
,
load_state_dict_into_model
,
get_shard_filename
,
get_base_filenames
)
save_state_dict
,
shard_checkpoint
,
)
__all__
=
[
'GeneralCheckpointIO'
]
...
...
@@ -29,6 +29,7 @@ class GeneralCheckpointIO(CheckpointIO):
"""
Checkpoint IO
"""
def
load_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
strict
:
bool
):
checkpoint
=
load_state_dict
(
checkpoint
)
model
.
load_state_dict
(
checkpoint
,
strict
=
strict
)
...
...
@@ -69,19 +70,23 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict
(
optimizer
.
state_dict
(),
checkpoint
,
use_safetensors
=
False
)
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
"""
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
"""
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
"""
if
os
.
path
.
isfile
(
checkpoint_path
):
logging
.
error
(
f
"Provided path (
{
checkpoint_path
}
) should be a directory, not a file"
)
return
Path
(
checkpoint_path
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# shard checkpoint
state_dict
=
model
.
state_dict
()
state_dict_shard
=
shard_checkpoint
(
state_dict
,
max_shard_size
=
max_shard_size
)
...
...
@@ -95,21 +100,22 @@ class GeneralCheckpointIO(CheckpointIO):
total_size
=
total_size
+
shard_pair
[
1
]
for
key
in
shard
.
keys
():
index_file
.
append_weight_map
(
key
,
shard_file
)
checkpoint_file_path
=
os
.
path
.
join
(
checkpoint_path
,
shard_file
)
save_state_dict
(
shard
,
checkpoint_file_path
,
use_safetensors
)
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
logging
.
info
(
f
"The model is going to be split to checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_index_file
:
Path
,
strict
:
bool
=
False
,
use_safetensors
:
bool
=
False
,
load_sub_module
:
bool
=
True
):
logging
.
info
(
f
"The model is going to be split to checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_index_file
:
Path
,
strict
:
bool
=
False
,
use_safetensors
:
bool
=
False
,
load_sub_module
:
bool
=
True
):
"""
load shard model, load model from multiple files
"""
...
...
@@ -119,7 +125,7 @@ class GeneralCheckpointIO(CheckpointIO):
if
use_safetensors
and
not
is_safetensors_available
():
raise
ImportError
(
"`safe_serialization` requires the `safetensors` library: `pip install safetensors`."
)
# read checkpoint index file
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
checkpoint_index_file
)
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_fileanames
()
...
...
@@ -134,10 +140,7 @@ class GeneralCheckpointIO(CheckpointIO):
if
strict
:
remain_keys
=
reduce
(
lambda
a
,
b
:
a
&
b
,
map
(
set
,
missing_keys
))
if
len
(
remain_keys
)
>
0
:
error_msgs
=
'Missing key(s) in state_dict: {}. '
.
format
(
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
missing_keys
))
error_msgs
=
'Missing key(s) in state_dict: {}. '
.
format
(
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
missing_keys
))
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
View file @
6b305a99
from
contextlib
import
nullcontext
import
pytest
import
torch
import
torch.distributed
as
dist
from
packaging
import
version
from
torch
import
nn
from
torch.optim
import
SGD
import
colossalai
...
...
@@ -19,6 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
from
tests.kit.model_zoo
import
model_zoo
# test baisc fsdp function
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
...
...
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
0 → 100644
View file @
6b305a99
import
pytest
import
torch
from
packaging
import
version
from
torch
import
nn
from
torch.optim
import
SGD
from
torchvision.models
import
resnet18
from
utils
import
shared_tempdir
import
colossalai
from
colossalai.booster
import
Booster
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
):
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
colossalai.booster.plugin
import
TorchFSDPPlugin
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
,
check_state_dict_equal
def
compare_nested_dict
(
dict1
,
dict2
):
for
key
in
dict1
:
if
key
in
dict2
:
if
type
(
dict1
[
key
])
is
dict
:
assert
type
(
dict2
[
key
])
is
dict
diff
=
compare_nested_dict
(
dict1
[
key
],
dict2
[
key
])
if
not
diff
:
return
diff
elif
type
(
dict1
[
key
])
is
list
:
assert
type
(
dict2
[
key
])
is
list
for
i
,
val
in
enumerate
(
dict1
[
key
]):
if
isinstance
(
val
,
torch
.
Tensor
):
if
not
torch
.
equal
(
dict1
[
key
][
i
],
dict2
[
key
][
i
]):
return
False
elif
val
!=
dict2
[
key
][
i
]:
return
False
elif
type
(
dict1
[
key
])
is
torch
.
Tensor
:
assert
type
(
dict2
[
key
])
is
torch
.
Tensor
if
not
torch
.
equal
(
dict1
[
key
],
dict2
[
key
]):
return
False
else
:
if
dict1
[
key
]
!=
dict2
[
key
]:
return
False
else
:
return
False
return
True
def
check_torch_fsdp_ckpt
():
model
=
resnet18
()
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
optimizer
=
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
)
criterion
=
lambda
x
:
x
.
mean
()
fsdp_model
,
optimizer
,
criterion
,
_
,
_
=
booster
.
boost
(
model
,
optimizer
,
criterion
)
inputs
=
torch
.
randn
(
4
,
3
,
224
,
224
)
outputs
=
None
def
run_model
():
nonlocal
outputs
outputs
=
fsdp_model
(
inputs
)
optimizer
.
zero_grad
()
criterion
(
outputs
).
backward
()
optimizer
.
step
()
with
shared_tempdir
()
as
tempdir
:
model_ckpt_path
=
f
"
{
tempdir
}
/model"
optim_ckpt_path
=
f
"
{
tempdir
}
/optimizer"
run_model
()
booster
.
save_model
(
fsdp_model
,
model_ckpt_path
,
shard
=
False
)
booster
.
save_optimizer
(
optimizer
,
optim_ckpt_path
,
shard
=
False
)
full_msd
=
fsdp_model
.
state_dict
()
#full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer)
sharded_osd
=
optimizer
.
state_dict
()
import
copy
sharded_osd
=
copy
.
deepcopy
(
sharded_osd
)
run_model
()
full_msd_updated
=
fsdp_model
.
state_dict
()
#full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
sharded_osd_updated
=
optimizer
.
state_dict
()
assert
not
compare_nested_dict
(
sharded_osd
,
sharded_osd_updated
)
assert
not
compare_nested_dict
(
full_msd_updated
,
full_msd
)
outputs_first
=
fsdp_model
(
inputs
)
assert
criterion
(
outputs_first
)
!=
criterion
(
outputs
)
booster
.
load_model
(
fsdp_model
,
model_ckpt_path
)
booster
.
load_optimizer
(
optimizer
,
optim_ckpt_path
)
full_msd_restore
=
fsdp_model
.
state_dict
()
#full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
sharded_osd_restore
=
optimizer
.
state_dict
()
assert
compare_nested_dict
(
sharded_osd
,
sharded_osd_restore
)
assert
compare_nested_dict
(
full_msd_restore
,
full_msd
)
outputs_sec
=
fsdp_model
(
inputs
)
assert
criterion
(
outputs_sec
)
==
criterion
(
outputs
)
def
run_dist
(
rank
,
world_size
,
port
):
# init dist env
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
check_torch_fsdp_ckpt
()
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
),
reason
=
"requires torch1.12 or higher"
)
@
rerun_if_address_is_in_use
()
def
test_torch_fsdp_ckpt
():
spawn
(
run_dist
,
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment