Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
986f8cba
Unverified
Commit
986f8cba
authored
Nov 10, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 10, 2022
Browse files
[inference] overlap comm and compute in Linear1D_Row when stream_chunk_num > 1 (#1876)
parent
1b494ad7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
20 additions
and
8 deletions
+20
-8
colossalai/nn/layer/parallel_1d/layers.py
colossalai/nn/layer/parallel_1d/layers.py
+15
-6
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
+5
-2
No files found.
colossalai/nn/layer/parallel_1d/layers.py
View file @
986f8cba
...
@@ -706,13 +706,22 @@ class Linear1D_Row(ParallelLayer):
...
@@ -706,13 +706,22 @@ class Linear1D_Row(ParallelLayer):
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
input_
=
split_forward_gather_backward
(
input_
,
ParallelMode
.
PARALLEL_1D
,
dim
=-
1
)
if
self
.
stream_chunk_num
>
1
:
if
self
.
stream_chunk_num
>
1
:
output_parallel_list
=
[
None
for
i
in
range
(
self
.
stream_chunk_num
)]
if
self
.
training
:
for
i
in
range
(
self
.
stream_chunk_num
):
raise
RuntimeError
(
"use stream_chunk_num=1 in Linear1D_Row for training!"
)
output_parallel_list
[
i
]
=
F
.
linear
(
input_
,
self
.
weight_list
[
i
])
with
torch
.
no_grad
():
output_parallel_list
[
i
]
=
reduce_input
(
output_parallel_list
[
i
],
ParallelMode
.
PARALLEL_1D
)
output_parallel_list
=
[
None
for
i
in
range
(
self
.
stream_chunk_num
)]
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
handle_list
=
[]
for
i
in
range
(
self
.
stream_chunk_num
):
output_parallel_list
[
i
]
=
F
.
linear
(
input_
,
self
.
weight_list
[
i
])
handle
=
torch
.
distributed
.
all_reduce
(
output_parallel_list
[
i
],
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
async_op
=
True
)
handle_list
.
append
(
handle
)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for
handle
in
handle_list
:
handle
.
wait
()
output
=
torch
.
cat
(
output_parallel_list
,
dim
=-
1
)
else
:
else
:
print
(
input_
.
shape
,
self
.
weight
.
shape
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
output_parallel
=
F
.
linear
(
input_
,
self
.
weight
)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
output_parallel
,
ParallelMode
.
PARALLEL_1D
)
...
...
tests/test_layers/test_1d/checks_1d/check_layer_1d.py
View file @
986f8cba
...
@@ -514,8 +514,9 @@ def check_linear_row_stream_inference():
...
@@ -514,8 +514,9 @@ def check_linear_row_stream_inference():
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
i
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
assert
HIDDEN_SIZE
%
2
==
0
stream_chunk_num
=
4
layer
=
Linear1D_Row
(
OUTPUT_SIZE
,
INPUT_SIZE
,
stream_chunk_num
=
2
)
assert
HIDDEN_SIZE
%
stream_chunk_num
==
0
layer
=
Linear1D_Row
(
OUTPUT_SIZE
,
INPUT_SIZE
,
stream_chunk_num
=
stream_chunk_num
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
OUTPUT_SIZE
)
A_shape
=
(
BATCH_SIZE
,
SEQ_LENGTH
,
OUTPUT_SIZE
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
A_master
=
torch
.
randn
(
A_shape
,
dtype
=
dtype
,
device
=
device
)
...
@@ -537,6 +538,8 @@ def check_linear_row_stream_inference():
...
@@ -537,6 +538,8 @@ def check_linear_row_stream_inference():
layer
.
weight
=
Parameter
(
W
)
layer
.
weight
=
Parameter
(
W
)
layer
.
bias
=
Parameter
(
B
)
layer
.
bias
=
Parameter
(
B
)
layer
.
chunk_weight
()
layer
.
chunk_weight
()
layer
.
eval
()
out
=
layer
(
A
)
out
=
layer
(
A
)
A_master
=
A_master
.
clone
()
A_master
=
A_master
.
clone
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment