Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9506a8be
Commit
9506a8be
authored
Mar 15, 2022
by
ver217
Browse files
use double buffer to handle grad
parent
0f5f5dd5
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
29 additions
and
41 deletions
+29
-41
colossalai/engine/ophooks/zero_hook.py
colossalai/engine/ophooks/zero_hook.py
+7
-5
colossalai/zero/init_ctx/init_context.py
colossalai/zero/init_ctx/init_context.py
+5
-4
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+15
-23
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+2
-9
No files found.
colossalai/engine/ophooks/zero_hook.py
View file @
9506a8be
from
typing
import
Optional
import
torch
import
torch
from
colossalai.registry
import
OPHOOKS
from
colossalai.registry
import
OPHOOKS
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
._base_ophook
import
BaseOpHook
from
._base_ophook
import
BaseOpHook
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
typing
import
Optional
@
OPHOOKS
.
register_module
@
OPHOOKS
.
register_module
...
@@ -62,8 +64,8 @@ class ZeroHook(BaseOpHook):
...
@@ -62,8 +64,8 @@ class ZeroHook(BaseOpHook):
if
param
.
grad
is
not
None
:
if
param
.
grad
is
not
None
:
if
param
.
col_attr
.
bwd_count
==
0
:
if
param
.
col_attr
.
bwd_count
==
0
:
# We haven't stored local accumulated grad yet
# We haven't stored local accumulated grad yet
assert
param
.
col_attr
.
grad
is
None
assert
param
.
col_attr
.
fp32_
grad
is
None
param
.
col_attr
.
grad
=
param
.
grad
.
data
param
.
col_attr
.
fp32_
grad
=
param
.
grad
.
data
param
.
grad
=
None
param
.
grad
=
None
else
:
else
:
# We have stored local accumulated grad
# We have stored local accumulated grad
...
...
colossalai/zero/init_ctx/init_context.py
View file @
9506a8be
import
functools
import
functools
import
torch
import
torch
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
# Inserts _post_init_method at the end of init method
# Inserts _post_init_method at the end of init method
...
@@ -154,6 +155,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
...
@@ -154,6 +155,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if
self
.
shard_param
:
if
self
.
shard_param
:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_data_sharded_tensor
])
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_data_sharded_tensor
])
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
_data_sharded_tensor
.
payload
)
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
_data_sharded_tensor
.
payload
)
if
param
.
col_attr
.
grad
and
self
.
shard_grad
:
#
if param.col_attr.grad and self.shard_grad:
self
.
shard_strategy
.
shard
(
tensor_list
=
[
param
.
col_attr
.
_grad_sharded_tensor
])
#
self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor])
GLOBAL_MODEL_DATA_TRACER
.
add_tensor
(
param
.
col_attr
.
_grad_sharded_tensor
.
payload
)
#
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
9506a8be
from
ast
import
Try
import
functools
import
functools
from
ast
import
Try
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
...
@@ -12,16 +12,17 @@ from colossalai.engine.ophooks import register_ophooks_recursively
...
@@ -12,16 +12,17 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.commons.memory
import
col_cuda_memory_capacity
from
colossalai.utils.memory_tracer.allocator
import
col_move_to_cpu
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
colossalai.zero.sharded_param
import
ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.allocator
import
col_move_to_cpu
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
get_gradient_predivide_factor
)
get_gradient_predivide_factor
)
from
colossalai.utils.commons.memory
import
col_cuda_memory_capacity
class
ShardedModelV2
(
nn
.
Module
):
class
ShardedModelV2
(
nn
.
Module
):
...
@@ -164,8 +165,15 @@ class ShardedModelV2(nn.Module):
...
@@ -164,8 +165,15 @@ class ShardedModelV2(nn.Module):
# If world size == 1 and sharded param,
# If world size == 1 and sharded param,
# the shape `grad` is the same as unsharded param
# the shape `grad` is the same as unsharded param
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
p
.
grad
.
data
=
p
.
col_attr
.
grad
.
view
(
-
1
)
grad
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
)
p
.
col_attr
.
grad
=
None
if
self
.
_cpu_offload
:
col_move_to_cpu
(
grad
)
if
p
.
col_attr
.
fp32_grad
is
not
None
:
p
.
col_attr
.
fp32_grad
.
add_
(
grad
.
view_as
(
p
.
col_attr
.
fp32_grad
))
grad
=
p
.
col_attr
.
fp32_grad
p
.
grad
.
data
=
grad
.
view
(
-
1
)
p
.
col_attr
.
fp16_grad
=
None
p
.
col_attr
.
fp32_grad
=
None
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
def
_grad_post_backward_hook
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
...
@@ -216,23 +224,7 @@ class ShardedModelV2(nn.Module):
...
@@ -216,23 +224,7 @@ class ShardedModelV2(nn.Module):
# Average grad by world_size for consistency with PyTorch DDP.
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
# Make sure we store fp32 grad
param
.
col_attr
.
fp16_grad
=
reduced_grad
.
data
reduced_grad
.
data
=
cast_tensor_to_fp32
(
reduced_grad
.
data
)
# Maybe offload
# TODO() optimize GPU->CPU bandwidth utilization
if
self
.
_cpu_offload
:
col_move_to_cpu
(
reduced_grad
)
# reduced_grad.data = reduced_grad.data.cpu()
if
param
.
col_attr
.
grad
is
None
:
param
.
col_attr
.
grad
=
reduced_grad
.
data
else
:
# When dp size = 1
# param.col_attr.grad is local accumulated grad shard (full but flatten)
# But reduced_grad here is full grad
# We should call `view_as`
param
.
col_attr
.
grad
.
add_
(
reduced_grad
.
data
.
view_as
(
param
.
col_attr
.
grad
))
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
)
->
'OrderedDict[str, torch.Tensor]'
:
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
data
for
p
in
self
.
module
.
parameters
()])
self
.
shard_strategy
.
gather
([
p
.
col_attr
.
data
for
p
in
self
.
module
.
parameters
()])
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
9506a8be
...
@@ -16,7 +16,8 @@ class ShardedParamV2(object):
...
@@ -16,7 +16,8 @@ class ShardedParamV2(object):
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
rm_torch_payload
=
False
)
->
None
:
rm_torch_payload
=
False
)
->
None
:
self
.
_data_sharded_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
,
process_group
)
self
.
_data_sharded_tensor
:
ShardedTensor
=
ShardedTensor
(
param
.
data
,
process_group
)
self
.
_grad_sharded_tensor
:
Optional
[
torch
.
Tensor
]
=
None
self
.
fp16_grad
:
Optional
[
torch
.
Tensor
]
=
None
self
.
fp32_grad
:
Optional
[
torch
.
Tensor
]
=
None
# make sure the shared param is the only owner of payload
# make sure the shared param is the only owner of payload
# The param.data maybe used to init the other part of the model.
# The param.data maybe used to init the other part of the model.
...
@@ -39,14 +40,6 @@ class ShardedParamV2(object):
...
@@ -39,14 +40,6 @@ class ShardedParamV2(object):
def
data
(
self
):
def
data
(
self
):
return
self
.
_data_sharded_tensor
return
self
.
_data_sharded_tensor
@
property
def
grad
(
self
):
return
self
.
_grad_sharded_tensor
@
grad
.
setter
def
grad
(
self
,
t
:
torch
.
Tensor
):
self
.
_grad_sharded_tensor
=
t
@
property
@
property
def
param_is_sharded
(
self
):
def
param_is_sharded
(
self
):
return
self
.
_data_sharded_tensor
.
is_sharded
return
self
.
_data_sharded_tensor
.
is_sharded
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment