Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
56ddc9ca
Unverified
Commit
56ddc9ca
authored
Feb 17, 2023
by
HELSON
Committed by
GitHub
Feb 17, 2023
Browse files
[hotfix] add correct device for fake_param (#2796)
parent
a619a190
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
4 deletions
+3
-4
colossalai/nn/optimizer/zero_optimizer.py
colossalai/nn/optimizer/zero_optimizer.py
+3
-2
tests/test_gemini/update/test_zerooptim_state_dict.py
tests/test_gemini/update/test_zerooptim_state_dict.py
+0
-2
No files found.
colossalai/nn/optimizer/zero_optimizer.py
View file @
56ddc9ca
...
@@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -136,7 +136,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
for
fake_param
in
group
[
'params'
]:
for
fake_param
in
group
[
'params'
]:
assert
fake_param
.
grad
is
None
assert
fake_param
.
grad
is
None
fake_param
.
data
=
none_tensor
fake_param
.
data
=
none_tensor
.
to
(
fake_param
.
device
)
for
chunk16
in
self
.
chunk16_set
:
for
chunk16
in
self
.
chunk16_set
:
chunk16
.
optim_update
()
chunk16
.
optim_update
()
...
@@ -307,7 +307,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -307,7 +307,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
if
range_pair
[
0
]
>=
range_pair
[
1
]:
if
range_pair
[
0
]
>=
range_pair
[
1
]:
continue
continue
fake_param
=
torch
.
nn
.
Parameter
(
torch
.
empty
([
0
]))
grad_device
=
self
.
module
.
grads_device
[
param
]
fake_param
=
torch
.
nn
.
Parameter
(
torch
.
empty
([
0
],
device
=
grad_device
))
self
.
param_to_chunk32
[
fake_param
]
=
chunk16
.
paired_chunk
self
.
param_to_chunk32
[
fake_param
]
=
chunk16
.
paired_chunk
self
.
param_to_range
[
fake_param
]
=
range_pair
self
.
param_to_range
[
fake_param
]
=
range_pair
...
...
tests/test_gemini/update/test_zerooptim_state_dict.py
View file @
56ddc9ca
...
@@ -70,8 +70,6 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
...
@@ -70,8 +70,6 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
for
n
,
m
in
v
.
items
():
for
n
,
m
in
v
.
items
():
if
isinstance
(
m
,
torch
.
Tensor
):
if
isinstance
(
m
,
torch
.
Tensor
):
o
=
w
[
n
]
o
=
w
[
n
]
if
m
.
device
!=
o
.
device
:
o
=
o
.
to
(
m
.
device
)
assert
torch
.
equal
(
m
,
o
)
assert
torch
.
equal
(
m
,
o
)
else
:
else
:
assert
m
==
w
[
n
]
assert
m
==
w
[
n
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment