Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4fc748f6
Unverified
Commit
4fc748f6
authored
Jun 06, 2022
by
Ziyue Jiang
Committed by
GitHub
Jun 06, 2022
Browse files
[Tensor] fix optimizer for CPU parallel (#1069)
parent
49832b23
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
12 deletions
+16
-12
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+12
-10
tests/test_tensor/test_hybrid_device.py
tests/test_tensor/test_hybrid_device.py
+4
-2
No files found.
colossalai/nn/parallel/data_parallel.py
View file @
4fc748f6
...
...
@@ -42,14 +42,15 @@ class ColoDDP(torch.nn.Module):
loss
.
backward
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
comm_stream
)
for
p
in
self
.
module
.
parameters
():
p
.
grad
=
p
.
_saved_grad
if
p
.
grad
.
device
.
type
!=
"cpu"
:
p
.
grad
=
p
.
_saved_grad
def
grad_handle
(
self
,
p
,
grad
):
empty_grad
=
torch
.
empty_like
(
grad
)
free_storage
(
empty_grad
)
if
self
.
dp_world_size
>
1
:
grad
=
grad
/
self
.
dp_world_size
if
grad
.
device
.
type
!=
"cpu"
:
if
grad
.
device
.
type
!=
"cpu"
:
empty_grad
=
torch
.
empty_
like
(
grad
)
free_storage
(
empty_grad
)
if
self
.
dp_world_size
>
1
:
grad
=
grad
/
self
.
dp_world_size
self
.
comm_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
comm_stream
):
group
=
gpc
.
get_group
(
ParallelMode
.
DATA
)
...
...
@@ -57,12 +58,13 @@ class ColoDDP(torch.nn.Module):
ColoDDP
.
_save_grad
(
p
,
grad
)
grad
.
record_stream
(
self
.
comm_stream
)
else
:
group
=
gpc
.
get_cpu_group
(
ParallelMode
.
DATA
)
dist
.
all_reduce
(
grad
,
group
=
group
)
ColoDDP
.
_save_grad
(
p
,
grad
)
return
empty_grad
else
:
ColoDDP
.
_save_grad
(
p
,
grad
)
return
empty_grad
group
=
gpc
.
get_cpu_group
(
ParallelMode
.
DATA
)
dist
.
all_reduce
(
grad
,
group
=
group
)
return
grad
@
staticmethod
def
_save_grad
(
p
,
grad
):
...
...
tests/test_tensor/test_hybrid_device.py
View file @
4fc748f6
...
...
@@ -9,6 +9,7 @@ from colossalai.context import ParallelMode
from
colossalai.nn.parallel.layers
import
init_colo_module
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.nn.optimizer
import
ColoOptimizer
import
colossalai
import
torch
...
...
@@ -56,10 +57,11 @@ def run_hybrid_device(use_ddp):
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| new device:
{
real_model
.
embed
.
weight
.
device
}
'
)
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
data
=
torch
.
randint
(
low
=
0
,
high
=
20
,
size
=
(
16
,),
device
=
get_current_device
())
out
=
model
(
data
)
out
.
sum
().
backward
()
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
):
if
use_ddp
and
world_size
==
1
:
...
...
@@ -81,4 +83,4 @@ def _test_hybrid_device(world_size, use_ddp):
if
__name__
==
'__main__'
:
_test_hybrid_device
(
1
,
Fals
e
)
_test_hybrid_device
(
4
,
Tru
e
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment