Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bf34c6fe
Unverified
Commit
bf34c6fe
authored
Feb 27, 2024
by
QinLuo
Committed by
GitHub
Feb 27, 2024
Browse files
[fsdp] impl save/load shard model/optimizer (#5357)
parent
d882d18c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
179 additions
and
12 deletions
+179
-12
colossalai/booster/plugin/torch_fsdp_plugin.py
colossalai/booster/plugin/torch_fsdp_plugin.py
+141
-12
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
+38
-0
No files found.
colossalai/booster/plugin/torch_fsdp_plugin.py
View file @
bf34c6fe
import
logging
import
os
import
warnings
import
warnings
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
...
@@ -25,7 +27,7 @@ from torch.optim import Optimizer
...
@@ -25,7 +27,7 @@ from torch.optim import Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
colossalai.checkpoint_io
import
CheckpointIO
,
GeneralCheckpointIO
,
utils
from
colossalai.checkpoint_io
import
CheckpointIO
,
GeneralCheckpointIO
,
utils
,
CheckpointIndexFile
from
colossalai.cluster
import
DistCoordinator
from
colossalai.cluster
import
DistCoordinator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
...
@@ -74,17 +76,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
...
@@ -74,17 +76,54 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
def
save_sharded_model
(
def
save_sharded_model
(
self
,
self
,
model
:
nn
.
Module
,
model
:
ModelWrapper
,
checkpoint
:
str
,
checkpoint
_path
:
str
,
gather_dtensor
:
bool
,
gather_dtensor
:
bool
=
True
,
prefix
:
Optional
[
str
],
prefix
:
Optional
[
str
]
=
None
,
size_per_shard
:
int
,
size_per_shard
:
int
=
1024
,
use_safetensors
:
bool
,
use_safetensors
:
bool
=
False
,
):
):
"""
"""
Save model to checkpoint but only on master process.
Save model to checkpoint but only on master process.
"""
"""
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
assert
isinstance
(
model
,
TorchFSDPModel
),
"Please boost the model before saving!"
if
os
.
path
.
isfile
(
checkpoint_path
):
logging
.
error
(
f
"Provided path (
{
checkpoint_path
}
) should be a directory, not a file"
)
return
Path
(
checkpoint_path
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
FSDP
.
state_dict_type
(
model
.
unwrap
(),
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
)
):
state_dict
=
model
.
unwrap
().
state_dict
()
state_dict_shard
=
utils
.
shard_model_checkpoint
(
state_dict
,
max_shard_size
=
size_per_shard
)
weights_name
,
save_index_file
=
utils
.
get_model_base_filenames
(
prefix
,
use_safetensors
)
index_file
=
CheckpointIndexFile
(
checkpoint_path
)
# In general cases, is_master is set to True to get the right behavior.
total_size
=
utils
.
save_state_dict_shards
(
sharded_state_dict
=
state_dict_shard
,
checkpoint
=
checkpoint_path
,
index_file
=
index_file
,
base_filename
=
weights_name
,
is_master
=
self
.
coordinator
.
is_master
(),
use_safetensors
=
use_safetensors
,
)
# only save the index file on the master rank
if
self
.
coordinator
.
is_master
():
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
utils
.
save_config_file
(
model
.
unwrap
(),
checkpoint_path
)
logging
.
info
(
f
"The model is split into checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
def
load_sharded_model
(
def
load_sharded_model
(
self
,
self
,
...
@@ -97,7 +136,24 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
...
@@ -97,7 +136,24 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
"""
Load model to checkpoint but only on master process.
Load model to checkpoint but only on master process.
"""
"""
raise
NotImplementedError
(
"Sharded model checkpoint is not supported yet."
)
assert
isinstance
(
model
,
TorchFSDPModel
),
"Please boost the model before loading!"
use_safetensors
=
False
if
"safetensors"
in
checkpoint_index_file
.
name
:
use_safetensors
=
True
if
use_safetensors
and
not
utils
.
is_safetensors_available
():
raise
ImportError
(
"`safe_serialization` requires the `safetensors` library: `pip install safetensors`."
)
# read checkpoint index file
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
checkpoint_index_file
)
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_filenames
()
fsdp_state_dict
=
{}
for
shard_file
in
checkpoint_files
:
fsdp_state_dict
.
update
(
utils
.
load_shard_state_dict
(
Path
(
shard_file
),
use_safetensors
))
with
FSDP
.
state_dict_type
(
model
.
unwrap
(),
StateDictType
.
FULL_STATE_DICT
):
model
.
unwrap
().
load_state_dict
(
fsdp_state_dict
,
strict
=
False
)
def
save_sharded_optimizer
(
def
save_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
str
,
size_per_shard
:
int
self
,
optimizer
:
Optimizer
,
checkpoint
:
str
,
gather_dtensor
:
bool
,
prefix
:
str
,
size_per_shard
:
int
...
@@ -105,13 +161,86 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
...
@@ -105,13 +161,86 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
"""
"""
Save optimizer to checkpoint but only on master process.
Save optimizer to checkpoint but only on master process.
"""
"""
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
assert
isinstance
(
optimizer
,
FSDPOptimizerWrapper
),
"Please boost the optimizer before saving!"
if
os
.
path
.
isfile
(
checkpoint
):
logging
.
error
(
f
"Provided path (
{
checkpoint
}
) should be a directory, not a file"
)
return
Path
(
checkpoint
).
mkdir
(
parents
=
True
,
exist_ok
=
True
)
with
FSDP
.
state_dict_type
(
optimizer
.
unwrap_model
().
unwrap
(),
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
)
):
fsdp_optim_state
=
FSDP
.
full_optim_state_dict
(
optimizer
.
unwrap_model
().
unwrap
(),
optim
=
optimizer
,
rank0_only
=
True
)
if
self
.
coordinator
.
is_master
():
# Preparing file paths and index file.
states_name
,
save_index_file
,
param_group_file
=
utils
.
get_optimizer_base_filenames
(
prefix
)
index_file
=
CheckpointIndexFile
(
checkpoint
)
index_file
.
append_meta_data
(
"param_groups"
,
param_group_file
)
group_file_path
=
os
.
path
.
join
(
checkpoint
,
param_group_file
)
utils
.
save_param_groups
(
fsdp_optim_state
,
group_file_path
)
sharded_state
=
utils
.
shard_optimizer_checkpoint
(
fsdp_optim_state
,
max_shard_size
=
size_per_shard
)
# Save shards of optimizer states.
# In general cases, is_master is set to True to get the right behavior.
total_size
=
utils
.
save_state_dict_shards
(
sharded_state_dict
=
sharded_state
,
checkpoint
=
checkpoint
,
index_file
=
index_file
,
base_filename
=
states_name
,
is_master
=
self
.
coordinator
.
is_master
(),
use_safetensors
=
False
,
)
index_file
.
append_meta_data
(
"total_size"
,
total_size
)
index_file
.
write_index_file
(
save_index_file
)
logging
.
info
(
f
"The optimizer is going to be split to checkpoint shards. "
f
"You can find where each parameters has been saved in the "
f
"index located at
{
save_index_file
}
."
)
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
size_per_shard
:
int
):
def
load_sharded_optimizer
(
self
,
optimizer
:
Optimizer
,
index_file_path
:
str
,
size_per_shard
:
int
):
"""
"""
Load optimizer to checkpoint but only on master process.
Load optimizer to checkpoint but only on master process.
"""
"""
raise
NotImplementedError
(
"Sharded optimizer checkpoint is not supported yet."
)
assert
isinstance
(
optimizer
,
FSDPOptimizerWrapper
),
"Please boost the optimizer before saving!"
ckpt_index_file
=
CheckpointIndexFile
.
from_file
(
index_file_path
)
# Load param_groups
param_group_path
=
ckpt_index_file
.
get_param_group_filename
()
if
param_group_path
is
None
:
raise
RuntimeError
(
f
"Invalid index file path
{
index_file_path
}
for an optimizer. "
"Looking param group file under current directory."
)
saved_param_groups
=
torch
.
load
(
param_group_path
)
# Load param
fsdp_optim_state
=
{}
checkpoint_files
,
_
=
ckpt_index_file
.
get_checkpoint_filenames
()
for
shard_file
in
checkpoint_files
:
state_dict_shard
=
utils
.
load_shard_state_dict
(
Path
(
shard_file
),
use_safetensors
=
False
)
fsdp_optim_state
.
update
(
state_dict_shard
)
fsdp_optim_dict
=
dict
(
state
=
fsdp_optim_state
,
param_groups
=
saved_param_groups
)
with
FSDP
.
state_dict_type
(
optimizer
.
unwrap_model
().
unwrap
(),
StateDictType
.
FULL_STATE_DICT
):
fsdp_state
=
FSDP
.
optim_state_dict_to_load
(
model
=
optimizer
.
unwrap_model
().
unwrap
(),
optim
=
optimizer
,
optim_state_dict
=
fsdp_optim_dict
)
optimizer
.
load_state_dict
(
fsdp_state
)
def
save_lr_scheduler
(
self
,
lr_scheduler
:
LRScheduler
,
checkpoint
:
str
):
def
save_lr_scheduler
(
self
,
lr_scheduler
:
LRScheduler
,
checkpoint
:
str
):
"""
"""
...
@@ -190,7 +319,7 @@ class TorchFSDPPlugin(DPPluginBase):
...
@@ -190,7 +319,7 @@ class TorchFSDPPlugin(DPPluginBase):
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
def
support_no_sync
(
self
)
->
bool
:
def
support_no_sync
(
self
)
->
bool
:
False
return
False
def
no_sync
(
self
,
model
:
nn
.
Module
,
optimizer
:
OptimizerWrapper
)
->
Iterator
[
None
]:
def
no_sync
(
self
,
model
:
nn
.
Module
,
optimizer
:
OptimizerWrapper
)
->
Iterator
[
None
]:
raise
NotImplementedError
(
"Torch fsdp no_sync func not supported yet."
)
raise
NotImplementedError
(
"Torch fsdp no_sync func not supported yet."
)
...
...
tests/test_checkpoint_io/test_torch_fsdp_checkpoint_io.py
View file @
bf34c6fe
...
@@ -10,6 +10,7 @@ from colossalai.booster import Booster
...
@@ -10,6 +10,7 @@ from colossalai.booster import Booster
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"1.12.0"
):
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
"1.12.0"
):
from
colossalai.booster.plugin
import
TorchFSDPPlugin
from
colossalai.booster.plugin
import
TorchFSDPPlugin
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
...
@@ -99,6 +100,43 @@ def check_torch_fsdp_ckpt():
...
@@ -99,6 +100,43 @@ def check_torch_fsdp_ckpt():
outputs_sec
=
fsdp_model
(
inputs
)
outputs_sec
=
fsdp_model
(
inputs
)
assert
criterion
(
outputs_sec
)
==
criterion
(
outputs
)
assert
criterion
(
outputs_sec
)
==
criterion
(
outputs
)
with
shared_tempdir
()
as
tempdir
:
model_ckpt_path
=
f
"
{
tempdir
}
/model"
optim_ckpt_path
=
f
"
{
tempdir
}
/optimizer"
run_model
()
booster
.
save_model
(
fsdp_model
,
model_ckpt_path
,
shard
=
True
)
booster
.
save_optimizer
(
optimizer
,
optim_ckpt_path
,
shard
=
True
)
full_msd
=
fsdp_model
.
unwrap
().
state_dict
()
full_osd
=
FSDP
.
full_optim_state_dict
(
optimizer
.
unwrap_model
().
unwrap
(),
optim
=
optimizer
)
import
copy
sharded_osd
=
copy
.
deepcopy
(
full_osd
)
run_model
()
full_msd_updated
=
fsdp_model
.
unwrap
().
state_dict
()
full_osd_updated
=
FSDP
.
full_optim_state_dict
(
optimizer
.
unwrap_model
().
unwrap
(),
optim
=
optimizer
)
# cost much time led to timeout
# assert not compare_nested_dict(full_osd_updated, sharded_osd)
# assert not compare_nested_dict(full_msd_updated, full_msd)
outputs_first
=
fsdp_model
(
inputs
)
assert
criterion
(
outputs_first
)
!=
criterion
(
outputs
)
booster
.
load_model
(
fsdp_model
,
model_ckpt_path
)
booster
.
load_optimizer
(
optimizer
,
optim_ckpt_path
)
full_msd_restore
=
fsdp_model
.
unwrap
().
state_dict
()
sharded_osd_restore
=
FSDP
.
full_optim_state_dict
(
optimizer
.
unwrap_model
().
unwrap
(),
optim
=
optimizer
)
assert
compare_nested_dict
(
sharded_osd
,
sharded_osd_restore
)
assert
compare_nested_dict
(
full_msd_restore
,
full_msd
)
outputs_sec
=
fsdp_model
(
inputs
)
assert
criterion
(
outputs_sec
)
==
criterion
(
outputs
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
# init dist env
# init dist env
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment