Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0653c63e
Unverified
Commit
0653c63e
authored
Jun 08, 2022
by
Ziyue Jiang
Committed by
GitHub
Jun 08, 2022
Browse files
[Tensor] 1d row embedding (#1075)
* Add CPU 1d row embedding * polish
parent
d66ffb4d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
10 deletions
+12
-10
colossalai/nn/layer/parallel_1d/_utils.py
colossalai/nn/layer/parallel_1d/_utils.py
+4
-2
tests/test_tensor/test_hybrid_device.py
tests/test_tensor/test_hybrid_device.py
+8
-8
No files found.
colossalai/nn/layer/parallel_1d/_utils.py
View file @
0653c63e
...
...
@@ -32,7 +32,8 @@ def _reduce(input_, parallel_mode):
# skip if only one rank involved
if
gpc
.
get_world_size
(
parallel_mode
)
==
1
:
return
input_
dist
.
all_reduce
(
input_
,
group
=
gpc
.
get_group
(
parallel_mode
))
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
input_
.
device
.
type
==
"cpu"
else
gpc
.
get_group
(
parallel_mode
)
dist
.
all_reduce
(
input_
,
group
=
group
)
return
input_
...
...
@@ -66,7 +67,8 @@ def _gather(input_, parallel_mode, dim=-1):
rank
=
gpc
.
get_local_rank
(
parallel_mode
)
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
gpc
.
get_group
(
parallel_mode
))
group
=
gpc
.
get_cpu_group
(
parallel_mode
)
if
input_
.
device
.
type
==
"cpu"
else
gpc
.
get_group
(
parallel_mode
)
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
# concat
output
=
torch
.
cat
(
tensor_list
,
dim
=
dim
).
contiguous
()
...
...
tests/test_tensor/test_hybrid_device.py
View file @
0653c63e
...
...
@@ -35,7 +35,7 @@ class Net(torch.nn.Module):
return
x
def
run_hybrid_device
(
use_ddp
):
def
run_hybrid_device
(
use_ddp
,
mode
):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
Net
()
...
...
@@ -47,7 +47,7 @@ def run_hybrid_device(use_ddp):
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| device:
{
real_model
.
embed
.
weight
.
device
}
'
)
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
'col'
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
# use cpu gloo to handle embedding
real_model
.
embed
.
to
(
'cpu'
)
...
...
@@ -63,24 +63,24 @@ def run_hybrid_device(use_ddp):
out
.
sum
().
backward
()
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
):
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
mode
):
if
use_ddp
and
world_size
==
1
:
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_hybrid_device
(
use_ddp
)
run_hybrid_device
(
use_ddp
,
mode
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'mode'
,
[
'col'
,
'row'
])
@
rerun_if_address_is_in_use
()
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
def
_test_hybrid_device
(
world_size
,
use_ddp
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
)
def
_test_hybrid_device
(
world_size
,
use_ddp
,
mode
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
,
mode
=
mode
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
_test_hybrid_device
(
4
,
True
)
_test_hybrid_device
(
4
,
True
,
'row'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment