Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4d90a7b5
Unverified
Commit
4d90a7b5
authored
Apr 11, 2022
by
Jiarui Fang
Committed by
GitHub
Apr 11, 2022
Browse files
[refactor] zero directory (#724)
parent
20ab1f55
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
196 additions
and
223 deletions
+196
-223
colossalai/amp/__init__.py
colossalai/amp/__init__.py
+2
-0
colossalai/engine/ophooks/__init__.py
colossalai/engine/ophooks/__init__.py
+2
-118
colossalai/engine/ophooks/_base_ophook.py
colossalai/engine/ophooks/_base_ophook.py
+0
-30
colossalai/engine/ophooks/utils.py
colossalai/engine/ophooks/utils.py
+142
-0
colossalai/zero/shard_utils/__init__.py
colossalai/zero/shard_utils/__init__.py
+1
-2
colossalai/zero/shard_utils/tensor_shard_strategy.py
colossalai/zero/shard_utils/tensor_shard_strategy.py
+1
-1
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+3
-4
colossalai/zero/sharded_optim/_utils.py
colossalai/zero/sharded_optim/_utils.py
+15
-50
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+2
-2
colossalai/zero/sharded_param/__init__.py
colossalai/zero/sharded_param/__init__.py
+8
-1
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+1
-1
colossalai/zero/sharded_param/tensor_utils.py
colossalai/zero/sharded_param/tensor_utils.py
+0
-0
colossalai/zero/utils/__init__.py
colossalai/zero/utils/__init__.py
+4
-0
colossalai/zero/utils/stateful_tensor_mgr.py
colossalai/zero/utils/stateful_tensor_mgr.py
+1
-1
colossalai/zero/utils/zero_hook.py
colossalai/zero/utils/zero_hook.py
+5
-4
tests/test_utils/test_commons.py
tests/test_utils/test_commons.py
+1
-1
tests/test_utils/test_tensor_move.py
tests/test_utils/test_tensor_move.py
+4
-3
tests/test_zero_data_parallel/test_found_inf.py
tests/test_zero_data_parallel/test_found_inf.py
+3
-4
tests/test_zero_data_parallel/test_stateful_tensor_mgr.py
tests/test_zero_data_parallel/test_stateful_tensor_mgr.py
+1
-1
No files found.
colossalai/amp/__init__.py
View file @
4d90a7b5
...
...
@@ -10,6 +10,8 @@ from .torch_amp import convert_to_torch_amp
from
.apex_amp
import
convert_to_apex_amp
from
.naive_amp
import
convert_to_naive_amp
__all__
=
[
'convert_to_amp'
,
'convert_to_naive_amp'
,
'convert_to_apex_amp'
,
'convert_to_torch_amp'
,
'AMP_TYPE'
]
def
convert_to_amp
(
model
:
nn
.
Module
,
optimizer
:
Optimizer
,
criterion
:
_Loss
,
mode
:
AMP_TYPE
,
amp_config
:
Config
=
None
):
"""A helper function to wrap training components with Torch AMP modules.
...
...
colossalai/engine/ophooks/__init__.py
View file @
4d90a7b5
from
typing
import
List
,
Callable
,
Optional
from
.utils
import
register_ophooks_recursively
,
BaseOpHook
import
torch
from
._base_ophook
import
BaseOpHook
from
._memtracer_ophook
import
MemTracerOpHook
from
._shard_grad_ophook
import
ShardGradHook
from
._shard_param_ophook
import
ShardParamHook
all
=
[
"BaseOpHook"
,
"MemTracerOpHook"
,
"register_ophooks_recursively"
,
"ShardParamHook"
,
"ShardGradHook"
]
# apply torch.autograd.Function that calls a backward_function to tensors in output
def
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
outputs
):
if
type
(
outputs
)
is
tuple
:
touched_outputs
=
[]
for
output
in
outputs
:
touched_output
=
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
output
)
touched_outputs
.
append
(
touched_output
)
return
tuple
(
touched_outputs
)
elif
type
(
outputs
)
is
torch
.
Tensor
:
return
functional
.
apply
(
module
,
backward_function
,
outputs
)
else
:
return
outputs
class
PreBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
outputs
):
ctx
.
module
=
module
ctx
.
pre_backward_function
=
pre_backward_function
module
.
applied_pre_backward
=
False
outputs
=
outputs
.
detach
()
return
outputs
@
staticmethod
def
backward
(
ctx
,
*
args
):
ctx
.
pre_backward_function
(
ctx
.
module
)
return
(
None
,
None
)
+
args
class
PostBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
output
):
ctx
.
module
=
module
output
=
output
.
detach
()
ctx
.
pre_backward_function
=
pre_backward_function
return
output
@
staticmethod
def
backward
(
ctx
,
*
args
):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx
.
pre_backward_function
(
ctx
.
module
)
return
(
None
,
None
)
+
args
def
register_ophooks_recursively
(
module
:
torch
.
nn
.
Module
,
ophook_list
:
List
[
BaseOpHook
]
=
None
,
name
:
str
=
""
,
filter_fn
:
Optional
[
Callable
]
=
None
):
r
"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
# Add hooks for submodules
for
child_name
,
child
in
module
.
named_children
():
register_ophooks_recursively
(
child
,
ophook_list
,
name
+
child_name
,
filter_fn
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
# return from flitered module
if
filter_fn
is
not
None
and
filter_fn
(
module
):
return
if
ophook_list
is
not
None
:
for
hook
in
ophook_list
:
assert
(
isinstance
(
hook
,
BaseOpHook
))
def
_pre_forward_module_hook
(
submodule
,
*
args
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
pre_fwd_exec
(
submodule
,
*
args
)
def
_post_forward_module_hook
(
submodule
,
*
args
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
post_fwd_exec
(
submodule
,
*
args
)
def
_pre_backward_module_hook
(
submodule
,
inputs
,
output
):
def
_run_before_backward_function
(
submodule
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
pre_bwd_exec
(
submodule
,
inputs
,
output
)
return
_apply_to_tensors_only
(
submodule
,
PreBackwardFunction
,
_run_before_backward_function
,
output
)
def
_post_backward_module_hook
(
submodule
,
inputs
):
def
_run_after_backward_function
(
submodule
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
post_bwd_exec
(
submodule
,
inputs
)
return
_apply_to_tensors_only
(
submodule
,
PostBackwardFunction
,
_run_after_backward_function
,
inputs
)
module
.
register_forward_pre_hook
(
_pre_forward_module_hook
)
module
.
register_forward_hook
(
_post_forward_module_hook
)
module
.
register_forward_hook
(
_pre_backward_module_hook
)
module
.
register_forward_pre_hook
(
_post_backward_module_hook
)
__all__
=
[
"BaseOpHook"
,
"MemTracerOpHook"
,
"register_ophooks_recursively"
]
colossalai/engine/ophooks/_base_ophook.py
deleted
100644 → 0
View file @
20ab1f55
from
abc
import
ABC
,
abstractmethod
import
torch
class
BaseOpHook
(
ABC
):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def
__init__
(
self
):
pass
@
abstractmethod
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
@
abstractmethod
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
@
abstractmethod
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
pass
@
abstractmethod
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
pass
@
abstractmethod
def
post_iter
(
self
):
pass
colossalai/engine/ophooks/utils.py
0 → 100644
View file @
4d90a7b5
import
torch
from
typing
import
List
,
Callable
,
Optional
from
abc
import
ABC
,
abstractmethod
import
torch
class
BaseOpHook
(
ABC
):
"""This class allows users to add customized operations
before and after the execution of a PyTorch submodule"""
def
__init__
(
self
):
pass
@
abstractmethod
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
@
abstractmethod
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
pass
@
abstractmethod
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
pass
@
abstractmethod
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
pass
@
abstractmethod
def
post_iter
(
self
):
pass
# apply torch.autograd.Function that calls a backward_function to tensors in output
def
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
outputs
):
if
type
(
outputs
)
is
tuple
:
touched_outputs
=
[]
for
output
in
outputs
:
touched_output
=
_apply_to_tensors_only
(
module
,
functional
,
backward_function
,
output
)
touched_outputs
.
append
(
touched_output
)
return
tuple
(
touched_outputs
)
elif
type
(
outputs
)
is
torch
.
Tensor
:
return
functional
.
apply
(
module
,
backward_function
,
outputs
)
else
:
return
outputs
class
PreBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
outputs
):
ctx
.
module
=
module
ctx
.
pre_backward_function
=
pre_backward_function
module
.
applied_pre_backward
=
False
outputs
=
outputs
.
detach
()
return
outputs
@
staticmethod
def
backward
(
ctx
,
*
args
):
ctx
.
pre_backward_function
(
ctx
.
module
)
return
(
None
,
None
)
+
args
class
PostBackwardFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
module
,
pre_backward_function
,
output
):
ctx
.
module
=
module
output
=
output
.
detach
()
ctx
.
pre_backward_function
=
pre_backward_function
return
output
@
staticmethod
def
backward
(
ctx
,
*
args
):
"""
Args:
activation_grad of the next layer.
Returns:
grad of the input activation.
"""
ctx
.
pre_backward_function
(
ctx
.
module
)
return
(
None
,
None
)
+
args
def
register_ophooks_recursively
(
module
:
torch
.
nn
.
Module
,
ophook_list
:
List
[
BaseOpHook
]
=
None
,
name
:
str
=
""
,
filter_fn
:
Optional
[
Callable
]
=
None
):
r
"""Recursilvely register pre/post hooks for all submodules in the module in FWD and BWD."""
assert
isinstance
(
module
,
torch
.
nn
.
Module
)
# Add hooks for submodules
for
child_name
,
child
in
module
.
named_children
():
register_ophooks_recursively
(
child
,
ophook_list
,
name
+
child_name
,
filter_fn
)
# Early return on modules with no parameters.
if
len
(
list
(
module
.
parameters
(
recurse
=
False
)))
==
0
:
return
# return from flitered module
if
filter_fn
is
not
None
and
filter_fn
(
module
):
return
if
ophook_list
is
not
None
:
for
hook
in
ophook_list
:
assert
(
isinstance
(
hook
,
BaseOpHook
))
def
_pre_forward_module_hook
(
submodule
,
*
args
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
pre_fwd_exec
(
submodule
,
*
args
)
def
_post_forward_module_hook
(
submodule
,
*
args
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
post_fwd_exec
(
submodule
,
*
args
)
def
_pre_backward_module_hook
(
submodule
,
inputs
,
output
):
def
_run_before_backward_function
(
submodule
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
pre_bwd_exec
(
submodule
,
inputs
,
output
)
return
_apply_to_tensors_only
(
submodule
,
PreBackwardFunction
,
_run_before_backward_function
,
output
)
def
_post_backward_module_hook
(
submodule
,
inputs
):
def
_run_after_backward_function
(
submodule
):
for
hook
in
ophook_list
:
assert
isinstance
(
submodule
,
torch
.
nn
.
Module
)
hook
.
post_bwd_exec
(
submodule
,
inputs
)
return
_apply_to_tensors_only
(
submodule
,
PostBackwardFunction
,
_run_after_backward_function
,
inputs
)
module
.
register_forward_pre_hook
(
_pre_forward_module_hook
)
module
.
register_forward_hook
(
_post_forward_module_hook
)
module
.
register_forward_hook
(
_pre_backward_module_hook
)
module
.
register_forward_pre_hook
(
_post_backward_module_hook
)
colossalai/zero/shard_utils/__init__.py
View file @
4d90a7b5
from
.base_shard_strategy
import
BaseShardStrategy
from
.bucket_tensor_shard_strategy
import
BucketTensorShardStrategy
from
.tensor_shard_strategy
import
TensorShardStrategy
from
.stateful_tensor_mgr
import
StatefulTensorMgr
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
,
'BucketTensorShardStrategy'
,
'StatefulTensorMgr'
]
__all__
=
[
'BaseShardStrategy'
,
'TensorShardStrategy'
,
'BucketTensorShardStrategy'
]
colossalai/zero/shard_utils/tensor_shard_strategy.py
View file @
4d90a7b5
...
...
@@ -3,7 +3,7 @@ from typing import List, Optional
import
torch
import
torch.distributed
as
dist
from
colossalai.utils
import
get_current_device
from
colossalai.zero.shard
_utils
.tensor_utils
import
colo_model_data_tensor_move_inline
from
colossalai.zero.shard
ed_param
.tensor_utils
import
colo_model_data_tensor_move_inline
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils.commons
import
get_shard
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
4d90a7b5
...
...
@@ -8,9 +8,8 @@ import torch.nn as nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.engine.ophooks
import
register_ophooks_recursively
from
colossalai.
engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.
zero.utils
import
ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.gradient_handler.utils
import
bucket_allreduce
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils
import
get_current_device
,
disposable
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
...
...
@@ -18,12 +17,12 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard
_utils
.tensor_utils
import
colo_model_data_move_to_cpu
from
colossalai.zero.shard
ed_param
.tensor_utils
import
colo_model_data_move_to_cpu
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
colossalai.zero.
shard_
utils.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.zero.utils.stateful_tensor_mgr
import
StatefulTensorMgr
from
._utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
...
...
colossalai/zero/sharded_optim/_utils.py
View file @
4d90a7b5
...
...
@@ -8,22 +8,6 @@ from colossalai.utils import is_model_parallel_parameter
import
torch.distributed
as
dist
def
move_tensor
(
input_
,
device
):
assert
device
in
[
'cpu'
,
'gpu'
]
if
isinstance
(
input_
,
(
list
,
tuple
)):
for
tensor
in
input_
:
tensor
.
data
=
tensor
.
data
.
cpu
(
)
if
device
==
'cpu'
else
tensor
.
data
.
cuda
()
elif
torch
.
is_tensor
(
input_
):
input_
.
data
=
input_
.
data
.
cpu
(
)
if
device
==
'cpu'
else
tensor
.
data
.
cuda
()
else
:
raise
TypeError
(
f
"Expected argument 'input_' to be torch.Tensor, list or tuple, but got
{
type
(
input_
)
}
"
)
def
flatten
(
input_
):
return
_flatten_dense_tensors
(
input_
)
...
...
@@ -51,8 +35,7 @@ def shuffle_by_round_robin(tensor_list, num_partitions):
partition_to_go
=
tensor_idx
%
num_partitions
if
partition_to_go
not
in
partitions
:
partitions
[
partition_to_go
]
=
[]
partitions
[
partition_to_go
].
append
(
dict
(
tensor
=
tensor
,
index
=
tensor_idx
))
partitions
[
partition_to_go
].
append
(
dict
(
tensor
=
tensor
,
index
=
tensor_idx
))
partitions_count
=
len
(
partitions
)
new_tensor_list
=
[]
...
...
@@ -73,9 +56,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
padding
=
calculate_padding
(
num_elements
,
unit_size
=
unit_size
)
if
padding
>
0
:
pad_tensor
=
torch
.
zeros
(
padding
,
device
=
tensor_list
[
0
].
device
,
dtype
=
tensor_list
[
0
].
dtype
)
pad_tensor
=
torch
.
zeros
(
padding
,
device
=
tensor_list
[
0
].
device
,
dtype
=
tensor_list
[
0
].
dtype
)
padded_tensor_list
=
tensor_list
+
[
pad_tensor
]
else
:
padded_tensor_list
=
tensor_list
...
...
@@ -86,6 +67,7 @@ def flatten_dense_tensors_with_padding(tensor_list, unit_size):
def
is_nccl_aligned
(
tensor
):
return
tensor
.
data_ptr
()
%
4
==
0
def
get_grad_accumulate_object
(
tensor
):
"""
Return the AccumulateGrad of the input tensor
...
...
@@ -108,10 +90,7 @@ def get_grad_accumulate_object(tensor):
def
split_half_float_double
(
tensor_list
):
dtypes
=
[
"torch.cuda.HalfTensor"
,
"torch.cuda.FloatTensor"
,
"torch.cuda.DoubleTensor"
,
"torch.cuda.BFloat16Tensor"
]
dtypes
=
[
"torch.cuda.HalfTensor"
,
"torch.cuda.FloatTensor"
,
"torch.cuda.DoubleTensor"
,
"torch.cuda.BFloat16Tensor"
]
buckets
=
[]
for
i
,
dtype
in
enumerate
(
dtypes
):
bucket
=
[
t
for
t
in
tensor_list
if
t
.
type
()
==
dtype
]
...
...
@@ -120,10 +99,7 @@ def split_half_float_double(tensor_list):
return
buckets
def
reduce_tensor
(
tensor
,
dtype
,
dst_rank
=
None
,
parallel_mode
=
ParallelMode
.
DATA
):
def
reduce_tensor
(
tensor
,
dtype
,
dst_rank
=
None
,
parallel_mode
=
ParallelMode
.
DATA
):
"""
Reduce the tensor in the data parallel process group
...
...
@@ -165,6 +141,7 @@ def reduce_tensor(tensor,
tensor
.
copy_
(
tensor_to_reduce
)
return
tensor
def
has_inf_or_nan
(
tensor
):
try
:
# if tensor is half, the .float() incurs an additional deep copy, but it's necessary if
...
...
@@ -181,8 +158,7 @@ def has_inf_or_nan(tensor):
raise
return
True
else
:
if
tensor_sum
==
float
(
'inf'
)
or
tensor_sum
==
-
float
(
'inf'
)
or
tensor_sum
!=
tensor_sum
:
if
tensor_sum
==
float
(
'inf'
)
or
tensor_sum
==
-
float
(
'inf'
)
or
tensor_sum
!=
tensor_sum
:
return
True
return
False
...
...
@@ -201,11 +177,7 @@ def calculate_global_norm_from_list(norm_list):
return
math
.
sqrt
(
total_norm
)
def
compute_norm
(
gradients
,
params
,
dp_group
,
mp_group
,
norm_type
=
2
):
def
compute_norm
(
gradients
,
params
,
dp_group
,
mp_group
,
norm_type
=
2
):
"""Clips gradient norm of an iterable of parameters.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
...
...
@@ -229,14 +201,11 @@ def compute_norm(gradients,
if
norm_type
==
inf
:
total_norm
=
max
(
g
.
data
.
abs
().
max
()
for
g
in
gradients
)
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
dist
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dp_group
)
dist
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
dp_group
)
# Take max across all GPUs.
if
mp_group
is
not
None
:
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
)
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
total_norm
=
0.0
...
...
@@ -248,21 +217,17 @@ def compute_norm(gradients,
if
is_model_parallel_parameter
(
p
)
or
mp_rank
==
0
:
param_norm
=
g
.
data
.
double
().
norm
(
2
)
total_norm
+=
param_norm
.
item
()
**
2
# Sum across all model parallel GPUs.
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
dp_group
)
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
dp_group
)
if
mp_group
is
not
None
:
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
dist
.
all_reduce
(
tensor
=
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
total_norm
=
total_norm_cuda
[
0
].
item
()
**
(
1.
/
norm_type
)
if
total_norm
==
float
(
'inf'
)
or
total_norm
==
-
float
(
'inf'
)
or
total_norm
!=
total_norm
:
if
total_norm
==
float
(
'inf'
)
or
total_norm
==
-
float
(
'inf'
)
or
total_norm
!=
total_norm
:
total_norm
=
-
1
return
total_norm
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
4d90a7b5
...
...
@@ -12,8 +12,8 @@ from colossalai.logging import get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.shard
_utils
.tensor_utils
import
(
colo_model_data_tensor_move_inline
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.shard
ed_param
.tensor_utils
import
(
colo_model_data_tensor_move_inline
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
...
...
colossalai/zero/sharded_param/__init__.py
View file @
4d90a7b5
from
colossalai.zero.sharded_param.sharded_tensor
import
ShardedTensor
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param.tensor_utils
import
(
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
,
colo_model_data_move_to_cpu
,
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
,
StatefulTensor
__all__
=
[
'ShardedTensor'
,
'ShardedParamV2'
]
__all__
=
[
'ShardedTensor'
,
'ShardedParamV2'
,
'colo_model_data_tensor_move'
,
'colo_model_data_tensor_move_inline'
,
'colo_model_data_move_to_cpu'
,
'colo_model_tensor_clone'
,
'colo_tensor_mem_usage'
,
'TensorState'
,
'StatefulTensor'
]
colossalai/zero/sharded_param/sharded_param.py
View file @
4d90a7b5
import
torch
from
colossalai.zero.sharded_param
import
ShardedTensor
from
typing
import
Optional
,
Tuple
from
colossalai.zero.shard
_utils
.tensor_utils
import
colo_tensor_mem_usage
from
colossalai.zero.shard
ed_param
.tensor_utils
import
colo_tensor_mem_usage
from
.tensorful_state
import
StatefulTensor
,
TensorState
from
typing
import
List
...
...
colossalai/zero/shard
_utils
/tensor_utils.py
→
colossalai/zero/shard
ed_param
/tensor_utils.py
View file @
4d90a7b5
File moved
colossalai/zero/utils/__init__.py
0 → 100644
View file @
4d90a7b5
from
.stateful_tensor_mgr
import
StatefulTensorMgr
from
.zero_hook
import
ZeroHook
__all__
=
[
'StatefulTensorMgr'
,
'ZeroHook'
]
\ No newline at end of file
colossalai/zero/
shard_
utils/stateful_tensor_mgr.py
→
colossalai/zero/utils/stateful_tensor_mgr.py
View file @
4d90a7b5
...
...
@@ -4,7 +4,7 @@ import types
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
,
TensorState
from
colossalai.zero.shard
_utils
.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.zero.shard
ed_param
.tensor_utils
import
colo_model_data_tensor_move_inline
,
colo_tensor_mem_usage
from
colossalai.utils.memory
import
colo_device_memory_capacity
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Dict
,
List
...
...
colossalai/
engine/ophook
s/zero_hook.py
→
colossalai/
zero/util
s/zero_hook.py
View file @
4d90a7b5
...
...
@@ -3,15 +3,16 @@ from typing import Optional
import
torch
import
torch.distributed
as
dist
from
colossalai.registry
import
OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
colossalai.zero.shard_utils.stateful_tensor_mgr
import
StatefulTensorMgr
from
._base_ophook
import
BaseOpHook
from
colossalai.zero.utils.stateful_tensor_mgr
import
StatefulTensorMgr
from
colossalai.zero.sharded_param.tensor_utils
import
colo_model_data_tensor_move_inline
from
colossalai.
zero.shard_utils.tensor_utils
import
colo_model_data_tensor_move_inline
from
colossalai.
engine.ophooks
import
BaseOpHook
@
OPHOOKS
.
register_module
...
...
tests/test_utils/test_commons.py
View file @
4d90a7b5
from
colossalai.zero.shard
_utils
.tensor_utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.zero.shard
ed_param
.tensor_utils
import
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_on_exception
from
colossalai.zero.sharded_param
import
ShardedTensor
...
...
tests/test_utils/test_tensor_move.py
View file @
4d90a7b5
import
pytest
import
colossalai
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.zero.shard_utils.tensor_utils
import
colo_tensor_mem_usage
,
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
,
colo_model_data_move_to_cpu
,
colo_model_tensor_clone
from
colossalai.zero.sharded_param
import
(
StatefulTensor
,
colo_tensor_mem_usage
,
colo_model_data_tensor_move
,
colo_model_data_tensor_move_inline
,
colo_model_data_move_to_cpu
,
colo_model_tensor_clone
)
from
colossalai.utils.memory
import
colo_set_process_memory_fraction
,
colo_device_memory_capacity
from
colossalai.utils
import
free_port
from
colossalai.zero.sharded_param.tensorful_state
import
StatefulTensor
import
colossalai
import
torch
...
...
tests/test_zero_data_parallel/test_found_inf.py
View file @
4d90a7b5
...
...
@@ -30,10 +30,9 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
optimizer_class
,
criterion
=
get_components_func
()
with
ZeroInitContext
(
target_device
=
torch
.
device
(
f
'cpu:0'
)
if
cpu_offload
else
torch
.
device
(
get_current_device
()),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
with
ZeroInitContext
(
target_device
=
torch
.
device
(
f
'cpu:0'
)
if
cpu_offload
else
get_current_device
(),
shard_strategy
=
shard_strategy
,
shard_param
=
True
):
zero_model
=
model_builder
(
checkpoint
=
True
)
zero_model
=
ShardedModelV2
(
zero_model
,
...
...
tests/test_zero_data_parallel/test_stateful_tensor_mgr.py
View file @
4d90a7b5
...
...
@@ -6,7 +6,7 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.utils.memory_tracer
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory
import
colo_device_memory_capacity
,
colo_set_process_memory_fraction
from
colossalai.zero.
shard_
utils
import
StatefulTensorMgr
from
colossalai.zero.utils
import
StatefulTensorMgr
from
colossalai.zero.sharded_param.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param.tensorful_state
import
TensorState
from
colossalai.utils
import
free_port
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment