Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4d322b79
Unverified
Commit
4d322b79
authored
Mar 25, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 25, 2022
Browse files
[refactor] remove old zero code (#517)
parent
6a3f9fda
Changes
28
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
29 additions
and
2515 deletions
+29
-2515
colossalai/engine/schedule/_pipeline_schedule.py
colossalai/engine/schedule/_pipeline_schedule.py
+0
-3
colossalai/zero/__init__.py
colossalai/zero/__init__.py
+1
-37
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+1
-1
colossalai/zero/shard_utils/commons.py
colossalai/zero/shard_utils/commons.py
+20
-0
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+1
-1
colossalai/zero/sharded_model/__init__.py
colossalai/zero/sharded_model/__init__.py
+1
-2
colossalai/zero/sharded_model/_utils.py
colossalai/zero/sharded_model/_utils.py
+1
-47
colossalai/zero/sharded_model/param_manager.py
colossalai/zero/sharded_model/param_manager.py
+0
-385
colossalai/zero/sharded_model/sharded_grad.py
colossalai/zero/sharded_model/sharded_grad.py
+0
-85
colossalai/zero/sharded_model/sharded_model.py
colossalai/zero/sharded_model/sharded_model.py
+0
-1104
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+2
-2
colossalai/zero/sharded_optim/__init__.py
colossalai/zero/sharded_optim/__init__.py
+1
-2
colossalai/zero/sharded_optim/bookkeeping/__init__.py
colossalai/zero/sharded_optim/bookkeeping/__init__.py
+0
-6
colossalai/zero/sharded_optim/bookkeeping/base_store.py
colossalai/zero/sharded_optim/bookkeeping/base_store.py
+0
-17
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
+0
-43
colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
+0
-66
colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
+0
-96
colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py
colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py
+0
-54
colossalai/zero/sharded_optim/sharded_optim.py
colossalai/zero/sharded_optim/sharded_optim.py
+0
-563
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+1
-1
No files found.
colossalai/engine/schedule/_pipeline_schedule.py
View file @
4d322b79
...
@@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
...
@@ -12,7 +12,6 @@ from colossalai.core import global_context as gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils
import
switch_virtual_pipeline_parallel_rank
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero
import
ShardedModel
,
ShardedOptimizer
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
._base_schedule
import
BaseSchedule
from
._base_schedule
import
BaseSchedule
...
@@ -92,8 +91,6 @@ class PipelineSchedule(BaseSchedule):
...
@@ -92,8 +91,6 @@ class PipelineSchedule(BaseSchedule):
def
pre_processing
(
self
,
engine
):
def
pre_processing
(
self
,
engine
):
# TODO: remove this after testing new zero with pipeline parallelism
# TODO: remove this after testing new zero with pipeline parallelism
if
isinstance
(
engine
.
optimizer
,
ShardedOptimizer
)
or
isinstance
(
engine
.
model
,
ShardedModel
):
raise
TypeError
(
"Pipeline schedule is currently not compatible with ZeRO"
)
model
=
engine
.
model
model
=
engine
.
model
if
isinstance
(
model
,
(
NaiveAMPModel
,
ShardedModelV2
)):
if
isinstance
(
model
,
(
NaiveAMPModel
,
ShardedModelV2
)):
self
.
dtype
=
torch
.
half
self
.
dtype
=
torch
.
half
...
...
colossalai/zero/__init__.py
View file @
4d322b79
...
@@ -2,14 +2,9 @@ from typing import Tuple
...
@@ -2,14 +2,9 @@ from typing import Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.amp.naive_amp
import
NaiveAMPModel
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_model.sharded_model_v2
import
ShardedModelV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim.sharded_optim_v2
import
ShardedOptimizerV2
from
torch.optim
import
Optimizer
from
.sharded_model
import
ShardedModel
from
.sharded_optim
import
ShardedOptimizer
def
convert_to_zero_v2
(
model
:
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
model_config
,
def
convert_to_zero_v2
(
model
:
nn
.
Module
,
optimizer
:
torch
.
optim
.
Optimizer
,
model_config
,
...
@@ -40,35 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
...
@@ -40,35 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
return
zero_model
,
zero_optimizer
return
zero_model
,
zero_optimizer
def
convert_to_zero
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
level
:
int
,
zero_config
:
dict
):
__all__
=
[
'convert_to_zerov2'
,
'ShardedModelV2'
,
'ShardedOptimizerV2'
]
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer: Your optimizer object
:type optimizer: :class:`torch.optim.Optimizer`
:param level: Optimizer level, can be 2 or 3
:type level: int
:param zero_config: Configuration for zero
:type zero_config: dict
:return: (model, optimizer)
:rtype: Tuple
"""
assert
1
<=
level
<=
3
,
'Only ZERO Optimizer Level 1-3 are provided'
if
level
in
[
1
,
2
]:
if
level
==
2
:
if
'partition_grad'
in
zero_config
:
assert
zero_config
[
'partition_grad'
],
\
'Sharded Optimizer requires partition_grad to be True'
else
:
zero_config
[
'partiton_grad'
]
=
True
model
=
NaiveAMPModel
(
model
,
output_to_fp32
=
True
)
optimizer
=
ShardedOptimizer
(
optimizer
,
**
zero_config
)
else
:
model
=
ShardedModel
(
module
=
model
,
**
zero_config
)
return
model
,
optimizer
__all__
=
[
'convert_to_zero'
,
'ShardedModel'
,
'ShardedOptimizer'
]
colossalai/zero/init_ctx/init_context.py
View file @
4d322b79
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
...
@@ -8,7 +8,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.utils.memory_utils.memory_monitor
import
colo_cuda_memory_used
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model._
zero3_
utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp16
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
from
colossalai.logging
import
get_dist_logger
,
disable_existing_loggers
...
...
colossalai/zero/shard_utils/commons.py
0 → 100644
View file @
4d322b79
import
torch
import
torch.nn.functional
as
F
from
typing
import
Tuple
def
get_shard
(
tensor
:
torch
.
Tensor
,
rank
:
int
,
world_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks
=
list
(
torch
.
flatten
(
tensor
).
chunk
(
world_size
))
while
len
(
chunks
)
<
world_size
:
chunks
.
append
(
chunks
[
0
].
new_empty
(
0
))
# Determine number of padding elements.
num_to_pad
=
chunks
[
0
].
numel
()
-
chunks
[
rank
].
numel
()
assert
num_to_pad
>=
0
,
num_to_pad
shard
=
chunks
[
rank
].
clone
()
if
num_to_pad
>
0
:
shard
=
F
.
pad
(
shard
,
[
0
,
num_to_pad
])
return
shard
,
num_to_pad
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
4d322b79
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard
ed_model._zero3_util
s
import
get_shard
from
colossalai.zero.shard
_utils.common
s
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
...
...
colossalai/zero/sharded_model/__init__.py
View file @
4d322b79
from
.sharded_model
import
ShardedModel
from
.sharded_model_v2
import
ShardedModelV2
from
.sharded_model_v2
import
ShardedModelV2
__all__
=
[
'ShardedModel'
,
'ShardedModelV2'
]
__all__
=
[
'ShardedModelV2'
]
\ No newline at end of file
\ No newline at end of file
colossalai/zero/sharded_model/_
zero3_
utils.py
→
colossalai/zero/sharded_model/_utils.py
View file @
4d322b79
from
collections
import
OrderedDict
from
typing
import
Any
,
Callable
,
List
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -12,23 +11,6 @@ def get_gradient_predivide_factor(world_size: int) -> float:
...
@@ -12,23 +11,6 @@ def get_gradient_predivide_factor(world_size: int) -> float:
return
float
(
factor
)
return
float
(
factor
)
def
get_shard
(
tensor
:
torch
.
Tensor
,
rank
:
int
,
world_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
int
]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks
=
list
(
torch
.
flatten
(
tensor
).
chunk
(
world_size
))
while
len
(
chunks
)
<
world_size
:
chunks
.
append
(
chunks
[
0
].
new_empty
(
0
))
# Determine number of padding elements.
num_to_pad
=
chunks
[
0
].
numel
()
-
chunks
[
rank
].
numel
()
assert
num_to_pad
>=
0
,
num_to_pad
shard
=
chunks
[
rank
].
clone
()
if
num_to_pad
>
0
:
shard
=
F
.
pad
(
shard
,
[
0
,
num_to_pad
])
return
shard
,
num_to_pad
def
free_storage
(
data
:
torch
.
Tensor
)
->
None
:
def
free_storage
(
data
:
torch
.
Tensor
)
->
None
:
"""Free underlying storage of a Tensor."""
"""Free underlying storage of a Tensor."""
if
data
.
storage
().
size
()
>
0
:
if
data
.
storage
().
size
()
>
0
:
...
@@ -86,31 +68,3 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
...
@@ -86,31 +68,3 @@ def chunk_and_pad(tensor: torch.Tensor, num_chunks: int) -> List[torch.Tensor]:
if
len
(
chunks
)
<
num_chunks
:
if
len
(
chunks
)
<
num_chunks
:
chunks
.
extend
([
torch
.
zeros_like
(
chunks
[
0
])
for
_
in
range
(
num_chunks
-
len
(
chunks
))])
chunks
.
extend
([
torch
.
zeros_like
(
chunks
[
0
])
for
_
in
range
(
num_chunks
-
len
(
chunks
))])
return
chunks
return
chunks
def
assert_in_engine
(
cond
:
Any
,
s
:
Any
)
->
None
:
"""Used in backward context to make sure error is printed."""
if
not
cond
:
print
(
s
)
raise
AssertionError
def
replace_state_dict_prefix
(
state_dict
:
Union
[
Dict
[
str
,
torch
.
Tensor
],
"OrderedDict[str, torch.Tensor]"
],
old_prefix
:
str
,
new_prefix
:
str
)
->
None
:
"""
Replace all keys that match a given old_prefix with a new_prefix (in-place).
Usage::
state_dict = {"layer.xyz": torch.tensor(1)}
replace_state_dict_prefix(state_dict, "layer.", "module.layer.")
assert state_dict == {"module.layer.xyz": torch.tensor(1)}
"""
if
old_prefix
==
new_prefix
:
raise
ValueError
(
"old_prefix and new_prefix must be distinct"
)
for
key
in
list
(
state_dict
.
keys
()):
if
not
key
.
startswith
(
old_prefix
):
continue
new_key
=
new_prefix
+
key
[
len
(
old_prefix
):]
state_dict
[
new_key
]
=
state_dict
[
key
]
del
state_dict
[
key
]
colossalai/zero/sharded_model/param_manager.py
deleted
100644 → 0
View file @
6a3f9fda
import
os
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
alloc_storage
,
free_storage
,
get_shard
# TODO: Remove the toggle-enable_nccl_base_collectives in the future
if
os
.
getenv
(
"ENABLE_NCCL_BASE_COLLECTIVES"
,
"1"
)
==
"0"
:
enable_nccl_base_collectives
=
False
else
:
enable_nccl_base_collectives
=
True
# TODO: add flatten params
class
Zero3ParameterManager
:
def
__init__
(
self
,
module
:
nn
.
Module
,
process_group
:
Optional
[
ProcessGroup
],
mixed_precision
:
bool
=
False
,
flatten_parameters
:
bool
=
True
,
compute_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
compute_device
:
Optional
[
torch
.
device
]
=
None
,
offload_config
:
Optional
[
dict
]
=
None
)
->
None
:
"""Manage parameter shards. We manage several attributes on each Parameter instance:
``zero_is_sharded``: ``True`` if the Parameter is sharded or ``False``
if the Parameter is intentionally not sharded (in which case we
will all-reduce grads for this param).
``zero_orig_size``: the size of the original Parameter (before sharding)
``zero_shard_padding``: the padding size. All paddings are right padding.
``zero_fp32_shard``: a single shard of the parameters in full precision
(typically FP32, but this is dependent on the dtype of the model
as it's passed in by the user). This can be on CPU or GPU
depending on the value of *``offload_config``*.
``zero_fp16_shard``: This will be a single shard of the parameters in FP16, used for all-gather.
This can be in FP16 or FP32 depending on the value of *``compute_dtype``* and
if params are offloaded to CPU.
``zero_full_param_padded``: the full weight (padded to be evenly
divisible by ``world_size``), used for computation in the
forward and backward pass. This will be resized in place and
only materialized (via all-gather) as needed.
``zero_cpu_grad``: the gradient saved on CPU. It's set only when using CPU offload.
:param module: original module
:type module: nn.Module
:param process_group: typically data parallel process group, defaults to None
:type process_group: Optional[ProcessGroup], optional
:param mixed_precision: whether to use mixed precision mode, defaults to False
:type mixed_precision: bool, optional
:param flatten_parameters: whether to flatten parameters, useless now, defaults to True
:type flatten_parameters: bool, optional
:param compute_dtype: the dtype of parameters when computing, defaults to None
:type compute_dtype: Optional[torch.dtype], optional
:param compute_device: the device of parameters when computing, defaults to None
:type compute_device: Optional[torch.device], optional
:param offload_config: offload config, defaults to None
:type offload_config: Optional[dict], optional
"""
self
.
process_group
=
process_group
self
.
shard_idx
=
process_group
.
rank
()
self
.
num_shards
=
process_group
.
size
()
self
.
mixed_precision
=
mixed_precision
self
.
compute_dtype
=
compute_dtype
self
.
compute_device
=
compute_device
self
.
offload_config
=
offload_config
self
.
_cpu_offload
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
self
.
params
:
List
[
Parameter
]
=
[]
for
param
in
module
.
parameters
():
if
not
hasattr
(
param
,
'zero_is_sharded'
):
self
.
params
.
append
(
param
)
self
.
_has_params
=
len
(
self
.
params
)
>
0
self
.
_has_sharded_params
=
False
# Flag to indicate if the full params are gathered.
self
.
has_full_params
:
bool
=
False
self
.
_shard_params
()
# Maybe no need, reserve to prevent bugs
# self.delete_fp32_shards()
self
.
_streams
:
Dict
[
str
,
torch
.
cuda
.
Stream
]
=
{}
def
_shard_params
(
self
)
->
None
:
for
p
in
self
.
params
:
assert
not
hasattr
(
p
,
"zero_is_sharded"
)
assert
p
.
is_floating_point
()
if
self
.
mixed_precision
:
assert
p
.
dtype
==
torch
.
float32
# If world_size is 1, then we all-reduce grads instead of sharding.
p
.
zero_is_sharded
=
self
.
num_shards
>
1
p
.
zero_orig_size
=
p
.
data
.
size
()
if
not
p
.
zero_is_sharded
:
p
.
zero_shard_padding
=
0
continue
# Replace p.data with the relevant shard.
orig_data
=
p
.
data
p
.
data
,
p
.
zero_shard_padding
=
get_shard
(
p
.
data
,
self
.
shard_idx
,
self
.
num_shards
)
free_storage
(
orig_data
)
@
torch
.
no_grad
()
def
reset_param_attr
(
self
,
p
:
Parameter
,
training
:
bool
)
->
None
:
"""This should be called by ``ZeroRedundancyLevel3Model._lazy_init()``
"""
assert
hasattr
(
p
,
'zero_is_sharded'
)
and
hasattr
(
p
,
'zero_orig_size'
)
if
hasattr
(
p
,
'zero_fp32_shard'
):
return
# A single shard of the parameters in full precision.
p
.
zero_fp32_shard
=
p
.
data
if
self
.
mixed_precision
:
assert
p
.
zero_fp32_shard
.
dtype
==
torch
.
float32
if
self
.
_cpu_offload
:
assert
p
.
zero_fp32_shard
.
device
==
torch
.
device
(
'cpu'
)
# If we plan to keep the FP32 parameters on CPU, then pinning
# memory allows us to later use non-blocking transfers when moving
# the FP32 param shard to compute_device.
p
.
zero_fp32_shard
=
p
.
zero_fp32_shard
.
pin_memory
()
p
.
data
=
p
.
zero_fp32_shard
if
self
.
mixed_precision
or
self
.
_cpu_offload
:
# In mixed precision mode, we maintain a reduced precision
# (typically FP16) parameter shard on compute_device for performing
# the computation in the forward/backward pass. We resize the
# storage to size 0 at init (here) and re-materialize (by copying
# from _fp32_shard) as needed. If offloading params to CPU, the
# dtype of the fp16 shard will depend on the *`compute_dtype`*.
p
.
zero_fp16_shard
=
torch
.
zeros_like
(
p
.
zero_fp32_shard
,
device
=
self
.
compute_device
,
dtype
=
self
.
compute_dtype
)
free_storage
(
p
.
zero_fp16_shard
)
if
self
.
mixed_precision
:
assert
p
.
zero_fp32_shard
.
dtype
==
torch
.
float32
if
not
self
.
mixed_precision
and
not
self
.
_cpu_offload
:
# use _fp32_shard if you are not in using mixed precision or
# offloading params and grads to CPU.
p
.
zero_fp16_shard
=
None
# We also maintain a full-sized parameter of type self.compute_dtype
# (FP16 for mixed_precision or FP32 otherwise). We resize the
# storage to size 0 at init (here) and only materialize as needed. The
# storage may contain padding elements so that it is evenly divisible by
# world_size, although these padding elements will be removed before the
# relevant computation.
if
p
.
zero_is_sharded
:
p
.
zero_full_param_padded
=
torch
.
zeros
(
p
.
data
.
numel
()
*
self
.
num_shards
,
device
=
self
.
compute_device
,
dtype
=
self
.
compute_dtype
)
free_storage
(
p
.
zero_full_param_padded
)
if
self
.
_cpu_offload
and
training
:
p
.
zero_cpu_grad
=
torch
.
zeros_like
(
p
.
data
,
device
=
'cpu'
).
pin_memory
()
def
setup_streams
(
self
,
streams
):
self
.
_streams
=
streams
@
torch
.
no_grad
()
def
rebuild_full_params
(
self
,
force_full_precision
:
bool
=
False
)
->
Optional
[
List
[
Tuple
[
torch
.
Tensor
,
bool
]]]:
"""
Gather all shards of params.
Note, this is idempotent if full params are already gathered. Callers
assume the idempotency. So please keep it that way.
Args:
force_full_precision (bool, Optional): by default params will be gathered
in ``compute_dtype`` (e.g., FP16), unless *force_full_precision* is
``True``, in which case they will be gathered in full precision
(e.g., FP32), possibly in fresh storage. The parameter that's being
rebuilt will end up in full precision as well.
Returns:
A list of tuples, where the first element is the full-sized param
and the second element is a bool indicating if it's safe for the
caller to free the full-sized param. This will be ``None`` if
``force_full_precision=False`` and the full params are already gathered.
"""
# Store tensor and free flag
output_tensors
:
List
[
Tuple
[
torch
.
Tensor
,
bool
]]
=
[]
def
update_p_data
(
custom_output_tensor
:
Optional
[
torch
.
Tensor
]
=
None
)
->
None
:
"""
Helper function to update p.data pointer.
Args:
custom_output_tensor (torch.Tensor, Optional): if not None, this
tensor contains the data we just gathered.
"""
if
custom_output_tensor
is
not
None
:
assert
p
.
zero_is_sharded
p
.
data
=
custom_output_tensor
output_tensors
.
append
((
p
.
data
,
True
))
elif
not
p
.
zero_is_sharded
:
if
(
self
.
mixed_precision
or
self
.
_cpu_offload
)
and
not
force_full_precision
:
assert
p
.
zero_fp16_shard
is
not
None
p
.
data
=
p
.
zero_fp16_shard
output_tensors
.
append
((
p
.
data
,
True
))
else
:
# Here p.data == p._fp32_shard, so it's not safe to free.
output_tensors
.
append
((
p
.
data
,
False
))
else
:
p
.
data
=
p
.
zero_full_param_padded
output_tensors
.
append
((
p
.
data
,
True
))
# Trim any padding and reshape to match original size.
p
.
data
=
p
.
data
[:
p
.
zero_orig_size
.
numel
()].
view
(
p
.
zero_orig_size
)
if
self
.
_has_sharded_params
:
# self.has_full_params flag can be out of sync if a shared param is
# sharded by another ZeroRedundancyLevel3Model instance. An example is that in eval case
# with reshard_after_forward=False but the sharing instance has
# reshard_after_forward=True. Then, on the second forward, the
# other instance can shard the shared param and but this instance
# can mistakenly think the full param is already gathered from the
# has_full_params flag.
#
# Therefore, we update the flag accordingly here.
self
.
has_full_params
=
not
any
(
p
.
zero_full_param_padded
.
storage
().
size
()
==
0
for
p
in
self
.
params
)
# Early exit if we already have full params and don't need full precision.
if
self
.
has_full_params
and
not
force_full_precision
:
for
p
in
self
.
params
:
update_p_data
()
return
output_tensors
self
.
has_full_params
=
True
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"all_gather"
]):
if
(
self
.
mixed_precision
or
self
.
_cpu_offload
)
and
not
force_full_precision
:
self
.
use_fp16_shards
()
if
self
.
_cpu_offload
and
force_full_precision
:
# If the compute_dtype and storage dtype are the same,
# use pinned memory. Otherwise move p.data to the compute
# device.
if
self
.
params
[
0
].
dtype
==
self
.
compute_dtype
:
self
.
use_fp16_shards
()
else
:
for
p
in
self
.
params
:
p
.
data
=
p
.
data
.
to
(
self
.
compute_device
)
for
p
in
self
.
params
:
if
not
p
.
zero_is_sharded
:
# e.g., when world_size == 1
update_p_data
()
else
:
# Skip if already built. Only shared param can be rebuilt multiple times.
# A corner case is p.zero_orig_size = (1,), which means the shape equality is
# not a perfect check. But we assume we don't share a param with shape (1,).
# if p.data.shape == p.zero_orig_size and hasattr(p, "zero_is_shared") and p.zero_is_shared:
# continue
# If self._cpu_offload and force_full_precision, we need to cast
# the FP32 CPU param to CUDA for the all-gather.
p_data
=
p
.
data
.
to
(
p
.
zero_full_param_padded
.
device
,
non_blocking
=
True
)
p_size
=
p
.
zero_full_param_padded
.
size
()
assert
p_size
.
numel
()
%
self
.
num_shards
==
0
if
self
.
mixed_precision
and
force_full_precision
:
# Allocate fresh tensor in full precision since we are in
# mixed precision and full precision rebuild is asked.
output_tensor
=
p_data
.
new_zeros
(
p_size
)
else
:
if
p
.
zero_full_param_padded
.
storage
().
size
()
!=
p_size
.
numel
():
# Allocate based on full size from all shards.
alloc_storage
(
p
.
zero_full_param_padded
,
size
=
p_size
)
output_tensor
=
p
.
zero_full_param_padded
# Fill output_tensor with (p.data for each shard in self.world_size)
if
hasattr
(
dist
,
"_all_gather_base"
)
and
enable_nccl_base_collectives
:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist
.
_all_gather_base
(
output_tensor
,
p_data
,
group
=
self
.
process_group
)
else
:
chunks
=
list
(
output_tensor
.
chunk
(
self
.
num_shards
))
dist
.
all_gather
(
chunks
,
p_data
,
group
=
self
.
process_group
)
# Set p.data = output_tensor (with padding trimmed)
update_p_data
(
output_tensor
)
if
(
self
.
mixed_precision
or
self
.
_cpu_offload
)
and
not
force_full_precision
:
self
.
free_fp16_shards
([
p
])
if
self
.
_cpu_offload
and
(
self
.
params
[
0
].
dtype
==
self
.
compute_dtype
):
self
.
free_fp16_shards
([
p
])
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"all_gather"
])
return
output_tensors
@
torch
.
no_grad
()
def
use_full_params
(
self
)
->
None
:
"""
Switch p.data pointers to use the full params.
Note: this assumes full params are already gathered.
Note: this might be called after full_params is already in used. So please
make sure it is idempotent in that case.
"""
assert
self
.
has_full_params
for
p
in
self
.
params
:
if
not
p
.
zero_is_sharded
:
if
self
.
mixed_precision
or
self
.
_cpu_offload
:
assert
p
.
zero_fp16_shard
is
not
None
assert
p
.
zero_fp16_shard
.
storage
().
size
()
!=
0
p
.
data
=
p
.
zero_fp16_shard
else
:
assert
p
.
zero_full_param_padded
.
storage
().
size
()
!=
0
,
f
"
{
p
.
zero_orig_size
}
{
id
(
self
)
}
"
p
.
data
=
p
.
zero_full_param_padded
[:
p
.
zero_orig_size
.
numel
()].
view
(
p
.
zero_orig_size
)
@
torch
.
no_grad
()
def
use_fp16_shards
(
self
,
params
:
Optional
[
List
[
Parameter
]]
=
None
)
->
None
:
"""Cast FP32 param shard to FP16 for a list of params."""
if
params
is
None
:
params
=
self
.
params
with
torch
.
cuda
.
stream
(
self
.
_streams
[
"fp32_to_fp16"
]):
for
p
in
params
:
assert
p
.
zero_fp16_shard
is
not
None
alloc_storage
(
p
.
zero_fp16_shard
,
size
=
p
.
zero_fp32_shard
.
size
())
p
.
zero_fp16_shard
.
copy_
(
# If _cpu_offload is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p
.
zero_fp32_shard
.
to
(
p
.
zero_fp16_shard
.
device
,
non_blocking
=
True
)
)
p
.
data
=
p
.
zero_fp16_shard
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_streams
[
"fp32_to_fp16"
])
@
torch
.
no_grad
()
def
use_fp32_shards
(
self
,
params
:
Optional
[
List
[
Parameter
]]
=
None
)
->
None
:
"""Use FP32 shard for a list of params."""
if
params
is
None
:
params
=
self
.
params
for
p
in
params
:
p
.
data
=
p
.
zero_fp32_shard
@
torch
.
no_grad
()
def
free_full_params
(
self
,
params
:
Optional
[
List
[
Parameter
]]
=
None
)
->
None
:
"""Free up storage for full parameters."""
if
params
is
None
:
params
=
self
.
params
self
.
has_full_params
=
False
current_stream
=
torch
.
cuda
.
current_stream
()
for
p
in
params
:
if
not
p
.
zero_is_sharded
:
# e.g., world_size == 1
if
self
.
mixed_precision
or
self
.
_cpu_offload
:
self
.
free_fp16_shards
([
p
])
continue
# Don't let PyTorch reuse this memory until all work in the current
# stream is complete.
p
.
zero_full_param_padded
.
record_stream
(
current_stream
)
# There may be external references to the Tensor Storage that we
# can't modify, such as references that are created by
# ctx.save_for_backward in the forward pass. Thus when we
# unshard parameters, we should reuse the original Tensor
# Storage object and unshard it in-place. For now, just resize
# the Storage to 0 to save memory.
free_storage
(
p
.
zero_full_param_padded
)
@
torch
.
no_grad
()
def
free_fp16_shards
(
self
,
params
:
Optional
[
List
[
Parameter
]]
=
None
)
->
None
:
"""Free storage for FP16 shards for a list of params."""
if
params
is
None
:
params
=
self
.
params
current_stream
=
torch
.
cuda
.
current_stream
()
for
p
in
params
:
if
p
.
zero_fp16_shard
is
not
None
:
# zero_fp16_shard is allocated in "fp32_to_fp16" stream, so we can't
# free it until the work in the current stream completes.
p
.
zero_fp16_shard
.
record_stream
(
current_stream
)
free_storage
(
p
.
zero_fp16_shard
)
def
delete_fp32_shards
(
self
)
->
None
:
for
p
in
self
.
params
:
if
hasattr
(
p
,
'zero_fp32_shard'
):
del
p
.
zero_fp32_shard
# reset _init_param_attr
colossalai/zero/sharded_model/sharded_grad.py
deleted
100644 → 0
View file @
6a3f9fda
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
torch.nn.parameter
import
Parameter
class
ShardedGradient
:
def
__init__
(
self
,
param
:
Parameter
,
sharded_module
:
nn
.
Module
,
offload_config
:
Optional
[
dict
]
=
None
)
->
None
:
assert
hasattr
(
param
,
'ca_attr'
)
and
param
.
ca_attr
.
is_sharded
,
'ShardedGradient can only be initialized with sharded parameter'
self
.
param
=
param
self
.
sharded_module
=
sharded_module
self
.
offload_config
=
offload_config
self
.
_cpu_offload
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
# _gpu_grad is either sharded or not
# all saved grads are fp32
self
.
_gpu_grad
:
Optional
[
torch
.
Tensor
]
=
None
self
.
_cpu_grad
:
Optional
[
torch
.
Tensor
]
=
None
if
self
.
_cpu_offload
:
# this buffer will be held and reused every iteration
self
.
_cpu_grad
=
torch
.
zeros
(
param
.
ca_attr
.
payload
(
'cpu'
),
dtype
=
torch
.
float
).
pin_memory
()
@
torch
.
no_grad
()
def
setup
(
self
)
->
None
:
"""This function will be called pre-backward. Save the local accumulated gradient to _gpu_grad.
When no_sync() is enable (_require_backward_grad_sync=False), the grad is accumulated locally in param.grad
:raises AssertionError: Raise if grad shape is wrong
"""
if
self
.
sharded_module
.
_require_backward_grad_sync
and
self
.
param
.
grad
is
not
None
:
if
self
.
param
.
grad
.
device
!=
self
.
param
.
data
.
device
:
# TODO: offload?
raise
RuntimeError
(
'grad and param are on different device, grad {self.param.grad.device} vs. param {self.param.data.device}'
)
else
:
self
.
_gpu_grad
=
self
.
param
.
grad
.
data
self
.
param
.
grad
=
None
def
reduce_scatter_callback
(
self
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
"""This function will be called in post-backward hook, so we cannot modify param.grad directly
:param reduced_grad: the reduced grad
:type reduced_grad: torch.Tensor
"""
# Make sure we store fp32 grad
if
torch
.
is_floating_point
(
reduced_grad
)
and
reduced_grad
.
dtype
!=
torch
.
float
:
reduced_grad
.
data
=
reduced_grad
.
data
.
to
(
torch
.
float
)
if
self
.
_gpu_grad
is
None
:
self
.
_gpu_grad
=
reduced_grad
.
data
else
:
self
.
_gpu_grad
+=
reduced_grad
.
data
# Optionally move gradients to CPU, typically used if one is running the optimizer on the CPU. Once the full
# backwards pass completes, we will set `.grad` to the CPU copy.
if
self
.
_cpu_offload
:
self
.
_cpu_grad
.
copy_
(
self
.
_gpu_grad
.
data
,
non_blocking
=
True
)
# Don't let this memory get reused until after the transfer.
self
.
_gpu_grad
.
data
.
record_stream
(
torch
.
cuda
.
current_stream
())
@
torch
.
no_grad
()
def
write_back
(
self
)
->
None
:
"""This function will be called in final backward hook
"""
if
self
.
_cpu_grad
is
not
None
:
assert
self
.
param
.
device
==
torch
.
device
(
'cpu'
),
f
'Incorrect param device, expected CPU, got
{
self
.
param
.
device
}
'
self
.
param
.
grad
.
data
=
self
.
_cpu_grad
elif
self
.
_gpu_grad
is
not
None
:
assert
self
.
param
.
device
==
self
.
_gpu_grad
.
device
,
f
'Incorrect _gpu_grad device, param on
{
self
.
param
.
device
}
but _gpu_grad on
{
self
.
_gpu_grad
.
device
}
'
self
.
param
.
grad
.
data
=
self
.
_gpu_grad
else
:
raise
RuntimeError
(
'No grad to write back'
)
# If using CPU offload, _cpu_grad will store the CPU tensor of _gpu_grad
# They should be released here
self
.
_gpu_grad
=
None
colossalai/zero/sharded_model/sharded_model.py
deleted
100644 → 0
View file @
6a3f9fda
This diff is collapsed.
Click to expand it.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
4d322b79
...
@@ -18,8 +18,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
...
@@ -18,8 +18,8 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
._
zero3_
utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
get_gradient_predivide_factor
)
class
ShardedModelV2
(
nn
.
Module
):
class
ShardedModelV2
(
nn
.
Module
):
...
...
colossalai/zero/sharded_optim/__init__.py
View file @
4d322b79
from
.sharded_optim
import
ShardedOptimizer
from
.sharded_optim_v2
import
ShardedOptimizerV2
from
.sharded_optim_v2
import
ShardedOptimizerV2
__all__
=
[
'ShardedOptimizer'
,
'ShardedOptimizerV2'
]
__all__
=
[
'ShardedOptimizerV2'
]
colossalai/zero/sharded_optim/bookkeeping/__init__.py
deleted
100644 → 0
View file @
6a3f9fda
from
.gradient_store
import
GradientStore
from
.parameter_store
import
ParameterStore
from
.bucket_store
import
BucketStore
from
.tensor_bucket
import
TensorBucket
__all__
=
[
'GradientStore'
,
'ParameterStore'
,
'BucketStore'
,
'TensorBucket'
]
\ No newline at end of file
colossalai/zero/sharded_optim/bookkeeping/base_store.py
deleted
100644 → 0
View file @
6a3f9fda
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
class
BaseStore
:
def
__init__
(
self
,
dp_parallel_mode
=
ParallelMode
.
DATA
):
self
.
_world_size
=
gpc
.
get_world_size
(
dp_parallel_mode
)
self
.
_local_rank
=
gpc
.
get_local_rank
(
dp_parallel_mode
)
@
property
def
world_size
(
self
):
return
self
.
_world_size
@
property
def
local_rank
(
self
):
return
self
.
_local_rank
colossalai/zero/sharded_optim/bookkeeping/bucket_store.py
deleted
100644 → 0
View file @
6a3f9fda
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
.base_store
import
BaseStore
class
BucketStore
(
BaseStore
):
def
__init__
(
self
,
dp_parallel_mode
):
super
().
__init__
(
dp_parallel_mode
)
self
.
_grads
=
dict
()
self
.
_params
=
dict
()
self
.
_num_elements_in_bucket
=
dict
()
self
.
reset
()
def
num_elements_in_bucket
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_num_elements_in_bucket
[
reduce_rank
]
def
add_num_elements_in_bucket
(
self
,
num_elements
,
reduce_rank
:
int
=
None
):
self
.
_num_elements_in_bucket
[
reduce_rank
]
+=
num_elements
def
add_grad
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
self
.
_grads
[
reduce_rank
].
append
(
tensor
)
def
add_param
(
self
,
tensor
,
reduce_rank
:
int
=
None
):
self
.
_params
[
reduce_rank
].
append
(
tensor
)
def
reset
(
self
):
keys
=
[
None
]
+
list
(
range
(
self
.
_world_size
))
self
.
_grads
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_params
=
{
rank
:
[]
for
rank
in
keys
}
self
.
_num_elements_in_bucket
=
{
rank
:
0
for
rank
in
keys
}
def
reset_by_rank
(
self
,
reduce_rank
=
None
):
self
.
_grads
[
reduce_rank
]
=
[]
self
.
_params
[
reduce_rank
]
=
[]
self
.
_num_elements_in_bucket
[
reduce_rank
]
=
0
def
get_grad
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_grads
[
reduce_rank
]
def
get_param
(
self
,
reduce_rank
:
int
=
None
):
return
self
.
_params
[
reduce_rank
]
colossalai/zero/sharded_optim/bookkeeping/gradient_store.py
deleted
100644 → 0
View file @
6a3f9fda
from
typing
import
List
from
torch
import
Tensor
from
.base_store
import
BaseStore
class
GradientStore
(
BaseStore
):
def
__init__
(
self
,
*
args
):
super
().
__init__
(
*
args
)
# bookkeeping data structures
self
.
_averaged_gradients
=
dict
()
# for backward reduction hooks
self
.
_grad_acc_objs
=
[]
def
add_accumulate_grad_object
(
self
,
obj
):
"""
Keep :class:`AccumulateGrad` objects. If these objects are not kept, reduction hooks may not
be attached successfully.
:param obj: An object of :class:`AccumulateGrad` class
:type obj: :class:`AccumulateGrad`
"""
self
.
_grad_acc_objs
.
append
(
obj
)
def
get_averaged_gradients_by_group
(
self
,
group_id
:
int
)
->
List
[
Tensor
]:
"""
Return average gradients of a parameter group
:param group_id: The index of parameter group
:type group_id: int
:return: Return the list of averaged gradients of a parameter group. Each element is a gradient, not a parameter.
:rtype: List[torch.Tensor]
"""
return
self
.
_averaged_gradients
[
group_id
]
def
add_average_gradient_by_group
(
self
,
group_id
:
int
,
tensor
:
Tensor
)
->
None
:
"""
Append an average gradient to the list of averaged gradients of a parameter group
:param group_id: The index of a parameter group
:param tensor: A :class:`torch.Tensor` object
:type group_id: int
:type tensor: torch.Tensor
"""
if
group_id
in
self
.
_averaged_gradients
:
self
.
_averaged_gradients
[
group_id
].
append
(
tensor
)
else
:
self
.
_averaged_gradients
[
group_id
]
=
[
tensor
]
def
reset_average_gradients_by_group
(
self
,
group_id
:
int
)
->
None
:
"""
Reset the bookkeeping data structure for averaged gradients to an empty list
:param group_id: The index of a parameter group
:type group_id: int
"""
self
.
_averaged_gradients
[
group_id
]
=
[]
colossalai/zero/sharded_optim/bookkeeping/parameter_store.py
deleted
100644 → 0
View file @
6a3f9fda
from
.base_store
import
BaseStore
from
torch
import
Tensor
from
typing
import
List
class
ParameterStore
(
BaseStore
):
def
__init__
(
self
,
dp_paralle_mode
):
super
().
__init__
(
dp_paralle_mode
)
# param partitioning data structures
self
.
_fp16_param_to_rank
=
dict
()
self
.
_rank_groupid_to_fp16_param_list
=
dict
()
self
.
_rank_group_id_to_flat_fp16_param
=
dict
()
# param reduction data structures
self
.
_is_param_reduced
=
dict
()
self
.
_reduced_param
=
[]
def
set_param_to_rank
(
self
,
tensor
:
Tensor
,
rank
:
int
)
->
None
:
"""
Set the mapping between parameter to rank, each parameter should be owned by a rank.
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:param rank: The rank of which the process is responsible for updating the parameter
:type rank: int
"""
self
.
_fp16_param_to_rank
[
tensor
]
=
rank
def
get_param_rank
(
self
,
tensor
:
Tensor
)
->
int
:
"""
Gives the rank which the parameter belongs to
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
"""
return
self
.
_fp16_param_to_rank
[
tensor
]
def
belongs_to_current_rank
(
self
,
tensor
)
->
bool
:
"""
Check whether a parameter is supposed to be updated by the process of the current rank
:param tensor: A :class:`torch.Tensor` object
:type tensor: torch.Tensor
:return: True if the parameter should be updated by the current rank. Otherwise false.
:rtype: bool
"""
tensor_rank
=
self
.
_fp16_param_to_rank
[
tensor
]
return
tensor_rank
==
self
.
_local_rank
def
add_fp16_param_list_by_rank_group
(
self
,
rank
,
group_id
,
tensor_list
)
->
None
:
if
rank
not
in
self
.
_rank_groupid_to_fp16_param_list
:
self
.
_rank_groupid_to_fp16_param_list
[
rank
]
=
dict
()
if
group_id
not
in
self
.
_rank_groupid_to_fp16_param_list
[
rank
]:
self
.
_rank_groupid_to_fp16_param_list
[
rank
][
group_id
]
=
[]
self
.
_rank_groupid_to_fp16_param_list
[
rank
][
group_id
].
extend
(
tensor_list
)
def
get_fp16_params_by_rank_group
(
self
,
rank
,
group_id
)
->
List
[
Tensor
]:
return
self
.
_rank_groupid_to_fp16_param_list
[
rank
][
group_id
]
def
add_flat_fp16_param_by_rank_group
(
self
,
rank
,
group_id
,
tensor
)
->
None
:
if
rank
not
in
self
.
_rank_group_id_to_flat_fp16_param
:
self
.
_rank_group_id_to_flat_fp16_param
[
rank
]
=
dict
()
self
.
_rank_group_id_to_flat_fp16_param
[
rank
][
group_id
]
=
tensor
def
get_flat_fp16_param_by_rank_group
(
self
,
rank
,
group_id
)
->
Tensor
:
return
self
.
_rank_group_id_to_flat_fp16_param
[
rank
][
group_id
]
def
is_param_reduced
(
self
,
tensor
):
return
self
.
_is_param_reduced
[
tensor
]
def
set_param_reduction_state
(
self
,
tensor
,
state
):
self
.
_is_param_reduced
[
tensor
]
=
state
def
get_param_reduction_states
(
self
):
return
self
.
_is_param_reduced
def
reset_previous_reduced_params
(
self
):
self
.
_reduced_param
=
[]
def
add_previous_reduced_param
(
self
,
tensor
):
self
.
_reduced_param
.
append
(
tensor
)
def
clear_grads_of_previous_reduced_params
(
self
):
if
len
(
self
.
_reduced_param
)
>
0
:
for
param
in
self
.
_reduced_param
:
param
.
grad
=
None
self
.
reset_previous_reduced_params
()
colossalai/zero/sharded_optim/bookkeeping/tensor_bucket.py
deleted
100644 → 0
View file @
6a3f9fda
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
class
TensorBucket
:
def
__init__
(
self
,
size
):
self
.
_max_size
=
size
self
.
_current_size
=
0
self
.
_bucket
=
[]
@
property
def
max_size
(
self
):
return
self
.
_max_size
@
property
def
current_size
(
self
):
return
self
.
_current_size
def
is_full_or_oversized
(
self
):
return
self
.
_current_size
>=
self
.
_max_size
def
is_empty
(
self
):
return
len
(
self
.
_bucket
)
==
0
def
add_to_bucket
(
self
,
tensor
,
allow_oversize
=
False
):
tensor_size
=
tensor
.
numel
()
if
not
allow_oversize
and
self
.
will_exceed_max_size
(
tensor_size
):
msg
=
f
"The param bucket max size
{
self
.
_max_size
}
is exceeded"
\
+
f
"by tensor (size
{
tensor_size
}
)"
raise
RuntimeError
(
msg
)
self
.
_bucket
.
append
(
tensor
)
self
.
_current_size
+=
tensor_size
def
will_exceed_max_size
(
self
,
tensor_size
):
expected_size
=
self
.
_current_size
+
tensor_size
return
expected_size
>
self
.
_max_size
def
get_bucket
(
self
):
return
self
.
_bucket
def
empty
(
self
):
self
.
_bucket
=
[]
self
.
_size
=
0
def
flatten
(
self
):
return
_flatten_dense_tensors
(
self
.
_bucket
)
def
unflatten_and_copy
(
self
,
flat_tensor
):
unflattened_tensor_list
=
_unflatten_dense_tensors
(
flat_tensor
,
self
.
_bucket
)
for
old
,
new
in
zip
(
self
.
_bucket
,
unflattened_tensor_list
):
old
.
copy_
(
new
)
colossalai/zero/sharded_optim/sharded_optim.py
deleted
100644 → 0
View file @
6a3f9fda
This diff is collapsed.
Click to expand it.
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
4d322b79
...
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
...
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._
zero3_
utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment