Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
632e622d
Unverified
Commit
632e622d
authored
Dec 16, 2021
by
HELSON
Committed by
GitHub
Dec 16, 2021
Browse files
overlap computation and communication in 2d operations (#75)
parent
cd9c28e0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
140 additions
and
45 deletions
+140
-45
colossalai/nn/layer/parallel_2d/_operation.py
colossalai/nn/layer/parallel_2d/_operation.py
+140
-45
No files found.
colossalai/nn/layer/parallel_2d/_operation.py
View file @
632e622d
...
@@ -85,30 +85,57 @@ class Matmul_AB_2D(torch.autograd.Function):
...
@@ -85,30 +85,57 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx
.
save_for_backward
(
A
,
B
)
ctx
.
save_for_backward
(
A
,
B
)
A_shape
=
A
.
shape
A_shape
=
A
.
shape
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
.
contiguous
()
A
=
A
.
reshape
((
-
1
,
A_shape
[
-
1
]))
B_shape
=
B
.
shape
B_shape
=
B
.
shape
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
.
contiguous
()
B
=
B
.
reshape
((
-
1
,
B_shape
[
-
1
]))
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
-
1
])
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
-
1
])
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
zeros
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
A_list
=
[
torch
.
empty_like
(
A
)
for
_
in
range
(
gpc
.
get_world_size
(
row_parallel_mode
)
-
1
)]
# use circular buffer to store the communication tensor
B_list
=
[
torch
.
empty_like
(
B
)
for
_
in
range
(
gpc
.
get_world_size
(
col_parallel_mode
)
-
1
)]
# 2 is enough for all cases
A_list
.
insert
(
gpc
.
get_local_rank
(
row_parallel_mode
),
A
)
A_list
=
[
torch
.
empty_like
(
A
)
for
_
in
range
(
2
)]
B_list
.
insert
(
gpc
.
get_local_rank
(
col_parallel_mode
),
B
)
B_list
=
[
torch
.
empty_like
(
B
)
for
_
in
range
(
2
)]
op_a
=
dist
.
all_gather
(
A_list
,
A
,
group
=
gpc
.
get_group
(
row_parallel_mode
),
async_op
=
True
)
op_a
.
wait
()
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
op_b
=
dist
.
all_gather
(
B_list
,
B
,
group
=
gpc
.
get_group
(
col_parallel_mode
),
async_op
=
True
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
for
op
in
[
op_a
,
op_b
]:
op
.
wait
()
src_a
=
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
src_b
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
opa
=
[
None
]
*
2
opb
=
[
None
]
*
2
A_list
[
0
].
copy_
(
A
)
B_list
[
0
].
copy_
(
B
)
opa
[
0
]
=
dist
.
broadcast
(
A_list
[
0
],
src
=
src_a
,
group
=
row_group
,
async_op
=
True
)
opb
[
0
]
=
dist
.
broadcast
(
B_list
[
0
],
src
=
src_b
,
group
=
col_group
,
async_op
=
True
)
cur
=
0
for
i
in
range
(
summa_dim
):
for
i
in
range
(
summa_dim
):
src_a
=
i
+
summa_dim
*
row_rank
if
i
!=
summa_dim
-
1
:
src_b
=
i
+
summa_dim
*
col_rank
A_list
[
1
-
cur
].
copy_
(
A
)
src_a
=
src_a
%
summa_dim
opa
[
1
-
cur
]
=
dist
.
broadcast
(
A_list
[
1
-
cur
],
src_b
=
src_b
%
summa_dim
src
=
src_a
+
1
,
A_temp
=
A_list
[
src_a
]
group
=
row_group
,
B_temp
=
B_list
[
src_b
]
async_op
=
True
)
torch
.
addmm
(
C
,
A_temp
,
B_temp
,
out
=
C
)
B_list
[
1
-
cur
].
copy_
(
B
)
opb
[
1
-
cur
]
=
dist
.
broadcast
(
B_list
[
1
-
cur
],
src
=
src_b
+
summa_dim
,
group
=
col_group
,
async_op
=
True
)
if
opa
[
cur
]
is
not
None
:
opa
[
cur
].
wait
()
if
opb
[
cur
]
is
not
None
:
opb
[
cur
].
wait
()
torch
.
addmm
(
C
,
A_list
[
cur
],
B_list
[
cur
],
out
=
C
)
cur
=
1
-
cur
src_a
+=
1
src_b
+=
summa_dim
out
=
C
.
reshape
(
out_shape
)
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
if
ctx
:
...
@@ -188,21 +215,55 @@ class Matmul_ABT_2D(torch.autograd.Function):
...
@@ -188,21 +215,55 @@ class Matmul_ABT_2D(torch.autograd.Function):
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
])
C_shape
=
(
A
.
shape
[
0
],
B
.
shape
[
0
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
summa_dim
):
# use circular buffer to store the communication tensor
B_temp
=
B
.
clone
()
# 2 is enough for all cases
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
B_list
=
[
torch
.
empty_like
(
B
)
for
_
in
range
(
2
)]
src_b
=
col_rank
+
summa_dim
*
i
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
C_list
=
[
torch
.
empty_like
(
C
)
for
_
in
range
(
2
)]
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
src_b
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
B_temp
,
src
=
src_b
,
src_c
=
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
group
=
gpc
.
get_group
(
col_parallel_mode
))
C_temp
=
torch
.
matmul
(
A
,
B_temp
.
transpose
(
0
,
1
))
src_c
=
i
+
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
C_temp
,
dst
=
src_c
,
group
=
gpc
.
get_group
(
row_parallel_mode
))
if
i
==
col_rank
:
C
=
C_temp
.
clone
()
opb
=
[
None
]
*
2
opr
=
[
None
]
*
2
B_list
[
0
].
copy_
(
B
)
opb
[
0
]
=
dist
.
broadcast
(
B_list
[
0
],
src
=
src_b
,
group
=
col_group
,
async_op
=
True
)
cur
=
0
for
i
in
range
(
summa_dim
):
if
i
!=
summa_dim
-
1
:
B_list
[
1
-
cur
].
copy_
(
B
)
opb
[
1
-
cur
]
=
dist
.
broadcast
(
B_list
[
1
-
cur
],
src
=
src_b
+
summa_dim
,
group
=
col_group
,
async_op
=
True
)
if
opr
[
cur
]
is
not
None
:
opr
[
cur
].
wait
()
if
i
-
2
==
col_rank
:
C
.
copy_
(
C_list
[
cur
])
if
opb
[
cur
]
is
not
None
:
opb
[
cur
].
wait
()
torch
.
matmul
(
A
,
B_list
[
cur
].
transpose
(
0
,
1
),
out
=
C_list
[
cur
])
opr
[
cur
]
=
dist
.
reduce
(
C_list
[
cur
],
dst
=
src_c
,
group
=
row_group
,
async_op
=
True
)
cur
=
1
-
cur
src_b
+=
summa_dim
src_c
+=
1
for
op
in
opr
:
op
.
wait
()
if
summa_dim
-
2
==
col_rank
:
C
.
copy_
(
C_list
[
cur
])
if
summa_dim
-
1
==
col_rank
:
C
.
copy_
(
C_list
[
1
-
cur
])
out
=
C
.
reshape
(
out_shape
)
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
if
ctx
:
...
@@ -284,21 +345,55 @@ class Matmul_ATB_2D(torch.autograd.Function):
...
@@ -284,21 +345,55 @@ class Matmul_ATB_2D(torch.autograd.Function):
C_shape
=
(
A
.
shape
[
-
1
],
B
.
shape
[
-
1
])
C_shape
=
(
A
.
shape
[
-
1
],
B
.
shape
[
-
1
])
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
C
=
torch
.
empty
(
C_shape
,
dtype
=
A
.
dtype
,
device
=
get_current_device
())
for
i
in
range
(
summa_dim
):
# use circular buffer to store the communication tensor
A_temp
=
A
.
clone
()
# 2 is enough for all cases
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
A_list
=
[
torch
.
empty_like
(
A
)
for
_
in
range
(
2
)]
src_a
=
i
+
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
C_list
=
[
torch
.
empty_like
(
C
)
for
_
in
range
(
2
)]
row_group
=
gpc
.
get_group
(
row_parallel_mode
)
col_group
=
gpc
.
get_group
(
col_parallel_mode
)
src_a
=
summa_dim
*
row_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
A_temp
,
src
=
src_a
,
src_c
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
group
=
gpc
.
get_group
(
row_parallel_mode
))
C_temp
=
torch
.
matmul
(
A_temp
.
transpose
(
0
,
1
),
B
)
src_c
=
col_rank
+
summa_dim
*
i
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
C_temp
,
dst
=
src_c
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
if
i
==
row_rank
:
C
=
C_temp
.
clone
()
opa
=
[
None
]
*
2
opr
=
[
None
]
*
2
A_list
[
0
].
copy_
(
A
)
opa
[
0
]
=
dist
.
broadcast
(
A_list
[
0
],
src
=
src_a
,
group
=
row_group
,
async_op
=
True
)
cur
=
0
for
i
in
range
(
summa_dim
):
if
i
!=
summa_dim
-
1
:
A_list
[
1
-
cur
].
copy_
(
A
)
opa
[
1
-
cur
]
=
dist
.
broadcast
(
A_list
[
1
-
cur
],
src
=
src_a
+
1
,
group
=
row_group
,
async_op
=
True
)
if
opr
[
cur
]
is
not
None
:
opr
[
cur
].
wait
()
if
i
-
2
==
row_rank
:
C
.
copy_
(
C_list
[
cur
])
if
opa
[
cur
]
is
not
None
:
opa
[
cur
].
wait
()
torch
.
matmul
(
A_list
[
cur
].
transpose
(
0
,
1
),
B
,
out
=
C_list
[
cur
])
opr
[
cur
]
=
dist
.
reduce
(
C_list
[
cur
],
dst
=
src_c
,
group
=
col_group
,
async_op
=
True
)
cur
=
1
-
cur
src_a
+=
1
src_c
+=
summa_dim
for
op
in
opr
:
op
.
wait
()
if
summa_dim
-
2
==
row_rank
:
C
.
copy_
(
C_list
[
cur
])
if
summa_dim
-
1
==
row_rank
:
C
.
copy_
(
C_list
[
1
-
cur
])
out
=
C
.
reshape
(
out_shape
)
out
=
C
.
reshape
(
out_shape
)
if
ctx
:
if
ctx
:
...
@@ -374,7 +469,7 @@ class Add_Bias_2D(torch.autograd.Function):
...
@@ -374,7 +469,7 @@ class Add_Bias_2D(torch.autograd.Function):
dtype
=
bias
.
dtype
,
dtype
=
bias
.
dtype
,
device
=
get_current_device
())
device
=
get_current_device
())
src_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
src_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
broadcast
(
bias_temp
,
src
=
src_rank
,
dist
.
broadcast
(
bias_temp
,
src
=
src_rank
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
group
=
gpc
.
get_group
(
col_parallel_mode
))
...
@@ -408,7 +503,7 @@ class Add_Bias_2D(torch.autograd.Function):
...
@@ -408,7 +503,7 @@ class Add_Bias_2D(torch.autograd.Function):
if
ctx
.
bias
:
if
ctx
.
bias
:
dst_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
dst_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
output_grad
,
dst
=
dst_rank
,
dist
.
reduce
(
output_grad
,
dst
=
dst_rank
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
group
=
gpc
.
get_group
(
col_parallel_mode
))
if
row_rank
==
0
:
if
row_rank
==
0
:
...
@@ -421,7 +516,7 @@ class Add_Bias_2D(torch.autograd.Function):
...
@@ -421,7 +516,7 @@ class Add_Bias_2D(torch.autograd.Function):
reduce_dim
=
tuple
(
range
(
output_grad
.
ndim
-
1
))
reduce_dim
=
tuple
(
range
(
output_grad
.
ndim
-
1
))
reduce
=
torch
.
sum
(
output_grad
,
dim
=
reduce_dim
)
reduce
=
torch
.
sum
(
output_grad
,
dim
=
reduce_dim
)
dst_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
dst_rank
=
col_rank
+
data_parallel_rank
*
pipeline_parallel_size
*
tensor_parallel_size
+
\
pipeline_parallel_rank
*
tensor_parallel_size
pipeline_parallel_rank
*
tensor_parallel_size
dist
.
reduce
(
reduce
,
dst
=
dst_rank
,
dist
.
reduce
(
reduce
,
dst
=
dst_rank
,
group
=
gpc
.
get_group
(
col_parallel_mode
))
group
=
gpc
.
get_group
(
col_parallel_mode
))
if
row_rank
==
0
:
if
row_rank
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment