Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c2fdc6a0
Unverified
Commit
c2fdc6a0
authored
May 16, 2022
by
ver217
Committed by
GitHub
May 16, 2022
Browse files
[tensor] derive compute pattern from dist spec (#971)
* derive compute pattern from dist spec * polish code
parent
46bc9570
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
79 additions
and
65 deletions
+79
-65
colossalai/tensor/_ops/addmm.py
colossalai/tensor/_ops/addmm.py
+8
-8
colossalai/tensor/_ops/embedding.py
colossalai/tensor/_ops/embedding.py
+7
-6
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+12
-10
colossalai/tensor/_ops/loss.py
colossalai/tensor/_ops/loss.py
+7
-6
colossalai/tensor/dist_spec.py
colossalai/tensor/dist_spec.py
+2
-0
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+16
-17
tests/test_tensor/test_addmm_tp.py
tests/test_tensor/test_addmm_tp.py
+2
-2
tests/test_tensor/test_embedding_tp.py
tests/test_tensor/test_embedding_tp.py
+2
-2
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+2
-2
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+21
-12
No files found.
colossalai/tensor/_ops/addmm.py
View file @
c2fdc6a0
...
...
@@ -11,7 +11,7 @@ from colossalai.tensor import dist_spec
def
colo_addmm_1Drow
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Union
[
int
,
float
],
alpha
:
Union
[
int
,
float
])
->
ColoTensor
:
parallel_action
=
mat2
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
Row
)
parallel_action
=
mat2
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
...
...
@@ -32,7 +32,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
def
colo_addmm_1Dcol
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Union
[
int
,
float
],
alpha
:
Union
[
int
,
float
])
->
ColoTensor
:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action
=
mat2
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
Col
)
parallel_action
=
mat2
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
mat1
.
to_dist_spec
(
dist_spec
.
replicate
(
mat2
.
spec
.
get_process_group
()))
mat1_torch_tensor
=
reduce_grad
(
mat1
.
torch_tensor
(),
parallel_action
.
parallel_mode
)
...
...
@@ -71,16 +71,16 @@ def colo_addmm(types, args, kwargs, pg):
# Add communication logic before and after linear call.
ret_tensor
=
None
if
not
mat2
.
has_spec
():
# No Model Parallel Applied
assert
not
input_tensor
.
has_spec
(),
'Invalid input spec for native addmm op'
assert
mat2
.
spec
.
is_gathered
(),
'Invalid mat2 spec for native addmm op'
assert
input_tensor
.
spec
.
is_gathered
(),
'Invalid input spec for native addmm op'
ret_tensor
=
ColoTensor
.
init_from_torch_tensor
(
torch
.
add
b
mm
(
input_tensor
.
torch_tensor
(),
mat1
,
mat2
.
torch_tensor
(),
beta
=
beta
,
alpha
=
alpha
))
elif
mat2
.
spec
.
num_action
==
1
:
# Single Model Parallel Applied
torch
.
addmm
(
input_tensor
.
torch_tensor
(),
mat1
,
mat2
.
torch_tensor
(),
beta
=
beta
,
alpha
=
alpha
))
elif
mat2
.
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
)
:
# Single Model Parallel Applied
spec
=
TensorSpec
(
dist_spec
.
replicate
(
mat2
.
spec
.
get_process_group
()))
mat1
=
args
[
1
]
if
isinstance
(
args
[
1
],
ColoTensor
)
else
ColoTensor
.
init_from_torch_tensor
(
args
[
1
],
spec
=
spec
)
compute_patterns
=
mat2
.
spec
.
compute_patterns
if
ComputePattern
.
TP1DRow
in
compute_patterns
:
if
mat2
.
spec
.
is_1D_row
()
and
input_tensor
.
spec
.
is_gathered
():
ret_tensor
=
colo_addmm_1Drow
(
input_tensor
,
mat1
,
mat2
,
beta
,
alpha
)
elif
ComputePattern
.
TP1DCol
in
compute_patterns
:
elif
mat2
.
spec
.
is_1D_col
()
and
(
input_tensor
.
spec
.
is_1D_col
()
or
input_tensor
.
spec
.
is_1D_row
())
:
ret_tensor
=
colo_addmm_1Dcol
(
input_tensor
,
mat1
,
mat2
,
beta
,
alpha
)
else
:
raise
NotImplementedError
...
...
colossalai/tensor/_ops/embedding.py
View file @
c2fdc6a0
...
...
@@ -12,7 +12,7 @@ from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, Parall
def
colo_embedding_1Dcol
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
args
,
kwargs
)
->
ColoTensor
:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
Col
)
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
input_tensor
.
to_dist_spec
(
dist_spec
.
replicate
(
weight
.
spec
.
get_process_group
()))
output_parallel
=
torch
.
nn
.
functional
.
embedding
(
input_tensor
.
torch_tensor
(),
weight
.
torch_tensor
(),
*
args
,
**
kwargs
)
...
...
@@ -28,7 +28,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, weight: ColoTensor, args, kwa
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
Row
)
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
input_tensor
.
to_dist_spec
(
dist_spec
.
replicate
(
weight
.
spec
.
get_process_group
()))
tensor_parallel_rank
=
gpc
.
get_local_rank
(
parallel_action
.
parallel_mode
)
...
...
@@ -71,16 +71,17 @@ def colo_embedding(types, args, kwargs, pg):
weight
=
ColoTensor
.
init_from_torch_tensor
(
weight
)
# Handle differen parallel actions.
if
not
weight
.
has_spec
():
# No Model Parallel Applied
assert
weight
.
spec
.
is_gathered
(),
'Invalid weight spec for native embedding op'
input_tensor
=
input_tensor
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
output
=
torch
.
nn
.
functional
.
embedding
(
input_tensor
,
weight
,
*
args
,
**
kwargs
)
return
ColoTensor
.
init_from_torch_tensor
(
output
)
elif
weight
.
spec
.
num_action
==
1
:
# Single Model Parallel Applied
compute_patterns
=
weight
.
spec
.
compute_patterns
if
ComputePattern
.
TP1DRow
in
compute_patterns
:
elif
weight
.
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
spec
.
is_1D_row
():
return
colo_embedding_1Drow
(
input_tensor
,
weight
,
args
,
kwargs
)
elif
ComputePattern
.
TP1DCol
in
compute_patterns
:
elif
weight
.
spec
.
is_1D_col
()
:
return
colo_embedding_1Dcol
(
input_tensor
,
weight
,
args
,
kwargs
)
else
:
raise
NotImplementedError
...
...
colossalai/tensor/_ops/linear.py
View file @
c2fdc6a0
...
...
@@ -9,7 +9,7 @@ from colossalai.tensor.graph import GraphOpNode, GraphGlobalEnv
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
Row
)
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
...
...
@@ -33,11 +33,12 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
Col
)
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
input_tensor
.
to_dist_spec
(
dist_spec
.
replicate
(
weight
.
spec
.
get_process_group
()))
input_parallel
=
reduce_grad
(
input_tensor
.
torch_tensor
(),
parallel_action
.
parallel_mode
)
output_parallel
=
torch
.
nn
.
functional
.
linear
(
input_parallel
,
weight
.
torch_tensor
(),
bias
.
torch_tensor
())
if
bias
is
not
None
:
bias
=
bias
.
torch_tensor
()
output_parallel
=
torch
.
nn
.
functional
.
linear
(
input_parallel
,
weight
.
torch_tensor
(),
bias
)
output
=
ColoTensor
.
init_from_torch_tensor
(
output_parallel
,
...
...
@@ -83,16 +84,17 @@ def colo_linear(types, args, kwargs, pg):
# Add communication logic before and after linear call.
ret_tensor
=
None
if
not
weight
.
has_spec
():
# No Model Parallel Applied
assert
not
bias
.
has_spec
(),
'Invalid bias spec for native Linear op'
assert
bias
.
spec
.
is_gathered
(),
'Invalid bias spec for native Linear op'
assert
bias
.
spec
.
is_gathered
(),
'Invalid bias spec for native Linear op'
input_tensor
=
input_tensor
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
if
bias
is
not
None
:
bias
=
bias
.
torch_tensor
()
ret_tensor
=
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
))
elif
weight
.
spec
.
num_action
==
1
:
# Single Model Parallel Applied
compute_patterns
=
weight
.
spec
.
compute_patterns
if
ComputePattern
.
TP1DRow
in
compute_patterns
:
elif
weight
.
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
spec
.
is_1D_col
()
and
(
bias
is
None
or
bias
.
spec
.
is_gathered
()):
ret_tensor
=
colo_linear_1Drow
(
input_tensor
,
weight
,
bias
)
elif
ComputePattern
.
TP1DCol
in
compute_patterns
:
elif
weight
.
spec
.
is_1D_row
()
and
(
bias
is
None
or
bias
.
spec
.
is_1D_row
()
or
bias
.
spec
.
is_1D_col
())
:
ret_tensor
=
colo_linear_1Dcol
(
input_tensor
,
weight
,
bias
)
else
:
raise
NotImplementedError
...
...
colossalai/tensor/_ops/loss.py
View file @
c2fdc6a0
...
...
@@ -4,6 +4,7 @@ from colossalai.tensor.op_wrapper import colo_op_impl
from
colossalai.tensor
import
ColoTensor
from
colossalai.nn.loss.loss_1d
import
VocabParallelCrossEntropyLoss1D
@
colo_op_impl
(
torch
.
nn
.
functional
.
cross_entropy
)
def
colo_cross_entropy
(
types
,
args
=
(),
kwargs
=
None
,
pg
=
None
):
arg_num
=
len
(
args
)
...
...
@@ -28,12 +29,12 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
target
=
target
.
torch_tensor
()
if
input_tensor
.
spec
.
is_gathered
():
# Input is gathered
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
cross_entropy
(
input_tensor
.
torch_tensor
(),
target
,
weight
))
elif
input_tensor
.
has_spec
()
and
input_tensor
.
spec
.
num_action
==
1
:
# Single Model Parallel Applied
if
input_tensor
.
spec
.
is_1Dcol
():
return
ColoTensor
.
init_from_torch_tensor
(
VocabParallelCrossEntropyLoss1D
()(
input_tensor
.
torch_tensor
(),
target
))
torch
.
nn
.
functional
.
cross_entropy
(
input_tensor
.
torch_tensor
(),
target
,
weight
))
elif
input_tensor
.
has_spec
()
and
input_tensor
.
spec
.
num_action
==
1
:
# Single Model Parallel Applied
if
input_tensor
.
spec
.
is_1D_col
():
return
ColoTensor
.
init_from_torch_tensor
(
VocabParallelCrossEntropyLoss1D
()(
input_tensor
.
torch_tensor
(),
target
))
else
:
raise
NotImplementedError
else
:
...
...
colossalai/tensor/dist_spec.py
View file @
c2fdc6a0
from
enum
import
Enum
from
torch.distributed
import
ProcessGroup
from
typing
import
Optional
,
List
from
numpy
import
prod
__all__
=
[
'replicate'
,
'shard'
]
...
...
@@ -39,4 +40,5 @@ def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int
assert
process_group
is
not
None
assert
isinstance
(
dims
,
list
)
and
isinstance
(
num_partitions
,
list
)
assert
len
(
dims
)
==
len
(
num_partitions
)
assert
prod
(
num_partitions
)
==
process_group
.
size
()
return
_DistSpec
(
DistPlacementPattern
.
SHARD
,
process_group
,
dims
=
tuple
(
dims
),
num_partitions
=
tuple
(
num_partitions
))
colossalai/tensor/spec.py
View file @
c2fdc6a0
...
...
@@ -5,17 +5,9 @@ from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
class
ComputePattern
(
Enum
):
# TODO (ver217): remove TP1DRow_<ops>
TP1DRow
=
0
TP1DCol
=
9
TP1DRow_Linear
=
1
TP1DCol_Linear
=
2
TP1DRow_Embedding
=
3
TP1DCol_Embedding
=
4
TP1DRow_mm
=
5
TP1DCol_mm
=
6
ZeRO
=
7
DP
=
8
TP1D
=
0
ZeRO
=
1
DP
=
2
class
ParallelAction
(
object
):
...
...
@@ -45,14 +37,14 @@ class TensorSpec(object):
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# parallel_action_list = [
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
# ParallelAction(1, ComputePattern.TP1D
Row
_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ParallelAction(1, ComputePattern.TP1D_Linear, gpc.get_group(ParallelMode.PARALLEL_1D))
# ]
# When the ColoTensor is initialized,
# we first splitting tensor according to ParallelAction of ZeRO,
# then splitting tensor according to ParallelAction of TP1D
Row
_Linear.
# then splitting tensor according to ParallelAction of TP1D_Linear.
# During Linear computation
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1D
Row
_Linear.
# We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO.
def
__init__
(
self
,
dist_spec
:
_DistSpec
,
parallel_action_list
:
List
[
ParallelAction
]
=
[]):
...
...
@@ -94,6 +86,13 @@ class TensorSpec(object):
and
self
.
dist_spec
.
num_partitions
[
0
]
==
1
)
\
or
(
self
.
dist_spec
.
process_group
.
size
()
==
1
)
def
is_1Dcol
(
self
):
def
is_1D
_
col
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
-
1
def
is_1D_row
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
0
def
has_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
return
self
.
get_action_by_compute_pattern
(
compute_pattern
)
is
not
None
tests/test_tensor/test_addmm_tp.py
View file @
c2fdc6a0
...
...
@@ -40,7 +40,7 @@ class Conv1D(nn.Module):
def
init_1d_row
(
weight
,
bias
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Row
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -55,7 +55,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
def
init_1d_col
(
weight
,
bias
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Col
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
bias
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_embedding_tp.py
View file @
c2fdc6a0
...
...
@@ -17,7 +17,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
def
init_1d_row
(
weight
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Row
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -31,7 +31,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight):
def
init_1d_col
(
weight
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Col
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_linear_tp.py
View file @
c2fdc6a0
...
...
@@ -18,7 +18,7 @@ from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, dist_s
def
init_1d_row
(
weight
,
bias
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Row
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -33,7 +33,7 @@ def check_grad_1d_row(model: torch.nn.Module, weight, bias):
def
init_1d_col
(
weight
,
bias
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Col
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
bias
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_model.py
View file @
c2fdc6a0
...
...
@@ -86,35 +86,43 @@ def set_seed(seed):
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
def
init_1d_row_linear
(
weight
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Row
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
init_1d_col_linear
(
weight
,
gather_out
=
True
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
,
\
gather_out
=
gather_out
)])
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
,
gather_out
=
gather_out
)
])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
init_1d_row_embedding
(
weight
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Row
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
init_1d_col_embedding
(
weight
):
spec
=
TensorSpec
(
dist_spec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
Col
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)])
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
run_1d_hybrid_tp
(
model_name
):
# A simple net with two stacked nn.Linear
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
...
...
@@ -376,7 +384,7 @@ def _run_pretrain_load():
if
isinstance
(
param
,
ColoParameter
):
c1
+=
1
else
:
c2
+=
1
c2
+=
1
dict_col
[
name
]
=
param
assert
c_ref
==
c1
assert
c2
==
0
...
...
@@ -395,6 +403,7 @@ def run_model_dist(rank, world_size, port):
for
name
in
[
'bert'
,
'simple_net'
]:
run_1d_hybrid_tp
(
name
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
# @parameterize('world_size', [1, 4])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment