Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7675366f
Unverified
Commit
7675366f
authored
Mar 31, 2022
by
Jiarui Fang
Committed by
GitHub
Mar 31, 2022
Browse files
[polish] rename col_attr -> colo_attr (#558)
parent
2c45efc3
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
91 additions
and
91 deletions
+91
-91
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+18
-18
colossalai/utils/memory_tracer/model_data_memtracer.py
colossalai/utils/memory_tracer/model_data_memtracer.py
+2
-2
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+5
-5
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+29
-29
colossalai/zero/sharded_model/utils.py
colossalai/zero/sharded_model/utils.py
+5
-5
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+18
-18
tests/test_moe/test_moe_zero_init.py
tests/test_moe/test_moe_zero_init.py
+7
-7
tests/test_zero_data_parallel/common.py
tests/test_zero_data_parallel/common.py
+2
-2
tests/test_zero_data_parallel/test_init_context.py
tests/test_zero_data_parallel/test_init_context.py
+5
-5
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
7675366f
...
...
@@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook):
def
pre_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
assert
hasattr
(
param
,
'col
o
_attr'
)
tensor_list
.
append
(
param
.
col
o
_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
col_attr
.
sharded_data_tensor
,
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
sharded_data_tensor
.
payload
colo_model_data_tensor_move_inline
(
param
.
col
o
_attr
.
sharded_data_tensor
,
self
.
computing_device
)
param
.
data
=
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
param
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
def
post_fwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_FWD
)
param
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_FWD
)
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
assert
hasattr
(
param
,
'col
o
_attr'
)
tensor_list
.
append
(
param
.
col
o
_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
remove_torch_payload
()
param
.
col
o
_attr
.
remove_torch_payload
()
def
pre_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
,
output
):
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
assert
hasattr
(
param
,
'col
o
_attr'
)
tensor_list
.
append
(
param
.
col
o
_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
gather
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
colo_model_data_tensor_move_inline
(
param
.
col_attr
.
sharded_data_tensor
,
self
.
computing_device
)
param
.
data
=
param
.
col_attr
.
sharded_data_tensor
.
payload
colo_model_data_tensor_move_inline
(
param
.
col
o
_attr
.
sharded_data_tensor
,
self
.
computing_device
)
param
.
data
=
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
if
self
.
_memstarts_collector
:
self
.
_memstarts_collector
.
sample_memstats
()
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
param
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
COMPUTE
)
def
post_bwd_exec
(
self
,
module
:
torch
.
nn
.
Module
,
input
):
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
param
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
tensor_list
=
[]
for
param
in
module
.
parameters
(
recurse
=
False
):
assert
hasattr
(
param
,
'col_attr'
)
tensor_list
.
append
(
param
.
col_attr
.
sharded_data_tensor
)
assert
hasattr
(
param
,
'col
o
_attr'
)
tensor_list
.
append
(
param
.
col
o
_attr
.
sharded_data_tensor
)
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
for
param
in
module
.
parameters
(
recurse
=
False
):
param
.
col_attr
.
remove_torch_payload
()
param
.
col
o
_attr
.
remove_torch_payload
()
def
pre_iter
(
self
):
pass
...
...
colossalai/utils/memory_tracer/model_data_memtracer.py
View file @
7675366f
...
...
@@ -45,8 +45,8 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
cuda_mem_usage
=
0
cpu_mem_usage
=
0
for
param
in
model
.
parameters
():
if
hasattr
(
param
,
'col_attr'
):
t_cuda
,
t_cpu
=
param
.
col_attr
.
get_memory_usage
()
if
hasattr
(
param
,
'col
o
_attr'
):
t_cuda
,
t_cpu
=
param
.
col
o
_attr
.
get_memory_usage
()
cuda_mem_usage
+=
t_cuda
cpu_mem_usage
+=
t_cpu
else
:
...
...
colossalai/zero/init_ctx/init_context.py
View file @
7675366f
...
...
@@ -162,8 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
if
not
self
.
rm_torch_payload_on_the_fly
:
for
param
in
self
.
initialized_param_list
:
assert
hasattr
(
param
,
'col_attr'
)
param
.
col_attr
.
remove_torch_payload
()
assert
hasattr
(
param
,
'col
o
_attr'
)
param
.
col
o
_attr
.
remove_torch_payload
()
del
self
.
initialized_param_list
...
...
@@ -178,7 +178,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for
param
in
module
.
parameters
(
recurse
=
False
):
# avoid adapting a param to ShardedParam twice
if
hasattr
(
param
,
'col_attr'
):
if
hasattr
(
param
,
'col
o
_attr'
):
continue
self
.
model_numel_tensor
+=
param
.
numel
()
...
...
@@ -196,10 +196,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
param
.
grad
is
not
None
:
param
.
grad
=
param
.
grad
.
to
(
target_device
)
param
.
col_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
param
.
col
o
_attr
=
ShardedParamV2
(
param
,
rm_torch_payload
=
self
.
rm_torch_payload_on_the_fly
)
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
([
param
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
param
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
initialized_param_list
.
append
(
param
)
# We must cast buffers
...
...
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
7675366f
...
...
@@ -70,9 +70,9 @@ class ShardedModelV2(nn.Module):
sharded
=
[]
unsharded
=
[]
for
param
in
module
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
),
'You must use ZeroInitContext to init your module first.'
sharded
.
append
(
param
.
col_attr
.
param_is_sharded
)
unsharded
.
append
(
not
param
.
col_attr
.
param_is_sharded
)
assert
hasattr
(
param
,
'col
o
_attr'
),
'You must use ZeroInitContext to init your module first.'
sharded
.
append
(
param
.
col
o
_attr
.
param_is_sharded
)
unsharded
.
append
(
not
param
.
col
o
_attr
.
param_is_sharded
)
assert
all
(
sharded
)
or
all
(
unsharded
),
'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
self
.
shard_param
=
all
(
sharded
)
...
...
@@ -103,7 +103,7 @@ class ShardedModelV2(nn.Module):
self
.
_cpu_offload
:
bool
=
offload_config
.
get
(
'device'
,
None
)
==
'cpu'
if
offload_config
else
False
for
param
in
module
.
parameters
():
# Init `offload_grad`
param
.
col_attr
.
offload_grad
=
self
.
_cpu_offload
param
.
col
o
_attr
.
offload_grad
=
self
.
_cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
...
...
@@ -162,13 +162,13 @@ class ShardedModelV2(nn.Module):
self
.
_memstats_collector
.
start_collection
()
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'col_attr'
):
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
if
hasattr
(
p
,
'col
o
_attr'
):
p
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
_post_forward_operations
(
self
):
for
p
in
self
.
module
.
parameters
():
if
hasattr
(
p
,
'col_attr'
):
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
if
hasattr
(
p
,
'col
o
_attr'
):
p
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD
)
def
forward
(
self
,
*
args
:
Any
,
**
kwargs
:
Any
)
->
torch
.
Tensor
:
self
.
_pre_forward_operations
()
...
...
@@ -228,10 +228,10 @@ class ShardedModelV2(nn.Module):
if
self
.
shard_param
:
tensor_list
=
[]
for
p
in
self
.
module
.
parameters
():
if
not
p
.
col_attr
.
param_is_sharded
:
tensor_list
.
append
(
p
.
col_attr
.
sharded_data_tensor
)
p
.
col_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
p
.
col_attr
.
remove_torch_payload
()
if
not
p
.
col
o
_attr
.
param_is_sharded
:
tensor_list
.
append
(
p
.
col
o
_attr
.
sharded_data_tensor
)
p
.
col
o
_attr
.
sharded_data_tensor
.
trans_state
(
TensorState
.
HOLD_AFTER_BWD
)
p
.
col
o
_attr
.
remove_torch_payload
()
self
.
shard_strategy
.
shard
(
tensor_list
,
self
.
process_group
)
# 4. move sharded param grad payload to param.grad
...
...
@@ -245,27 +245,27 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
if
not
self
.
_require_backward_grad_sync
:
continue
# Reduced grad is saved in `p.col_attr.saved_grad`
# Reduced grad is saved in `p.col
o
_attr.saved_grad`
# It can be on CPU or CUDA
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if
self
.
reuse_fp16_shard
:
grad_fp16_payload
=
p
.
col_attr
.
sharded_data_tensor
.
payload
grad_fp16_payload
=
p
.
col
o
_attr
.
sharded_data_tensor
.
payload
else
:
grad_fp16_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
.
payload
)
grad_fp16_payload
=
cast_tensor_to_fp32
(
p
.
col
o
_attr
.
fp16_grad
.
payload
)
assert
isinstance
(
grad_fp16_payload
,
torch
.
Tensor
)
if
p
.
col_attr
.
offload_grad
:
if
p
.
col
o
_attr
.
offload_grad
:
colo_model_data_move_to_cpu
(
grad_fp16_payload
)
if
not
p
.
col_attr
.
saved_grad
.
is_null
():
if
not
p
.
col
o
_attr
.
saved_grad
.
is_null
():
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
# Accumulate grad, saved grad must be fp32
p
.
col_attr
.
saved_grad
.
reset_payload
(
cast_tensor_to_fp32
(
p
.
col_attr
.
saved_grad
.
payload
))
p
.
col_attr
.
saved_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col_attr
.
saved_grad
.
payload
))
p
.
col
o
_attr
.
saved_grad
.
reset_payload
(
cast_tensor_to_fp32
(
p
.
col
o
_attr
.
saved_grad
.
payload
))
p
.
col
o
_attr
.
saved_grad
.
payload
.
add_
(
grad_fp16_payload
.
view_as
(
p
.
col
o
_attr
.
saved_grad
.
payload
))
else
:
p
.
col_attr
.
saved_grad
.
reset_payload
(
grad_fp16_payload
)
p
.
col
o
_attr
.
saved_grad
.
reset_payload
(
grad_fp16_payload
)
p
.
grad
=
None
p
.
col_attr
.
fp16_grad
.
set_null
()
p
.
col
o
_attr
.
fp16_grad
.
set_null
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
...
@@ -273,7 +273,7 @@ class ShardedModelV2(nn.Module):
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
GPUs to param.col
o
_attr.grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
...
...
@@ -285,7 +285,7 @@ class ShardedModelV2(nn.Module):
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.col_attr.grad`, which ensures that
alignment is created by `param.col
o
_attr.grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if
grad
is
None
:
...
...
@@ -323,20 +323,20 @@ class ShardedModelV2(nn.Module):
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
if
self
.
reuse_fp16_shard
:
param
.
col_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
.
data
)
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
=
True
param
.
col
o
_attr
.
sharded_data_tensor
.
reset_payload
(
reduced_grad
.
data
)
param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
=
True
else
:
param
.
col_attr
.
fp16_grad
=
StatefulTensor
(
reduced_grad
.
data
)
param
.
col
o
_attr
.
fp16_grad
=
StatefulTensor
(
reduced_grad
.
data
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
shard_strategy
.
gather
([
p
.
col
o
_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
process_group
)
prev_params
=
{}
for
p
in
self
.
module
.
parameters
():
prev_params
[
p
]
=
p
.
data
p
.
data
=
p
.
col_attr
.
sharded_data_tensor
.
payload
p
.
data
=
p
.
col
o
_attr
.
sharded_data_tensor
.
payload
gathered_state_dict
=
self
.
module
.
state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
shard_strategy
.
shard
([
p
.
col
o
_attr
.
sharded_data_tensor
for
p
in
self
.
module
.
parameters
()],
self
.
process_group
)
for
p
in
self
.
module
.
parameters
():
p
.
data
=
prev_params
[
p
]
...
...
colossalai/zero/sharded_model/utils.py
View file @
7675366f
...
...
@@ -10,10 +10,10 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
Note the other_model has to be the same as self.
"""
for
zero_param
,
param
in
zip
(
sharded_model
.
parameters
(),
other_model
.
parameters
()):
assert
hasattr
(
zero_param
,
'col_attr'
)
shard_flag
=
zero_param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
hasattr
(
zero_param
,
'col
o
_attr'
)
shard_flag
=
zero_param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
if
shard_flag
:
sharded_model
.
shard_strategy
.
gather
([
zero_param
.
col_attr
.
sharded_data_tensor
])
param
.
data
=
copy
.
deepcopy
(
zero_param
.
col_attr
.
sharded_data_tensor
.
payload
)
sharded_model
.
shard_strategy
.
gather
([
zero_param
.
col
o
_attr
.
sharded_data_tensor
])
param
.
data
=
copy
.
deepcopy
(
zero_param
.
col
o
_attr
.
sharded_data_tensor
.
payload
)
if
shard_flag
:
sharded_model
.
shard_strategy
.
shard
([
zero_param
.
col_attr
.
sharded_data_tensor
])
sharded_model
.
shard_strategy
.
shard
([
zero_param
.
col
o
_attr
.
sharded_data_tensor
])
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
7675366f
...
...
@@ -116,18 +116,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
assert
hasattr
(
p
,
'col_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
hasattr
(
p
,
'col
o
_attr'
),
'The parameter must be wrapped with ShardedParam'
is_param_sharded
=
p
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
master_params
[
p
]
=
cast_tensor_to_fp32
(
p
.
col_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
)
self
.
shard_strategy
.
shard
([
p
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
master_params
[
p
]
=
cast_tensor_to_fp32
(
p
.
col
o
_attr
.
sharded_data_tensor
.
payload
).
to
(
self
.
device
)
if
not
is_param_sharded
:
# In this branch, there's no need to shard param
# So we gather here
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
gather
([
p
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
_logger
.
debug
(
f
"After init ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
...
...
@@ -201,30 +201,30 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self
.
_logger
.
debug
(
f
"After step ShardedOptimizerV2 consumes
{
self
.
get_memory_usage
()[
0
]
/
1e6
}
MB CUDA Memory,
{
self
.
get_memory_usage
()[
1
]
/
1e6
}
MB CUDA Memory!"
,
ranks
=
[
0
])
# Copy master param data (fp32) to payload of col_attr (fp16)
# Copy master param data (fp32) to payload of col
o
_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
is_param_sharded
=
p
.
col_attr
.
sharded_data_tensor
.
is_sharded
is_param_sharded
=
p
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
if
not
is_param_sharded
:
# We use ZeRO-2 here
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
# The `p.col
o
_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self
.
shard_strategy
.
shard
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
self
.
shard_strategy
.
shard
([
p
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# Since p.data is fp32 and p.col
o
_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p
.
col_attr
.
sharded_data_tensor
.
reset_payload
(
p
.
col
o
_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
is_param_sharded
:
# We gather full fp16 param here
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
col_attr
.
sharded_data_tensor
.
payload
self
.
shard_strategy
.
gather
([
p
.
col
o
_attr
.
sharded_data_tensor
],
self
.
dp_process_group
)
p
.
data
=
p
.
col
o
_attr
.
sharded_data_tensor
.
payload
return
ret
def
backward
(
self
,
loss
:
Tensor
)
->
None
:
...
...
@@ -292,7 +292,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if
fp32_shards_used_cuda_margin_mem
+
shard_mem
<
fp32_shards_available_cuda_margin_mem
:
self
.
master_params
[
p
]
=
self
.
master_params
[
p
].
to
(
torch
.
cuda
.
current_device
())
p
.
grad
.
data
=
p
.
grad
.
data
.
to
(
torch
.
cuda
.
current_device
())
p
.
col_attr
.
offload_grad
=
False
p
.
col
o
_attr
.
offload_grad
=
False
fp32_shards_used_cuda_margin_mem
+=
shard_mem
def
_prepare_grads
(
self
):
...
...
@@ -301,7 +301,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.col_attr.saved_grad.payload here
p
.
data
=
p
.
col_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
col_attr
.
saved_grad
.
payload
p
.
col_attr
.
saved_grad
.
set_null
()
# We just set p.data = p.col
o
_attr.saved_grad.payload here
p
.
data
=
p
.
col
o
_attr
.
saved_grad
.
payload
p
.
grad
=
p
.
col
o
_attr
.
saved_grad
.
payload
p
.
col
o
_attr
.
saved_grad
.
set_null
()
tests/test_moe/test_moe_zero_init.py
View file @
7675366f
...
...
@@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
model
=
MoeModel
()
for
name
,
param
in
model
.
named_parameters
():
assert
hasattr
(
param
,
'col_attr'
)
assert
hasattr
(
param
,
'col
o
_attr'
)
# the weights in the gate should be fp32
if
'gate'
in
name
:
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
dtype
==
torch
.
float32
else
:
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
# the parameters in moe experts and its gate should not be sharded
if
(
'experts'
in
name
)
or
(
'gate'
in
name
)
or
(
'residual_combine'
in
name
):
assert
not
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
not
param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
else
:
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
def
_run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_zero_data_parallel/common.py
View file @
7675366f
...
...
@@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
rank
=
dist
.
get_rank
()
for
p
,
zero_p
in
zip
(
model
.
parameters
(),
zero_model
.
parameters
()):
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad
=
zero_p
.
col_attr
.
saved_grad
.
payload
.
clone
().
to
(
p
.
device
)
zero_grad
=
zero_p
.
col
o
_attr
.
saved_grad
.
payload
.
clone
().
to
(
p
.
device
)
chunks
=
torch
.
flatten
(
p
.
grad
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
...
...
@@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if
reuse_fp16_shard
:
zero_p
=
zero_p
.
data
.
to
(
p
.
device
).
float
()
else
:
zero_p
=
zero_p
.
col_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
).
float
()
zero_p
=
zero_p
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
to
(
p
.
device
).
float
()
chunks
=
torch
.
flatten
(
p
).
chunk
(
dist
.
get_world_size
())
if
rank
>=
len
(
chunks
):
continue
...
...
tests/test_zero_data_parallel/test_init_context.py
View file @
7675366f
...
...
@@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class):
model
=
model_builder
(
checkpoint
=
True
)
for
param
in
model
.
parameters
():
assert
hasattr
(
param
,
'col_attr'
)
assert
param
.
col_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
assert
param
.
col_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
assert
hasattr
(
param
,
'col
o
_attr'
)
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
dtype
==
torch
.
half
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
is_sharded
assert
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
device
.
type
==
init_device
.
type
,
\
f
'
{
param
.
col
o
_attr
.
sharded_data_tensor
.
payload
.
device
.
type
}
vs.
{
init_device
.
type
}
'
cuda_mem_use
,
cpu_mem_use
=
colo_model_mem_usage
(
model
)
model_data_cuda_mem_MB
=
cuda_mem_use
/
1e6
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment