Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6b305a99
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "7080a8edb08400d97ba4c31458f532e4ceeacf4b"
Unverified
Commit
6b305a99
authored
May 23, 2023
by
wukong1992
Committed by
GitHub
May 23, 2023
Browse files
[booster] torch fsdp fix ckpt (#3788)
parent
9265f2d4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
223 additions
and
179 deletions
+223
-179
colossalai/booster/booster.py
colossalai/booster/booster.py
+1
-1
colossalai/booster/plugin/torch_fsdp_plugin.py
colossalai/booster/plugin/torch_fsdp_plugin.py
+69
-138
colossalai/checkpoint_io/general_checkpoint_io.py
colossalai/checkpoint_io/general_checkpoint_io.py
+39
-36
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+1
-4
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
+113
-0
No files found.
colossalai/booster/booster.py
View file @
6b305a99
...
@@ -196,7 +196,7 @@ class Booster:
...
@@ -196,7 +196,7 @@ class Booster:
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024.
"""
"""
self
.
checkpoint_io
.
save_model
(
model
,
checkpoint
,
prefix
,
shard
,
size_per_shard
)
self
.
checkpoint_io
.
save_model
(
model
,
checkpoint
=
checkpoint
,
shard
=
shard
,
size_per_shard
=
size_per_shard
)
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
def
load_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
"""Load optimizer from checkpoint.
"""Load optimizer from checkpoint.
...
...
colossalai/booster/plugin/torch_fsdp_plugin.py
View file @
6b305a99
from
pathlib
import
Path
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -5,30 +6,18 @@ import torch.nn as nn
...
@@ -5,30 +6,18 @@ import torch.nn as nn
from
packaging
import
version
from
packaging
import
version
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'
2.0
.0'
):
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'
1.12
.0'
):
from
torch.distributed.fsdp
import
FullStateDictConfig
from
torch.distributed.fsdp
import
FullStateDictConfig
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
StateDictType
from
torch.distributed.fsdp
import
StateDictType
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
(
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
(
BackwardPrefetch
,
BackwardPrefetch
,
CPUOffload
,
CPUOffload
,
MixedPrecision
,
ShardingStrategy
,
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp._init_utils
import
ProcessGroupType
from
torch.distributed.fsdp.api
import
(
BackwardPrefetch
,
CPUOffload
,
FullOptimStateDictConfig
,
FullStateDictConfig
,
FullStateDictConfig
,
MixedPrecision
,
MixedPrecision
,
ShardingStrategy
,
ShardingStrategy
,
StateDictType
,
)
)
from
torch.distributed.fsdp.wrap
import
_FSDPPolicy
else
:
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
...
@@ -36,7 +25,7 @@ from torch.optim import Optimizer
...
@@ -36,7 +25,7 @@ from torch.optim import Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
colossalai.checkpoint_io
import
CheckpointIO
,
GeneralCheckpointIO
from
colossalai.checkpoint_io
import
CheckpointIO
,
GeneralCheckpointIO
,
utils
from
colossalai.cluster
import
DistCoordinator
from
colossalai.cluster
import
DistCoordinator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
...
@@ -51,102 +40,71 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
...
@@ -51,102 +40,71 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
super
().
__init__
()
super
().
__init__
()
self
.
coordinator
=
DistCoordinator
()
self
.
coordinator
=
DistCoordinator
()
def
__set_model_optim_state
(
def
load_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
strict
:
bool
):
self
,
checkpoint
=
utils
.
load_state_dict
(
checkpoint
)
model
,
model
.
load_state_dict
(
checkpoint
)
state_dict_type
,
state_dict_config
,
optim_state_dict_config
,
):
return
FSDP
.
set_state_dict_type
(
model
,
state_dict_type
,
state_dict_config
,
optim_state_dict_config
)
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
load_sharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
save_sharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
def
load_unsharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
Path
):
checkpoint
=
utils
.
load_state_dict
(
checkpoint
)
fsdp_model
=
optimizer
.
unwrap_model
()
sharded_osd
=
FSDP
.
scatter_full_optim_state_dict
(
checkpoint
,
fsdp_model
)
optimizer
.
load_state_dict
(
sharded_osd
)
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
def
save_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
use_safetensors
:
bool
):
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
load_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
"""
"""
Load
model
from
checkpoint
with automatic unwrapping
.
Save
model
to
checkpoint
but only on master process
.
"""
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
cfg
=
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
)
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
cfg
):
full_model_state
=
model
.
state_dict
()
utils
.
save_state_dict
(
full_model_state
,
checkpoint_file_path
=
checkpoint
,
use_safetensors
=
use_safetensors
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
def
save_unsharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
rank0_only
=
True
))
full_state_dict
=
model
.
state_dict
()
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
model
.
load_state_dict
(
full_state_dict
)
def
load_unsharded_optimizer
(
self
,
model
:
nn
.
Module
,
optim
:
Optimizer
,
checkpoint
:
str
):
"""
"""
Load O
ptimizer
from
checkpoint
with automatic unwrapping
.
Save o
ptimizer
to
checkpoint
but only on master process
.
"""
"""
assert
isinstance
(
optimizer
,
FSDPOptimizerWrapper
)
fsdp_model
=
optimizer
.
unwrap_model
()
full_optimizer_state
=
FSDP
.
full_optim_state_dict
(
fsdp_model
,
optim
=
optimizer
,
rank0_only
=
True
)
utils
.
save_state_dict
(
full_optimizer_state
,
checkpoint_file_path
=
checkpoint
,
use_safetensors
=
False
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
variant
:
Optional
[
str
],
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
size_per_shard
:
int
,
use_safetensors
:
bool
):
optim_full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
optim_full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
FSDP
.
full_optim_state_dict_to_load
(
optim_full_state_dict
,
model
,
optim
)
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
optim
.
load_state_dict
(
optim_full_state_dict
)
def
save_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
"""
"""
Save model to checkpoint but only on master process.
Save model to checkpoint but only on master process.
"""
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_index_file
:
Path
,
strict
:
bool
=
False
,
use_safetensors
:
bool
=
False
,
load_sub_module
:
bool
=
True
):
"""
Load model to checkpoint but only on master process.
"""
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
def
save_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
):
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
cfg
=
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
)
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
cfg
):
model_state_dict
=
model
.
state_dict
()
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
rank0_only
=
True
))
model_state_dict
=
model
.
state_dict
()
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
self
.
save_checkpoint
(
model_state_dict
,
checkpoint
)
def
save_unsharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
"""
"""
Save optimizer to checkpoint but only on master process.
Save optimizer to checkpoint but only on master process.
"""
"""
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
prefix
:
str
,
size_per_shard
:
int
):
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
"""
optim_state_dict
=
FSDP
.
full_optim_state_dict
(
model
=
model
,
optim
=
optimizer
)
Load optimizer to checkpoint but only on master process.
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
"""
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
FullOptimStateDictConfig
(
rank0_only
=
True
))
optim_state_dict
=
FSDP
.
optim_state_dict
(
model
,
optimizer
)
def
save_lr_scheduler
(
self
,
lr_scheduler
:
LRScheduler
,
checkpoint
:
str
):
else
:
"""
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
Save model to checkpoint but only on master process.
self
.
save_checkpoint
(
optim_state_dict
,
checkpoint
)
"""
if
self
.
coordinator
.
is_master
():
super
().
save_lr_scheduler
(
lr_scheduler
,
checkpoint
)
class
TorchFSDPModel
(
ModelWrapper
):
class
TorchFSDPModel
(
ModelWrapper
):
...
@@ -156,7 +114,17 @@ class TorchFSDPModel(ModelWrapper):
...
@@ -156,7 +114,17 @@ class TorchFSDPModel(ModelWrapper):
self
.
module
=
FSDP
(
module
,
*
args
,
**
kwargs
)
self
.
module
=
FSDP
(
module
,
*
args
,
**
kwargs
)
def
unwrap
(
self
):
def
unwrap
(
self
):
return
self
.
module
.
module
return
self
.
module
class
FSDPOptimizerWrapper
(
OptimizerWrapper
):
def
__init__
(
self
,
optimizer
:
Optimizer
,
model
:
nn
.
Module
):
self
.
model
=
model
super
().
__init__
(
optimizer
)
def
unwrap_model
(
self
)
->
nn
.
Module
:
return
self
.
model
class
TorchFSDPPlugin
(
DPPluginBase
):
class
TorchFSDPPlugin
(
DPPluginBase
):
...
@@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase):
...
@@ -178,8 +146,7 @@ class TorchFSDPPlugin(DPPluginBase):
See https://pytorch.org/docs/stable/fsdp.html for details.
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
"""
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
):
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -191,7 +158,6 @@ class TorchFSDPPlugin(DPPluginBase):
...
@@ -191,7 +158,6 @@ class TorchFSDPPlugin(DPPluginBase):
mixed_precision
:
Optional
[
MixedPrecision
]
=
None
,
mixed_precision
:
Optional
[
MixedPrecision
]
=
None
,
ignored_modules
:
Optional
[
Iterable
[
torch
.
nn
.
Module
]]
=
None
,
ignored_modules
:
Optional
[
Iterable
[
torch
.
nn
.
Module
]]
=
None
,
param_init_fn
:
Optional
[
Callable
[[
nn
.
Module
],
None
]]
=
None
,
param_init_fn
:
Optional
[
Callable
[[
nn
.
Module
],
None
]]
=
None
,
device_id
:
Optional
[
Union
[
int
,
torch
.
device
]]
=
None
,
sync_module_states
:
bool
=
False
,
sync_module_states
:
bool
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -203,42 +169,7 @@ class TorchFSDPPlugin(DPPluginBase):
...
@@ -203,42 +169,7 @@ class TorchFSDPPlugin(DPPluginBase):
mixed_precision
=
mixed_precision
,
mixed_precision
=
mixed_precision
,
ignored_modules
=
ignored_modules
,
ignored_modules
=
ignored_modules
,
param_init_fn
=
param_init_fn
,
param_init_fn
=
param_init_fn
,
device_id
=
device_id
,
sync_module_states
=
sync_module_states
)
sync_module_states
=
sync_module_states
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
def
__init__
(
self
,
process_group
:
ProcessGroupType
=
None
,
sharding_strategy
:
Optional
[
ShardingStrategy
]
=
None
,
cpu_offload
:
Optional
[
CPUOffload
]
=
None
,
auto_wrap_policy
:
Optional
[
Union
[
Callable
,
_FSDPPolicy
]]
=
None
,
backward_prefetch
:
Optional
[
BackwardPrefetch
]
=
BackwardPrefetch
.
BACKWARD_PRE
,
mixed_precision
:
Optional
[
MixedPrecision
]
=
None
,
ignored_modules
:
Optional
[
Iterable
[
torch
.
nn
.
Module
]]
=
None
,
param_init_fn
:
Optional
[
Callable
[[
nn
.
Module
],
None
]]
=
None
,
device_id
:
Optional
[
Union
[
int
,
torch
.
device
]]
=
None
,
sync_module_states
:
bool
=
False
,
forward_prefetch
:
bool
=
False
,
limit_all_gathers
:
bool
=
False
,
use_orig_params
:
bool
=
False
,
ignored_parameters
:
Optional
[
Iterable
[
torch
.
nn
.
Parameter
]]
=
None
,
):
super
().
__init__
()
self
.
fsdp_kwargs
=
dict
(
process_group
=
process_group
,
sharding_strategy
=
sharding_strategy
,
cpu_offload
=
cpu_offload
,
auto_wrap_policy
=
auto_wrap_policy
,
backward_prefetch
=
backward_prefetch
,
mixed_precision
=
mixed_precision
,
ignored_modules
=
ignored_modules
,
param_init_fn
=
param_init_fn
,
device_id
=
device_id
,
sync_module_states
=
sync_module_states
,
forward_prefetch
=
forward_prefetch
,
limit_all_gathers
=
limit_all_gathers
,
use_orig_params
=
use_orig_params
,
ignored_parameters
=
ignored_parameters
)
else
:
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
...
@@ -269,14 +200,14 @@ class TorchFSDPPlugin(DPPluginBase):
...
@@ -269,14 +200,14 @@ class TorchFSDPPlugin(DPPluginBase):
lr_scheduler
:
LRScheduler
=
None
,
lr_scheduler
:
LRScheduler
=
None
,
)
->
Tuple
[
Union
[
nn
.
Module
,
OptimizerWrapper
,
LRScheduler
,
DataLoader
]]:
)
->
Tuple
[
Union
[
nn
.
Module
,
OptimizerWrapper
,
LRScheduler
,
DataLoader
]]:
model
=
model
.
cuda
()
# wrap the model with PyTorch FSDP
# wrap the model with PyTorch FSDP
model
=
TorchFSDPModel
(
model
,
**
self
.
fsdp_kwargs
)
fsdp_model
=
TorchFSDPModel
(
model
,
device_id
=
torch
.
cuda
.
current_device
(),
**
self
.
fsdp_kwargs
)
optimizer
.
__init__
(
fsdp_model
.
parameters
(),
**
optimizer
.
defaults
)
if
not
isinstance
(
optimizer
,
OptimizerWrapper
):
if
not
isinstance
(
optimizer
,
FSDP
OptimizerWrapper
):
optimizer
=
OptimizerWrapper
(
optimizer
)
optimizer
=
FSDP
OptimizerWrapper
(
optimizer
,
fsdp_model
)
return
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
return
fsdp_
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
def
control_checkpoint_io
(
self
)
->
bool
:
def
control_checkpoint_io
(
self
)
->
bool
:
return
True
return
True
...
...
colossalai/checkpoint_io/general_checkpoint_io.py
View file @
6b305a99
from
pathlib
import
Path
import
gc
import
logging
import
os
from
functools
import
reduce
from
functools
import
reduce
from
pathlib
import
Path
from
typing
import
Iterator
,
Optional
,
OrderedDict
,
Tuple
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
import
logging
import
os
import
gc
from
typing
import
Optional
,
Iterator
,
OrderedDict
,
Tuple
from
.checkpoint_io_base
import
CheckpointIO
from
.checkpoint_io_base
import
CheckpointIO
from
.index_file
import
CheckpointIndexFile
from
.index_file
import
CheckpointIndexFile
from
.utils
import
(
from
.utils
import
(
has_index_file
,
get_base_filenames
,
load_state_dict
,
get_shard_filename
,
save_state_dict
,
has_index_file
,
is_safetensors_available
,
is_safetensors_available
,
shard_checkpoint
,
load_shard_state_dict
,
load_shard_state_dict
,
load_state_dict
,
load_state_dict_into_model
,
load_state_dict_into_model
,
get_shard_filename
,
save_state_dict
,
get_base_filenames
shard_checkpoint
,
)
)
__all__
=
[
'GeneralCheckpointIO'
]
__all__
=
[
'GeneralCheckpointIO'
]
...
@@ -29,6 +29,7 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -29,6 +29,7 @@ class GeneralCheckpointIO(CheckpointIO):
"""
"""
Checkpoint IO
Checkpoint IO
"""
"""
def
load_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
strict
:
bool
):
def
load_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
,
strict
:
bool
):
checkpoint
=
load_state_dict
(
checkpoint
)
checkpoint
=
load_state_dict
(
checkpoint
)
model
.
load_state_dict
(
checkpoint
,
strict
=
strict
)
model
.
load_state_dict
(
checkpoint
,
strict
=
strict
)
...
@@ -69,19 +70,23 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -69,19 +70,23 @@ class GeneralCheckpointIO(CheckpointIO):
# TODO(FrankLeeeee): handle distributed tensors
# TODO(FrankLeeeee): handle distributed tensors
save_state_dict
(
optimizer
.
state_dict
(),
checkpoint
,
use_safetensors
=
False
)
save_state_dict
(
optimizer
.
state_dict
(),
checkpoint
,
use_safetensors
=
False
)
def
save_sharded_model
(
self
,
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_path
:
str
,
gather_dtensor
:
bool
=
False
,
model
:
nn
.
Module
,
variant
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
checkpoint_path
:
str
,
"""
gather_dtensor
:
bool
=
False
,
variant
:
Optional
[
str
]
=
None
,
max_shard_size
:
int
=
1024
,
use_safetensors
:
bool
=
False
):
"""
implement this method as it can be supported by Huggingface model,
implement this method as it can be supported by Huggingface model,
save shard model, save model to multiple files
save shard model, save model to multiple files
"""
"""
if
os
.
path
.
isfile
(
checkpoint_path
):
if
os
.
path
.
isfile
(
checkpoint_path
):
logging
.
error
(
f
"Provided path (
{
checkpoint_path
}
) should be a directory, not a file"
)
logging
.
error
(
f
"Provided path (
{
checkpoint_path
}
) should be a directory, not a file"
)
return
return
Path
(
checkpoint_path
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
Path
(
checkpoint_path
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# shard checkpoint
# shard checkpoint
state_dict
=
model
.
state_dict
()
state_dict
=
model
.
state_dict
()
state_dict_shard
=
shard_checkpoint
(
state_dict
,
max_shard_size
=
max_shard_size
)
state_dict_shard
=
shard_checkpoint
(
state_dict
,
max_shard_size
=
max_shard_size
)
...
@@ -95,21 +100,22 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -95,21 +100,22 @@ class GeneralCheckpointIO(CheckpointIO):
total_size
=
total_size
+
shard_pair
[
1
]
total_size
=
total_size
+
shard_pair
[
1
]
for
key
in
shard
.
keys
():
for
key
in
shard
.
keys
():
index_file
.
append_weight_map
(
key
,
shard_file
)
index_file
.
append_weight_map
(
key
,
shard_file
)
checkpoint_file_path
=
os
.
path
.
join
(
checkpoint_path
,
shard_file
)
checkpoint_file_path
=
os
.
path
.
join
(
checkpoint_path
,
shard_file
)
save_state_dict
(
shard
,
checkpoint_file_path
,
use_safetensors
)
save_state_dict
(
shard
,
checkpoint_file_path
,
use_safetensors
)
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
index_file
.
write_index_file
(
save_index_file
)
logging
.
info
(
logging
.
info
(
f
"The model is going to be split to checkpoint shards. "
f
"The model is going to be split to checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
f
"index located at
{
save_index_file
}
."
)
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_index_file
:
Path
,
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint_index_file
:
Path
,
strict
:
bool
=
False
,
strict
:
bool
=
False
,
use_safetensors
:
bool
=
False
,
load_sub_module
:
bool
=
True
):
use_safetensors
:
bool
=
False
,
load_sub_module
:
bool
=
True
):
"""
"""
load shard model, load model from multiple files
load shard model, load model from multiple files
"""
"""
...
@@ -119,7 +125,7 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -119,7 +125,7 @@ class GeneralCheckpointIO(CheckpointIO):
if
use_safetensors
and
not
is_safetensors_available
():
if
use_safetensors
and
not
is_safetensors_available
():
raise
ImportError
(
"`safe_serialization` requires the `safetensors` library: `pip install safetensors`."
)
raise
ImportError
(
"`safe_serialization` requires the `safetensors` library: `pip install safetensors`."
)
# read checkpoint index file
# read checkpoint index file
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
checkpoint_index_file
)
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
checkpoint_index_file
)
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_fileanames
()
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_fileanames
()
...
@@ -134,10 +140,7 @@ class GeneralCheckpointIO(CheckpointIO):
...
@@ -134,10 +140,7 @@ class GeneralCheckpointIO(CheckpointIO):
if
strict
:
if
strict
:
remain_keys
=
reduce
(
lambda
a
,
b
:
a
&
b
,
map
(
set
,
missing_keys
))
remain_keys
=
reduce
(
lambda
a
,
b
:
a
&
b
,
map
(
set
,
missing_keys
))
if
len
(
remain_keys
)
>
0
:
if
len
(
remain_keys
)
>
0
:
error_msgs
=
'Missing key(s) in state_dict: {}. '
.
format
(
error_msgs
=
'Missing key(s) in state_dict: {}. '
.
format
(
', '
.
join
(
', '
.
join
(
'"{}"'
.
format
(
k
)
for
k
in
missing_keys
))
'"{}"'
.
format
(
k
)
for
k
in
missing_keys
))
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
self
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
View file @
6b305a99
from
contextlib
import
nullcontext
import
pytest
import
pytest
import
torch
import
torch
import
torch.distributed
as
dist
from
packaging
import
version
from
packaging
import
version
from
torch
import
nn
from
torch.optim
import
SGD
from
torch.optim
import
SGD
import
colossalai
import
colossalai
...
@@ -19,6 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
...
@@ -19,6 +15,7 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
model_zoo
# test baisc fsdp function
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchFSDPPlugin
()
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
booster
=
Booster
(
plugin
=
plugin
)
...
...
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
0 → 100644
View file @
6b305a99
import
pytest
import
torch
from
packaging
import
version
from
torch
import
nn
from
torch.optim
import
SGD
from
torchvision.models
import
resnet18
from
utils
import
shared_tempdir
import
colossalai
from
colossalai.booster
import
Booster
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
):
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
colossalai.booster.plugin
import
TorchFSDPPlugin
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
,
check_state_dict_equal
def
compare_nested_dict
(
dict1
,
dict2
):
for
key
in
dict1
:
if
key
in
dict2
:
if
type
(
dict1
[
key
])
is
dict
:
assert
type
(
dict2
[
key
])
is
dict
diff
=
compare_nested_dict
(
dict1
[
key
],
dict2
[
key
])
if
not
diff
:
return
diff
elif
type
(
dict1
[
key
])
is
list
:
assert
type
(
dict2
[
key
])
is
list
for
i
,
val
in
enumerate
(
dict1
[
key
]):
if
isinstance
(
val
,
torch
.
Tensor
):
if
not
torch
.
equal
(
dict1
[
key
][
i
],
dict2
[
key
][
i
]):
return
False
elif
val
!=
dict2
[
key
][
i
]:
return
False
elif
type
(
dict1
[
key
])
is
torch
.
Tensor
:
assert
type
(
dict2
[
key
])
is
torch
.
Tensor
if
not
torch
.
equal
(
dict1
[
key
],
dict2
[
key
]):
return
False
else
:
if
dict1
[
key
]
!=
dict2
[
key
]:
return
False
else
:
return
False
return
True
def
check_torch_fsdp_ckpt
():
model
=
resnet18
()
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
optimizer
=
SGD
(
model
.
parameters
(),
lr
=
0.001
,
momentum
=
0.9
)
criterion
=
lambda
x
:
x
.
mean
()
fsdp_model
,
optimizer
,
criterion
,
_
,
_
=
booster
.
boost
(
model
,
optimizer
,
criterion
)
inputs
=
torch
.
randn
(
4
,
3
,
224
,
224
)
outputs
=
None
def
run_model
():
nonlocal
outputs
outputs
=
fsdp_model
(
inputs
)
optimizer
.
zero_grad
()
criterion
(
outputs
).
backward
()
optimizer
.
step
()
with
shared_tempdir
()
as
tempdir
:
model_ckpt_path
=
f
"
{
tempdir
}
/model"
optim_ckpt_path
=
f
"
{
tempdir
}
/optimizer"
run_model
()
booster
.
save_model
(
fsdp_model
,
model_ckpt_path
,
shard
=
False
)
booster
.
save_optimizer
(
optimizer
,
optim_ckpt_path
,
shard
=
False
)
full_msd
=
fsdp_model
.
state_dict
()
#full_osd = FSDP.full_optim_state_dict(fsdp_model, optimizer)
sharded_osd
=
optimizer
.
state_dict
()
import
copy
sharded_osd
=
copy
.
deepcopy
(
sharded_osd
)
run_model
()
full_msd_updated
=
fsdp_model
.
state_dict
()
#full_osd_updated = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
sharded_osd_updated
=
optimizer
.
state_dict
()
assert
not
compare_nested_dict
(
sharded_osd
,
sharded_osd_updated
)
assert
not
compare_nested_dict
(
full_msd_updated
,
full_msd
)
outputs_first
=
fsdp_model
(
inputs
)
assert
criterion
(
outputs_first
)
!=
criterion
(
outputs
)
booster
.
load_model
(
fsdp_model
,
model_ckpt_path
)
booster
.
load_optimizer
(
optimizer
,
optim_ckpt_path
)
full_msd_restore
=
fsdp_model
.
state_dict
()
#full_osd_restore = FSDP.full_optim_state_dict(fsdp_model, optimizer, rank0_only=True)
sharded_osd_restore
=
optimizer
.
state_dict
()
assert
compare_nested_dict
(
sharded_osd
,
sharded_osd_restore
)
assert
compare_nested_dict
(
full_msd_restore
,
full_msd
)
outputs_sec
=
fsdp_model
(
inputs
)
assert
criterion
(
outputs_sec
)
==
criterion
(
outputs
)
def
run_dist
(
rank
,
world_size
,
port
):
# init dist env
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
check_torch_fsdp_ckpt
()
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
),
reason
=
"requires torch1.12 or higher"
)
@
rerun_if_address_is_in_use
()
def
test_torch_fsdp_ckpt
():
spawn
(
run_dist
,
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment