Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a3b66f6d
Unverified
Commit
a3b66f6d
authored
May 20, 2022
by
ver217
Committed by
GitHub
May 20, 2022
Browse files
[tensor] refactor parallel action (#1007)
* refactor parallel action * polish unit tests
parent
9e3d602d
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
45 additions
and
77 deletions
+45
-77
colossalai/tensor/_ops/addmm.py
colossalai/tensor/_ops/addmm.py
+5
-5
colossalai/tensor/_ops/embedding.py
colossalai/tensor/_ops/embedding.py
+4
-6
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+9
-10
colossalai/tensor/_ops/loss.py
colossalai/tensor/_ops/loss.py
+1
-1
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+3
-0
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+1
-1
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+9
-37
tests/test_tensor/test_addmm_tp.py
tests/test_tensor/test_addmm_tp.py
+2
-2
tests/test_tensor/test_embedding_tp.py
tests/test_tensor/test_embedding_tp.py
+2
-2
tests/test_tensor/test_gpt.py
tests/test_tensor/test_gpt.py
+2
-2
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+2
-2
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+5
-9
No files found.
colossalai/tensor/_ops/addmm.py
View file @
a3b66f6d
...
...
@@ -3,12 +3,12 @@ from colossalai.tensor.op_wrapper import colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
from
colossalai.tensor
import
distspec
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
def
colo_addmm_1Drow
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
alpha
:
Number
)
->
ColoTensor
:
parallel_action
=
mat2
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
...
...
@@ -18,7 +18,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# Output:P
partial_output
=
torch
.
mm
(
mat1
,
mat2
)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
p
arallel
_action
.
parallel_mode
)
output
=
reduce_input
(
partial_output
,
P
arallel
Mode
.
PARALLEL_1D
)
# input
assert
not
input_tensor
.
has_spec
(),
'Invalid input spec for 1Drow addmm op'
output
=
beta
*
input_tensor
+
alpha
*
output
...
...
@@ -29,13 +29,13 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
def
colo_addmm_1Dcol
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
alpha
:
Number
)
->
ColoTensor
:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action
=
mat2
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
parallel_action
=
mat2
.
spec
.
parallel_action
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
spec
.
get_process_group
()))
mat1
=
reduce_grad
(
mat1
,
p
arallel
_action
.
parallel_mode
)
mat1
=
reduce_grad
(
mat1
,
P
arallel
Mode
.
PARALLEL_1D
)
output_parallel
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
spec
.
get_process_group_size
()]),
[
ParallelAction
(
priority
=
1
,
parallel_mode
=
parallel_action
.
parallel_mode
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
if
parallel_action
.
gather_out
:
# All-Gather(Output)
...
...
colossalai/tensor/_ops/embedding.py
View file @
a3b66f6d
import
torch
import
torch.nn.functional
as
F
from
typing
import
Optional
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
distspec
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
...
...
@@ -17,7 +17,6 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse
:
bool
=
False
)
->
ColoTensor
:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
output_parallel
=
F
.
embedding
(
input_tensor
,
...
...
@@ -29,7 +28,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse
=
sparse
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
[
ParallelAction
(
priority
=
1
,
parallel_mode
=
parallel_action
.
parallel_mode
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
output
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
return
output
...
...
@@ -45,10 +44,9 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
tensor_parallel_rank
=
gpc
.
get_local_rank
(
p
arallel
_action
.
parallel_mode
)
tensor_parallel_rank
=
gpc
.
get_local_rank
(
P
arallel
Mode
.
PARALLEL_1D
)
num_embeddings_per_partition
=
weight
.
size
(
0
)
vocab_start_index
=
tensor_parallel_rank
*
num_embeddings_per_partition
vocab_end_index
=
vocab_start_index
+
num_embeddings_per_partition
...
...
@@ -72,7 +70,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Mask the output embedding.
partial_output
[
input_mask
,
:]
=
0.
# Reduce across all the model parallel GPUs.
output
=
reduce_input
(
partial_output
,
p
arallel
_action
.
parallel_mode
)
output
=
reduce_input
(
partial_output
,
P
arallel
Mode
.
PARALLEL_1D
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
())))
return
output
...
...
colossalai/tensor/_ops/linear.py
View file @
a3b66f6d
import
torch
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
typing
import
Optional
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
distspec
from
colossalai.tensor.graph
import
GraphOpNode
,
GraphGlobalEnv
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
ColoTensor
:
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# Input:S[1]
...
...
@@ -20,7 +18,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Output:P
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
p
arallel
_action
.
parallel_mode
)
output
=
reduce_input
(
partial_output
,
P
arallel
Mode
.
PARALLEL_1D
)
# Bias
if
bias
is
not
None
:
assert
not
bias
.
has_spec
(),
'Invalid bias spec for 1Drow Linear op'
...
...
@@ -34,15 +32,16 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action
=
weight
.
spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1D
)
parallel_action
=
weight
.
spec
.
parallel_action
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
input_parallel
=
reduce_grad
(
input_tensor
,
p
arallel
_action
.
parallel_mode
)
input_parallel
=
reduce_grad
(
input_tensor
,
P
arallel
Mode
.
PARALLEL_1D
)
output_parallel
=
F
.
linear
(
input_parallel
,
weight
,
bias
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
[
ParallelAction
(
priority
=
1
,
parallel_mode
=
parallel_action
.
parallel_mode
)]))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
)))
if
parallel_action
.
gather_out
:
# All-Gather(Output)
output
=
output
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
...
...
colossalai/tensor/_ops/loss.py
View file @
a3b66f6d
...
...
@@ -28,7 +28,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduction
=
reduction
,
label_smoothing
=
label_smoothing
)
return
ColoTensor
.
from_torch_tensor
(
output
)
elif
input_tensor
.
has_spec
()
and
input_tensor
.
spec
.
num_action
==
1
:
# Single Model Parallel Applied
elif
input_tensor
.
has_spec
():
# Single Model Parallel Applied
if
input_tensor
.
spec
.
is_1D_col
():
output
=
VocabParallelCrossEntropyLoss1D
()(
input_tensor
,
target
)
return
ColoTensor
.
from_torch_tensor
(
output
)
...
...
colossalai/tensor/colo_parameter.py
View file @
a3b66f6d
...
...
@@ -33,3 +33,6 @@ class ColoParameter(ColoTensor):
tensor
=
tensor
.
as_subclass
(
ColoParameter
)
tensor
.
__init__
(
tensor
,
requires_grad
=
requires_grad
,
spec
=
spec
)
return
tensor
def
__repr__
(
self
):
return
f
'ColoParameter:
{
torch
.
Tensor
.
__repr__
(
self
)
}
'
colossalai/tensor/colo_tensor.py
View file @
a3b66f6d
...
...
@@ -45,7 +45,7 @@ class ColoTensor(torch.Tensor):
self
.
_spec
=
spec
def
has_spec
(
self
)
->
bool
:
return
self
.
_spec
.
num_action
>
0
return
self
.
_spec
.
parallel_action
is
not
None
def
is_model_data
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
MODEL
...
...
colossalai/tensor/spec.py
View file @
a3b66f6d
import
torch.distributed
as
dist
from
enum
import
Enum
from
typing
import
List
from
colossalai.context.parallel_mode
import
ParallelMode
from
typing
import
List
,
Optional
from
colossalai.tensor.distspec
import
_DistSpec
,
DistPlacementPattern
class
ComputePattern
(
Enum
):
TP1D
=
0
ZeRO
=
1
DP
=
2
TP2D
=
1
TP2P5D
=
2
TP3D
=
3
class
ParallelAction
(
object
):
def
__init__
(
self
,
priority
=
0
,
compute_pattern
=
ComputePattern
.
DP
,
parallel_mode
=
ParallelMode
.
DATA
,
gather_out
=
True
)
->
None
:
self
.
priority
=
priority
def
__init__
(
self
,
compute_pattern
:
ComputePattern
,
gather_out
:
bool
=
True
)
->
None
:
assert
isinstance
(
compute_pattern
,
ComputePattern
)
self
.
compute_pattern
=
compute_pattern
self
.
parallel_mode
=
parallel_mode
self
.
gather_out
=
gather_out
...
...
@@ -48,32 +43,9 @@ class TensorSpec(object):
# We perform Linear Op according to compute pattern of TP1D_Linear.
# After Linear Op, we split the tensors according to ZeRO.
def
__init__
(
self
,
dist_spec
:
_DistSpec
,
parallel_action
_list
:
List
[
ParallelAction
]
=
[]
):
self
.
_
parallel_action
_list
=
parallel_action
_list
def
__init__
(
self
,
dist_spec
:
_DistSpec
,
parallel_action
:
Optional
[
ParallelAction
]
=
None
):
self
.
parallel_action
=
parallel_action
self
.
dist_spec
=
dist_spec
self
.
sort
()
@
property
def
parallel_action_list
(
self
):
return
self
.
_parallel_action_list
@
property
def
num_action
(
self
):
return
len
(
self
.
_parallel_action_list
)
@
property
def
compute_patterns
(
self
):
return
[
parallel_action
.
compute_pattern
for
parallel_action
in
self
.
_parallel_action_list
]
def
sort
(
self
):
if
len
(
self
.
_parallel_action_list
)
>
0
:
self
.
_parallel_action_list
.
sort
(
key
=
lambda
parallel_action
:
parallel_action
.
priority
)
def
get_action_by_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
for
parallel_action
in
self
.
_parallel_action_list
:
if
parallel_action
.
compute_pattern
==
compute_pattern
:
return
parallel_action
return
None
def
get_process_group
(
self
):
return
self
.
dist_spec
.
process_group
...
...
@@ -99,4 +71,4 @@ class TensorSpec(object):
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
0
def
has_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
return
self
.
get
_action
_by_
compute_pattern
(
compute_pattern
)
is
not
None
return
self
.
parallel
_action
.
compute_pattern
==
compute_pattern
tests/test_tensor/test_addmm_tp.py
View file @
a3b66f6d
...
...
@@ -41,7 +41,7 @@ class Conv1D(nn.Module):
def
init_1d_row
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
def
init_1d_col
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
bias
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_embedding_tp.py
View file @
a3b66f6d
...
...
@@ -18,7 +18,7 @@ from _utils import tensor_equal, tensor_shard_equal
def
init_1d_row
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -26,7 +26,7 @@ def init_1d_row(weight):
def
init_1d_col
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_gpt.py
View file @
a3b66f6d
...
...
@@ -16,7 +16,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def
init_1d_row_spec
(
model
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
...
...
@@ -26,7 +26,7 @@ def init_1d_row_spec(model):
def
init_1d_col_spec
(
model
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
...
...
tests/test_tensor/test_linear_tp.py
View file @
a3b66f6d
...
...
@@ -19,7 +19,7 @@ from _utils import tensor_equal, tensor_shard_equal
def
init_1d_row
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
def
init_1d_col
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
bias
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_model.py
View file @
a3b66f6d
...
...
@@ -20,19 +20,15 @@ from _utils import set_seed
def
init_1d_row_linear
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
def
init_1d_col_linear
(
weight
,
gather_out
=
True
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
,
gather_out
=
gather_out
)
])
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
,
gather_out
=
gather_out
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -40,7 +36,7 @@ def init_1d_col_linear(weight, gather_out=True):
def
init_1d_row_embedding
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -48,7 +44,7 @@ def init_1d_row_embedding(weight):
def
init_1d_col_embedding
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1D
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)]
)
ParallelAction
(
ComputePattern
.
TP1D
)
)
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment