Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2c0d19d7
Unverified
Commit
2c0d19d7
authored
Apr 28, 2022
by
Ziyue Jiang
Committed by
GitHub
Apr 28, 2022
Browse files
[Tensor] add ColoTensor TP1Dcol Embedding (#899)
parent
e46e423c
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
173 additions
and
27 deletions
+173
-27
colossalai/tensor/_ops/__init__.py
colossalai/tensor/_ops/__init__.py
+1
-0
colossalai/tensor/_ops/embedding.py
colossalai/tensor/_ops/embedding.py
+56
-0
colossalai/tensor/_ops/layernorm.py
colossalai/tensor/_ops/layernorm.py
+1
-1
colossalai/tensor/_ops/linear.py
colossalai/tensor/_ops/linear.py
+5
-5
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+13
-8
colossalai/tensor/spec.py
colossalai/tensor/spec.py
+9
-7
tests/test_tensor/test_embedding_tp.py
tests/test_tensor/test_embedding_tp.py
+82
-0
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+3
-3
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+3
-3
No files found.
colossalai/tensor/_ops/__init__.py
View file @
2c0d19d7
...
@@ -2,3 +2,4 @@ from .linear import colo_linear
...
@@ -2,3 +2,4 @@ from .linear import colo_linear
from
.element_wise
import
*
from
.element_wise
import
*
from
.layernorm
import
colo_layernorm
from
.layernorm
import
colo_layernorm
from
.loss
import
colo_cross_entropy
from
.loss
import
colo_cross_entropy
from
.embedding
import
colo_embedding
\ No newline at end of file
colossalai/tensor/_ops/embedding.py
0 → 100644
View file @
2c0d19d7
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.context
import
ParallelMode
from
colossalai.nn.layer.parallel_1d._utils
import
split_forward_gather_backward
,
reduce_input
,
\
gather_forward_split_backward
,
reduce_grad
from
colossalai.nn.layer.utils
import
divide
from
colossalai.core
import
global_context
as
gpc
from
packaging
import
version
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
ShardPattern
def
colo_embedding_1Dcol
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
args
,
kwargs
)
->
ColoTensor
:
# embedding_1Dcol split the weight(lookup table)
# Gather splitted lookup table
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol_Embedding
)
if
not
input_tensor
.
is_gathered
():
input_tensor
.
gather
()
output_parallel
=
torch
.
nn
.
functional
.
embedding
(
input_tensor
.
torch_tensor
(),
weight
.
torch_tensor
(),
*
args
,
**
kwargs
)
output
=
ColoTensor
.
init_from_torch_tensor
(
output_parallel
)
out_parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
parallel_mode
=
parallel_action
.
parallel_mode
)]
output_spec
=
TensorSpec
(
out_parallel_action_list
)
output
.
set_spec
(
output_spec
,
shard
=
False
)
output
.
set_shard_pattern
(
ShardPattern
.
Col
)
output
.
gather
()
return
output
@
colo_op_impl
(
torch
.
nn
.
functional
.
embedding
)
def
colo_embedding
(
types
,
args
,
kwargs
,
pg
):
"""Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
This method looks up an embedding table.
"""
input_tensor
=
args
[
0
]
weight
=
args
[
1
]
args
=
args
[
2
:]
if
not
isinstance
(
input_tensor
,
ColoTensor
):
input_tensor
=
ColoTensor
.
init_from_torch_tensor
(
input_tensor
)
if
not
isinstance
(
weight
,
ColoTensor
):
weight
=
ColoTensor
.
init_from_torch_tensor
(
weight
)
# Handle differen parallel actions.
if
not
weight
.
has_spec
():
# No Model Parallel Applied
input_tensor
=
input_tensor
.
torch_tensor
()
weight
=
weight
.
torch_tensor
()
output
=
torch
.
nn
.
functional
.
embedding
(
input_tensor
,
weight
,
*
args
,
**
kwargs
)
return
ColoTensor
.
init_from_torch_tensor
(
output
)
elif
weight
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
compute_patterns
=
weight
.
shard_spec
.
compute_patterns
if
ComputePattern
.
TP1DCol_Embedding
in
compute_patterns
:
return
colo_embedding_1Dcol
(
input_tensor
,
weight
,
args
,
kwargs
)
else
:
raise
NotImplementedError
else
:
raise
NotImplementedError
colossalai/tensor/_ops/layernorm.py
View file @
2c0d19d7
...
@@ -27,7 +27,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None):
...
@@ -27,7 +27,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None):
eps
=
kwargs
[
'eps'
]
eps
=
kwargs
[
'eps'
]
if
isinstance
(
input_tensor
,
ColoTensor
):
if
isinstance
(
input_tensor
,
ColoTensor
):
if
input_tensor
.
is_activation
()
and
not
input_tensor
.
is_gathered
():
if
not
input_tensor
.
is_gathered
():
input_tensor
.
gather
()
input_tensor
.
gather
()
input_tensor
=
input_tensor
.
torch_tensor
()
input_tensor
=
input_tensor
.
torch_tensor
()
if
isinstance
(
weight
,
ColoTensor
):
if
isinstance
(
weight
,
ColoTensor
):
...
...
colossalai/tensor/_ops/linear.py
View file @
2c0d19d7
...
@@ -9,8 +9,8 @@ from packaging import version
...
@@ -9,8 +9,8 @@ from packaging import version
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
ShardPattern
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
ShardPattern
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
ColoTensor
)
->
ColoTensor
:
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
_Linear
)
# Input:S[1] x Weight:S[0] = Output:P
# Input:S[1] x Weight:S[0] = Output:P
# All-Reduce(Output) + bias = res
# All-Reduce(Output) + bias = res
# Input:S[1]
# Input:S[1]
...
@@ -47,7 +47,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe
...
@@ -47,7 +47,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: ColoTe
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# All-Gather(Output)
# Input:B
# Input:B
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol
)
parallel_action
=
weight
.
shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol
_Linear
)
if
input_tensor
.
is_gathered
():
if
input_tensor
.
is_gathered
():
# Not splited yet.
# Not splited yet.
assert
input_tensor
.
shape
[
-
1
]
==
weight
.
size
(
-
1
),
\
assert
input_tensor
.
shape
[
-
1
]
==
weight
.
size
(
-
1
),
\
...
@@ -108,9 +108,9 @@ def colo_linear(types, args, kwargs, pg):
...
@@ -108,9 +108,9 @@ def colo_linear(types, args, kwargs, pg):
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
))
return
ColoTensor
.
init_from_torch_tensor
(
torch
.
nn
.
functional
.
linear
(
input_tensor
,
weight
,
bias
))
elif
weight
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
elif
weight
.
shard_spec
.
num_action
==
1
:
# Single Model Parallel Applied
compute_patterns
=
weight
.
shard_spec
.
compute_patterns
compute_patterns
=
weight
.
shard_spec
.
compute_patterns
if
ComputePattern
.
TP1DRow
in
compute_patterns
:
if
ComputePattern
.
TP1DRow
_Linear
in
compute_patterns
:
return
colo_linear_1Drow
(
input_tensor
,
weight
,
bias
)
return
colo_linear_1Drow
(
input_tensor
,
weight
,
bias
)
elif
ComputePattern
.
TP1DCol
in
compute_patterns
:
elif
ComputePattern
.
TP1DCol
_Linear
in
compute_patterns
:
return
colo_linear_1Dcol
(
input_tensor
,
weight
,
bias
)
return
colo_linear_1Dcol
(
input_tensor
,
weight
,
bias
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
...
...
colossalai/tensor/colo_tensor.py
View file @
2c0d19d7
...
@@ -142,14 +142,19 @@ class ColoTensor(object):
...
@@ -142,14 +142,19 @@ class ColoTensor(object):
if
self
.
_shard_pattern
is
not
ShardPattern
.
NA
:
# reshard
if
self
.
_shard_pattern
is
not
ShardPattern
.
NA
:
# reshard
self
.
gather
()
self
.
gather
()
# Model Parameters
# Model Parameters
if
ComputePattern
.
TP1DRow
in
self
.
_shard_spec
.
compute_patterns
:
if
self
.
_shard_spec
.
num_action
==
1
:
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DRow
)
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=-
1
)
self
.
_shard_spec
.
compute_patterns
[
0
])
self
.
_shard_pattern
=
ShardPattern
.
Col
# We bind our ComputePattern on weight, which has to be transposed when linear().
if
parallel_action
.
compute_pattern
in
[
ComputePattern
.
TP1DRow_Linear
,
\
elif
ComputePattern
.
TP1DCol
in
self
.
_shard_spec
.
compute_patterns
:
ComputePattern
.
TP1DCol_Embedding
]:
parallel_action
=
self
.
_shard_spec
.
get_action_by_compute_pattern
(
ComputePattern
.
TP1DCol
)
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=-
1
)
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=
0
)
self
.
_shard_pattern
=
ShardPattern
.
Col
# We bind our ComputePattern on weight, which has to be transposed when linear().
self
.
_shard_pattern
=
ShardPattern
.
Row
elif
parallel_action
.
compute_pattern
in
[
ComputePattern
.
TP1DCol_Linear
,
\
ComputePattern
.
TP1DRow_Embedding
]:
self
.
_shard_1d
(
parallel_action
=
parallel_action
,
dim
=
0
)
self
.
_shard_pattern
=
ShardPattern
.
Row
else
:
raise
NotImplementedError
def
gather
(
self
):
def
gather
(
self
):
assert
self
.
is_activation
(),
'Currently we only support gather Activation ColoTensor.'
assert
self
.
is_activation
(),
'Currently we only support gather Activation ColoTensor.'
...
...
colossalai/tensor/spec.py
View file @
2c0d19d7
...
@@ -4,10 +4,12 @@ from colossalai.context.parallel_mode import ParallelMode
...
@@ -4,10 +4,12 @@ from colossalai.context.parallel_mode import ParallelMode
class
ComputePattern
(
Enum
):
class
ComputePattern
(
Enum
):
TP1DRow
=
1
TP1DRow_Linear
=
1
TP1DCol
=
2
TP1DCol_Linear
=
2
ZeRO
=
3
TP1DRow_Embedding
=
3
DP
=
4
TP1DCol_Embedding
=
4
ZeRO
=
5
DP
=
6
class
ShardPattern
(
Enum
):
class
ShardPattern
(
Enum
):
...
@@ -43,14 +45,14 @@ class TensorSpec(object):
...
@@ -43,14 +45,14 @@ class TensorSpec(object):
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# using ZeRO with DP-degree = 4 and 1DRowTP with TP-degree = 2.
# parallel_action_list = [
# parallel_action_list = [
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
# ParallelAction(10, ComputePattern.ZeRO, gpc.get_group(ParallelMode.DATA)),
# ParallelAction(1, ComputePattern.TP1DRow, gpc.get_group(ParallelMode.PARALLEL_1D))
# ParallelAction(1, ComputePattern.TP1DRow
_Linear
, gpc.get_group(ParallelMode.PARALLEL_1D))
# ]
# ]
# When the ColoTensor is initialized,
# When the ColoTensor is initialized,
# we first splitting tensor according to ParallelAction of ZeRO,
# we first splitting tensor according to ParallelAction of ZeRO,
# then splitting tensor according to ParallelAction of TP1DRow.
# then splitting tensor according to ParallelAction of TP1DRow
_Linear
.
# During Linear computation
# During Linear computation
# Before Linear Op, we gather the tensors according to ZeRO.
# Before Linear Op, we gather the tensors according to ZeRO.
# We perform Linear Op according to compute pattern of TP1DRow.
# We perform Linear Op according to compute pattern of TP1DRow
_Linear
.
# After Linear Op, we split the tensors according to ZeRO.
# After Linear Op, we split the tensors according to ZeRO.
def
__init__
(
self
,
parallel_action_list
:
List
[
ParallelAction
]
=
[],
shard_pattern
:
ShardPattern
=
ShardPattern
.
NA
):
def
__init__
(
self
,
parallel_action_list
:
List
[
ParallelAction
]
=
[],
shard_pattern
:
ShardPattern
=
ShardPattern
.
NA
):
...
...
tests/test_tensor/test_embedding_tp.py
0 → 100644
View file @
2c0d19d7
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
ColoTensor
from
functools
import
partial
import
colossalai
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
from
_utils
import
check_equal
,
replace_parameter_add_grad
,
broadcast_tensor_chunk
def
run_embedding_tp1d_col_test
():
device
=
get_current_device
()
dtype
=
torch
.
float32
DEPTH
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
num_embeddings
=
12
embedding_dim
=
32
local_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
layer_master
=
torch
.
nn
.
Embedding
(
num_embeddings
,
embedding_dim
)
layer
=
torch
.
nn
.
Embedding
(
num_embeddings
,
embedding_dim
)
A_master
=
torch
.
tensor
((
0
,
3
,
6
,
9
),
device
=
device
)
A
=
broadcast_tensor_chunk
(
A_master
,
chunk_size
=
1
)
W_shape
=
(
num_embeddings
,
embedding_dim
)
W_master
=
torch
.
randn
(
W_shape
,
dtype
=
dtype
,
device
=
device
)
W
=
broadcast_tensor_chunk
(
W_master
,
chunk_size
=
1
)
W
.
requires_grad
=
True
# replace the torch nn.Parameters with ColoTensor
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol_Embedding
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
spec
=
TensorSpec
(
parallel_action_list
)
sharded_weight
.
set_spec
(
spec
)
# reshard
replace_parameter_add_grad
(
layer
,
sharded_weight
)
out
=
layer
(
A
)
replace_parameter_add_grad
(
layer_master
,
W_master
)
C_master
=
layer_master
(
A_master
)
C
=
C_master
.
clone
()
check_equal
(
out
,
C
)
grad_shape
=
C_master
.
shape
grad_master
=
torch
.
randn
(
grad_shape
,
dtype
=
dtype
,
device
=
get_current_device
())
grad
=
broadcast_tensor_chunk
(
grad_master
,
chunk_size
=
1
)
out
.
backward
(
grad
)
grad_master
=
grad_master
.
clone
()
C_master
.
backward
(
grad_master
)
W_grad
=
W_master
.
grad
W_grad
=
torch
.
chunk
(
W_grad
,
DEPTH
,
dim
=-
1
)[
local_rank
]
check_equal
(
W_grad
,
layer
.
weight
.
grad
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_embedding_tp1d_col_test
()
@
pytest
.
mark
.
dist
@
parameterize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
def
test_embedding_1d
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_embedding_1d
()
tests/test_tensor/test_linear_tp.py
View file @
2c0d19d7
...
@@ -47,7 +47,7 @@ def run_linear_tp1d_col_test():
...
@@ -47,7 +47,7 @@ def run_linear_tp1d_col_test():
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
B
)
sharded_bias
=
ColoTensor
.
init_from_torch_tensor
(
B
)
parallel_action_list
=
[
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
]
spec
=
TensorSpec
(
parallel_action_list
)
spec
=
TensorSpec
(
parallel_action_list
)
sharded_weight
.
set_spec
(
spec
)
# reshard
sharded_weight
.
set_spec
(
spec
)
# reshard
...
@@ -110,7 +110,7 @@ def run_linear_tp1d_row_test():
...
@@ -110,7 +110,7 @@ def run_linear_tp1d_row_test():
# replace the torch nn.Parameters with ColoTensor
# replace the torch nn.Parameters with ColoTensor
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
sharded_weight
=
ColoTensor
.
init_from_torch_tensor
(
W
)
parallel_action_list
=
[
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
]
spec
=
TensorSpec
(
parallel_action_list
)
spec
=
TensorSpec
(
parallel_action_list
)
sharded_weight
.
set_spec
(
spec
=
spec
)
# reshard
sharded_weight
.
set_spec
(
spec
=
spec
)
# reshard
...
@@ -145,7 +145,7 @@ def run_linear_tp1d_row_test():
...
@@ -145,7 +145,7 @@ def run_linear_tp1d_row_test():
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
#
run_linear_tp1d_row_test()
run_linear_tp1d_row_test
()
run_linear_tp1d_col_test
()
run_linear_tp1d_col_test
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_tensor/test_model.py
View file @
2c0d19d7
...
@@ -38,12 +38,12 @@ def run_1d_col_tp():
...
@@ -38,12 +38,12 @@ def run_1d_col_tp():
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
parallel_action_list_row
=
[
parallel_action_list_row
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
]
spec_row
=
TensorSpec
(
parallel_action_list_row
)
spec_row
=
TensorSpec
(
parallel_action_list_row
)
parallel_action_list_col
=
[
parallel_action_list_col
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DCol
_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
]
spec_col
=
TensorSpec
(
parallel_action_list_col
)
spec_col
=
TensorSpec
(
parallel_action_list_col
)
...
@@ -168,7 +168,7 @@ def run_1d_row_tp():
...
@@ -168,7 +168,7 @@ def run_1d_row_tp():
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
parallel_action_list
=
[
parallel_action_list
=
[
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
ParallelAction
(
priority
=
1
,
compute_pattern
=
ComputePattern
.
TP1DRow
_Linear
,
parallel_mode
=
ParallelMode
.
PARALLEL_1D
)
]
]
spec
=
TensorSpec
(
parallel_action_list
)
spec
=
TensorSpec
(
parallel_action_list
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment