Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ea6905a8
Commit
ea6905a8
authored
Mar 15, 2022
by
ver217
Browse files
free param.grad
parent
9506a8be
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
2 deletions
+5
-2
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+5
-2
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
ea6905a8
import
functools
import
functools
from
as
t
import
Try
from
as
yncio.log
import
logger
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
...
@@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2
...
@@ -21,7 +21,7 @@ from colossalai.zero.sharded_param import ShardedParamV2
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
from
._zero3_utils
import
(
cast_float_arguments
,
cast_tensor_to_fp16
,
cast_tensor_to_fp32
,
chunk_and_pad
,
free_storage
,
get_gradient_predivide_factor
)
get_gradient_predivide_factor
)
...
@@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module):
...
@@ -218,6 +218,9 @@ class ShardedModelV2(nn.Module):
else
:
else
:
self
.
_reduce_scatter_callback
(
param
,
new_grad
)
self
.
_reduce_scatter_callback
(
param
,
new_grad
)
orig_grad_data
.
record_stream
(
self
.
comm_stream
)
orig_grad_data
.
record_stream
(
self
.
comm_stream
)
empty_grad
=
torch
.
empty_like
(
grad
)
free_storage
(
empty_grad
)
return
empty_grad
def
_reduce_scatter_callback
(
self
,
param
:
Parameter
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
def
_reduce_scatter_callback
(
self
,
param
:
Parameter
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
if
self
.
gradient_postdivide_factor
>
1
:
if
self
.
gradient_postdivide_factor
>
1
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment