Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a93a7d73
Unverified
Commit
a93a7d73
authored
Apr 14, 2022
by
ver217
Committed by
GitHub
Apr 14, 2022
Browse files
[hotfix] fix reuse_fp16_shard of sharded model (#756)
* fix reuse_fp16_shard * disable test stm * polish code
parent
8f7ce94b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
7 deletions
+3
-7
colossalai/zero/sharded_model/sharded_model_v2.py
colossalai/zero/sharded_model/sharded_model_v2.py
+2
-5
colossalai/zero/sharded_param/sharded_param.py
colossalai/zero/sharded_param/sharded_param.py
+0
-1
tests/test_zero/test_stateful_tensor_mgr.py
tests/test_zero/test_stateful_tensor_mgr.py
+1
-1
No files found.
colossalai/zero/sharded_model/sharded_model_v2.py
View file @
a93a7d73
...
...
@@ -253,9 +253,6 @@ class ShardedModelV2(nn.Module):
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
self
.
reducer
.
flush
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
comm_stream
)
if
self
.
_cpu_offload
:
# Wait for the non-blocking GPU -> CPU grad transfers to finish.
torch
.
cuda
.
current_stream
().
synchronize
()
self
.
reducer
.
free
()
# 3. shard tensors not dealed in the zero hook
...
...
@@ -338,7 +335,7 @@ class ShardedModelV2(nn.Module):
def
_reduce_scatter_callback
(
self
,
param
:
Parameter
,
reduced_grad
:
torch
.
Tensor
)
->
None
:
assert
isinstance
(
reduced_grad
,
torch
.
Tensor
),
f
"_reduce_scatter_callback accept reduced_grad as
{
type
(
reduced_grad
)
}
"
reduced_grad
.
data
=
reduced_grad
.
data
.
view
(
-
1
)
reduced_grad
.
data
=
reduced_grad
.
data
.
contiguous
().
view
(
-
1
)
if
self
.
gradient_postdivide_factor
>
1
:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad
.
data
.
div_
(
self
.
gradient_postdivide_factor
)
...
...
@@ -362,7 +359,7 @@ class ShardedModelV2(nn.Module):
),
'Gradien accumulation is not supported when reuse_fp16_shard=True'
param
.
colo_attr
.
reset_grad_payload
(
grad
)
param
.
colo_attr
.
reset_
grad
_payload
(
grad
)
# release the memory of param
param
.
colo_attr
.
reset_
data
_payload
(
grad
)
# release the memory of param
if
param
.
colo_attr
.
is_replicated
:
param
.
colo_attr
.
sharded_data_tensor
.
is_sharded
=
True
...
...
colossalai/zero/sharded_param/sharded_param.py
View file @
a93a7d73
...
...
@@ -70,7 +70,6 @@ class ShardedParamV2(object):
assert
type
(
tensor
)
is
torch
.
Tensor
assert
tensor
.
requires_grad
is
False
self
.
sharded_data_tensor
.
reset_payload
(
tensor
)
self
.
set_data_none
()
def
reset_grad_payload
(
self
,
tensor
:
torch
.
Tensor
):
assert
type
(
tensor
)
is
torch
.
Tensor
...
...
tests/test_zero/test_stateful_tensor_mgr.py
View file @
a93a7d73
...
...
@@ -112,7 +112,7 @@ def run_dist(rank, world_size, port):
run_stm
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
@
rerun_on_exception
(
exception_type
=
mp
.
ProcessRaisedException
,
pattern
=
".*Address already in use.*"
)
def
test_stateful_tensor_manager
(
world_size
=
1
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment