Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4d322b79
"...git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "1fa8c5e09ff7422c30fe7683beb209bfba7e153b"
Unverified
Commit
4d322b79
authored
Mar 25, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 25, 2022
Browse files
[refactor] remove old zero code (#517)
parent
6a3f9fda
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
29 additions
and
2515 deletions
+29
-2515
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+0
-3
colossalai/zero/__init__.py
colossalai/zero/__init__.py
+1
-37
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+1
-1
colossalai/zero/shard_utils/commons.py
colossalai/zero/shard_utils/commons.py
+20
-0
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+1
-1
colossalai/zero/sharded_model/__init__.py
colossalai/zero/sharded_model/__init__.py
+1
-2
colossalai/zero/sharded_model/_utils.py
colossalai/zero/sharded_model/_utils.py
+1
-47
colossalai/zero/sharded_model/param_manager.py
colossalai/zero/sharded_model/param_manager.py
+0
-385
colossalai/zero/sharded_model/sharded_grad.py
colossalai/zero/sharded_model/sharded_grad.py
+0
-85
colossalai/zero/sharded_model/sharded_model.py
colossalai/zero/sharded_model/sharded_model.py
+0
-1104
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+2
-2
colossalai/zero/sharded_optim/__init__.py
colossalai/zero/sharded_optim/__init__.py
+1
-2
colossalai/zero/sharded_optim/bookkeeping/__init__.py
colossalai/zero/sharded_optim/bookkeeping/__init__.py
+0
-6
colossalai/zero/sharded_optim/bookkeeping/base_store.py
colossalai/zero/sharded_optim/bookkeeping/base_store.py
+0
-17
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
+0
-43
colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
+0
-66
colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
+0
-96
colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py
colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py
+0
-54
colossalai/zero/sharded_optim/sharded_optim.py
colossalai/zero/sharded_optim/sharded_optim.py
+0
-563
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+1
-1
No files found.
colossalai/engine/schedule/_pipeline_schedule.py
View file @
4d322b79
...
@@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
...
@@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ShardedModel
,
ShardedOptimizer
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
._base_schedule
import
BaseSchedule
from
._base_schedule
import
BaseSchedule
...
@@ -92,8 +91,6 @@ class PipelineSchedule(BaseSchedule):
...
@@ -92,8 +91,6 @@ class PipelineSchedule(BaseSchedule):
def
pre_processing
(
self
,
engine
):
def
pre_processing
(
self
,
engine
):
# TODO: remove this after testing new zero with pipeline parallelism
# TODO: remove this after testing new zero with pipeline parallelism
if
isinstance
(
engine
.
optimizer
,
ShardedOptimizer
)
or
isinstance
(
engine
.
model
,
ShardedModel
):
raise
TypeError
(
"Pipeline schedule is currently not compatible with ZeRO"
)
model
=
engine
.
model
model
=
engine
.
model
if
isinstance
(
model
,
(
NaiveAMPModel
,
ShardedModelV2
)):
if
isinstance
(
model
,
(
NaiveAMPModel
,
ShardedModelV2
)):
self
.
dtype
=
torch
.
half
self
.
dtype
=
torch
.
half
...
...
colossalai/zero/__init__.py
View file @
4d322b79
...
@@ -2,14 +2,9 @@ from typing import Tuple
...
@@ -2,14 +2,9 @@ from typing import Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
torch.optim
import
Optimizer
from
.sharded_model
import
ShardedModel
from
.sharded_optim
import
ShardedOptimizer
def
convert_to_zero_v2
(
model
:
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
model_config
,
def
convert_to_zero_v2
(
model
:
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
model_config
,
...
@@ -40,35 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
...
@@ -40,35 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
return
zero_model
,
zero_optimizer
return
zero_model
,
zero_optimizer
def
convert_to_zero
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
level
:
int
,
zero_config
:
dict
):
__all__
=
[
'convert_to_zerov2'
,
'ShardedModelV2'
,
'ShardedOptimizerV2'
]
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer: Your optimizer object
:type optimizer: :class:`torch.optim.Optimizer`
:param level: Optimizer level, can be 2 or 3
:type level: int
:param zero_config: Configuration for zero
:type zero_config: dict
:return: (model, optimizer)
:rtype: Tuple
"""
assert
1
<=
level
<=
3
,
'Only ZERO Optimizer Level 1-3 are provided'
if
level
in
[
1
,
2
]:
if
level
==
2
:
if
'partition_grad'
in
zero_config
:
assert
zero_config
[
'partition_grad'
],
\
'Sharded Optimizer requires partition_grad to be True'
else
:
zero_config
[
'partiton_grad'
]
=
True
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
True
)
optimizer
=
ShardedOptimizer
(
optimizer
,
**
zero_config
)
else
:
model
=
ShardedModel
(
module
=
model
,
**
zero_config
)
return
model
,
optimizer
__all__
=
[
'convert_to_zero'
,
'ShardedModel'
,
'ShardedOptimizer'
]
colossalai/zero/init_ctx/init_context.py
View file @
4d322b79
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._
zero3_
utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
...
...
colossalai/zero/shard_utils/commons.py
0 → 100644
View file @
4d322b79
import
torch
import
torch.nn.functional
as
F
from
typing
import
Tuple
def
get_shard
(
tensor
:
torch
.
Tensor
,
rank
:
int
,
world_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks
=
list
(
torch
.
flatten
(
tensor
).
chunk
(
world_size
))
while
len
(
chunks
)
<
world_size
:
chunks
.
append
(
chunks
[
0
].
new_empty
(
0
))
# Determine number of padding elements.
num_to_pad
=
chunks
[
0
].
numel
()
-
chunks
[
rank
].
numel
()
assert
num_to_pad
>=
0
,
num_to_pad
shard
=
chunks
[
rank
].
clone
()
if
num_to_pad
>
0
:
shard
=
F
.
pad
(
shard
,
[
0
,
num_to_pad
])
return
shard
,
num_to_pad
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
4d322b79
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard
ed_model._zero3_util
s
import
get_shard
from
colossalai.zero.shard
_utils.common
s
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
...
...
colossalai/zero/sharded_model/__init__.py
View file @
4d322b79
from
.sharded_model
import
ShardedModel
from
.sharded_model_v2
import
ShardedModelV2
from
.sharded_model_v2
import
ShardedModelV2
__all__
=
[
'ShardedModel'
,
'ShardedModelV2'
]
__all__
=
[
'ShardedModelV2'
]
\ No newline at end of file
\ No newline at end of file
colossalai/zero/sharded_model/_
zero3_
utils.py
→
colossalai/zero/sharded_model/_utils.py
View file @
4d322b79
from
collections
import
OrderedDict
from
typing
import
Any
,
Callable
,
List
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -12,23 +11,6 @@ def get_gradient_predivide_factor(world_size: int) -> float:
...
@@ -12,23 +11,6 @@ def get_gradient_predivide_factor(world_size: int) -> float:
return
float
(
factor
)
return
float
(
factor
)
def
get_shard
(
tensor
:
torch
.
Tensor
,
rank
:
int
,
world_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks
=
list
(
torch
.
flatten
(
tensor
).
chunk
(
world_size
))
while
len
(
chunks
)
<
world_size
:
chunks
.
append
(
chunks
[
0
].
new_empty
(
0
))
# Determine number of padding elements.
num_to_pad
=
chunks
[
0
].
numel
()
-
chunks
[
rank
].
numel
()
assert
num_to_pad
>=
0
,
num_to_pad
shard
=
chunks
[
rank
].
clone
()
if
num_to_pad
>
0
:
shard
=
F
.
pad
(
shard
,
[
0
,
num_to_pad
])
return
shard
,
num_to_pad
def
free_storage
(
data
:
torch
.
Tensor
)
->
None
:
def
free_storage
(
data
:
torch
.
Tensor
)
->
None
:
"""Free underlying storage of a Tensor."""
"""Free underlying storage of a Tensor."""
if
data
.
storage
().
size
()
>
0
:
if
data
.
storage
().
size
()
>
0
:
...
@@ -86,31 +68,3 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
...
@@ -86,31 +68,3 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
if
len
(
chunks
)
<
num_chunks
:
if
len
(
chunks
)
<
num_chunks
:
chunks
.
extend
([
torch
.
zeros_like
(
chunks
[
0
])
for
_
in
range
(
num_chunks
-
len
(
chunks
))])
chunks
.
extend
([
torch
.
zeros_like
(
chunks
[
0
])
for
_
in
range
(
num_chunks
-
len
(
chunks
))])
return
chunks
return
chunks
def
assert_in_engine
(
cond
:
Any
,
s
:
Any
)
->
None
:
"""Used in backward context to make sure error is printed."""
if
not
cond
:
print
(
s
)
raise
AssertionError
def
replace_state_dict_prefix
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
old_prefix
:
str
,
new_prefix
:
str
)
->
None
:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_state_dict_prefix(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if
old_prefix
==
new_prefix
:
raise
ValueError
(
"old_prefix and new_prefix must be distinct"
)
for
key
in
list
(
state_dict
.
keys
()):
if
not
key
.
startswith
(
old_prefix
):
continue
new_key
=
new_prefix
+
key
[
len
(
old_prefix
):]
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
colossalai/zero/sharded_model/param_manager.py
deleted
100644 → 0
View file @
6a3f9fda
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
alloc_storage
,
free_storage
,
get_shard
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
if
os
.
getenv
(
"ENABLE_NCCL_BASE_COLLECTIVES"
,
"1"
)
==
"0"
:
enable_nccl_base_collectives
=
False
else
:
enable_nccl_base_collectives
=
True
# TODO: add flatten params
class
Zero3ParameterManager
:
def
__init__
(
self
,
module
:
nn
.
Module
,
process_group
:
Optional
[
ProcessGroup
],
mixed_precision
:
bool
=
False
,
flatten_parameters
:
bool
=
True
,
compute_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
compute_device
:
Optional
[
torch
.
device
]
=
None
,
offload_config
:
Optional
[
dict
]
=
None
)
->
None
:
"""Manage parameter shards. We manage several attributes on each Parameter instance:
``zero_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param).
``zero_orig_size``: the size of the original Parameter (before sharding)
``zero_shard_padding``: the padding size. All paddings are right padding.
``zero_fp32_shard``: a single shard of the parameters in full precision
(typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU
depending on the value of *``offload_config``*.
``zero_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
if params are offloaded to CPU.
``zero_full_param_padded``: the full weight (padded to be evenly
divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and
only materialized (via all-gather) as needed.
``zero_cpu_grad``: the gradient saved on CPU. It's set only when using CPU offload.
:param module: original module
:type module: nn.Module
:param process_group: typically data parallel process group, defaults to None
:type process_group: Optional[ProcessGroup], optional
:param mixed_precision: whether to use mixed precision mode, defaults to False
:type mixed_precision: bool, optional
:param flatten_parameters: whether to flatten parameters, useless now, defaults to True
:type flatten_parameters: bool, optional
:param compute_dtype: the dtype of parameters when computing, defaults to None
:type compute_dtype: Optional[torch.dtype], optional
:param compute_device: the device of parameters when computing, defaults to None
:type compute_device: Optional[torch.device], optional
:param offload_config: offload config, defaults to None
:type offload_config: Optional[dict], optional
"""
self
.
process_group
=
process_group
self
.
shard_idx
=
process_group
.
rank
()
self
.
num_shards
=
process_group
.
size
()
self
.
mixed_precision
=
mixed_precision
self
.
compute_dtype
=
compute_dtype
self
.
compute_device
=
compute_device
self
.
offload_config
=
offload_config
self
.
_cpu_offload
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
self
.
params
:
List
[
Parameter
]
=
[]
for
param
in
module
.
parameters
():
if
not
hasattr
(
param
,
'zero_is_sharded'
):
self
.
params
.
append
(
param
)
self
.
_has_params
=
len
(
self
.
params
)
>
0
self
.
_has_sharded_params
=
False
# Flag to indicate if the full params are gathered.
self
.
has_full_params
:
bool
=
False
self
.
_shard_params
()
# Maybe no need, reserve to prevent bugs
# self.delete_fp32_shards()
self
.
_streams
:
Dict
[
str
,
torch
.
cuda
.
Stream
]
=
{}
def
_shard_params
(
self
)
->
None
:
for
p
in
self
.
params
:
assert
not
hasattr
(
p
,
"zero_is_sharded"
)
assert
p
.
is_floating_point
()
if
self
.
mixed_precision
:
assert
p
.
dtype
==
torch
.
float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p
.
zero_is_sharded
=
self
.
num_shards
>
1
p
.
zero_orig_size
=
p
.
data
.
size
()
if
not
p
.
zero_is_sharded
:
p
.
zero_shard_padding
=
0
continue
# Replace p.data with the relevant shard.
orig_data
=
p
.
data
p
.
data
,
p
.
zero_shard_padding
=
get_shard
(
p
.
data
,
self
.
shard_idx
,
self
.
num_shards
)
free_storage
(
orig_data
)
@
torch
.
no_grad
()
def
reset_param_attr
(
self
,
p
:
Parameter
,
training
:
bool
)
->
None
:
"""This should be called by ``ZeroRedundancyLevel3Model._lazy_init()``
"""
assert
hasattr
(
p
,
'zero_is_sharded'
)
and
hasattr
(
p
,
'zero_orig_size'
)
if
hasattr
(
p
,
'zero_fp32_shard'
):
return
# A single shard of the parameters in full precision.
p
.
zero_fp32_shard
=
p
.
data
if
self
.
mixed_precision
:
assert
p
.
zero_fp32_shard
.
dtype
==
torch
.
float32
if
self
.
_cpu_offload
:
assert
p
.
zero_fp32_shard
.
device
==
torch
.
device
(
'cpu'
)
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p
.
zero_fp32_shard
=
p
.
zero_fp32_shard
.
pin_memory
()
p
.
data
=
p
.
zero_fp32_shard
if
self
.
mixed_precision
or
self
.
_cpu_offload
:
# In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
p
.
zero_fp16_shard
=
torch
.
zeros_like
(
p
.
zero_fp32_shard
,
device
=
self
.
compute_device
,
dtype
=
self
.
compute_dtype
)
free_storage
(
p
.
zero_fp16_shard
)
if
self
.
mixed_precision
:
assert
p
.
zero_fp32_shard
.
dtype
==
torch
.
float32
if
not
self
.
mixed_precision
and
not
self
.
_cpu_offload
:
# use _fp32_shard if you are not in using mixed precision or
# offloading params and grads to CPU.
p
.
zero_fp16_shard
=
None
# We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the
# storage to size 0 at init (here) and only materialize as needed. The
# storage may contain padding elements so that it is evenly divisible by
# world_size, although these padding elements will be removed before the
# relevant computation.
if
p
.
zero_is_sharded
:
p
.
zero_full_param_padded
=
torch
.
zeros
(
p
.
data
.
numel
()
*
self
.
num_shards
,
device
=
self
.
compute_device
,
dtype
=
self
.
compute_dtype
)
free_storage
(
p
.
zero_full_param_padded
)
if
self
.
_cpu_offload
and
training
:
p
.
zero_cpu_grad
=
torch
.
zeros_like
(
p
.
data
,
device
=
'cpu'
).
pin_memory
()
def
setup_streams
(
self
,
streams
):
self
.
_streams
=
streams
@
torch
.
no_grad
()
def
rebuild_full_params
(
self
,
force_full_precision
:
bool
=
False
)
->
Optional
[
List
[
Tuple
[
torch
.
Tensor
,
bool
]]]:
"""
Gather all shards of params.
Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.
Args:
force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage. The parameter that's being
rebuilt will end up in full precision as well.
Returns:
A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param. This will be ``None`` if
``force_full_precision=False`` and the full params are already gathered.
"""
# Store tensor and free flag
output_tensors
:
List
[
Tuple
[
torch
.
Tensor
,
bool
]]
=
[]
def
update_p_data
(
custom_output_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
"""
Helper function to update p.data pointer.
Args:
custom_output_tensor (torch.Tensor, Optional): if not None, this
tensor contains the data we just gathered.
"""
if
custom_output_tensor
is
not
None
:
assert
p
.
zero_is_sharded
p
.
data
=
custom_output_tensor
output_tensors
.
append
((
p
.
data
,
True
))
elif
not
p
.
zero_is_sharded
:
if
(
self
.
mixed_precision
or
self
.
_cpu_offload
)
and
not
force_full_precision
:
assert
p
.
zero_fp16_shard
is
not
None
p
.
data
=
p
.
zero_fp16_shard
output_tensors
.
append
((
p
.
data
,
True
))
else
:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors
.
append
((
p
.
data
,
False
))
else
:
p
.
data
=
p
.
zero_full_param_padded
output_tensors
.
append
((
p
.
data
,
True
))
# Trim any padding and reshape to match original size.
p
.
data
=
p
.
data
[:
p
.
zero_orig_size
.
numel
()].
view
(
p
.
zero_orig_size
)
if
self
.
_has_sharded_params
:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another ZeroRedundancyLevel3Model instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self
.
has_full_params
=
not
any
(
p
.
zero_full_param_padded
.
storage
().
size
()
==
0
for
p
in
self
.
params
)
# Early exit if we already have full params and don't need full precision.
if
self
.
has_full_params
and
not
force_full_precision
:
for
p
in
self
.
params
:
update_p_data
()
return
output_tensors
self
.
has_full_params
=
True
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
if
(
self
.
mixed_precision
or
self
.
_cpu_offload
)
and
not
force_full_precision
:
self
.
use_fp16_shards
()
if
self
.
_cpu_offload
and
force_full_precision
:
# If the compute_dtype and storage dtype are the same,
# use pinned memory. Otherwise move p.data to the compute
# device.
if
self
.
params
[
0
].
dtype
==
self
.
compute_dtype
:
self
.
use_fp16_shards
()
else
:
for
p
in
self
.
params
:
p
.
data
=
p
.
data
.
to
(
self
.
compute_device
)
for
p
in
self
.
params
:
if
not
p
.
zero_is_sharded
:
# e.g., when world_size == 1
update_p_data
()
else
:
# Skip if already built. Only shared param can be rebuilt multiple times.
# A corner case is p.zero_orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
# if p.data.shape == p.zero_orig_size and hasattr(p, "zero_is_shared") and p.zero_is_shared:
# continue
# If self._cpu_offload and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data
=
p
.
data
.
to
(
p
.
zero_full_param_padded
.
device
,
non_blocking
=
True
)
p_size
=
p
.
zero_full_param_padded
.
size
()
assert
p_size
.
numel
()
%
self
.
num_shards
==
0
if
self
.
mixed_precision
and
force_full_precision
:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor
=
p_data
.
new_zeros
(
p_size
)
else
:
if
p
.
zero_full_param_padded
.
storage
().
size
()
!=
p_size
.
numel
():
# Allocate based on full size from all shards.
alloc_storage
(
p
.
zero_full_param_padded
,
size
=
p_size
)
output_tensor
=
p
.
zero_full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size)
if
hasattr
(
dist
,
"_all_gather_base"
)
and
enable_nccl_base_collectives
:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist
.
_all_gather_base
(
output_tensor
,
p_data
,
group
=
self
.
process_group
)
else
:
chunks
=
list
(
output_tensor
.
chunk
(
self
.
num_shards
))
dist
.
all_gather
(
chunks
,
p_data
,
group
=
self
.
process_group
)
# Set p.data = output_tensor (with padding trimmed)
update_p_data
(
output_tensor
)
if
(
self
.
mixed_precision
or
self
.
_cpu_offload
)
and
not
force_full_precision
:
self
.
free_fp16_shards
([
p
])
if
self
.
_cpu_offload
and
(
self
.
params
[
0
].
dtype
==
self
.
compute_dtype
):
self
.
free_fp16_shards
([
p
])
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"all_gather"
])
return
output_tensors
@
torch
.
no_grad
()
def
use_full_params
(
self
)
->
None
:
"""
Switch p.data pointers to use the full params.
Note: this assumes full params are already gathered.
Note: this might be called after full_params is already in used. So please
make sure it is idempotent in that case.
"""
assert
self
.
has_full_params
for
p
in
self
.
params
:
if
not
p
.
zero_is_sharded
:
if
self
.
mixed_precision
or
self
.
_cpu_offload
:
assert
p
.
zero_fp16_shard
is
not
None
assert
p
.
zero_fp16_shard
.
storage
().
size
()
!=
0
p
.
data
=
p
.
zero_fp16_shard
else
:
assert
p
.
zero_full_param_padded
.
storage
().
size
()
!=
0
,
f
"
{
p
.
zero_orig_size
}
{
id
(
self
)
}
"
p
.
data
=
p
.
zero_full_param_padded
[:
p
.
zero_orig_size
.
numel
()].
view
(
p
.
zero_orig_size
)
@
torch
.
no_grad
()
def
use_fp16_shards
(
self
,
params
:
Optional
[
List
[
Parameter
]]
=
None
)
->
None
:
"""Cast FP32 param shard to FP16 for a list of params."""
if
params
is
None
:
params
=
self
.
params
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"fp32_to_fp16"
]):
for
p
in
params
:
assert
p
.
zero_fp16_shard
is
not
None
alloc_storage
(
p
.
zero_fp16_shard
,
size
=
p
.
zero_fp32_shard
.
size
())
p
.
zero_fp16_shard
.
copy_
(
# If _cpu_offload is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p
.
zero_fp32_shard
.
to
(
p
.
zero_fp16_shard
.
device
,
non_blocking
=
True
)
)
p
.
data
=
p
.
zero_fp16_shard
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"fp32_to_fp16"
])
@
torch
.
no_grad
()
def
use_fp32_shards
(
self
,
params
:
Optional
[
List
[
Parameter
]]
=
None
)
->
None
:
"""Use FP32 shard for a list of params."""
if
params
is
None
:
params
=
self
.
params
for
p
in
params
:
p
.
data
=
p
.
zero_fp32_shard
@
torch
.
no_grad
()
def
free_full_params
(
self
,
params
:
Optional
[
List
[
Parameter
]]
=
None
)
->
None
:
"""Free up storage for full parameters."""
if
params
is
None
:
params
=
self
.
params
self
.
has_full_params
=
False
current_stream
=
torch
.
cuda
.
current_stream
()
for
p
in
params
:
if
not
p
.
zero_is_sharded
:
# e.g., world_size == 1
if
self
.
mixed_precision
or
self
.
_cpu_offload
:
self
.
free_fp16_shards
([
p
])
continue
# Don't let PyTorch reuse this memory until all work in the current
# stream is complete.
p
.
zero_full_param_padded
.
record_stream
(
current_stream
)
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage
(
p
.
zero_full_param_padded
)
@
torch
.
no_grad
()
def
free_fp16_shards
(
self
,
params
:
Optional
[
List
[
Parameter
]]
=
None
)
->
None
:
"""Free storage for FP16 shards for a list of params."""
if
params
is
None
:
params
=
self
.
params
current_stream
=
torch
.
cuda
.
current_stream
()
for
p
in
params
:
if
p
.
zero_fp16_shard
is
not
None
:
# zero_fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
# free it until the work in the current stream completes.
p
.
zero_fp16_shard
.
record_stream
(
current_stream
)
free_storage
(
p
.
zero_fp16_shard
)
def
delete_fp32_shards
(
self
)
->
None
:
for
p
in
self
.
params
:
if
hasattr
(
p
,
'zero_fp32_shard'
):
del
p
.
zero_fp32_shard
# reset _init_param_attr
colossalai/zero/sharded_model/sharded_grad.py
deleted
100644 → 0
View file @
6a3f9fda
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
torch.nn.parameter
import
Parameter
class
ShardedGradient
:
def
__init__
(
self
,
param
:
Parameter
,
sharded_module
:
nn
.
Module
,
offload_config
:
Optional
[
dict
]
=
None
)
->
None
:
assert
hasattr
(
param
,
'ca_attr'
)
and
param
.
ca_attr
.
is_sharded
,
'ShardedGradient can only be initialized with sharded parameter'
self
.
param
=
param
self
.
sharded_module
=
sharded_module
self
.
offload_config
=
offload_config
self
.
_cpu_offload
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
# _gpu_grad is either sharded or not
# all saved grads are fp32
self
.
_gpu_grad
:
Optional
[
torch
.
Tensor
]
=
None
self
.
_cpu_grad
:
Optional
[
torch
.
Tensor
]
=
None
if
self
.
_cpu_offload
:
# this buffer will be held and reused every iteration
self
.
_cpu_grad
=
torch
.
zeros
(
param
.
ca_attr
.
payload
(
'cpu'
),
dtype
=
torch
.
float
).
pin_memory
()
@
torch
.
no_grad
()
def
setup
(
self
)
->
None
:
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
:raises AssertionError: Raise if grad shape is wrong
"""
if
self
.
sharded_module
.
_require_backward_grad_sync
and
self
.
param
.
grad
is
not
None
:
if
self
.
param
.
grad
.
device
!=
self
.
param
.
data
.
device
:
# TODO: offload?
raise
RuntimeError
(
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}'
)
else
:
self
.
_gpu_grad
=
self
.
param
.
grad
.
data
self
.
param
.
grad
=
None
def
reduce_scatter_callback
(
self
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
:param reduced_grad: the reduced grad
:type reduced_grad: torch.Tensor
"""
# Make sure we store fp32 grad
if
torch
.
is_floating_point
(
reduced_grad
)
and
reduced_grad
.
dtype
!=
torch
.
float
:
reduced_grad
.
data
=
reduced_grad
.
data
.
to
(
torch
.
float
)
if
self
.
_gpu_grad
is
None
:
self
.
_gpu_grad
=
reduced_grad
.
data
else
:
self
.
_gpu_grad
+=
reduced_grad
.
data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if
self
.
_cpu_offload
:
self
.
_cpu_grad
.
copy_
(
self
.
_gpu_grad
.
data
,
non_blocking
=
True
)
# Don't let this memory get reused until after the transfer.
self
.
_gpu_grad
.
data
.
record_stream
(
torch
.
cuda
.
current_stream
())
@
torch
.
no_grad
()
def
write_back
(
self
)
->
None
:
"""This function will be called in final backward hook
"""
if
self
.
_cpu_grad
is
not
None
:
assert
self
.
param
.
device
==
torch
.
device
(
'cpu'
),
f
'Incorrect param device, expected CPU, got
{
self
.
param
.
device
}
'
self
.
param
.
grad
.
data
=
self
.
_cpu_grad
elif
self
.
_gpu_grad
is
not
None
:
assert
self
.
param
.
device
==
self
.
_gpu_grad
.
device
,
f
'Incorrect _gpu_grad device, param on
{
self
.
param
.
device
}
but _gpu_grad on
{
self
.
_gpu_grad
.
device
}
'
self
.
param
.
grad
.
data
=
self
.
_gpu_grad
else
:
raise
RuntimeError
(
'No grad to write back'
)
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
# They should be released here
self
.
_gpu_grad
=
None
colossalai/zero/sharded_model/sharded_model.py
deleted
100644 → 0
View file @
6a3f9fda
import
contextlib
import
copy
import
functools
import
os
import
traceback
from
collections
import
OrderedDict
from
enum
import
Enum
,
auto
from
typing
import
(
Any
,
Callable
,
Dict
,
Generator
,
List
,
NamedTuple
,
Optional
,
Set
,
Union
)
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
from
torch.autograd
import
Variable
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
(
apply_to_tensors
,
assert_in_engine
,
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
,
get_shard
,
replace_state_dict_prefix
)
from
.param_manager
import
Zero3ParameterManager
from
.reduce_scatter
import
ReduceScatterBucketer
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
if
os
.
getenv
(
"ENABLE_NCCL_BASE_COLLECTIVES"
,
"1"
)
==
"0"
:
enable_nccl_base_collectives
=
False
else
:
enable_nccl_base_collectives
=
True
class
TrainingState
(
Enum
):
IDLE
=
auto
()
FORWARD
=
auto
()
PRE_BACKWARD
=
auto
()
POST_BACKWARD
=
auto
()
GATHER_FULL_PARAMS
=
auto
()
# TODO: Add clip_grad_norm_
# TODO: Add gather_full_optim_state_dict and get_shard_from_optim_state_dict
class
ShardedModel
(
nn
.
Module
):
def
__init__
(
self
,
module
:
nn
.
Module
,
process_group
:
Optional
[
ProcessGroup
]
=
None
,
reduce_scatter_process_group
:
Optional
[
ProcessGroup
]
=
None
,
reshard_after_forward
:
bool
=
True
,
disable_reshard_on_root
:
bool
=
True
,
mixed_precision
:
bool
=
False
,
fp32_reduce_scatter
:
bool
=
False
,
flatten_parameters
:
bool
=
True
,
compute_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
buffer_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_scatter_bucket_size_mb
:
int
=
25
,
compute_device
:
Optional
[
torch
.
device
]
=
None
,
no_broadcast_optim_state
:
Optional
[
bool
]
=
False
,
state_dict_device
:
Optional
[
torch
.
device
]
=
None
,
clear_autocast_cache
:
bool
=
False
,
force_input_to_fp32
:
bool
=
False
,
verbose
:
bool
=
False
,
offload_config
:
Optional
[
dict
]
=
None
,
state_dict_on_rank_0_only
:
bool
=
False
,
gradient_predivide_factor
:
Optional
[
float
]
=
1.0
)
->
None
:
super
().
__init__
()
self
.
logger
=
get_dist_logger
()
self
.
process_group
=
process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
reduce_scatter_process_group
=
reduce_scatter_process_group
or
self
.
process_group
self
.
world_size
=
dist
.
get_world_size
(
self
.
process_group
)
self
.
rank
=
dist
.
get_rank
(
self
.
process_group
)
self
.
reshard_after_forward
=
self
.
_orig_reshard_after_forward
=
reshard_after_forward
self
.
disable_reshard_on_root
=
disable_reshard_on_root
self
.
mixed_precision
=
mixed_precision
self
.
fp32_reduce_scatter
=
fp32_reduce_scatter
self
.
offload_config
=
offload_config
self
.
compute_dtype
=
compute_dtype
or
(
torch
.
float16
if
mixed_precision
else
torch
.
float32
)
self
.
buffer_dtype
=
buffer_dtype
or
self
.
compute_dtype
self
.
reduce_scatter_bucket_size_mb
=
reduce_scatter_bucket_size_mb
self
.
compute_device
=
compute_device
or
torch
.
device
(
f
'cuda:
{
get_current_device
()
}
'
)
self
.
uncollected_opt_state
:
Dict
[
int
,
Dict
]
=
{}
self
.
no_broadcast_optim_state
=
no_broadcast_optim_state
self
.
state_dict_device
=
state_dict_device
or
self
.
compute_device
self
.
clear_autocast_cache
=
clear_autocast_cache
self
.
force_input_to_fp32
=
force_input_to_fp32
self
.
verbose
=
verbose
self
.
state_dict_on_rank_0_only
=
state_dict_on_rank_0_only
self
.
_cpu_offload
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None
# we will set gradient_predivide_factor to a value >= 1.0 automatically
self
.
gradient_predivide_factor
:
float
=
gradient_predivide_factor
if
\
gradient_predivide_factor
is
not
None
else
\
get_gradient_predivide_factor
(
self
.
world_size
)
self
.
gradient_postdivide_factor
:
float
=
self
.
world_size
/
self
.
gradient_predivide_factor
self
.
_check_sanity
()
self
.
params
:
List
[
Parameter
]
=
[]
for
name
,
param
in
module
.
named_parameters
():
if
not
hasattr
(
param
,
'zero_is_sharded'
):
self
.
params
.
append
(
param
)
self
.
module
=
module
self
.
param_manager
=
Zero3ParameterManager
(
module
,
process_group
=
self
.
process_group
,
mixed_precision
=
self
.
mixed_precision
,
flatten_parameters
=
flatten_parameters
,
compute_dtype
=
self
.
compute_dtype
,
compute_device
=
self
.
compute_device
,
offload_config
=
offload_config
)
self
.
_reset_lazy_init_info
()
# Flag to indicate if we require gradient reduction in the backward
# pass. This will be False when inside the no_sync context manager.
self
.
_require_backward_grad_sync
:
bool
=
True
# Enum to indicate if we're in the forward/backward pass, idle, etc.
self
.
training_state
=
TrainingState
.
IDLE
# Register hook after state_dict() to remove the "_zero3_module."
# prefix and before load_state_dict() to add it back.
self
.
_register_state_dict_hook
(
functools
.
partial
(
_post_state_dict_hook
,
self
.
state_dict_on_rank_0_only
))
self
.
_register_load_state_dict_pre_hook
(
_pre_load_state_dict_hook
)
# Flag to indicate whether state_dict() should automatically gather the full params.
self
.
_return_full_state_dict
=
True
# Flag to guard against preparing gradients multiple times per iteration.
# This is reset at the end of the backward pass.
self
.
_pre_backward_hook_has_run
=
False
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
self
.
_lazy_init
()
# Start of a forward pass.
self
.
training_state
=
TrainingState
.
FORWARD
# For root and mixed precision, we convert the input to FP16 (no_grad is needed for
# the conversion).
if
self
.
_is_root
and
self
.
mixed_precision
:
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp16
,
*
args
,
**
kwargs
)
# If enabled, convert the input to FP32 if we are in full precision.
# no_grad is not used because the input might be for a non-root instance,
# which mean autograd needs to go through the conversion.
if
self
.
force_input_to_fp32
and
not
self
.
mixed_precision
:
args
,
kwargs
=
cast_float_arguments
(
cast_tensor_to_fp32
,
*
args
,
**
kwargs
)
# All-gather full parameters. This will also transfer FP32 parameters to
# ``self.compute_dtype`` (e.g., FP16 if *mixed_precision* is ``True``).
self
.
param_manager
.
rebuild_full_params
()
# Register backward hooks to reshard params and reduce-scatter grads.
# These need to be re-registered every forward pass.
self
.
_register_post_backward_hooks
()
outputs
=
self
.
module
(
*
args
,
**
kwargs
)
if
self
.
reshard_after_forward
:
self
.
param_manager
.
free_full_params
()
if
self
.
mixed_precision
or
self
.
_cpu_offload
:
self
.
param_manager
.
free_fp16_shards
()
# Switch to main FP32 param shard. We maintain this invariant throughout
# the code, i.e., ``p.data == p.zero_fp32_shard`` after each function. This
# also ensures that after the first forward, the optimizer state will be
# initialized with the correct dtype and (sharded) size, since optimizer
# state is typically initialized lazily in ``optim.step()``.
self
.
param_manager
.
use_fp32_shards
()
# Register pre-backward hooks to all-gather the params for the backward
# pass (if output's grad was needed). This won't register anything if
# we are in eval mode.
#
# Some model does forward pass multiple times, we need to register the
# pre-backward hook on every output since the last output's hook has to
# fire first to setup for backward. However, we use ``self._pre_backward_hook_has_run``
# to prevent repeated overhead from multiple hook callbacks.
outputs
=
self
.
_register_pre_backward_hooks
(
outputs
)
# Done with a forward pass.
self
.
training_state
=
TrainingState
.
IDLE
# Only need to clear cache during forward. During backward, the cache is not used.
if
self
.
clear_autocast_cache
:
torch
.
clear_autocast_cache
()
return
outputs
def
_check_sanity
(
self
)
->
None
:
if
self
.
fp32_reduce_scatter
and
not
self
.
mixed_precision
:
raise
ValueError
(
"fp32_reduce_scatter requires mixed_precision=True"
)
if
self
.
compute_device
.
type
==
'cuda'
:
input_tensor
=
torch
.
ones
(
1
).
to
(
self
.
compute_device
)
output
=
list
(
torch
.
zeros
(
self
.
world_size
).
to
(
self
.
compute_device
).
chunk
(
self
.
world_size
))
dist
.
all_gather
(
output
,
input_tensor
,
group
=
self
.
process_group
)
assert
torch
.
cat
(
output
).
sum
()
==
float
(
self
.
world_size
),
(
f
"found
{
torch
.
cat
(
output
).
sum
()
}
devices in process group but "
f
"world_size=
{
self
.
world_size
}
. Check torch.cuda.set_device is called properly"
)
def
_reset_lazy_init_info
(
self
)
->
None
:
self
.
_is_root
:
Optional
[
bool
]
=
None
self
.
_streams
:
Dict
[
str
,
torch
.
cuda
.
Stream
]
=
{}
self
.
_reducer
:
Optional
[
ReduceScatterBucketer
]
=
None
self
.
param_manager
.
delete_fp32_shards
()
self
.
_output_pre_backward_hook_registered
:
Optional
[
List
]
=
None
self
.
reshard_after_forward
=
self
.
_orig_reshard_after_forward
def
_lazy_init
(
self
):
# Initialize param attributes lazily, in case the param's dtype or
# device changes after __init__.
for
p
in
self
.
params
:
self
.
param_manager
.
reset_param_attr
(
p
,
self
.
training
)
# Initialize _is_root and setup streams. These steps would ideally
# happen in __init__, but _is_root can only be determined after the
# entire model hierarchy is setup, thus we run it lazily.
if
self
.
_is_root
is
None
:
self
.
_set_is_root
()
self
.
_setup_streams
()
self
.
_setup_output_hook_list
()
if
self
.
_is_root
:
# Buffers stay on GPU, and don't get sharded. Since _cast_buffers
# applies recursively, we only call this from the root instance.
self
.
_cast_buffers
()
if
self
.
disable_reshard_on_root
:
# Don't free the full params for the outer-most (root) instance,
# since those params will be needed immediately after for the
# backward pass.
self
.
reshard_after_forward
=
False
# Due to the use of streams, we need to make sure the previous
# ``optim.step()`` is done before we all-gather parameters.
self
.
_wait_for_previous_optim_step
()
def
_set_is_root
(
self
)
->
None
:
"""If ``True``, implies that no other :class:`ShardedModel`
instance wraps this one. Called once by :func:`_lazy_init`.
Also sets self.children_share_process_group = True if all child
instances share the same process group. If some child instances use a
different process group, self.clip_grad_norm_ will raise an error.
"""
if
self
.
_is_root
is
not
None
:
return
# No Zero3Model instance wraps this, else _is_root would be set to False.
self
.
_is_root
=
True
# If final backward callback is never been queued, state should be IDLE.
# If final backward callback is queued, the callback should be finished
# and the state was reset to be IDLE.
# This should be asserted at the beginning of forward pass in the root instance only.
# For children instances, if they are checkpointed, state will not be reset to
# IDLE after each inner forward/backward.
self
.
_assert_state
(
TrainingState
.
IDLE
)
# As the root, we now set all children instances to False and
# give them a closure to try to queue a wait_for_post_backward.
self
.
children_share_process_group
=
True
for
n
,
m
in
self
.
named_modules
():
# `n != ""` excludes self.
if
n
!=
''
and
isinstance
(
m
,
ShardedModel
):
# We relax the assert for non-root instance, when the nested inialized module is wrapped
# again in ShardedModel later, for example after training to run inference.
assert
m
.
_is_root
is
None
or
not
m
.
_is_root
if
m
.
_is_root
is
None
:
m
.
_is_root
=
False
if
m
.
process_group
!=
self
.
process_group
:
self
.
children_share_process_group
=
False
# if child instance in its own (smaller) world, that was probably an attempt to avoid OOM.
# Therefore gathering this child's optim state will probably cause OOM, so we won't do it.
m
.
no_broadcast_optim_state
=
m
.
no_broadcast_optim_state
or
\
((
m
.
world_size
==
1
)
and
(
m
.
world_size
<
self
.
world_size
)
and
(
m
.
process_group
!=
self
.
process_group
))
def
_setup_streams
(
self
)
->
None
:
"""Create streams to overlap data transfer and computation."""
if
len
(
self
.
_streams
)
>
0
or
not
self
.
_is_root
:
return
if
torch
.
cuda
.
is_available
():
# Stream to move main FP32 params (may be on CPU) to FP16 for forward.
self
.
_streams
[
'fp32_to_fp16'
]
=
torch
.
cuda
.
Stream
()
# Stream for all-gathering parameters.
self
.
_streams
[
'all_gather'
]
=
torch
.
cuda
.
Stream
()
# Stream for overlapping grad reduction with the backward pass.
self
.
_streams
[
'post_backward'
]
=
torch
.
cuda
.
Stream
()
self
.
param_manager
.
setup_streams
(
self
.
_streams
)
# Helper for bucketing reduce-scatter ops. This is also shared with
# children instances to improve bucket utilization.
self
.
_reducer
=
ReduceScatterBucketer
(
self
.
reduce_scatter_bucket_size_mb
)
# We share streams with all children instances, which allows them to
# overlap transfers across the forward pass without synchronizing with
# the default stream.
for
n
,
m
in
self
.
named_modules
():
if
n
!=
""
and
isinstance
(
m
,
ShardedModel
):
m
.
_streams
=
self
.
_streams
m
.
_reducer
=
self
.
_reducer
m
.
param_manager
.
setup_streams
(
self
.
_streams
)
def
_setup_output_hook_list
(
self
)
->
None
:
"""set up a list to avoid registering pre-backward hooks
incorrectly.
"""
assert
self
.
_is_root
,
"This should only be called on the root"
self
.
_output_pre_backward_hook_registered
=
[]
for
n
,
m
in
self
.
named_modules
():
if
n
!=
""
and
isinstance
(
m
,
ShardedModel
):
m
.
_output_pre_backward_hook_registered
=
self
.
_output_pre_backward_hook_registered
def
_wait_for_previous_optim_step
(
self
)
->
None
:
"""
The outer-most :class:`ShardedModel` instance (i.e., the root
instance) needs to synchronize with the default stream to ensure the
previous optimizer step is done.
"""
if
not
torch
.
cuda
.
is_available
():
return
if
self
.
mixed_precision
or
self
.
_cpu_offload
:
self
.
_streams
[
"fp32_to_fp16"
].
wait_stream
(
torch
.
cuda
.
current_stream
())
else
:
self
.
_streams
[
"all_gather"
].
wait_stream
(
torch
.
cuda
.
current_stream
())
def
_cast_buffers
(
self
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
memo
:
Optional
[
Set
]
=
None
)
->
None
:
"""Move all buffers to the given *device* and *dtype*.
If *device* or *dtype* are not given, then they will default to
``self.compute_device`` and ``self.buffer_dtype``, respectively. In the
case of nested ShardedModel instances, we will respect the child instance's
``compute_device`` and ``buffer_dtype`` configuration.
Args:
device (torch.device, Optional):
device to cast buffers to (defaults to compute_device)
dtype (torch.dtype, Optional):
dtype to cast buffers to (defaults to buffer_dtype)
memo (Set, Optional):
set of modules that have already been processed
"""
if
memo
is
None
:
memo
=
set
()
for
module
in
self
.
modules
():
if
module
is
not
self
and
isinstance
(
module
,
ShardedModel
):
# Allow any child Zero3Model instances to handle their own buffers.
module
.
_cast_buffers
(
device
=
device
,
dtype
=
dtype
,
memo
=
memo
)
elif
module
not
in
memo
:
memo
.
add
(
module
)
for
name
,
buf
in
module
.
named_buffers
(
recurse
=
False
):
if
buf
is
None
:
continue
buf
=
buf
.
to
(
device
=
device
or
self
.
compute_device
)
if
torch
.
is_floating_point
(
buf
):
buf
=
buf
.
to
(
dtype
=
dtype
or
self
.
buffer_dtype
)
setattr
(
module
,
name
,
buf
)
@
torch
.
no_grad
()
def
_prep_grads_for_backward
(
self
)
->
None
:
"""Make sure p.grad is correctly prepared for the backward with
right shape, device, accumulated values, etc.
"""
for
p
in
self
.
params
:
if
p
.
grad
is
not
None
:
if
p
.
grad
.
device
!=
p
.
data
.
device
:
p
.
grad
=
None
elif
p
.
grad
.
size
()
==
p
.
zero_orig_size
:
if
not
p
.
zero_is_sharded
:
p
.
zero_saved_grad
=
p
.
grad
.
data
p
.
grad
=
None
else
:
# This is gradient accumulation with no_sync context.
pass
elif
p
.
grad
.
size
()
==
p
.
zero_fp32_shard
.
shape
:
# This is gradient accumulation without no_sync context.
# We save the grad shard and set p.grad to None for this backward pass.
# We will accumulate after this pass's grad is generated and reduced and
# sharded.
p
.
zero_saved_grad_shard
=
p
.
grad
.
data
p
.
grad
=
None
else
:
raise
AssertionError
(
f
"unexpected grad shape:
{
p
.
grad
.
size
()
}
"
)
def
_register_pre_backward_hooks
(
self
,
outputs
:
Any
)
->
Any
:
"""Register pre-backward hook to run before the wrapped module's
backward. Hooks should be attached to all outputs from the forward.
Returns:
outputs: new outputs with hooks registered if they requires gradient.
"""
if
not
torch
.
is_grad_enabled
():
return
outputs
# don't register hooks if grad isn't enabled
if
self
.
_is_root
:
# This actually means that only root instance has
# _post_backward_callback_queued defined. Accidentally accessing this field
# will assert on all other instances, giving us a nice bug checker.
self
.
_post_backward_callback_queued
=
False
def
_pre_backward_hook
(
*
unused
:
Any
)
->
None
:
# try to queue final backward callback only once for root, so
# that final backward callback is attached to the outer most
# backward graph task and called after all the backward
# calls are completed.
if
self
.
_is_root
:
self
.
_register_final_backward_hook
()
# All-gather full parameters or switching to the full params.
#
# This needs to be done on every pre_backward hook, even within the same
# iteration (i.e. for checkpointed, multiple forward pass modules). This is
# because after the forward pass (i.e. in checkpoint inner graph), we always
# switch to fp32_shard in the ``forward`` function.
#
# We used to do this only after the ``self._pre_backward_hook_has_run``
# boolean guard below, which is incorrect. It worked in pytorch < 1.9 for
# some unknown reason, but pytorch 1.10 nightly exposed this bug.
#
# Note, both ``self.param_manager.rebuild_full_params`` and ``self.param_manager.use_full_params`` are
# idempotent. So in case they are called unnecessarily, they don't incur much
# overhead.
if
self
.
reshard_after_forward
:
self
.
param_manager
.
rebuild_full_params
()
else
:
self
.
param_manager
.
use_full_params
()
# Only run the ``self._prep_grads_for_backward`` once per iteration (i.e. in case
# it is multiple outputs or multiple forward passes).
if
not
self
.
_pre_backward_hook_has_run
:
self
.
_pre_backward_hook_has_run
=
True
# Start of a backward pass for the first time in an iteration.
self
.
_assert_state
([
TrainingState
.
IDLE
,
TrainingState
.
PRE_BACKWARD
])
# Prepare p.grad so that it is in the right shape, device, accumulated values, etc.
self
.
_prep_grads_for_backward
()
# Transition to PRE_BACKWARD state if currently IDLE. We can transition from POST_BACKWARD
# to IDLE when ShardedModel is within activation checkpointing and called multiple times, due to the
# extra forward pass for re-computation.
if
self
.
training_state
==
TrainingState
.
IDLE
:
self
.
training_state
=
TrainingState
.
PRE_BACKWARD
self
.
_assert_state
([
TrainingState
.
PRE_BACKWARD
,
TrainingState
.
POST_BACKWARD
])
_registered
=
0
def
_register_hook
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# We don't register the pre_backward hook on the same tensor that has been
# returned from an inner ShardedModel, unless it is the first one. This does
# not cover all problematic cases though. A tensor not from an inner
# ShardedModel can cause problems too:
# ```
# x = layer1(input)
# state = [x] # better change to x.detach(), not fixed by the following if-condition
# x = inner_zero3_module_layer2(x)
# state.append(x) # better change to x.detach(), but fixed by the following if-condition
# x = layer3(x)
# return x, state
# ```
# The tensors in `state`, if not detached, can be registered with
# backward hooks (in addition to the `x` on the last line). In that case,
# pre-backward hook can fire multiple times in the order that causes
# the outer ShardedModel to crash.
#
# The best practice is for modules to be wrapped by ShardedModel to return 1 and only
# 1 tensor to be used for backward. All other tensors returned should be
# detached.
nonlocal
_registered
assert
self
.
_output_pre_backward_hook_registered
is
not
None
if
t
.
requires_grad
and
(
_registered
==
0
or
id
(
t
)
not
in
self
.
_output_pre_backward_hook_registered
):
t
.
register_hook
(
_pre_backward_hook
)
self
.
_output_pre_backward_hook_registered
.
append
(
id
(
t
))
_registered
+=
1
return
t
# Attach hooks to Tensor outputs.
outputs
=
apply_to_tensors
(
outputs
,
_register_hook
)
return
outputs
def
_register_post_backward_hooks
(
self
)
->
None
:
"""
Register backward hooks to reshard params and reduce-scatter grads.
This is called during forward pass. The goal is to attach a hook
on each of the parameter's gradient generating function (``grad_acc``
below) so that the hook is called *after* all gradients for that
param are computed.
Goals:
1. We want the hook to fire once and only once *after* all gradients
are accumulated for a param.
2. If it fires more than once, we end up incorrectly shard the grad
multiple times. (could lead to dimension too small)
3. If it fires once but too early or doesn't fire, we leave gradients
unsharded. (could lead to dimension too large)
Due to multiple-pass forward, this function can be called on
the same parameter multiple times in a single forward pass. If we register
the hook multiple time, we end up getting called multiple times. We
could try to get a new hook every time and delete the previous one
registered. However, due to *unknown reason* (I have debugged it for
a long time!), in mixed precision mode, we get two different ``grad_acc``
objects below during different calls of this function (in the same
forward pass). If we keep the last one, the hook end up firing too
early. In full precision mode, we luckily get the *same* ``grad_acc``
object, so deleting and re-registering still ensured the hook fire
once after all gradients are generated. However, we find if we use activation
checkpoint in mixed precision mode, hook on ``grad_acc`` object won't be
fire for *unknown reason*. So we finally register hook on parameter directly.
Empirically, keep the first hook register per forward pass seems to
work the best. We do need to remove the hook at the end of the
backward pass. Otherwise, the next forward pass will not register
a new hook, which is needed for a new forward pass.
"""
if
not
torch
.
is_grad_enabled
():
return
# don't register grad hooks if grad isn't enabled
for
p
in
self
.
params
:
if
p
.
requires_grad
:
if
hasattr
(
p
,
"zero_shard_bwd_hook"
):
continue
# For mixed precision with activation checkpoint, hooks on GradAccumulation won't be fired normally
# Instead we register hook on parameter
# In this way, we can't modify param.grad and param.data directly, which leads to more memory usage
# Register a hook on the first call, empirically, autograd
# fires it at the end for this param, which makes sense.
# p_tmp = p.expand_as(p) # Get a grad_fn on p_tmp.
# assert p_tmp.grad_fn is not None
# grad_acc = p_tmp.grad_fn.next_functions[0][0] # Gets its GradAccumulation object.
# handle = grad_acc.register_hook(functools.partial(self._post_backward_hook, p))
# p.zero_shard_bwd_hook = (grad_acc, handle)
handle
=
p
.
register_hook
(
functools
.
partial
(
self
.
_post_backward_hook
,
p
))
p
.
zero_shard_bwd_hook
=
handle
@
torch
.
no_grad
()
def
_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
"""
At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will replace
``param.grad`` with a single shard of the summed gradient across all
GPUs. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #1): [5, 6, 7, 8]
after reduce_scatter:
param.grad (GPU #0): [6, 8] # 1+5, 2+6
param.grad (GPU #1): [10, 12] # 3+7, 4+8
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param_manager`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
# First hook callback will see PRE state. If we have multiple params,
# then subsequent hook callbacks will see POST state.
self
.
_assert_state
([
TrainingState
.
PRE_BACKWARD
,
TrainingState
.
POST_BACKWARD
])
self
.
training_state
=
TrainingState
.
POST_BACKWARD
if
grad
is
None
:
return
assert
grad
is
not
None
,
param
.
shape
if
grad
.
requires_grad
:
raise
RuntimeError
(
"ShardedModel only works with gradients that don't require gradients"
)
if
self
.
_require_backward_grad_sync
or
self
.
reshard_after_forward
:
# Free full params. As a special case, we don't free the full params
# when in a ``no_sync`` context (as inversely indicated by
# ``self._require_backward_grad_sync``), since the params will not
# get updated before the next forward. This saves networking
# bandwidth but uses more GPU memory.
self
.
param_manager
.
free_full_params
([
param
])
if
self
.
mixed_precision
:
# This is a no-op if reshard_after_forward is True, since we already
# free the param shard when rebuilding the full params in the
# pre_backward_hook.
self
.
param_manager
.
free_fp16_shards
([
param
])
# Switch to FP32 shard after backward.
# Cannot modify param.data, so we switch to FP32 in final backward hook
# self.param_manager.use_fp32_shards([param])
if
not
self
.
_require_backward_grad_sync
:
return
# Wait for all work in the current stream to finish, then start the
# reductions in post_backward stream.
self
.
_streams
[
"post_backward"
].
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"post_backward"
]):
new_grad
=
grad
.
clone
()
if
self
.
mixed_precision
and
self
.
fp32_reduce_scatter
:
# Cast grad to FP32.
new_grad
.
data
=
new_grad
.
data
.
to
(
param
.
dtype
)
if
self
.
gradient_predivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
new_grad
.
data
.
div_
(
self
.
gradient_predivide_factor
)
orig_grad_data
=
new_grad
.
data
if
param
.
zero_is_sharded
:
assert
self
.
_reducer
is
not
None
# Save the unsharded grad for reduction. We will asynchronously accumulate the reduced gradient into
# param.zero_saved_grad_shard. If this ShardedModel module was called multiple times
# it's possible that multiple
# gradient reductions will happen in an undefined order. But addition commutes, so this order doesn't
# matter, neglecting rounding.
# Clear grad on the tensor, so any repeated gradient computations do not interfere with this reduction.
#
# The effect on memory consumption is not usually significant. No extra memory is allocated if this
# module is called only once, reduction happens quickly, or the tensor is bucketed. If the module is
# called multiple times, and the backwards pass runs far enough ahead of the `post_backward` stream,
# then we can end up with multiple unsharded gradients allocated and queued for reduction.
#
# We could guard against this by using CUDA events (see record_event, wait_event in torch.cuda.Stream).
# This ensures the `default` stream will wait for the `post_backward` stream to complete the last
# reduction for this module, before scheduling additional reduction work. Then at most there are two
# unsharded gradients allocated; one for a pending reduction, and one for gradient computation.
callback_fn
=
functools
.
partial
(
self
.
_reduce_scatter_callback
,
param
)
grad_chunks
=
chunk_and_pad
(
orig_grad_data
,
self
.
reduce_scatter_process_group
.
size
())
self
.
_reducer
.
reduce_scatter_async
(
grad_chunks
,
group
=
self
.
reduce_scatter_process_group
,
callback_fn
=
callback_fn
)
else
:
# Currently the only way for _is_sharded to be False is if
# world_size == 1. This could be relaxed in the future, in which
# case grads should be all-reduced here.
assert
self
.
world_size
==
1
self
.
_reduce_scatter_callback
(
param
,
new_grad
)
# After _post_backward_hook returns, orig_grad_data will eventually
# go out of scope, at which point it could otherwise be freed for
# further reuse by the main stream while the div/reduce_scatter/copy
# are underway in the post_backward stream. See:
# github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
orig_grad_data
.
record_stream
(
self
.
_streams
[
"post_backward"
])
def
_reduce_scatter_callback
(
self
,
param
:
Parameter
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
"""Hook to call on each param after the reduce-scatter."""
assert
torch
.
cuda
.
current_stream
()
==
self
.
_streams
[
"post_backward"
]
self
.
_assert_state
(
TrainingState
.
POST_BACKWARD
)
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# Cast grad to param's dtype (typically FP32). Note: we do this
# before the cpu offload step so that this entire hook remains
# non-blocking. The downside is a bit more D2H transfer in that case.
if
self
.
mixed_precision
:
orig_param_grad_data
=
reduced_grad
.
data
reduced_grad
.
data
=
reduced_grad
.
data
.
to
(
dtype
=
param
.
zero_fp32_shard
.
dtype
)
# Don't let this memory get reused until after the transfer.
orig_param_grad_data
.
record_stream
(
torch
.
cuda
.
current_stream
())
if
param
.
zero_is_sharded
:
# Accumulate into the gradient shard.
if
getattr
(
param
,
"zero_saved_grad_shard"
,
None
)
is
None
:
param
.
zero_saved_grad_shard
=
reduced_grad
.
data
else
:
assert
(
param
.
zero_saved_grad_shard
.
shape
==
reduced_grad
.
shape
),
f
"
{
param
.
zero_saved_grad_shard
.
shape
}
\
vs
{
reduced_grad
.
shape
}
"
param
.
zero_saved_grad_shard
.
data
+=
reduced_grad
.
data
reduced_grad
=
param
.
zero_saved_grad_shard
.
data
else
:
# We can't modify the dtype of grad in this function
# So we use `param.zero_saved_grad` to store gradient
# This is useful when using mixed precision mode on single node
if
getattr
(
param
,
'zero_saved_grad'
,
None
)
is
None
:
param
.
zero_saved_grad
=
reduced_grad
.
data
else
:
param
.
zero_saved_grad
.
data
+=
reduced_grad
.
data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if
self
.
_cpu_offload
:
param
.
zero_cpu_grad
.
copy_
(
reduced_grad
.
data
,
non_blocking
=
True
)
# Don't let this memory get reused until after the transfer.
reduced_grad
.
data
.
record_stream
(
torch
.
cuda
.
current_stream
())
def
_register_final_backward_hook
(
self
)
->
None
:
"""Try to queue a `_final_backward_hook` callback.
Only called on root and only queue one callback at the beginning of
outer most backward.
"""
assert
self
.
_is_root
if
not
self
.
_post_backward_callback_queued
:
self
.
_assert_state
([
TrainingState
.
IDLE
])
self
.
_post_backward_callback_queued
=
True
Variable
.
_execution_engine
.
queue_callback
(
self
.
_final_backward_hook
)
@
torch
.
no_grad
()
def
_final_backward_hook
(
self
)
->
None
:
"""Wait for post-backward to finish. Only called on root instance."""
# None, backward runtime swallow the assert error, so we use assert_in_engine() here.
assert_in_engine
(
self
.
_is_root
,
"FinalBackwardHook not called on root"
)
# Check if the root module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.PRE_BACKWARD`.
if
any
([
p
.
requires_grad
for
p
in
self
.
params
]):
self
.
_assert_state
(
TrainingState
.
POST_BACKWARD
)
else
:
self
.
_assert_state
(
TrainingState
.
PRE_BACKWARD
)
self
.
param_manager
.
use_fp32_shards
()
if
self
.
_require_backward_grad_sync
:
# Flush any unreduced buckets in the post_backward stream.
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"post_backward"
]):
assert_in_engine
(
self
.
_reducer
is
not
None
,
"FinalBackwardHook: reducer is None"
)
assert
self
.
_reducer
is
not
None
# make mypy happy
self
.
_reducer
.
flush
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"post_backward"
])
if
self
.
_cpu_offload
:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch
.
cuda
.
current_stream
().
synchronize
()
# A backward pass is done, clean up below.
# Free reducer buffers.
if
self
.
_reducer
is
not
None
:
self
.
_reducer
.
free
()
def
_finalize_parameters
(
zero_module
:
ShardedModel
)
->
None
:
"""Helper used below on all zero3 modules."""
for
p
in
zero_module
.
params
:
if
not
p
.
requires_grad
:
continue
if
hasattr
(
p
,
"zero_shard_bwd_hook"
):
p
.
zero_shard_bwd_hook
.
remove
()
delattr
(
p
,
"zero_shard_bwd_hook"
)
# Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad
# remains the unsharded gradient accumulated from prior no-sync passes, and p.zero_saved_grad_shard
# remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and
# sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
continue
# Parameter and gradient devices must match.
if
hasattr
(
p
,
"zero_cpu_grad"
):
assert_in_engine
(
p
.
device
==
torch
.
device
(
"cpu"
),
f
"FinalBackwardHook: incorrect cpu_grad device
{
p
.
device
}
"
)
p
.
grad
=
p
.
zero_cpu_grad
elif
hasattr
(
p
,
"zero_saved_grad_shard"
):
assert_in_engine
(
p
.
device
==
p
.
zero_saved_grad_shard
.
device
,
f
"FinalBackwardHook: incorrect saved_grad_shard device
\
{
p
.
device
}
vs
{
p
.
zero_saved_grad_shard
.
device
}
"
,
)
p
.
grad
=
p
.
zero_saved_grad_shard
elif
hasattr
(
p
,
'zero_saved_grad'
):
p
.
grad
=
p
.
zero_saved_grad
if
hasattr
(
p
,
"zero_saved_grad_shard"
):
delattr
(
p
,
"zero_saved_grad_shard"
)
if
hasattr
(
p
,
'zero_saved_grad'
):
delattr
(
p
,
"zero_saved_grad"
)
# Update root and nested ShardedModel's hooks and flags.
for
m
in
self
.
modules
():
# includes self
if
isinstance
(
m
,
ShardedModel
):
_finalize_parameters
(
m
)
m
.
_pre_backward_hook_has_run
=
False
if
any
(
p
.
requires_grad
for
p
in
m
.
parameters
()):
# Check if the module has params and if any of them has
# the `requires_grad` field set. If `requires_grad=False` for
# all the params, the post_backward hook will not fire and the
# state will remain in `TrainingState.PRE_BACKWARD`.
if
any
([
p
.
requires_grad
for
p
in
m
.
params
]):
m
.
_assert_state
(
TrainingState
.
POST_BACKWARD
)
else
:
m
.
_assert_state
(
TrainingState
.
PRE_BACKWARD
)
else
:
# When `m` and its children has no params or has params but
# none with `requires_grad==True`, there are two cases:
# 1. output tensors are `requires_grad==True`. In this case,
# pre-backward hook is still registered, so it is in PRE_BACKWARD state.
# 2. output tensors are `requires_grad==False`. In this case,
# pre-backward hook is not registered, so it is in IDLE state.
m
.
_assert_state
([
TrainingState
.
PRE_BACKWARD
,
TrainingState
.
IDLE
])
m
.
training_state
=
TrainingState
.
IDLE
if
m
.
_is_root
:
# reset this flag for cases like "one forward pass + multiple backward passes"
self
.
_post_backward_callback_queued
=
False
# clear this list for next iteration
assert_in_engine
(
self
.
_output_pre_backward_hook_registered
is
not
None
,
"FinalBackwardHook: self._output_pre_backward_hook_registered should not be None"
,
)
assert
self
.
_output_pre_backward_hook_registered
is
not
None
# make mypy happy
self
.
_output_pre_backward_hook_registered
.
clear
()
@
contextlib
.
contextmanager
def
gather_full_params
(
self
,
recurse
:
bool
=
True
,
volatile
:
bool
=
False
)
->
Generator
:
"""
A context manager to expose full params for the current ShardedModel instance.
Can be useful *after* forward/backward for a model to get the params for
additional processing or checking. Parameters will be gathered in full
precision (e.g., FP32).
.. note:: This can be used on inner ShardedModels.
.. note:: This can *not* be used within a forward or backward pass. Nor
can forward and backward be started from within this context.
.. note:: The full parameters will be freed after the context manager
exits; it is up to the caller to clone them if needed.
.. note:: The full parameters can be modified, but only the portion
corresponding to the local param shard will persist after the
context manager exits (unless ``volatile=True``, in which case there
are no guarantees about persistence).
Args:
recurse (bool, Optional): recursively summon all params for nested
ShardedModel instances (default: True)
volatile (bool, Optional): if ``True``, modifications to params are
not guaranteed to persist after the context manager exists;
enabling this can be slightly more efficient (default: False)
"""
if
recurse
:
with
contextlib
.
ExitStack
()
as
stack
:
# Summon all params for any nested Zero3Model instances.
for
module
in
self
.
modules
():
if
isinstance
(
module
,
ShardedModel
):
stack
.
enter_context
(
module
.
gather_full_params
(
recurse
=
False
,
volatile
=
volatile
))
# Yield to the caller, with full params in all nested instances.
yield
# Exiting from the ExitStack will re-shard params.
return
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
self
.
_assert_state
(
TrainingState
.
IDLE
)
# Set the state so that we assert when trying to go into
# forward/backward.
self
.
training_state
=
TrainingState
.
GATHER_FULL_PARAMS
full_tensors
=
self
.
param_manager
.
rebuild_full_params
(
force_full_precision
=
True
)
assert
full_tensors
is
not
None
with
contextlib
.
ExitStack
()
as
stack
:
try
:
yield
finally
:
stack
.
close
()
for
p
,
(
full_tensor
,
safe_to_free
)
in
zip
(
self
.
params
,
full_tensors
):
if
not
volatile
:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard
,
_
=
get_shard
(
full_tensor
)
p
.
zero_fp32_shard
.
copy_
(
local_shard
.
view_as
(
p
.
zero_fp32_shard
))
if
safe_to_free
:
free_storage
(
full_tensor
)
self
.
has_full_params
=
False
self
.
param_manager
.
use_fp32_shards
()
self
.
training_state
=
TrainingState
.
IDLE
def
apply
(
self
,
fn
:
Callable
[[
nn
.
Module
],
None
])
->
"ShardedModel"
:
"""
Applies ``fn`` recursively to every submodule (as returned by
``.children()``) as well as self. Typical use includes initializing the
parameters of a model.
Compared to ``torch.nn.Module.apply``, this version additionally gathers
the full parameters before applying ``fn``. It should not be called from
within another ``summon_full_params`` context.
Args:
fn (nn.Module): function to be applied to each submodule
Returns:
Module: self
"""
is_uninitialized
=
self
.
_is_root
is
None
self
.
_assert_state
(
TrainingState
.
IDLE
)
with
self
.
gather_full_params
(
recurse
=
False
):
return_value
=
super
().
apply
(
fn
)
# summon_full_params will call _lazy_init, which sets _is_root. However,
# apply() may be called directly on children instances to do weight
# init, so we should reset the _is_root flag in this case.
if
is_uninitialized
and
self
.
_is_root
:
for
module
in
self
.
modules
():
if
isinstance
(
module
,
ShardedModel
):
module
.
_reset_lazy_init_info
()
return
return_value
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
try
:
return
super
().
__getattr__
(
name
)
except
AttributeError
:
return
getattr
(
self
.
module
,
name
)
def
__getstate__
(
self
)
->
Dict
[
str
,
str
]:
"""Serialize the state.
Some properties are not serializable (e.g., process groups, streams), so
we remove them and try to reconstruct them in :func:`__setstate__`.
"""
state
=
copy
.
copy
(
self
.
__dict__
)
state
[
"is_sharded"
]
=
[
p
.
zero_is_sharded
for
p
in
self
.
params
]
state
[
"orig_sizes"
]
=
[
p
.
zero_orig_size
for
p
in
self
.
params
]
if
state
[
"process_group"
]
is
not
None
:
state
[
"process_group"
]
=
"MISSING"
# process_group isn't pickleable
if
state
[
"process_group_reduce_scatter"
]
is
not
None
:
state
[
"process_group_reduce_scatter"
]
=
"MISSING"
# process_group_reduce_scatter isn't pickleable
self
.
_reset_lazy_init_info
()
return
state
def
__setstate__
(
self
,
state
:
Dict
[
str
,
Any
])
->
None
:
"""Intercept state setting and perform needed changes on params."""
super
().
__setstate__
(
state
)
def
fixup
(
p
:
Parameter
,
is_sharded
:
bool
,
size
:
torch
.
Size
)
->
Parameter
:
assert
isinstance
(
p
,
Parameter
)
p
.
data
=
p
.
data
.
clone
()
# move tensors out of shared memory
p
.
zero_is_sharded
=
is_sharded
p
.
zero_orig_size
=
size
return
p
self
.
params
=
[
fixup
(
p
,
is_sharded
,
size
)
for
p
,
is_sharded
,
size
in
zip
(
self
.
params
,
self
.
is_sharded
,
self
.
orig_sizes
)
]
del
self
.
is_sharded
del
self
.
orig_sizes
self
.
_reset_lazy_init_info
()
def
__getitem__
(
self
,
key
:
int
)
->
Any
:
"""Forward indexing calls in case the module is a nn.Sequential."""
return
self
.
module
.
__getitem__
(
key
)
@
contextlib
.
contextmanager
def
no_sync
(
self
)
->
Generator
:
"""
A context manager to disable gradient synchronizations across ShardedModel
processes. Within this context, gradients will be accumulated on module
variables, which will later be synchronized in the first
forward-backward pass after exiting the context.
.. note:: This likely results in higher memory usage because ShardedModel will
accumulate the full model gradients (instead of gradient shards)
until the eventual sync.
.. note:: Gradient accumulation can be done without this context,
avoiding the extra GPU memory overhead, but with the extra
networking overhead.
"""
self
.
_lazy_init
()
assert
self
.
_is_root
,
"no_sync on inner ShardedModel is not supported"
self
.
_assert_state
(
TrainingState
.
IDLE
)
# This instance may wrap other ShardedModel instances and we
# need to set all of them to accumulate gradients.
old_flags
=
[]
for
m
in
self
.
modules
():
# includes self
if
isinstance
(
m
,
ShardedModel
):
old_flags
.
append
((
m
,
m
.
_require_backward_grad_sync
))
m
.
_require_backward_grad_sync
=
False
try
:
yield
finally
:
for
m
,
old_flag
in
old_flags
:
assert
m
.
_require_backward_grad_sync
is
False
m
.
_require_backward_grad_sync
=
old_flag
def
_assert_state
(
self
,
state
:
Union
[
TrainingState
,
List
[
TrainingState
]])
->
None
:
"""Assert we are in the given state."""
# Since assert can be turned off and this error checking
# is really important, we use explicit error checking
# and raise a ValueError if needed.
if
isinstance
(
state
,
TrainingState
):
state
=
[
state
]
if
self
.
training_state
not
in
state
:
msg
=
f
"expected to be in states
{
state
}
but current state "
f
"is
{
self
.
training_state
}
"
# In case we are failing in the context of autograd hook, asserting
# may not generate useful msg. So, let's print it to be sure.
self
.
logger
.
error
(
f
'Zero3 instance
{
self
}
got error:
{
msg
}
'
,
ranks
=
[
0
])
if
self
.
rank
==
0
:
traceback
.
print_stack
()
raise
ValueError
(
msg
)
def
extra_repr
(
self
)
->
str
:
repr
=
(
f
"world_size=
{
self
.
world_size
}
, "
f
"mixed_precision=
{
self
.
mixed_precision
}
, "
)
if
self
.
verbose
:
repr
=
(
f
"rank=
{
self
.
rank
}
, "
+
repr
+
f
"reshard_after_forward=
{
self
.
reshard_after_forward
}
, "
f
"compute_dtype=
{
self
.
compute_dtype
}
, "
f
"buffer_dtype=
{
self
.
buffer_dtype
}
, "
f
"fp32_reduce_scatter=
{
self
.
fp32_reduce_scatter
}
, "
f
"compute_device=
{
self
.
compute_device
}
"
f
"reduce_scatter_bucket_size_mb=
{
self
.
reduce_scatter_bucket_size_mb
}
, "
f
"clear_autocast_cache=
{
self
.
clear_autocast_cache
}
"
f
"force_input_to_fp32=
{
self
.
force_input_to_fp32
}
"
f
"offload_config=
{
self
.
offload_config
}
"
)
return
repr
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
"""
Returns the whole (unsharded) state of the module. Parameters are not
sharded, so the resulting state_dict can be loaded directly by the
wrapped Module without any sharding-specific logic. Returned tensors
will be full precision (e.g., FP32).
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
def
maybe_cast_buffers
(
dtype
:
Optional
[
torch
.
dtype
]
=
None
)
->
None
:
if
self
.
mixed_precision
:
self
.
_cast_buffers
(
dtype
=
dtype
)
assert
self
.
_return_full_state_dict
is
True
,
'Only support return full state dict now'
if
self
.
training_state
!=
TrainingState
.
GATHER_FULL_PARAMS
:
with
self
.
gather_full_params
(
recurse
=
False
,
volatile
=
True
):
maybe_cast_buffers
(
torch
.
float32
)
state_dict
=
super
().
state_dict
()
else
:
maybe_cast_buffers
(
torch
.
float32
)
state_dict
=
super
().
state_dict
(
destination
=
destination
,
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
_cpu_offload
:
for
k
,
tensor
in
state_dict
.
items
():
state_dict
[
k
]
=
tensor
.
cpu
()
# In case we are in mixed precision, restore buffers back to buffer_dtype.
maybe_cast_buffers
()
return
state_dict
def
load_state_dict
(
self
,
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
strict
:
bool
=
True
)
->
NamedTuple
:
"""
Load a whole (unsharded) state_dict.
.. warning:: This needs to be called on all ranks, since synchronization
primitives will be used.
"""
if
self
.
_return_full_state_dict
:
with
self
.
gather_full_params
():
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
else
:
torch
.
cuda
.
synchronize
()
self
.
_lazy_init
()
return
self
.
module
.
load_state_dict
(
state_dict
,
strict
)
def
_post_state_dict_hook
(
state_dict_on_rank_0_only
:
bool
,
module
:
Zero3ParameterManager
,
state_dict
:
"OrderedDict[str, torch.Tensor]"
,
prefix
:
str
,
*
args
:
Any
,
)
->
"OrderedDict[str, torch.Tensor]"
:
# When state_dict_on_rank_0_only is ``True``, ``model.state_dict()`` will only
# returns full state dict on rank 0 and return empty dict non-rank 0,
# which allow ShardedModel to skip the GPU -> CPU copy on
# non-rank 0 altogether and prevent OOM.
if
state_dict_on_rank_0_only
and
dist
.
get_rank
()
!=
0
:
state_dict
.
clear
()
return
state_dict
# Assuming we are in a ``gather_full_params()`` context, we need to clone
# each tensor so that it does not get freed (in-place) when the context
# exits. At the same time, this hook can be called multiple times
# recursively, so we need to make sure that we only clone each tensor at
# most once. Thus we add an attribute on the tensor called "_has_been_cloned"
# which keeps track of tensors that are no longer at risk of being freed.
for
key
in
state_dict
.
keys
():
if
not
key
.
startswith
(
prefix
)
or
getattr
(
state_dict
[
key
],
"_has_been_cloned"
,
False
):
continue
if
state_dict
[
key
].
device
.
type
!=
module
.
state_dict_device
.
type
:
state_dict
[
key
]
=
state_dict
[
key
].
to
(
device
=
module
.
state_dict_device
)
state_dict
[
key
].
_has_been_cloned
=
True
elif
module
.
training_state
==
TrainingState
.
GATHER_FULL_PARAMS
:
# We copy the state_dict since full param will be freed after we
# exit the ``summon_full_params()`` context.
state_dict
[
key
]
=
state_dict
[
key
].
clone
()
state_dict
[
key
].
_has_been_cloned
=
True
# Remove "_zero3_module." prefix
replace_state_dict_prefix
(
state_dict
,
prefix
+
"_zero3_module."
,
prefix
)
return
state_dict
def
_pre_load_state_dict_hook
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
prefix
:
str
,
*
args
:
Any
)
->
None
:
replace_state_dict_prefix
(
state_dict
,
prefix
,
prefix
+
"_zero3_module."
)
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
4d322b79
...
@@ -18,8 +18,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
...
@@ -18,8 +18,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
._
zero3_
utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
get_gradient_predivide_factor
)
class
ShardedModelV2
(
nn
.
Module
):
class
ShardedModelV2
(
nn
.
Module
):
...
...
colossalai/zero/sharded_optim/__init__.py
View file @
4d322b79
from
.sharded_optim
import
ShardedOptimizer
from
.sharded_optim_v2
import
ShardedOptimizerV2
from
.sharded_optim_v2
import
ShardedOptimizerV2
__all__
=
[
'ShardedOptimizer'
,
'ShardedOptimizerV2'
]
__all__
=
[
'ShardedOptimizerV2'
]
colossalai/zero/sharded_optim/bookkeeping/__init__.py
deleted
100644 → 0
View file @
6a3f9fda
from
.gradient_store
import
GradientStore
from
.parameter_store
import
ParameterStore
from
.bucket_store
import
BucketStore
from
.tensor_bucket
import
TensorBucket
__all__
=
[
'GradientStore'
,
'ParameterStore'
,
'BucketStore'
,
'TensorBucket'
]
\ No newline at end of file
colossalai/zero/sharded_optim/bookkeeping/base_store.py
deleted
100644 → 0
View file @
6a3f9fda
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
class
BaseStore
:
def
__init__
(
self
,
dp_parallel_mode
=
ParallelMode
.
DATA
):
self
.
_world_size
=
gpc
.
get_world_size
(
dp_parallel_mode
)
self
.
_local_rank
=
gpc
.
get_local_rank
(
dp_parallel_mode
)
@
property
def
world_size
(
self
):
return
self
.
_world_size
@
property
def
local_rank
(
self
):
return
self
.
_local_rank
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
deleted
100644 → 0
View file @
6a3f9fda
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
.base_store
import
BaseStore
class
BucketStore
(
BaseStore
):
def
__init__
(
self
,
dp_parallel_mode
):
super
().
__init__
(
dp_parallel_mode
)
self
.
_grads
=
dict
()
self
.
_params
=
dict
()
self
.
_num_elements_in_bucket
=
dict
()
self
.
reset
()
def
num_elements_in_bucket
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_num_elements_in_bucket
[
reduce_rank
]
def
add_num_elements_in_bucket
(
self
,
num_elements
,
reduce_rank
:
int
=
None
):
self
.
_num_elements_in_bucket
[
reduce_rank
]
+=
num_elements
def
add_grad
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
self
.
_grads
[
reduce_rank
].
append
(
tensor
)
def
add_param
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
self
.
_params
[
reduce_rank
].
append
(
tensor
)
def
reset
(
self
):
keys
=
[
None
]
+
list
(
range
(
self
.
_world_size
))
self
.
_grads
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_params
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_num_elements_in_bucket
=
{
rank
:
0
for
rank
in
keys
}
def
reset_by_rank
(
self
,
reduce_rank
=
None
):
self
.
_grads
[
reduce_rank
]
=
[]
self
.
_params
[
reduce_rank
]
=
[]
self
.
_num_elements_in_bucket
[
reduce_rank
]
=
0
def
get_grad
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_grads
[
reduce_rank
]
def
get_param
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_params
[
reduce_rank
]
colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
deleted
100644 → 0
View file @
6a3f9fda
from
typing
import
List
from
torch
import
Tensor
from
.base_store
import
BaseStore
class
GradientStore
(
BaseStore
):
def
__init__
(
self
,
*
args
):
super
().
__init__
(
*
args
)
# bookkeeping data structures
self
.
_averaged_gradients
=
dict
()
# for backward reduction hooks
self
.
_grad_acc_objs
=
[]
def
add_accumulate_grad_object
(
self
,
obj
):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.
:param obj: An object of :class:`AccumulateGrad` class
:type obj: :class:`AccumulateGrad`
"""
self
.
_grad_acc_objs
.
append
(
obj
)
def
get_averaged_gradients_by_group
(
self
,
group_id
:
int
)
->
List
[
Tensor
]:
"""
Return average gradients of a parameter group
:param group_id: The index of parameter group
:type group_id: int
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
return
self
.
_averaged_gradients
[
group_id
]
def
add_average_gradient_by_group
(
self
,
group_id
:
int
,
tensor
:
Tensor
)
->
None
:
"""
Append an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
"""
if
group_id
in
self
.
_averaged_gradients
:
self
.
_averaged_gradients
[
group_id
].
append
(
tensor
)
else
:
self
.
_averaged_gradients
[
group_id
]
=
[
tensor
]
def
reset_average_gradients_by_group
(
self
,
group_id
:
int
)
->
None
:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group
:type group_id: int
"""
self
.
_averaged_gradients
[
group_id
]
=
[]
colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
deleted
100644 → 0
View file @
6a3f9fda
from
.base_store
import
BaseStore
from
torch
import
Tensor
from
typing
import
List
class
ParameterStore
(
BaseStore
):
def
__init__
(
self
,
dp_paralle_mode
):
super
().
__init__
(
dp_paralle_mode
)
# param partitioning data structures
self
.
_fp16_param_to_rank
=
dict
()
self
.
_rank_groupid_to_fp16_param_list
=
dict
()
self
.
_rank_group_id_to_flat_fp16_param
=
dict
()
# param reduction data structures
self
.
_is_param_reduced
=
dict
()
self
.
_reduced_param
=
[]
def
set_param_to_rank
(
self
,
tensor
:
Tensor
,
rank
:
int
)
->
None
:
"""
Set the mapping between parameter to rank, each parameter should be owned by a rank.
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:param rank: The rank of which the process is responsible for updating the parameter
:type rank: int
"""
self
.
_fp16_param_to_rank
[
tensor
]
=
rank
def
get_param_rank
(
self
,
tensor
:
Tensor
)
->
int
:
"""
Gives the rank which the parameter belongs to
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
"""
return
self
.
_fp16_param_to_rank
[
tensor
]
def
belongs_to_current_rank
(
self
,
tensor
)
->
bool
:
"""
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
"""
tensor_rank
=
self
.
_fp16_param_to_rank
[
tensor
]
return
tensor_rank
==
self
.
_local_rank
def
add_fp16_param_list_by_rank_group
(
self
,
rank
,
group_id
,
tensor_list
)
->
None
:
if
rank
not
in
self
.
_rank_groupid_to_fp16_param_list
:
self
.
_rank_groupid_to_fp16_param_list
[
rank
]
=
dict
()
if
group_id
not
in
self
.
_rank_groupid_to_fp16_param_list
[
rank
]:
self
.
_rank_groupid_to_fp16_param_list
[
rank
][
group_id
]
=
[]
self
.
_rank_groupid_to_fp16_param_list
[
rank
][
group_id
].
extend
(
tensor_list
)
def
get_fp16_params_by_rank_group
(
self
,
rank
,
group_id
)
->
List
[
Tensor
]:
return
self
.
_rank_groupid_to_fp16_param_list
[
rank
][
group_id
]
def
add_flat_fp16_param_by_rank_group
(
self
,
rank
,
group_id
,
tensor
)
->
None
:
if
rank
not
in
self
.
_rank_group_id_to_flat_fp16_param
:
self
.
_rank_group_id_to_flat_fp16_param
[
rank
]
=
dict
()
self
.
_rank_group_id_to_flat_fp16_param
[
rank
][
group_id
]
=
tensor
def
get_flat_fp16_param_by_rank_group
(
self
,
rank
,
group_id
)
->
Tensor
:
return
self
.
_rank_group_id_to_flat_fp16_param
[
rank
][
group_id
]
def
is_param_reduced
(
self
,
tensor
):
return
self
.
_is_param_reduced
[
tensor
]
def
set_param_reduction_state
(
self
,
tensor
,
state
):
self
.
_is_param_reduced
[
tensor
]
=
state
def
get_param_reduction_states
(
self
):
return
self
.
_is_param_reduced
def
reset_previous_reduced_params
(
self
):
self
.
_reduced_param
=
[]
def
add_previous_reduced_param
(
self
,
tensor
):
self
.
_reduced_param
.
append
(
tensor
)
def
clear_grads_of_previous_reduced_params
(
self
):
if
len
(
self
.
_reduced_param
)
>
0
:
for
param
in
self
.
_reduced_param
:
param
.
grad
=
None
self
.
reset_previous_reduced_params
()
colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py
deleted
100644 → 0
View file @
6a3f9fda
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
class
TensorBucket
:
def
__init__
(
self
,
size
):
self
.
_max_size
=
size
self
.
_current_size
=
0
self
.
_bucket
=
[]
@
property
def
max_size
(
self
):
return
self
.
_max_size
@
property
def
current_size
(
self
):
return
self
.
_current_size
def
is_full_or_oversized
(
self
):
return
self
.
_current_size
>=
self
.
_max_size
def
is_empty
(
self
):
return
len
(
self
.
_bucket
)
==
0
def
add_to_bucket
(
self
,
tensor
,
allow_oversize
=
False
):
tensor_size
=
tensor
.
numel
()
if
not
allow_oversize
and
self
.
will_exceed_max_size
(
tensor_size
):
msg
=
f
"The param bucket max size
{
self
.
_max_size
}
is exceeded"
\
+
f
"by tensor (size
{
tensor_size
}
)"
raise
RuntimeError
(
msg
)
self
.
_bucket
.
append
(
tensor
)
self
.
_current_size
+=
tensor_size
def
will_exceed_max_size
(
self
,
tensor_size
):
expected_size
=
self
.
_current_size
+
tensor_size
return
expected_size
>
self
.
_max_size
def
get_bucket
(
self
):
return
self
.
_bucket
def
empty
(
self
):
self
.
_bucket
=
[]
self
.
_size
=
0
def
flatten
(
self
):
return
_flatten_dense_tensors
(
self
.
_bucket
)
def
unflatten_and_copy
(
self
,
flat_tensor
):
unflattened_tensor_list
=
_unflatten_dense_tensors
(
flat_tensor
,
self
.
_bucket
)
for
old
,
new
in
zip
(
self
.
_bucket
,
unflattened_tensor_list
):
old
.
copy_
(
new
)
colossalai/zero/sharded_optim/sharded_optim.py
deleted
100644 → 0
View file @
6a3f9fda
from
colossalai.utils.cuda
import
get_current_device
import
torch
import
torch.distributed
as
dist
from
colossalai.logging
import
get_dist_logger
from
torch.optim
import
Optimizer
from
.bookkeeping
import
ParameterStore
,
GradientStore
,
BucketStore
,
TensorBucket
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.amp.naive_amp.grad_scaler
import
DynamicGradScaler
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
._utils
import
(
move_tensor
,
flatten
,
get_grad_accumulate_object
,
split_half_float_double
,
reduce_tensor
,
release_param_grad
,
calculate_global_norm_from_list
,
compute_norm
,
sync_param
,
has_inf_or_nan
)
from
functools
import
partial
class
ShardedOptimizer
(
ColossalaiOptimizer
):
def
__init__
(
self
,
optimizer
:
Optimizer
,
initial_scale
=
2
**
32
,
min_scale
=
1
,
growth_factor
=
2
,
backoff_factor
=
0.5
,
growth_interval
=
1000
,
hysteresis
=
2
,
max_scale
:
int
=
2
**
32
,
clip_grad_norm
=
2.0
,
verbose
=
False
,
reduce_bucket_size
=
500000000
,
communication_dtype
=
torch
.
float16
,
overlap_communication
=
False
,
partition_grad
=
False
,
dp_parallel_mode
=
ParallelMode
.
DATA
,
mp_parallel_mode
=
ParallelMode
.
MODEL
,
cpu_offload
=
False
,
cpu_fp16_param
=
False
,
cpu_fp16_grad
=
False
):
# TODO: add support for
# 1. fp16 master weights
# 2. contiguous gradients
# 3. cpu offload
# 4. support when some parameters requires_grad = False
self
.
_optimizer
=
optimizer
self
.
_dtype
=
self
.
_optimizer
.
param_groups
[
0
][
'params'
][
0
].
dtype
self
.
_logger
=
get_dist_logger
()
self
.
_verbose
=
verbose
# stage 2
self
.
_partition_grads
=
partition_grad
# cpu_offload
self
.
_cpu_offload
=
cpu_offload
self
.
_cpu_fp16_param
=
cpu_fp16_param
self
.
_cpu_fp16_grad
=
cpu_fp16_grad
# get process groups
self
.
_dp_parallel_mode
=
dp_parallel_mode
self
.
_mp_parallel_mode
=
mp_parallel_mode
self
.
_local_rank
=
gpc
.
get_local_rank
(
dp_parallel_mode
)
self
.
_world_size
=
gpc
.
get_world_size
(
dp_parallel_mode
)
self
.
_dp_group
=
gpc
.
get_group
(
dp_parallel_mode
)
if
gpc
.
is_initialized
(
mp_parallel_mode
)
and
gpc
.
get_world_size
(
mp_parallel_mode
)
>
1
:
self
.
_mp_group
=
gpc
.
get_group
(
mp_parallel_mode
)
else
:
self
.
_mp_group
=
None
# fp16 and fp32 params for mixed precision training
self
.
_fp16_param_groups
=
dict
()
self
.
_fp32_flat_param_groups_of_current_rank
=
dict
()
# communication params
self
.
_overlap_communication
=
overlap_communication
self
.
_reduce_bucket_size
=
reduce_bucket_size
self
.
_communication_dtype
=
communication_dtype
# gradient scaler
self
.
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
initial_scale
,
min_scale
=
min_scale
,
growth_factor
=
growth_factor
,
backoff_factor
=
backoff_factor
,
growth_interval
=
growth_interval
,
hysteresis
=
hysteresis
,
max_scale
=
max_scale
,
verbose
=
verbose
)
self
.
_found_overflow
=
torch
.
FloatTensor
([
0
]).
to
(
get_current_device
())
# gradient clipping
self
.
_clip_grad_norm
=
clip_grad_norm
# check argument conflict
self
.
_sanity_checks
()
# ParameterStore will manage the tensor buffers used for zero
# it will not manage the tensors used by mixed precision training
self
.
_param_store
=
ParameterStore
(
self
.
_dp_parallel_mode
)
self
.
_grad_store
=
GradientStore
(
self
.
_dp_parallel_mode
)
self
.
_bucket_store
=
BucketStore
(
self
.
_dp_parallel_mode
)
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
# and add buffers to parameter store for future access
for
group_id
,
param_group
in
enumerate
(
self
.
_optimizer
.
param_groups
):
params
=
param_group
[
'params'
]
# add the fp16 params to fp16_param_groups for bookkeeping
self
.
_fp16_param_groups
[
group_id
]
=
params
# assign parameters to ranks
# the params in the list are sorted
params_per_rank
=
self
.
_partition_param_list
(
params
)
# store the mapping between param to rank
# each param should belong to only one rank
for
rank
,
params
in
enumerate
(
params_per_rank
):
self
.
_param_store
.
add_fp16_param_list_by_rank_group
(
rank
,
group_id
,
params
)
for
param
in
params
:
self
.
_param_store
.
set_param_to_rank
(
param
,
rank
)
# move to cpu to make room to create the flat tensor
move_tensor
(
params
,
device
=
'cpu'
)
# flatten the reordered tensors
for
rank
in
range
(
self
.
_world_size
):
tensor_list
=
self
.
_param_store
.
get_fp16_params_by_rank_group
(
rank
,
group_id
)
flat_tensor
=
flatten
(
tensor_list
)
flat_tensor
=
flat_tensor
.
cuda
()
self
.
_param_store
.
add_flat_fp16_param_by_rank_group
(
rank
,
group_id
,
flat_tensor
)
# sync parameters
for
rank
in
range
(
self
.
_world_size
):
flat_tensor
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
rank
,
group_id
)
tensor_list
=
self
.
_param_store
.
get_fp16_params_by_rank_group
(
rank
,
group_id
)
sync_param
(
flat_tensor
=
flat_tensor
,
tensor_list
=
tensor_list
)
# create a copy of fp32 weights of the parameters for which this rank is responsible
fp16_flat_current_rank
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
self
.
_local_rank
,
group_id
)
# when using cpu offload, our cpu adam support fp16 paramters
if
self
.
_cpu_fp16_param
:
fp32_flat_current_rank
=
fp16_flat_current_rank
.
detach
()
else
:
fp32_flat_current_rank
=
fp16_flat_current_rank
.
detach
().
float
()
device
=
'cpu'
if
self
.
_cpu_offload
else
get_current_device
()
fp32_flat_current_rank
=
fp32_flat_current_rank
.
to
(
device
)
fp32_flat_current_rank
.
requires_grad
=
True
self
.
_fp32_flat_param_groups_of_current_rank
[
group_id
]
=
fp32_flat_current_rank
# need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group
[
'params'
]
=
[
fp32_flat_current_rank
]
# set reduction state
for
param
in
self
.
_fp16_param_groups
[
group_id
]:
self
.
_param_store
.
set_param_reduction_state
(
param
,
False
)
# intialize communication stream for
# communication-compuation overlapping
if
self
.
_overlap_communication
:
self
.
_comm_stream
=
torch
.
cuda
.
Stream
()
# reduction hook is only used if overlapping communication
# or stage 2 is used
# if it is stage 1 without overlapping, no hook will be attached
if
self
.
_overlap_communication
or
self
.
_partition_grads
:
self
.
_attach_reduction_hook
()
self
.
_initialize_optimizer_states
()
@
property
def
loss_scale
(
self
):
return
self
.
grad_scaler
.
scale
@
property
def
num_param_groups
(
self
):
return
len
(
self
.
_fp16_param_groups
)
def
_partition_param_list
(
self
,
param_list
):
params_per_rank
=
[[]
for
_
in
range
(
self
.
_world_size
)]
numel_per_rank
=
[
0
for
_
in
range
(
self
.
_world_size
)]
# partititon the parameters in a greedy fashion
sorted_params
=
sorted
(
param_list
,
key
=
lambda
x
:
x
.
numel
(),
reverse
=
True
)
for
param
in
sorted_params
:
# allocate this parameter to the rank with
# the smallest numel for load balancing purpose
rank_to_go
=
numel_per_rank
.
index
(
min
(
numel_per_rank
))
params_per_rank
[
rank_to_go
].
append
(
param
)
numel_per_rank
[
rank_to_go
]
+=
param
.
numel
()
if
self
.
_verbose
:
self
.
_logger
.
info
(
f
'Number of elements on ranks:
{
numel_per_rank
}
'
,
ranks
=
[
0
],
parallel_mode
=
self
.
_dp_parallel_mode
)
return
params_per_rank
def
_initialize_optimizer_states
(
self
):
# create a dummy zero tensor which has the same shape as that of the param
# set this dummpy zero tensor as grad
for
group_id
in
range
(
len
(
self
.
_fp32_flat_param_groups_of_current_rank
)):
fp32_partition_param
=
self
.
_fp32_flat_param_groups_of_current_rank
[
group_id
]
fp32_partition_grad
=
torch
.
zeros_like
(
fp32_partition_param
)
fp32_partition_param
.
grad
=
fp32_partition_grad
# update the parameter with zero gradients for initialization of optimizer stateus
self
.
_optimizer
.
step
()
# remove the grad of the paramter to save memory
for
group_id
,
fp32_flat_tensor
in
self
.
_fp32_flat_param_groups_of_current_rank
.
items
():
fp32_flat_tensor
.
grad
=
None
def
_sanity_checks
(
self
):
assert
torch
.
cuda
.
is_available
(),
'CUDA is required'
assert
self
.
_dtype
==
torch
.
float16
,
\
f
'Parameters are expected to be of type torch.float16, but got
{
self
.
_dtype
}
'
###########################################################
# Backward Reduction Hook
###########################################################
def
_attach_reduction_hook
(
self
):
# we iterate over the fp16 params
# on each param, we register a hook to its AccumulateGrad object
for
group_id
in
range
(
self
.
num_param_groups
):
param_group
=
self
.
_fp16_param_groups
[
group_id
]
for
param
in
param_group
:
if
param
.
requires_grad
:
# determines the reduction destionation rank
# this is only valid for stage 2
# dst_rank = None means using all-reduce
# else using reduce
if
self
.
_partition_grads
:
reduce_rank
=
self
.
_param_store
.
get_param_rank
(
param
)
else
:
reduce_rank
=
None
def
_define_and_attach
(
param
,
reduce_rank
):
# get the AccumulateGrad object of the param itself
accum_grad_obj
=
get_grad_accumulate_object
(
param
)
self
.
_grad_store
.
add_accumulate_grad_object
(
accum_grad_obj
)
reduction_func
=
partial
(
self
.
_reduce_and_remove_grads_by_bucket
,
param
=
param
,
reduce_rank
=
reduce_rank
)
# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
# args here is not grad, but allow_unreacable and accumulate_grad
def
reduce_grad_hook
(
*
args
):
reduction_func
()
accum_grad_obj
.
register_hook
(
reduce_grad_hook
)
_define_and_attach
(
param
,
reduce_rank
)
def
_reduce_and_remove_grads_by_bucket
(
self
,
param
,
reduce_rank
=
None
):
param_size
=
param
.
numel
()
# check if the bucket is full
# if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty
if
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
)
+
param_size
>
self
.
_reduce_bucket_size
:
self
.
_reduce_grads_in_bucket
(
reduce_rank
)
# the param must not be reduced to ensure correctness
is_param_reduced
=
self
.
_param_store
.
is_param_reduced
(
param
)
if
is_param_reduced
:
msg
=
f
'Parameter of size (
{
param
.
size
()
}
) has already been reduced, '
\
+
'duplicate reduction will lead to arithmetic incorrectness'
raise
RuntimeError
(
msg
)
# the param must have grad for reduction
assert
param
.
grad
is
not
None
,
f
'Parameter of size (
{
param
.
size
()
}
) has None grad, cannot be reduced'
self
.
_bucket_store
.
add_num_elements_in_bucket
(
param_size
,
reduce_rank
)
self
.
_bucket_store
.
add_grad
(
param
.
grad
,
reduce_rank
)
self
.
_bucket_store
.
add_param
(
param
,
reduce_rank
)
def
_reduce_grads_in_bucket
(
self
,
reduce_rank
=
None
):
# reduce grads
self
.
_reduce_grads_by_rank
(
reduce_rank
=
reduce_rank
,
grads
=
self
.
_bucket_store
.
get_grad
(
reduce_rank
=
reduce_rank
),
bucket_size
=
self
.
_bucket_store
.
num_elements_in_bucket
(
reduce_rank
))
# use communication stream if overlapping
# communication with computation
if
self
.
_overlap_communication
:
stream
=
self
.
_comm_stream
else
:
stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
stream
):
params_in_bucket
=
self
.
_bucket_store
.
get_param
(
reduce_rank
=
reduce_rank
)
for
param
in
params_in_bucket
:
# the is_param_reduced flag should be False showing that
# this param is not reduced before calling self._reduce_grads_by_rank
is_param_reduced
=
self
.
_param_store
.
is_param_reduced
(
param
)
if
is_param_reduced
:
msg
=
f
'Parameter of size (
{
param
.
size
()
}
) has been reduced, '
+
\
'duplicate reduction will lead to arithmetic incorrectness'
raise
RuntimeError
(
msg
)
# update the flag
self
.
_param_store
.
set_param_reduction_state
(
param
,
True
)
# if partition grads = True
# we do not keep the gradient after reduction
if
self
.
_partition_grads
and
not
self
.
_param_store
.
belongs_to_current_rank
(
param
):
if
self
.
_overlap_communication
:
# we need to keep this gradient for now as reduction may
# be completed yet since it is using a different cuda stream
self
.
_param_store
.
add_previous_reduced_param
(
param
)
else
:
param
.
grad
=
None
self
.
_bucket_store
.
reset_by_rank
(
reduce_rank
)
def
_reduce_grads_by_rank
(
self
,
reduce_rank
,
grads
,
bucket_size
):
grad_buckets_by_dtype
=
split_half_float_double
(
grads
)
for
tensor_list
in
grad_buckets_by_dtype
:
self
.
_reduce_no_retain
(
tensor_list
=
tensor_list
,
bucket_size
=
bucket_size
,
reduce_rank
=
reduce_rank
)
##############################
# Reduction Utility Function #
##############################
def
_reduce_no_retain
(
self
,
tensor_list
,
bucket_size
,
reduce_rank
):
param_bucket
=
TensorBucket
(
size
=
bucket_size
)
for
tensor
in
tensor_list
:
param_bucket
.
add_to_bucket
(
tensor
,
allow_oversize
=
True
)
if
param_bucket
.
is_full_or_oversized
():
self
.
_reduce_and_copy
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
param_bucket
.
empty
()
if
not
param_bucket
.
is_empty
():
self
.
_reduce_and_copy
(
bucket
=
param_bucket
,
reduce_rank
=
reduce_rank
)
def
_reduce_and_copy
(
self
,
bucket
:
TensorBucket
,
reduce_rank
):
if
self
.
_overlap_communication
:
torch
.
cuda
.
synchronize
()
self
.
_param_store
.
clear_grads_of_previous_reduced_params
()
stream
=
self
.
_comm_stream
else
:
stream
=
torch
.
cuda
.
current_stream
()
with
torch
.
cuda
.
stream
(
stream
):
flat
=
bucket
.
flatten
()
reduced_flat
=
reduce_tensor
(
tensor
=
flat
,
dtype
=
self
.
_communication_dtype
,
dst_rank
=
reduce_rank
,
parallel_mode
=
self
.
_dp_parallel_mode
)
# update the reduced tensor
if
reduce_rank
is
None
or
reduce_rank
==
self
.
_local_rank
:
bucket
.
unflatten_and_copy
(
reduced_flat
)
################################
# torch.optim.Optimizer methods
################################
def
backward
(
self
,
loss
,
retain_graph
=
True
):
loss
=
self
.
loss_scale
*
loss
loss
.
backward
(
retain_graph
=
retain_graph
)
def
zero_grad
(
self
,
set_to_none
=
True
):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
will be set to None to save memory.
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
for
group_id
,
param_group
in
self
.
_fp16_param_groups
.
items
():
for
param
in
param_group
:
if
set_to_none
:
param
.
grad
=
None
else
:
if
param
.
grad
is
not
None
:
param
.
grad
.
detach
()
param
.
grad
.
zero_
()
####################
# Update Parameter #
####################
def
step
(
self
,
closure
=
None
):
assert
closure
is
None
,
'closure is not supported by step()'
# check for overflow
found_inf
=
self
.
_check_overflow
()
self
.
grad_scaler
.
update
(
found_inf
)
# update loss scale if overflow occurs
if
found_inf
:
self
.
_grad_store
.
_averaged_gradients
=
dict
()
self
.
zero_grad
()
return
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups
=
[]
norm_groups
=
[]
for
group_id
in
range
(
self
.
num_param_groups
):
# compute norm
norm_group
=
compute_norm
(
gradients
=
self
.
_grad_store
.
_averaged_gradients
[
group_id
],
params
=
self
.
_param_store
.
get_fp16_params_by_rank_group
(
group_id
=
group_id
,
rank
=
self
.
_local_rank
),
dp_group
=
self
.
_dp_group
,
mp_group
=
self
.
_mp_group
)
norm_groups
.
append
(
norm_group
)
# create flat gradient for the flat fp32 params
fp16_avg_grads
=
self
.
_grad_store
.
get_averaged_gradients_by_group
(
group_id
)
flat_fp16_avg_grads
=
flatten
(
fp16_avg_grads
)
dtype
=
self
.
_fp32_flat_param_groups_of_current_rank
[
group_id
].
dtype
flat_fp32_avg_grads
=
flat_fp16_avg_grads
.
to
(
dtype
)
param_shape
=
self
.
_fp32_flat_param_groups_of_current_rank
[
group_id
].
shape
assert
param_shape
==
flat_fp32_avg_grads
.
shape
,
\
f
'fp32 param and grad have different shape
{
param_shape
}
vs
{
flat_fp32_avg_grads
.
shape
}
'
single_grad_partition_groups
.
append
(
flat_fp32_avg_grads
)
device
=
self
.
_fp32_flat_param_groups_of_current_rank
[
group_id
].
device
self
.
_fp32_flat_param_groups_of_current_rank
[
group_id
].
grad
=
flat_fp32_avg_grads
.
to
(
device
)
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
self
.
_grad_store
.
_averaged_gradients
[
group_id
]
=
[]
# unscale and clip grads
global_norm
=
calculate_global_norm_from_list
(
norm_list
=
norm_groups
)
self
.
_unscale_and_clip_grads
(
single_grad_partition_groups
,
global_norm
)
# update the parameters
self
.
_optimizer
.
step
()
# release the fp32 grad
release_param_grad
(
self
.
_fp32_flat_param_groups_of_current_rank
.
values
())
# update fp16 partition updated by the current rank
for
group_id
in
range
(
len
(
self
.
_fp16_param_groups
)):
fp16_param
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
rank
=
self
.
_local_rank
,
group_id
=
group_id
)
fp32_param
=
self
.
_fp32_flat_param_groups_of_current_rank
[
group_id
].
to
(
fp16_param
.
device
)
fp16_param
.
data
.
copy_
(
fp32_param
)
# broadcast the updated model weights
handles
=
[]
for
group_id
in
range
(
self
.
num_param_groups
):
for
rank
in
range
(
self
.
_world_size
):
fp16_param
=
self
.
_param_store
.
get_flat_fp16_param_by_rank_group
(
rank
=
rank
,
group_id
=
group_id
)
handle
=
dist
.
broadcast
(
fp16_param
,
src
=
rank
,
group
=
self
.
_dp_group
,
async_op
=
True
)
handles
.
append
(
handle
)
for
handle
in
handles
:
handle
.
wait
()
##################
# FP16 Utilities #
##################
def
_check_overflow
(
self
):
# clear previous overflow record
self
.
_found_overflow
.
fill_
(
0.0
)
# check for overflow
for
group_id
in
range
(
len
(
self
.
_fp16_param_groups
)):
for
avg_grad
in
self
.
_grad_store
.
get_averaged_gradients_by_group
(
group_id
):
if
avg_grad
is
not
None
and
has_inf_or_nan
(
avg_grad
):
self
.
_found_overflow
.
fill_
(
1.0
)
break
# all-reduce across dp group
dist
.
all_reduce
(
self
.
_found_overflow
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
_dp_group
)
# all-reduce over model parallel group
if
self
.
_mp_group
:
dist
.
all_reduce
(
self
.
_found_overflow
,
op
=
dist
.
ReduceOp
.
MAX
,
group
=
self
.
_mp_group
)
if
self
.
_found_overflow
.
item
()
>
0
:
return
True
else
:
return
False
def
_unscale_and_clip_grads
(
self
,
grad_groups_flat
,
total_norm
):
# compute combined scale factor for this group
combined_scale
=
self
.
loss_scale
if
self
.
_clip_grad_norm
>
0.
:
# norm is in fact norm*scale
clip
=
((
total_norm
/
self
.
loss_scale
)
+
1e-6
)
/
self
.
_clip_grad_norm
if
clip
>
1
:
combined_scale
=
clip
*
self
.
loss_scale
for
grad
in
grad_groups_flat
:
grad
.
data
.
mul_
(
1.
/
combined_scale
)
############################
# Gradient Synchronization #
############################
def
sync_grad
(
self
):
if
not
self
.
_partition_grads
:
self
.
_reduce_grad_stage1
()
else
:
# TODO: support async comm in reduce
self
.
_reduce_grad_stage2
()
# update param already reduced flag
reduction_states
=
self
.
_param_store
.
get_param_reduction_states
()
for
tensor
,
state
in
reduction_states
.
items
():
reduction_states
[
tensor
]
=
False
# clear reduced grads
if
self
.
_overlap_communication
:
torch
.
cuda
.
synchronize
()
self
.
_param_store
.
clear_grads_of_previous_reduced_params
()
# accumulate gradient
avg_gradients
=
self
.
_grad_store
.
_averaged_gradients
for
group_id
in
range
(
self
.
num_param_groups
):
param_group
=
self
.
_param_store
.
get_fp16_params_by_rank_group
(
self
.
_local_rank
,
group_id
)
if
group_id
not
in
avg_gradients
:
avg_gradients
[
group_id
]
=
[]
param_idx
=
0
for
param
in
param_group
:
if
param
.
grad
is
not
None
:
if
len
(
avg_gradients
[
group_id
])
==
param_idx
:
avg_gradients
[
group_id
].
append
(
param
.
grad
)
else
:
avg_gradients
[
group_id
][
param_idx
].
add_
(
param
.
grad
)
param_idx
+=
1
# the gradients needed are stored in the avg_gradients buffer
# thus, can clear this
self
.
zero_grad
()
def
_reduce_grad_stage1
(
self
):
# if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients
if
not
self
.
_overlap_communication
:
for
group_id
in
range
(
len
(
self
.
_fp16_param_groups
)):
param_group
=
self
.
_fp16_param_groups
[
group_id
]
for
param
in
param_group
:
if
param
.
grad
is
not
None
:
self
.
_reduce_and_remove_grads_by_bucket
(
param
)
# we need to reduce the gradients
# left in the communication bucket
self
.
_reduce_grads_in_bucket
()
def
_reduce_grad_stage2
(
self
):
# when partition_grads is True, reduction hooks
# are attached in the __init__ function, so we
# only need to reduce the gradients
# left in the communication bucket
for
reduce_rank
in
range
(
self
.
_world_size
):
self
.
_reduce_grads_in_bucket
(
reduce_rank
)
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
4d322b79
...
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
...
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._
zero3_
utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment