Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b37797ed
Unverified
Commit
b37797ed
authored
May 15, 2023
by
wukong1992
Committed by
GitHub
May 15, 2023
Browse files
[booster] support torch fsdp plugin in booster (#3697)
Co-authored-by:
纪少敏
<
jishaomin@jishaomindeMBP.lan
>
parent
ad6460cf
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
358 additions
and
2 deletions
+358
-2
colossalai/booster/plugin/__init__.py
colossalai/booster/plugin/__init__.py
+7
-0
colossalai/booster/plugin/torch_fsdp_plugin.py
colossalai/booster/plugin/torch_fsdp_plugin.py
+285
-0
colossalai/testing/utils.py
colossalai/testing/utils.py
+2
-2
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
+64
-0
No files found.
colossalai/booster/plugin/__init__.py
View file @
b37797ed
...
...
@@ -4,3 +4,10 @@ from .plugin_base import Plugin
from
.torch_ddp_plugin
import
TorchDDPPlugin
__all__
=
[
'Plugin'
,
'TorchDDPPlugin'
,
'GeminiPlugin'
,
'LowLevelZeroPlugin'
]
import
torch
from
packaging
import
version
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
):
from
.torch_fsdp_plugin
import
TorchFSDPPlugin
__all__
.
append
(
'TorchFSDPPlugin'
)
colossalai/booster/plugin/torch_fsdp_plugin.py
0 → 100644
View file @
b37797ed
from
typing
import
Callable
,
Iterable
,
Iterator
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
packaging
import
version
from
torch.distributed
import
ProcessGroup
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
from
torch.distributed.fsdp
import
FullStateDictConfig
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp
import
StateDictType
from
torch.distributed.fsdp.fully_sharded_data_parallel
import
(
BackwardPrefetch
,
CPUOffload
,
MixedPrecision
,
ShardingStrategy
,
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
torch.distributed.fsdp._init_utils
import
ProcessGroupType
from
torch.distributed.fsdp.api
import
(
BackwardPrefetch
,
CPUOffload
,
FullOptimStateDictConfig
,
FullStateDictConfig
,
MixedPrecision
,
ShardingStrategy
,
StateDictType
,
)
from
torch.distributed.fsdp.wrap
import
_FSDPPolicy
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
as
LRScheduler
from
torch.utils.data
import
DataLoader
from
colossalai.checkpoint_io
import
CheckpointIO
,
GeneralCheckpointIO
from
colossalai.cluster
import
DistCoordinator
from
colossalai.interface
import
ModelWrapper
,
OptimizerWrapper
from
.dp_plugin_base
import
DPPluginBase
__all__
=
[
'TorchFSDPPlugin'
]
class
TorchFSDPCheckpointIO
(
GeneralCheckpointIO
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
coordinator
=
DistCoordinator
()
def
__set_model_optim_state
(
self
,
model
,
state_dict_type
,
state_dict_config
,
optim_state_dict_config
,
):
return
FSDP
.
set_state_dict_type
(
model
,
state_dict_type
,
state_dict_config
,
optim_state_dict_config
)
def
load_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
load_sharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
save_sharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
save_sharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
# TODO(jishaomin): implement this method as it can be supported by Huggingface model
raise
NotImplementedError
(
"Torch FSDP sharded model checkpoint is not supported yet."
)
def
load_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
"""
Load model from checkpoint with automatic unwrapping.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
rank0_only
=
True
))
full_state_dict
=
model
.
state_dict
()
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
model
.
load_state_dict
(
full_state_dict
)
def
load_unsharded_optimizer
(
self
,
model
:
nn
.
Module
,
optim
:
Optimizer
,
checkpoint
:
str
):
"""
Load Optimizer from checkpoint with automatic unwrapping.
"""
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
optim_full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
optim_full_state_dict
=
self
.
load_state_dict
(
checkpoint
)
FSDP
.
full_optim_state_dict_to_load
(
optim_full_state_dict
,
model
,
optim
)
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
optim
.
load_state_dict
(
optim_full_state_dict
)
def
save_unsharded_model
(
self
,
model
:
nn
.
Module
,
checkpoint
:
str
):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
cfg
=
FullStateDictConfig
(
offload_to_cpu
=
True
,
rank0_only
=
True
)
with
FSDP
.
state_dict_type
(
model
,
StateDictType
.
FULL_STATE_DICT
,
cfg
):
model_state_dict
=
model
.
state_dict
()
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullStateDictConfig
(
rank0_only
=
True
))
model_state_dict
=
model
.
state_dict
()
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
self
.
save_checkpoint
(
model_state_dict
,
checkpoint
)
def
save_unsharded_optimizer
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
checkpoint
:
str
):
"""
Save optimizer to checkpoint but only on master process.
"""
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
optim_state_dict
=
FSDP
.
full_optim_state_dict
(
model
=
model
,
optim
=
optimizer
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
self
.
__set_model_optim_state
(
model
,
StateDictType
.
FULL_STATE_DICT
,
FullOptimStateDictConfig
(
rank0_only
=
True
))
optim_state_dict
=
FSDP
.
optim_state_dict
(
model
,
optimizer
)
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
self
.
save_checkpoint
(
optim_state_dict
,
checkpoint
)
class
TorchFSDPModel
(
ModelWrapper
):
def
__init__
(
self
,
module
:
nn
.
Module
,
*
args
,
**
kwargs
)
->
None
:
super
().
__init__
(
module
)
self
.
module
=
FSDP
(
module
,
*
args
,
**
kwargs
)
def
unwrap
(
self
):
return
self
.
module
.
module
class
TorchFSDPPlugin
(
DPPluginBase
):
"""
Plugin for PyTorch FSDP.
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import TorchFSDPPlugin
>>>
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = TorchFSDPPlugin()
>>> train_dataloader = plugin.prepare_train_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, train_dataloader, criterion = booster.boost(model, optimizer, train_dataloader, criterion)
Args:
See https://pytorch.org/docs/stable/fsdp.html for details.
"""
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
)
and
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'2.0.0'
):
def
__init__
(
self
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
sharding_strategy
:
Optional
[
ShardingStrategy
]
=
None
,
cpu_offload
:
Optional
[
CPUOffload
]
=
None
,
auto_wrap_policy
:
Optional
[
Callable
]
=
None
,
backward_prefetch
:
Optional
[
BackwardPrefetch
]
=
None
,
mixed_precision
:
Optional
[
MixedPrecision
]
=
None
,
ignored_modules
:
Optional
[
Iterable
[
torch
.
nn
.
Module
]]
=
None
,
param_init_fn
:
Optional
[
Callable
[[
nn
.
Module
],
None
]]
=
None
,
device_id
:
Optional
[
Union
[
int
,
torch
.
device
]]
=
None
,
sync_module_states
:
bool
=
False
,
):
super
().
__init__
()
self
.
fsdp_kwargs
=
dict
(
process_group
=
process_group
,
sharding_strategy
=
sharding_strategy
,
cpu_offload
=
cpu_offload
,
auto_wrap_policy
=
auto_wrap_policy
,
backward_prefetch
=
backward_prefetch
,
mixed_precision
=
mixed_precision
,
ignored_modules
=
ignored_modules
,
param_init_fn
=
param_init_fn
,
device_id
=
device_id
,
sync_module_states
=
sync_module_states
)
elif
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'2.0.0'
):
def
__init__
(
self
,
process_group
:
ProcessGroupType
=
None
,
sharding_strategy
:
Optional
[
ShardingStrategy
]
=
None
,
cpu_offload
:
Optional
[
CPUOffload
]
=
None
,
auto_wrap_policy
:
Optional
[
Union
[
Callable
,
_FSDPPolicy
]]
=
None
,
backward_prefetch
:
Optional
[
BackwardPrefetch
]
=
BackwardPrefetch
.
BACKWARD_PRE
,
mixed_precision
:
Optional
[
MixedPrecision
]
=
None
,
ignored_modules
:
Optional
[
Iterable
[
torch
.
nn
.
Module
]]
=
None
,
param_init_fn
:
Optional
[
Callable
[[
nn
.
Module
],
None
]]
=
None
,
device_id
:
Optional
[
Union
[
int
,
torch
.
device
]]
=
None
,
sync_module_states
:
bool
=
False
,
forward_prefetch
:
bool
=
False
,
limit_all_gathers
:
bool
=
False
,
use_orig_params
:
bool
=
False
,
ignored_parameters
:
Optional
[
Iterable
[
torch
.
nn
.
Parameter
]]
=
None
,
):
super
().
__init__
()
self
.
fsdp_kwargs
=
dict
(
process_group
=
process_group
,
sharding_strategy
=
sharding_strategy
,
cpu_offload
=
cpu_offload
,
auto_wrap_policy
=
auto_wrap_policy
,
backward_prefetch
=
backward_prefetch
,
mixed_precision
=
mixed_precision
,
ignored_modules
=
ignored_modules
,
param_init_fn
=
param_init_fn
,
device_id
=
device_id
,
sync_module_states
=
sync_module_states
,
forward_prefetch
=
forward_prefetch
,
limit_all_gathers
=
limit_all_gathers
,
use_orig_params
=
use_orig_params
,
ignored_parameters
=
ignored_parameters
)
else
:
raise
RuntimeError
(
"FSDP is not supported while torch version under 1.12.0."
)
def
support_no_sync
(
self
)
->
bool
:
False
def
no_sync
(
self
,
model
:
nn
.
Module
)
->
Iterator
[
None
]:
raise
NotImplementedError
(
"Torch fsdp no_sync func not supported yet."
)
def
control_precision
(
self
)
->
bool
:
return
True
def
supported_precisions
(
self
)
->
List
[
str
]:
return
[
'fp16'
,
'bf16'
]
def
control_device
(
self
)
->
bool
:
return
True
def
supported_devices
(
self
)
->
List
[
str
]:
return
[
'cuda'
]
def
configure
(
self
,
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
criterion
:
Callable
=
None
,
dataloader
:
DataLoader
=
None
,
lr_scheduler
:
LRScheduler
=
None
,
)
->
Tuple
[
Union
[
nn
.
Module
,
OptimizerWrapper
,
LRScheduler
,
DataLoader
]]:
model
=
model
.
cuda
()
# wrap the model with PyTorch FSDP
model
=
TorchFSDPModel
(
model
,
**
self
.
fsdp_kwargs
)
if
not
isinstance
(
optimizer
,
OptimizerWrapper
):
optimizer
=
OptimizerWrapper
(
optimizer
)
return
model
,
optimizer
,
criterion
,
dataloader
,
lr_scheduler
def
control_checkpoint_io
(
self
)
->
bool
:
return
True
def
get_checkpoint_io
(
self
)
->
CheckpointIO
:
return
TorchFSDPCheckpointIO
()
colossalai/testing/utils.py
View file @
b37797ed
...
...
@@ -167,10 +167,10 @@ def rerun_if_address_is_in_use():
"""
# check version
torch_version
=
version
.
parse
(
torch
.
__version__
)
assert
torch_version
.
major
=
=
1
assert
torch_version
.
major
>
=
1
# only torch >= 1.8 has ProcessRaisedException
if
torch_version
.
minor
>=
8
:
if
torch_version
>=
version
.
parse
(
"1.8.0"
)
:
exception
=
torch
.
multiprocessing
.
ProcessRaisedException
else
:
exception
=
Exception
...
...
tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
0 → 100644
View file @
b37797ed
from
contextlib
import
nullcontext
import
pytest
import
torch
import
torch.distributed
as
dist
from
packaging
import
version
from
torch
import
nn
from
torch.optim
import
SGD
import
colossalai
from
colossalai.booster
import
Booster
if
version
.
parse
(
torch
.
__version__
)
>=
version
.
parse
(
'1.12.0'
):
from
torch.distributed.fsdp
import
FullyShardedDataParallel
as
FSDP
from
colossalai.booster.plugin
import
TorchFSDPPlugin
from
colossalai.interface
import
OptimizerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
def
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
):
plugin
=
TorchFSDPPlugin
()
booster
=
Booster
(
plugin
=
plugin
)
model
=
model_fn
()
optimizer
=
SGD
(
model
.
parameters
(),
lr
=
1e-3
)
criterion
=
lambda
x
:
x
.
mean
()
data
=
data_gen_fn
()
data
=
{
k
:
v
.
to
(
'cuda'
)
if
torch
.
is_tensor
(
v
)
or
'Tensor'
in
v
.
__class__
.
__name__
else
v
for
k
,
v
in
data
.
items
()}
model
,
optimizer
,
criterion
,
_
,
_
=
booster
.
boost
(
model
,
optimizer
,
criterion
)
assert
isinstance
(
model
.
module
,
FSDP
)
assert
isinstance
(
optimizer
,
OptimizerWrapper
)
output
=
model
(
**
data
)
output
=
output_transform_fn
(
output
)
output_key
=
list
(
output
.
keys
())[
0
]
loss
=
criterion
(
output
[
output_key
])
booster
.
backward
(
loss
,
optimizer
)
optimizer
.
clip_grad_by_norm
(
1.0
)
optimizer
.
step
()
def
check_torch_fsdp_plugin
():
for
name
,
(
model_fn
,
data_gen_fn
,
output_transform_fn
,
_
)
in
model_zoo
.
items
():
if
'diffusers'
in
name
:
continue
run_fn
(
model_fn
,
data_gen_fn
,
output_transform_fn
)
torch
.
cuda
.
empty_cache
()
def
run_dist
(
rank
,
world_size
,
port
):
# init dist env
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
check_torch_fsdp_plugin
()
@
pytest
.
mark
.
skipif
(
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
'1.12.0'
),
reason
=
"requires torch1.12 or higher"
)
@
rerun_if_address_is_in_use
()
def
test_torch_fsdp_plugin
():
spawn
(
run_dist
,
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment