Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7675366f
Unverified
Commit
7675366f
authored
Mar 31, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 31, 2022
Browse files
[polish] rename col_attr -> colo_attr (#558)
parent
2c45efc3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
91 additions
and
91 deletions
+91
-91
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+18
-18
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+2
-2
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+5
-5
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+29
-29
colossalai/zero/sharded_model/utils.py
colossalai/zero/sharded_model/utils.py
+5
-5
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+18
-18
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+7
-7
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+2
-2
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+5
-5
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
7675366f
...
@@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook):
...
@@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook):
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col
o
_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
tensor_list
.
append
(
param
.
col
o
_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
col_attr
.
sharded_data_tensor
,
self
.
computing_device
)
colo_model_data_tensor_move_inline
(
param
.
col
o
_attr
.
sharded_data_tensor
,
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
sharded_data_tensor
.
payload
param
.
data
=
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
if
self
.
_memstarts_collector
:
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
param
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_FWD
)
param
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_FWD
)
tensor_list
=
[]
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col
o
_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
tensor_list
.
append
(
param
.
col
o
_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
remove_torch_payload
()
param
.
col
o
_attr
.
remove_torch_payload
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
tensor_list
=
[]
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col
o
_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
tensor_list
.
append
(
param
.
col
o
_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
col_attr
.
sharded_data_tensor
,
self
.
computing_device
)
colo_model_data_tensor_move_inline
(
param
.
col
o
_attr
.
sharded_data_tensor
,
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
sharded_data_tensor
.
payload
param
.
data
=
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
if
self
.
_memstarts_collector
:
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
param
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
param
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
tensor_list
=
[]
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col
o
_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
tensor_list
.
append
(
param
.
col
o
_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
remove_torch_payload
()
param
.
col
o
_attr
.
remove_torch_payload
()
def
pre_iter
(
self
):
def
pre_iter
(
self
):
pass
pass
...
...
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
7675366f
...
@@ -45,8 +45,8 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
...
@@ -45,8 +45,8 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
cuda_mem_usage
=
0
cuda_mem_usage
=
0
cpu_mem_usage
=
0
cpu_mem_usage
=
0
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
if
hasattr
(
param
,
'col_attr'
):
if
hasattr
(
param
,
'col
o
_attr'
):
t_cuda
,
t_cpu
=
param
.
col_attr
.
get_memory_usage
()
t_cuda
,
t_cpu
=
param
.
col
o
_attr
.
get_memory_usage
()
cuda_mem_usage
+=
t_cuda
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
cpu_mem_usage
+=
t_cpu
else
:
else
:
...
...
colossalai/zero/init_ctx/init_context.py
View file @
7675366f
...
@@ -162,8 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -162,8 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized_param_list
:
for
param
in
self
.
initialized_param_list
:
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col
o
_attr'
)
param
.
col_attr
.
remove_torch_payload
()
param
.
col
o
_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
del
self
.
initialized_param_list
...
@@ -178,7 +178,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -178,7 +178,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for
param
in
module
.
parameters
(
recurse
=
False
):
for
param
in
module
.
parameters
(
recurse
=
False
):
# avoid adapting a param to ShardedParam twice
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'col_attr'
):
if
hasattr
(
param
,
'col
o
_attr'
):
continue
continue
self
.
model_numel_tensor
+=
param
.
numel
()
self
.
model_numel_tensor
+=
param
.
numel
()
...
@@ -196,10 +196,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -196,10 +196,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
param
.
col
o
_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
param
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
initialized_param_list
.
append
(
param
)
self
.
initialized_param_list
.
append
(
param
)
# We must cast buffers
# We must cast buffers
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
7675366f
...
@@ -70,9 +70,9 @@ class ShardedModelV2(nn.Module):
...
@@ -70,9 +70,9 @@ class ShardedModelV2(nn.Module):
sharded
=
[]
sharded
=
[]
unsharded
=
[]
unsharded
=
[]
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
),
'You must use ZeroInitContext to init your module first.'
assert
hasattr
(
param
,
'col
o
_attr'
),
'You must use ZeroInitContext to init your module first.'
sharded
.
append
(
param
.
col_attr
.
param_is_sharded
)
sharded
.
append
(
param
.
col
o
_attr
.
param_is_sharded
)
unsharded
.
append
(
not
param
.
col_attr
.
param_is_sharded
)
unsharded
.
append
(
not
param
.
col
o
_attr
.
param_is_sharded
)
assert
all
(
sharded
)
or
all
(
assert
all
(
sharded
)
or
all
(
unsharded
),
'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
unsharded
),
'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
self
.
shard_param
=
all
(
sharded
)
self
.
shard_param
=
all
(
sharded
)
...
@@ -103,7 +103,7 @@ class ShardedModelV2(nn.Module):
...
@@ -103,7 +103,7 @@ class ShardedModelV2(nn.Module):
self
.
_cpu_offload
:
bool
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
self
.
_cpu_offload
:
bool
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
for
param
in
module
.
parameters
():
for
param
in
module
.
parameters
():
# Init `offload_grad`
# Init `offload_grad`
param
.
col_attr
.
offload_grad
=
self
.
_cpu_offload
param
.
col
o
_attr
.
offload_grad
=
self
.
_cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# So we use 1.0 as the default gradient_predivide_factor
...
@@ -162,13 +162,13 @@ class ShardedModelV2(nn.Module):
...
@@ -162,13 +162,13 @@ class ShardedModelV2(nn.Module):
self
.
_memstats_collector
.
start_collection
()
self
.
_memstats_collector
.
start_collection
()
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'col_attr'
):
if
hasattr
(
p
,
'col
o
_attr'
):
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
p
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
_post_forward_operations
(
self
):
def
_post_forward_operations
(
self
):
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'col_attr'
):
if
hasattr
(
p
,
'col
o
_attr'
):
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
p
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
self
.
_pre_forward_operations
()
self
.
_pre_forward_operations
()
...
@@ -228,10 +228,10 @@ class ShardedModelV2(nn.Module):
...
@@ -228,10 +228,10 @@ class ShardedModelV2(nn.Module):
if
self
.
shard_param
:
if
self
.
shard_param
:
tensor_list
=
[]
tensor_list
=
[]
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
if
not
p
.
col_attr
.
param_is_sharded
:
if
not
p
.
col
o
_attr
.
param_is_sharded
:
tensor_list
.
append
(
p
.
col_attr
.
sharded_data_tensor
)
tensor_list
.
append
(
p
.
col
o
_attr
.
sharded_data_tensor
)
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
p
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
p
.
col_attr
.
remove_torch_payload
()
p
.
col
o
_attr
.
remove_torch_payload
()
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# 4. move sharded param grad payload to param.grad
# 4. move sharded param grad payload to param.grad
...
@@ -245,27 +245,27 @@ class ShardedModelV2(nn.Module):
...
@@ -245,27 +245,27 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
# We also allows to interleave no-sync pass with sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
if
not
self
.
_require_backward_grad_sync
:
continue
continue
# Reduced grad is saved in `p.col_attr.saved_grad`
# Reduced grad is saved in `p.col
o
_attr.saved_grad`
# It can be on CPU or CUDA
# It can be on CPU or CUDA
# It can be fp16 or fp32
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if
self
.
reuse_fp16_shard
:
if
self
.
reuse_fp16_shard
:
grad_fp16_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
grad_fp16_payload
=
p
.
col
o
_attr
.
sharded_data_tensor
.
payload
else
:
else
:
grad_fp16_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
.
payload
)
grad_fp16_payload
=
cast_tensor_to_fp32
(
p
.
col
o
_attr
.
fp16_grad
.
payload
)
assert
isinstance
(
grad_fp16_payload
,
torch
.
Tensor
)
assert
isinstance
(
grad_fp16_payload
,
torch
.
Tensor
)
if
p
.
col_attr
.
offload_grad
:
if
p
.
col
o
_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_fp16_payload
)
colo_model_data_move_to_cpu
(
grad_fp16_payload
)
if
not
p
.
col_attr
.
saved_grad
.
is_null
():
if
not
p
.
col
o
_attr
.
saved_grad
.
is_null
():
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
# Accumulate grad, saved grad must be fp32
# Accumulate grad, saved grad must be fp32
p
.
col_attr
.
saved_grad
.
reset_payload
(
cast_tensor_to_fp32
(
p
.
col_attr
.
saved_grad
.
payload
))
p
.
col
o
_attr
.
saved_grad
.
reset_payload
(
cast_tensor_to_fp32
(
p
.
col
o
_attr
.
saved_grad
.
payload
))
p
.
col_attr
.
saved_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col_attr
.
saved_grad
.
payload
))
p
.
col
o
_attr
.
saved_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col
o
_attr
.
saved_grad
.
payload
))
else
:
else
:
p
.
col_attr
.
saved_grad
.
reset_payload
(
grad_fp16_payload
)
p
.
col
o
_attr
.
saved_grad
.
reset_payload
(
grad_fp16_payload
)
p
.
grad
=
None
p
.
grad
=
None
p
.
col_attr
.
fp16_grad
.
set_null
()
p
.
col
o
_attr
.
fp16_grad
.
set_null
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -273,7 +273,7 @@ class ShardedModelV2(nn.Module):
...
@@ -273,7 +273,7 @@ class ShardedModelV2(nn.Module):
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
a single shard of the summed gradient across all
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
GPUs to param.col
o
_attr.grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
param.grad (GPU #0): [1, 2, 3, 4]
...
@@ -285,7 +285,7 @@ class ShardedModelV2(nn.Module):
...
@@ -285,7 +285,7 @@ class ShardedModelV2(nn.Module):
The local GPU's ``optim.step`` is responsible for updating a single
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.col_attr.grad`, which ensures that
alignment is created by `param.col
o
_attr.grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
the local optimizer only sees the relevant parameter shard.
"""
"""
if
grad
is
None
:
if
grad
is
None
:
...
@@ -323,20 +323,20 @@ class ShardedModelV2(nn.Module):
...
@@ -323,20 +323,20 @@ class ShardedModelV2(nn.Module):
# Average grad by world_size for consistency with PyTorch DDP.
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
if
self
.
reuse_fp16_shard
:
if
self
.
reuse_fp16_shard
:
param
.
col_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
.
data
)
param
.
col
o
_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
.
data
)
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
=
True
param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
=
True
else
:
else
:
param
.
col_attr
.
fp16_grad
=
StatefulTensor
(
reduced_grad
.
data
)
param
.
col
o
_attr
.
fp16_grad
=
StatefulTensor
(
reduced_grad
.
data
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
shard_strategy
.
gather
([
p
.
col
o
_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
process_group
)
self
.
process_group
)
prev_params
=
{}
prev_params
=
{}
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
prev_params
[
p
]
=
p
.
data
prev_params
[
p
]
=
p
.
data
p
.
data
=
p
.
col_attr
.
sharded_data_tensor
.
payload
p
.
data
=
p
.
col
o
_attr
.
sharded_data_tensor
.
payload
gathered_state_dict
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
gathered_state_dict
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
shard_strategy
.
shard
([
p
.
col
o
_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
process_group
)
self
.
process_group
)
for
p
in
self
.
module
.
parameters
():
for
p
in
self
.
module
.
parameters
():
p
.
data
=
prev_params
[
p
]
p
.
data
=
prev_params
[
p
]
...
...
colossalai/zero/sharded_model/utils.py
View file @
7675366f
...
@@ -10,10 +10,10 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
...
@@ -10,10 +10,10 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
Note the other_model has to be the same as self.
Note the other_model has to be the same as self.
"""
"""
for
zero_param
,
param
in
zip
(
sharded_model
.
parameters
(),
other_model
.
parameters
()):
for
zero_param
,
param
in
zip
(
sharded_model
.
parameters
(),
other_model
.
parameters
()):
assert
hasattr
(
zero_param
,
'col_attr'
)
assert
hasattr
(
zero_param
,
'col
o
_attr'
)
shard_flag
=
zero_param
.
col_attr
.
sharded_data_tensor
.
is_sharded
shard_flag
=
zero_param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
if
shard_flag
:
if
shard_flag
:
sharded_model
.
shard_strategy
.
gather
([
zero_param
.
col_attr
.
sharded_data_tensor
])
sharded_model
.
shard_strategy
.
gather
([
zero_param
.
col
o
_attr
.
sharded_data_tensor
])
param
.
data
=
copy
.
deepcopy
(
zero_param
.
col_attr
.
sharded_data_tensor
.
payload
)
param
.
data
=
copy
.
deepcopy
(
zero_param
.
col
o
_attr
.
sharded_data_tensor
.
payload
)
if
shard_flag
:
if
shard_flag
:
sharded_model
.
shard_strategy
.
shard
([
zero_param
.
col_attr
.
sharded_data_tensor
])
sharded_model
.
shard_strategy
.
shard
([
zero_param
.
col
o
_attr
.
sharded_data_tensor
])
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
7675366f
...
@@ -116,18 +116,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -116,18 +116,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
assert
hasattr
(
p
,
'col_attr'
),
'The parameter must be wrapped with ShardedParam'
assert
hasattr
(
p
,
'col
o
_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
col_attr
.
sharded_data_tensor
.
is_sharded
is_param_sharded
=
p
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
if
not
is_param_sharded
:
# TODO (ver217): we may not use shard / gather here
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
# As we only store param shard, we shard it here
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
p
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
master_params
[
p
]
=
cast_tensor_to_fp32
(
p
.
col_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
)
self
.
master_params
[
p
]
=
cast_tensor_to_fp32
(
p
.
col
o
_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
)
if
not
is_param_sharded
:
if
not
is_param_sharded
:
# In this branch, there's no need to shard param
# In this branch, there's no need to shard param
# So we gather here
# So we gather here
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
gather
([
p
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
_logger
.
debug
(
f
"After init ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
self
.
_logger
.
debug
(
f
"After init ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
ranks
=
[
0
])
...
@@ -201,30 +201,30 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -201,30 +201,30 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_logger
.
debug
(
self
.
_logger
.
debug
(
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
ranks
=
[
0
])
# Copy master param data (fp32) to payload of col_attr (fp16)
# Copy master param data (fp32) to payload of col
o
_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
# a chunk.
for
group
in
self
.
optim
.
param_groups
:
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
for
p
in
group
[
'params'
]:
is_param_sharded
=
p
.
col_attr
.
sharded_data_tensor
.
is_sharded
is_param_sharded
=
p
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
if
not
is_param_sharded
:
# We use ZeRO-2 here
# We use ZeRO-2 here
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
# The `p.col
o
_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
# Then we will gather them
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
p
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
# We have to use `copy_payload` instead of `reset_payload`
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# Since p.data is fp32 and p.col
o
_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p
.
col_attr
.
sharded_data_tensor
.
reset_payload
(
p
.
col
o
_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
is_param_sharded
:
if
not
is_param_sharded
:
# We gather full fp16 param here
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
gather
([
p
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
col_attr
.
sharded_data_tensor
.
payload
p
.
data
=
p
.
col
o
_attr
.
sharded_data_tensor
.
payload
return
ret
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
...
@@ -292,7 +292,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -292,7 +292,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if
fp32_shards_used_cuda_margin_mem
+
shard_mem
<
fp32_shards_available_cuda_margin_mem
:
if
fp32_shards_used_cuda_margin_mem
+
shard_mem
<
fp32_shards_available_cuda_margin_mem
:
self
.
master_params
[
p
]
=
self
.
master_params
[
p
].
to
(
torch
.
cuda
.
current_device
())
self
.
master_params
[
p
]
=
self
.
master_params
[
p
].
to
(
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
col_attr
.
offload_grad
=
False
p
.
col
o
_attr
.
offload_grad
=
False
fp32_shards_used_cuda_margin_mem
+=
shard_mem
fp32_shards_used_cuda_margin_mem
+=
shard_mem
def
_prepare_grads
(
self
):
def
_prepare_grads
(
self
):
...
@@ -301,7 +301,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -301,7 +301,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.col_attr.saved_grad.payload here
# We just set p.data = p.col
o
_attr.saved_grad.payload here
p
.
data
=
p
.
col_attr
.
saved_grad
.
payload
p
.
data
=
p
.
col
o
_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
col_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
col
o
_attr
.
saved_grad
.
payload
p
.
col_attr
.
saved_grad
.
set_null
()
p
.
col
o
_attr
.
saved_grad
.
set_null
()
tests/test_moe/test_moe_zero_init.py
View file @
7675366f
...
@@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
...
@@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
model
=
MoeModel
()
model
=
MoeModel
()
for
name
,
param
in
model
.
named_parameters
():
for
name
,
param
in
model
.
named_parameters
():
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col
o
_attr'
)
# the weights in the gate should be fp32
# the weights in the gate should be fp32
if
'gate'
in
name
:
if
'gate'
in
name
:
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
else
:
else
:
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
# the parameters in moe experts and its gate should not be sharded
# the parameters in moe experts and its gate should not be sharded
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
assert
not
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
not
param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
else
:
else
:
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
f
'
{
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
def
_run_dist
(
rank
,
world_size
,
port
):
def
_run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_zero_data_parallel/common.py
View file @
7675366f
...
@@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
...
@@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
# zero_grad = zero_p.grad.clone().to(p.device)
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad
=
zero_p
.
col_attr
.
saved_grad
.
payload
.
clone
().
to
(
p
.
device
)
zero_grad
=
zero_p
.
col
o
_attr
.
saved_grad
.
payload
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
if
rank
>=
len
(
chunks
):
continue
continue
...
@@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
...
@@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if
reuse_fp16_shard
:
if
reuse_fp16_shard
:
zero_p
=
zero_p
.
data
.
to
(
p
.
device
).
float
()
zero_p
=
zero_p
.
data
.
to
(
p
.
device
).
float
()
else
:
else
:
zero_p
=
zero_p
.
col_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
).
float
()
zero_p
=
zero_p
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
).
float
()
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
if
rank
>=
len
(
chunks
):
continue
continue
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
7675366f
...
@@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class):
...
@@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class):
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col
o
_attr'
)
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
f
'
{
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
cuda_mem_use
,
cpu_mem_use
=
colo_model_mem_usage
(
model
)
cuda_mem_use
,
cpu_mem_use
=
colo_model_mem_usage
(
model
)
model_data_cuda_mem_MB
=
cuda_mem_use
/
1e6
model_data_cuda_mem_MB
=
cuda_mem_use
/
1e6
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment