Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
e6212f56
Unverified
Commit
e6212f56
authored
Apr 13, 2022
by
ver217
Committed by
GitHub
Apr 13, 2022
Browse files
[hotfix] fix memory leak in backward of sharded model (#741)
parent
f4f42d4c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
13 deletions
+10
-13
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+10
-13
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
e6212f56
...
@@ -303,41 +303,38 @@ class ShardedModelV2(nn.Module):
...
@@ -303,41 +303,38 @@ class ShardedModelV2(nn.Module):
assert
not
grad
.
requires_grad
,
'ShardedModel only works with gradients that don
\'
t require gradients'
assert
not
grad
.
requires_grad
,
'ShardedModel only works with gradients that don
\'
t require gradients'
if
not
self
.
_require_backward_grad_sync
:
if
not
self
.
_require_backward_grad_sync
:
return
return
# used to cheat Pytorch, since we can't return None
empty_grad
=
torch
.
empty_like
(
grad
)
free_storage
(
empty_grad
)
# As torch didn't allow modifying grad in hook, we make a copy
grad
=
grad
.
clone
()
if
param
.
colo_attr
.
is_replicated
:
if
param
.
colo_attr
.
is_replicated
:
self
.
_reduce_scatter_handler
(
param
,
grad
)
self
.
_reduce_scatter_handler
(
param
,
grad
)
else
:
else
:
self
.
_save_grad
(
param
,
grad
)
self
.
_save_grad
(
param
,
grad
)
# used to cheat Pytorch, since we can't return None
empty_grad
=
torch
.
empty_like
(
grad
)
free_storage
(
empty_grad
)
return
empty_grad
return
empty_grad
def
_reduce_scatter_handler
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
None
:
def
_reduce_scatter_handler
(
self
,
param
:
Parameter
,
grad
:
torch
.
Tensor
)
->
None
:
self
.
comm_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
self
.
comm_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
new_grad
=
grad
.
clone
()
if
self
.
fp32_reduce_scatter
:
if
self
.
fp32_reduce_scatter
:
new_
grad
.
data
=
new_
grad
.
data
.
to
(
param
.
dtype
)
grad
.
data
=
grad
.
data
.
to
(
param
.
dtype
)
if
self
.
gradient_predivide_factor
>
1.0
:
if
self
.
gradient_predivide_factor
>
1.0
:
# Average grad by world_size for consistency with PyTorch DDP.
# Average grad by world_size for consistency with PyTorch DDP.
new_grad
.
data
.
div_
(
self
.
gradient_predivide_factor
)
grad
.
data
.
div_
(
self
.
gradient_predivide_factor
)
orig_grad_data
=
new_grad
.
data
if
self
.
world_size
>
1
:
if
self
.
world_size
>
1
:
grad_chunks
=
chunk_and_pad
(
orig_grad_data
,
self
.
reduce_scatter_process_group
.
size
())
grad_chunks
=
chunk_and_pad
(
grad
,
self
.
reduce_scatter_process_group
.
size
())
self
.
reducer
.
reduce_scatter_async
(
grad_chunks
,
self
.
reducer
.
reduce_scatter_async
(
grad_chunks
,
group
=
self
.
reduce_scatter_process_group
,
group
=
self
.
reduce_scatter_process_group
,
callback_fn
=
functools
.
partial
(
self
.
_reduce_scatter_callback
,
param
))
callback_fn
=
functools
.
partial
(
self
.
_reduce_scatter_callback
,
param
))
else
:
else
:
self
.
_reduce_scatter_callback
(
param
,
new_grad
)
self
.
_reduce_scatter_callback
(
param
,
grad
)
orig_grad_data
.
record_stream
(
self
.
comm_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
comm_stream
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
comm_stream
)
def
_reduce_scatter_callback
(
self
,
param
:
Parameter
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
def
_reduce_scatter_callback
(
self
,
param
:
Parameter
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
assert
isinstance
(
reduced_grad
,
assert
isinstance
(
reduced_grad
,
torch
.
Tensor
),
f
"_reduce_scatter_callback accept reduced_grad as
{
type
(
reduced_grad
)
}
"
torch
.
Tensor
),
f
"_reduce_scatter_callback accept reduced_grad as
{
type
(
reduced_grad
)
}
"
reduced_grad
=
reduced_grad
.
view
(
-
1
)
reduced_grad
.
data
=
reduced_grad
.
data
.
view
(
-
1
)
if
self
.
gradient_postdivide_factor
>
1
:
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment