Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
632e622d
Unverified
Commit
632e622d
authored
Dec 16, 2021
by
HELSON
Committed by
GitHub
Dec 16, 2021
Browse files
overlap computation and communication in 2d operations (#75)
parent
cd9c28e0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
140 additions
and
45 deletions
+140
-45
colossalai/nn/layer/parallel_2d/_operation.py
colossalai/nn/layer/parallel_2d/_operation.py
+140
-45
No files found.
colossalai/nn/layer/parallel_2d/_operation.py
View file @
632e622d
...
...
@@ -85,30 +85,57 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx
.
save_for_backward
(
A
,
B
)
A_shape
=
A
.
shape
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
.
contiguous
()
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
.
contiguous
()
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
-
1
])
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
A_list
=
[
torch
.
empty_like
(
A
)
for
_
in
range
(
gpc
.
get_world_size
(
row_parallel_mode
)
-
1
)]
B_list
=
[
torch
.
empty_like
(
B
)
for
_
in
range
(
gpc
.
get_world_size
(
col_parallel_mode
)
-
1
)]
A_list
.
insert
(
gpc
.
get_local_rank
(
row_parallel_mode
),
A
)
B_list
.
insert
(
gpc
.
get_local_rank
(
col_parallel_mode
),
B
)
op_a
=
dist
.
all_gather
(
A_list
,
A
,
group
=
gpc
.
get_group
(
row_parallel_mode
),
async_op
=
True
)
op_a
.
wait
()
op_b
=
dist
.
all_gather
(
B_list
,
B
,
group
=
gpc
.
get_group
(
col_parallel_mode
),
async_op
=
True
)
for
op
in
[
op_a
,
op_b
]:
op
.
wait
()
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list
=
[
torch
.
empty_like
(
A
)
for
_
in
range
(
2
)]
B_list
=
[
torch
.
empty_like
(
B
)
for
_
in
range
(
2
)]
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
src_a
=
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
src_b
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
opa
=
[
None
]
*
2
opb
=
[
None
]
*
2
A_list
[
0
].
copy_
(
A
)
B_list
[
0
].
copy_
(
B
)
opa
[
0
]
=
dist
.
broadcast
(
A_list
[
0
],
src
=
src_a
,
group
=
row_group
,
async_op
=
True
)
opb
[
0
]
=
dist
.
broadcast
(
B_list
[
0
],
src
=
src_b
,
group
=
col_group
,
async_op
=
True
)
cur
=
0
for
i
in
range
(
summa_dim
):
src_a
=
i
+
summa_dim
*
row_rank
src_b
=
i
+
summa_dim
*
col_rank
src_a
=
src_a
%
summa_dim
src_b
=
src_b
%
summa_dim
A_temp
=
A_list
[
src_a
]
B_temp
=
B_list
[
src_b
]
torch
.
addmm
(
C
,
A_temp
,
B_temp
,
out
=
C
)
if
i
!=
summa_dim
-
1
:
A_list
[
1
-
cur
].
copy_
(
A
)
opa
[
1
-
cur
]
=
dist
.
broadcast
(
A_list
[
1
-
cur
],
src
=
src_a
+
1
,
group
=
row_group
,
async_op
=
True
)
B_list
[
1
-
cur
].
copy_
(
B
)
opb
[
1
-
cur
]
=
dist
.
broadcast
(
B_list
[
1
-
cur
],
src
=
src_b
+
summa_dim
,
group
=
col_group
,
async_op
=
True
)
if
opa
[
cur
]
is
not
None
:
opa
[
cur
].
wait
()
if
opb
[
cur
]
is
not
None
:
opb
[
cur
].
wait
()
torch
.
addmm
(
C
,
A_list
[
cur
],
B_list
[
cur
],
out
=
C
)
cur
=
1
-
cur
src_a
+=
1
src_b
+=
summa_dim
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
...
...
@@ -188,21 +215,55 @@ class Matmul_ABT_2D(torch.autograd.Function):
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
summa_dim
):
B_temp
=
B
.
clone
()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_b
=
col_rank
+
summa_dim
*
i
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
# use circular buffer to store the communication tensor
# 2 is enough for all cases
B_list
=
[
torch
.
empty_like
(
B
)
for
_
in
range
(
2
)]
C_list
=
[
torch
.
empty_like
(
C
)
for
_
in
range
(
2
)]
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
src_b
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
B_temp
,
src
=
src_b
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
C_temp
=
torch
.
matmul
(
A
,
B_temp
.
transpose
(
0
,
1
))
src_c
=
i
+
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_c
=
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
C_temp
,
dst
=
src_c
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
if
i
==
col_rank
:
C
=
C_temp
.
clone
()
opb
=
[
None
]
*
2
opr
=
[
None
]
*
2
B_list
[
0
].
copy_
(
B
)
opb
[
0
]
=
dist
.
broadcast
(
B_list
[
0
],
src
=
src_b
,
group
=
col_group
,
async_op
=
True
)
cur
=
0
for
i
in
range
(
summa_dim
):
if
i
!=
summa_dim
-
1
:
B_list
[
1
-
cur
].
copy_
(
B
)
opb
[
1
-
cur
]
=
dist
.
broadcast
(
B_list
[
1
-
cur
],
src
=
src_b
+
summa_dim
,
group
=
col_group
,
async_op
=
True
)
if
opr
[
cur
]
is
not
None
:
opr
[
cur
].
wait
()
if
i
-
2
==
col_rank
:
C
.
copy_
(
C_list
[
cur
])
if
opb
[
cur
]
is
not
None
:
opb
[
cur
].
wait
()
torch
.
matmul
(
A
,
B_list
[
cur
].
transpose
(
0
,
1
),
out
=
C_list
[
cur
])
opr
[
cur
]
=
dist
.
reduce
(
C_list
[
cur
],
dst
=
src_c
,
group
=
row_group
,
async_op
=
True
)
cur
=
1
-
cur
src_b
+=
summa_dim
src_c
+=
1
for
op
in
opr
:
op
.
wait
()
if
summa_dim
-
2
==
col_rank
:
C
.
copy_
(
C_list
[
cur
])
if
summa_dim
-
1
==
col_rank
:
C
.
copy_
(
C_list
[
1
-
cur
])
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
...
...
@@ -284,21 +345,55 @@ class Matmul_ATB_2D(torch.autograd.Function):
C_shape
=
(
A
.
shape
[
-
1
],
B
.
shape
[
-
1
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
summa_dim
):
A_temp
=
A
.
clone
()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_a
=
i
+
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list
=
[
torch
.
empty_like
(
A
)
for
_
in
range
(
2
)]
C_list
=
[
torch
.
empty_like
(
C
)
for
_
in
range
(
2
)]
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
src_a
=
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
A_temp
,
src
=
src_a
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
C_temp
=
torch
.
matmul
(
A_temp
.
transpose
(
0
,
1
),
B
)
src_c
=
col_rank
+
summa_dim
*
i
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_c
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
C_temp
,
dst
=
src_c
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
if
i
==
row_rank
:
C
=
C_temp
.
clone
()
opa
=
[
None
]
*
2
opr
=
[
None
]
*
2
A_list
[
0
].
copy_
(
A
)
opa
[
0
]
=
dist
.
broadcast
(
A_list
[
0
],
src
=
src_a
,
group
=
row_group
,
async_op
=
True
)
cur
=
0
for
i
in
range
(
summa_dim
):
if
i
!=
summa_dim
-
1
:
A_list
[
1
-
cur
].
copy_
(
A
)
opa
[
1
-
cur
]
=
dist
.
broadcast
(
A_list
[
1
-
cur
],
src
=
src_a
+
1
,
group
=
row_group
,
async_op
=
True
)
if
opr
[
cur
]
is
not
None
:
opr
[
cur
].
wait
()
if
i
-
2
==
row_rank
:
C
.
copy_
(
C_list
[
cur
])
if
opa
[
cur
]
is
not
None
:
opa
[
cur
].
wait
()
torch
.
matmul
(
A_list
[
cur
].
transpose
(
0
,
1
),
B
,
out
=
C_list
[
cur
])
opr
[
cur
]
=
dist
.
reduce
(
C_list
[
cur
],
dst
=
src_c
,
group
=
col_group
,
async_op
=
True
)
cur
=
1
-
cur
src_a
+=
1
src_c
+=
summa_dim
for
op
in
opr
:
op
.
wait
()
if
summa_dim
-
2
==
row_rank
:
C
.
copy_
(
C_list
[
cur
])
if
summa_dim
-
1
==
row_rank
:
C
.
copy_
(
C_list
[
1
-
cur
])
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment