Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fb841dd5
Unverified
Commit
fb841dd5
authored
Mar 29, 2022
by
ver217
Committed by
GitHub
Mar 29, 2022
Browse files
[zero] optimize grad offload (#539)
* optimize grad offload * polish code * polish code
parent
7d81b5b4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
11 deletions
+13
-11
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+5
-4
colossalai/zero/sharded_optim/sharded_optim_v2.py
colossalai/zero/sharded_optim/sharded_optim_v2.py
+8
-7
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
fb841dd5
...
@@ -11,9 +11,10 @@ from colossalai.engine.ophooks import register_ophooks_recursively
...
@@ -11,9 +11,10 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.ophooks.zero_hook
import
ZeroHook
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.engine.paramhooks
import
BaseParamHookMgr
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.utils.memory_tracer.model_data_memtracer
import
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
colo_model_data_move_to_cpu
,
colo_cuda_memory_capacity
,
colo_model_tensor_clone
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.memstats_collector
import
MemStatsCollector
from
colossalai.utils.memory_tracer.model_data_memtracer
import
\
GLOBAL_MODEL_DATA_TRACER
from
colossalai.utils.memory_utils.utils
import
(
colo_cuda_memory_capacity
,
colo_model_data_move_to_cpu
)
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.shard_utils
import
BaseShardStrategy
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
colossalai.zero.sharded_model.reduce_scatter
import
ReduceScatterBucketer
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
...
@@ -206,7 +207,7 @@ class ShardedModelV2(nn.Module):
...
@@ -206,7 +207,7 @@ class ShardedModelV2(nn.Module):
else
:
else
:
grad_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
)
grad_payload
=
cast_tensor_to_fp32
(
p
.
col_attr
.
fp16_grad
)
if
p
.
col_attr
.
offload_grad
:
if
p
.
col_attr
.
offload_grad
:
grad_payload
=
colo_model_tensor_clone
(
grad_payload
,
torch
.
device
(
'cpu'
)
)
colo_model_data_move_to_cpu
(
grad_payload
)
if
p
.
col_attr
.
fp32_grad
is
not
None
:
if
p
.
col_attr
.
fp32_grad
is
not
None
:
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
assert
not
self
.
reuse_fp16_shard
,
'Gradien accumulation is not supported when reuse_fp16_shard=True'
p
.
col_attr
.
fp32_grad
.
add_
(
grad_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
))
p
.
col_attr
.
fp32_grad
.
add_
(
grad_payload
.
view_as
(
p
.
col_attr
.
fp32_grad
))
...
...
colossalai/zero/sharded_optim/sharded_optim_v2.py
View file @
fb841dd5
...
@@ -10,14 +10,14 @@ from colossalai.context.parallel_mode import ParallelMode
...
@@ -10,14 +10,14 @@ from colossalai.context.parallel_mode import ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils.memory_utils.utils
import
(
colo_model_tensor_clone
,
colo_tensor_mem_usage
)
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_model._utils
import
cast_tensor_to_fp32
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
colossalai.zero.sharded_optim._utils
import
has_inf_or_nan
from
colossalai.utils.memory_utils.utils
import
colo_model_data_tensor_move
,
colo_tensor_mem_usage
class
OptimState
(
Enum
):
class
OptimState
(
Enum
):
...
@@ -204,7 +204,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
...
@@ -204,7 +204,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
colo_model_data_tensor_move
(
p
,
p
.
col_attr
.
sharded_data_tensor
)
p
.
col_attr
.
sharded_data_tensor
.
reset_payload
(
colo_model_tensor_clone
(
p
.
half
(),
torch
.
cuda
.
current_device
()))
if
not
is_param_sharded
:
if
not
is_param_sharded
:
# We gather full fp16 param here
# We gather full fp16 param here
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment