Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
060b917d
Unverified
Commit
060b917d
authored
Jul 04, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 04, 2022
Browse files
[refactor] remove gpc dependency in colotensor's _ops (#1189)
parent
abf6a262
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
342 additions
and
123 deletions
+342
-123
colossalai/nn/_ops/_utils.py
colossalai/nn/_ops/_utils.py
+185
-0
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+8
-12
colossalai/nn/_ops/embedding.py
colossalai/nn/_ops/embedding.py
+9
-13
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+2
-3
colossalai/nn/_ops/layernorm.py
colossalai/nn/_ops/layernorm.py
+1
-1
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+7
-9
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+15
-10
colossalai/nn/parallel/layers/colo_module.py
colossalai/nn/parallel/layers/colo_module.py
+1
-1
colossalai/nn/parallel/layers/embedding.py
colossalai/nn/parallel/layers/embedding.py
+6
-10
colossalai/nn/parallel/layers/linear.py
colossalai/nn/parallel/layers/linear.py
+8
-17
colossalai/nn/parallel/layers/module_utils.py
colossalai/nn/parallel/layers/module_utils.py
+12
-8
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-0
colossalai/tensor/dist_spec_mgr.py
colossalai/tensor/dist_spec_mgr.py
+9
-7
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+3
-3
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+41
-4
colossalai/tensor/tensor_spec.py
colossalai/tensor/tensor_spec.py
+3
-2
tests/test_ddp/test_ddp_ignore_params.py
tests/test_ddp/test_ddp_ignore_params.py
+5
-2
tests/test_ddp/test_ddp_state_dict.py
tests/test_ddp/test_ddp_state_dict.py
+5
-2
tests/test_tensor/_utils/_util.py
tests/test_tensor/_utils/_util.py
+5
-3
tests/test_tensor/test_addmm_tp.py
tests/test_tensor/test_addmm_tp.py
+11
-16
No files found.
colossalai/nn/_ops/_utils.py
View file @
060b917d
import
torch
import
torch
from
typing
import
Union
,
Optional
from
typing
import
Union
,
Optional
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
import
torch
import
torch.distributed
as
dist
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.nn.layer.utils
import
divide
from
colossalai.tensor
import
ProcessGroup
GeneralTensor
=
Union
[
ColoTensor
,
torch
.
Tensor
]
GeneralTensor
=
Union
[
ColoTensor
,
torch
.
Tensor
]
Number
=
Union
[
int
,
float
]
Number
=
Union
[
int
,
float
]
...
@@ -10,3 +16,182 @@ def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTens
...
@@ -10,3 +16,182 @@ def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTens
if
tensor
is
not
None
and
not
isinstance
(
tensor
,
ColoTensor
):
if
tensor
is
not
None
and
not
isinstance
(
tensor
,
ColoTensor
):
tensor
=
ColoTensor
.
from_torch_tensor
(
tensor
)
tensor
=
ColoTensor
.
from_torch_tensor
(
tensor
)
return
tensor
return
tensor
def
set_parallel_input
(
input_parallel
:
bool
):
env
.
parallel_input_1d
=
input_parallel
def
get_parallel_input
():
return
env
.
parallel_input_1d
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
):
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
def
vocab_range_from_global_vocab_size
(
global_vocab_size
,
rank
,
world_size
):
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
)
def
_reduce
(
input_
,
pg
:
ProcessGroup
):
# skip if only one rank involved
if
pg
.
tp_world_size
()
==
1
:
return
input_
assert
input_
.
device
.
type
==
'cuda'
group
=
pg
.
tp_process_group
()
dist
.
all_reduce
(
input_
,
group
=
group
)
return
input_
def
_split
(
input_
,
pg
:
ProcessGroup
,
dim
=-
1
):
# skip if only one rank involved
world_size
=
pg
.
tp_world_size
()
if
world_size
==
1
:
return
input_
# Split along last dimension.
dim_size
=
input_
.
size
(
dim
)
assert
dim_size
%
world_size
==
0
,
\
f
'The dimension to split (
{
dim_size
}
) is not a multiple of world size (
{
world_size
}
), '
\
f
'cannot split tensor evenly'
tensor_list
=
torch
.
split
(
input_
,
dim_size
//
world_size
,
dim
=
dim
)
rank
=
pg
.
tp_local_rank
()
output
=
tensor_list
[
rank
].
contiguous
()
return
output
def
_gather
(
input_
,
pg
:
ProcessGroup
,
dim
=-
1
):
# skip if only one rank involved
world_size
=
pg
.
tp_world_size
()
if
world_size
==
1
:
return
input_
# all gather
rank
=
pg
.
tp_local_rank
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
assert
input_
.
device
.
type
==
'cuda'
group
=
pg
.
tp_process_group
()
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
# concat
output
=
torch
.
cat
(
tensor_list
,
dim
=
dim
).
contiguous
()
return
output
class
_ReduceGrad
(
torch
.
autograd
.
Function
):
"""
Pass the input to the model parallel region.
Args:
input_: input matrix.
process_group: parallel mode.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
input_
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
):
ctx
.
mode
=
process_group
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_reduce
(
grad_output
,
ctx
.
mode
),
None
class
_ReduceInput
(
torch
.
autograd
.
Function
):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
process_group: parallel mode.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
):
return
_reduce
(
input_
,
process_group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
grad_output
,
None
class
_SplitForwardGatherBackward
(
torch
.
autograd
.
Function
):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
,
dim
):
ctx
.
mode
=
process_group
ctx
.
dim
=
dim
return
_split
(
input_
,
process_group
,
dim
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
,
ctx
.
mode
,
ctx
.
dim
),
None
,
None
class
_GatherForwardSplitBackward
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
,
dim
):
ctx
.
mode
=
process_group
ctx
.
dim
=
dim
return
_gather
(
input_
,
process_group
,
dim
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
,
ctx
.
mode
,
ctx
.
dim
),
None
,
None
def
reduce_grad
(
input_
,
process_group
):
return
_ReduceGrad
.
apply
(
input_
,
process_group
)
def
reduce_input
(
input_
,
process_group
):
return
_ReduceInput
.
apply
(
input_
,
process_group
)
def
split_forward_gather_backward
(
input_
,
process_group
,
dim
):
return
_SplitForwardGatherBackward
.
apply
(
input_
,
process_group
,
dim
)
def
gather_forward_split_backward
(
input_
,
process_group
,
dim
):
return
_GatherForwardSplitBackward
.
apply
(
input_
,
process_group
,
dim
)
colossalai/nn/_ops/addmm.py
View file @
060b917d
import
torch
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
from
colossalai.tensor
import
distspec
from
colossalai.tensor
import
distspec
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
from
._utils
import
reduce_input
,
reduce_grad
def
colo_addmm_1Drow
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
def
colo_addmm_1Drow
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
...
@@ -12,18 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
...
@@ -12,18 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
# beta * input + alpha * All-Reduce(Output) = res
mat1
=
mat1
.
convert_to_dist_spec
(
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
shard
(
mat2
.
get_process_group
(),
[
-
1
],
[
mat2
.
get_tp_world_size
()]))
distspec
.
shard
(
mat2
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
tensor_spec
.
get_process_group_size
()]))
# Output:P
# Output:P
partial_output
=
torch
.
mm
(
mat1
,
mat2
)
partial_output
=
torch
.
mm
(
mat1
,
mat2
)
# Reduce(Output)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
partial_output
,
mat1
.
get_process_group
()
)
# input
# input
assert
not
input_tensor
.
has_compute_spec
(),
'Invalid input spec for 1Drow addmm op'
assert
not
input_tensor
.
has_compute_spec
(),
'Invalid input spec for 1Drow addmm op'
output
=
beta
*
input_tensor
+
alpha
*
output
output
=
beta
*
input_tensor
+
alpha
*
output
output
=
ColoTensor
.
from_torch_tensor
(
output
,
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
mat2
.
get_process_group
())))
spec
=
TensorSpec
(
distspec
.
replicate
(
mat2
.
tensor_spec
.
get_process_group
())))
return
output
return
output
...
@@ -31,13 +28,12 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
...
@@ -31,13 +28,12 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha
:
Number
)
->
ColoTensor
:
alpha
:
Number
)
->
ColoTensor
:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec
=
mat2
.
tensor_spec
.
compute_spec
compute_spec
=
mat2
.
tensor_spec
.
compute_spec
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
tensor_spec
.
get_process_group
()))
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
get_process_group
()))
mat1
=
reduce_grad
(
mat1
,
ParallelMode
.
PARALLEL_1D
)
mat1
=
reduce_grad
(
mat1
,
mat1
.
get_process_group
()
)
output_parallel
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
)
output_parallel
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
)
output_spec
=
TensorSpec
(
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
get_process_group
(),
[
-
1
],
[
mat2
.
get_tp_world_size
()]),
distspec
.
shard
(
mat2
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
tensor_spec
.
get_process_group_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
if
compute_spec
.
output_replicate
:
if
compute_spec
.
output_replicate
:
...
...
colossalai/nn/_ops/embedding.py
View file @
060b917d
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
typing
import
Optional
from
typing
import
Optional
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
,
reduce_input
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
def
colo_embedding_1Dcol
(
input_tensor
:
ColoTensor
,
def
colo_embedding_1Dcol
(
input_tensor
:
ColoTensor
,
...
@@ -17,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
...
@@ -17,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse
:
bool
=
False
)
->
ColoTensor
:
sparse
:
bool
=
False
)
->
ColoTensor
:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
# Gather splitted lookup table
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
get_process_group
()))
output_parallel
=
F
.
embedding
(
input_tensor
,
output_parallel
=
F
.
embedding
(
input_tensor
,
weight
,
weight
,
...
@@ -26,9 +23,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
...
@@ -26,9 +23,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type
=
norm_type
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
)
sparse
=
sparse
)
output_spec
=
TensorSpec
(
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
get_process_group
(),
[
-
1
],
[
weight
.
get_tp_world_size
()]),
distspec
.
shard
(
weight
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_spec
.
get_process_group_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
compute_spec
=
weight
.
tensor_spec
.
compute_spec
compute_spec
=
weight
.
tensor_spec
.
compute_spec
...
@@ -49,9 +45,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
...
@@ -49,9 +45,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Find index in this shard and mask those not here
# Reduce all
# Reduce all
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
get_process_group
()))
tensor_parallel_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank
=
weight
.
tensor_spec
.
dist_spec
.
process_group
.
tp_local_rank
()
num_embeddings_per_partition
=
weight
.
size_local
(
0
)
num_embeddings_per_partition
=
weight
.
size_local
(
0
)
vocab_start_index
=
tensor_parallel_rank
*
num_embeddings_per_partition
vocab_start_index
=
tensor_parallel_rank
*
num_embeddings_per_partition
vocab_end_index
=
vocab_start_index
+
num_embeddings_per_partition
vocab_end_index
=
vocab_start_index
+
num_embeddings_per_partition
...
@@ -75,9 +72,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
...
@@ -75,9 +72,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Mask the output embedding.
# Mask the output embedding.
partial_output
[
input_mask
,
:]
=
0.
partial_output
[
input_mask
,
:]
=
0.
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
partial_output
,
weight
.
get_process_group
())
output
=
ColoTensor
.
from_torch_tensor
(
output
,
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
get_process_group
())))
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
())))
return
output
return
output
...
...
colossalai/nn/_ops/embedding_bag.py
View file @
060b917d
...
@@ -32,9 +32,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
...
@@ -32,9 +32,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights
=
per_sample_weights
,
per_sample_weights
=
per_sample_weights
,
include_last_offset
=
include_last_offset
,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
)
padding_idx
=
padding_idx
)
output_spec
=
TensorSpec
(
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
get_process_group
(),
[
-
1
],
[
weight
.
get_tp_world_size
()]),
distspec
.
shard
(
weight
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_spec
.
get_process_group_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
if
weight
.
tensor_spec
.
compute_spec
.
output_replicate
:
if
weight
.
tensor_spec
.
compute_spec
.
output_replicate
:
...
...
colossalai/nn/_ops/layernorm.py
View file @
060b917d
...
@@ -17,7 +17,7 @@ def colo_layernorm(
...
@@ -17,7 +17,7 @@ def colo_layernorm(
input_tensor
,
weight
,
bias
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
weight
,
bias
)))
input_tensor
,
weight
,
bias
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
weight
,
bias
)))
# TODO (ver217): check dist spec
# TODO (ver217): check dist spec
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
input_tensor
.
tensor_spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
input_tensor
.
get_process_group
()))
output
=
F
.
layer_norm
(
input_tensor
,
normalized_shape
,
weight
=
weight
,
bias
=
bias
,
eps
=
eps
)
output
=
F
.
layer_norm
(
input_tensor
,
normalized_shape
,
weight
=
weight
,
bias
=
bias
,
eps
=
eps
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
input_tensor
.
tensor_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
input_tensor
.
tensor_spec
)
...
...
colossalai/nn/_ops/linear.py
View file @
060b917d
...
@@ -2,9 +2,8 @@ import torch.nn.functional as F
...
@@ -2,9 +2,8 @@ import torch.nn.functional as F
from
typing
import
Optional
from
typing
import
Optional
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d
._utils
import
reduce_input
,
reduce_grad
from
._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
colossalai.context
import
ParallelMode
from
colossalai.nn.graph
import
register_colo_graph
,
GraphOpNode
,
GraphGlobalEnv
from
colossalai.nn.graph
import
register_colo_graph
,
GraphOpNode
,
GraphGlobalEnv
...
@@ -13,19 +12,18 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
...
@@ -13,19 +12,18 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res
# All-Reduce(Output) + bias = res
# Input:S[1]
# Input:S[1]
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
shard
(
weight
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_spec
.
get_process_group
_size
()]))
distspec
.
shard
(
weight
.
get_process_group
(),
[
-
1
],
[
weight
.
get_tp_world
_size
()]))
# Output:P
# Output:P
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
# Reduce(Output)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
partial_output
,
weight
.
get_process_group
()
)
# Bias
# Bias
if
bias
is
not
None
:
if
bias
is
not
None
:
assert
not
bias
.
has_compute_spec
(),
'Invalid bias spec for 1Drow Linear op'
assert
not
bias
.
has_compute_spec
(),
'Invalid bias spec for 1Drow Linear op'
output
=
output
+
bias
output
=
output
+
bias
output
=
ColoTensor
.
from_torch_tensor
(
output
,
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
get_process_group
())))
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
())))
return
output
return
output
...
@@ -35,13 +33,13 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
...
@@ -35,13 +33,13 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B
# Input:B
compute_spec
=
weight
.
tensor_spec
.
compute_spec
compute_spec
=
weight
.
tensor_spec
.
compute_spec
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
input_parallel
=
reduce_grad
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
)
input_parallel
=
reduce_grad
(
input_tensor
,
weight
.
tensor_spec
.
dist_spec
.
process_group
)
output_parallel
=
F
.
linear
(
input_parallel
,
weight
,
bias
)
output_parallel
=
F
.
linear
(
input_parallel
,
weight
,
bias
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
tensor_spec
.
get_process_group
(),
[
-
1
],
distspec
.
shard
(
weight
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_spec
.
get_process_group
_size
()]),
[
weight
.
get_tp_world
_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
)))
ComputeSpec
(
ComputePattern
.
TP1D
)))
if
compute_spec
.
output_replicate
:
if
compute_spec
.
output_replicate
:
return
output
.
to_replicate
()
return
output
.
to_replicate
()
...
...
colossalai/nn/parallel/data_parallel.py
View file @
060b917d
import
torch
import
torch
import
itertools
import
itertools
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
functools
import
partial
from
functools
import
partial
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.gemini.chunk
import
TensorState
,
Chunk
from
colossalai.gemini.chunk
import
TensorState
,
Chunk
...
@@ -12,6 +10,7 @@ from typing import Dict, Iterable, List, Optional
...
@@ -12,6 +10,7 @@ from typing import Dict, Iterable, List, Optional
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
.reducer
import
Reducer
from
.reducer
import
Reducer
try
:
try
:
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
,
_IncompatibleKeys
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
,
_IncompatibleKeys
...
@@ -45,8 +44,8 @@ class ColoDDP(torch.nn.Module):
...
@@ -45,8 +44,8 @@ class ColoDDP(torch.nn.Module):
>>> from colossalai.core import global_context as gpc
>>> from colossalai.core import global_context as gpc
>>> from colossalai.context import ParallelMode
>>> from colossalai.context import ParallelMode
>>> model = torch.nn.Linear(20, 1)
>>> model = torch.nn.Linear(20, 1)
>>>
model = ColoDDP(model
)
>>>
pg = ProcessGroup(tp_degree = world_size//2
)
>>>
//
model = ColoDDP(model, p
rocess_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA)
)
>>> model = ColoDDP(model, p
g
)
>>> logits = model(x)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
>>> model.backward(loss)
...
@@ -55,13 +54,13 @@ class ColoDDP(torch.nn.Module):
...
@@ -55,13 +54,13 @@ class ColoDDP(torch.nn.Module):
module (torch.nn.Module): Module to apply DDP.
module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None.
If it's None, the default data parallel group will be used. Defaults to None.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
cpu_
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None.
If it's None, the default CPU data parallel group will be used. Defaults to None.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
module
:
torch
.
nn
.
Module
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
process_group
:
Colo
ProcessGroup
,
cpu_process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
cpu_process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
bucket_cap_mb
:
int
=
25
,
bucket_cap_mb
:
int
=
25
,
rebuild_bucket
:
bool
=
True
)
->
None
:
rebuild_bucket
:
bool
=
True
)
->
None
:
...
@@ -69,8 +68,9 @@ class ColoDDP(torch.nn.Module):
...
@@ -69,8 +68,9 @@ class ColoDDP(torch.nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
module
=
module
self
.
module
=
module
self
.
comm_stream
:
torch
.
cuda
.
Stream
=
torch
.
cuda
.
Stream
()
self
.
comm_stream
:
torch
.
cuda
.
Stream
=
torch
.
cuda
.
Stream
()
self
.
process_group
=
process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
assert
process_group
self
.
cpu_process_group
=
cpu_process_group
or
gpc
.
get_cpu_group
(
ParallelMode
.
DATA
)
self
.
process_group
=
process_group
.
dp_process_group
()
self
.
dp_world_size
=
self
.
process_group
.
size
()
self
.
dp_world_size
=
self
.
process_group
.
size
()
self
.
reducer
=
Reducer
(
bucket_cap_mb
)
self
.
reducer
=
Reducer
(
bucket_cap_mb
)
self
.
rebuild_bucket
=
rebuild_bucket
self
.
rebuild_bucket
=
rebuild_bucket
...
@@ -120,6 +120,8 @@ class ColoDDP(torch.nn.Module):
...
@@ -120,6 +120,8 @@ class ColoDDP(torch.nn.Module):
return
empty_grad
return
empty_grad
else
:
else
:
#TODO(jiaruifang) fixme
raise
NotImplementedError
dist
.
all_reduce
(
grad
,
group
=
self
.
cpu_process_group
)
dist
.
all_reduce
(
grad
,
group
=
self
.
cpu_process_group
)
return
grad
return
grad
...
@@ -191,8 +193,11 @@ class ZeroDDP(ColoDDP):
...
@@ -191,8 +193,11 @@ class ZeroDDP(ColoDDP):
For more details, see the API reference of ``GeminiManager``.
For more details, see the API reference of ``GeminiManager``.
"""
"""
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
gemini_manager
:
GeminiManager
)
->
None
:
def
__init__
(
self
,
super
().
__init__
(
module
.
half
())
module
:
torch
.
nn
.
Module
,
gemini_manager
:
GeminiManager
,
process_group
:
Optional
[
ColoProcessGroup
]
=
None
)
->
None
:
super
().
__init__
(
module
.
half
(),
process_group
=
process_group
)
self
.
gemini_manager
=
gemini_manager
self
.
gemini_manager
=
gemini_manager
self
.
chunk_manager
=
gemini_manager
.
chunk_manager
self
.
chunk_manager
=
gemini_manager
.
chunk_manager
self
.
param_op_hook
=
ZeROHookV2
(
gemini_manager
)
self
.
param_op_hook
=
ZeROHookV2
(
gemini_manager
)
...
...
colossalai/nn/parallel/layers/colo_module.py
View file @
060b917d
...
@@ -52,5 +52,5 @@ class ColoModule(object):
...
@@ -52,5 +52,5 @@ class ColoModule(object):
def
get_param_names
(
self
):
def
get_param_names
(
self
):
return
self
.
_shard_params
return
self
.
_shard_params
def
register
(
self
,
compute_pattern
):
def
register
(
self
,
compute_pattern
,
pg
):
raise
NotImplementedError
raise
NotImplementedError
colossalai/nn/parallel/layers/embedding.py
View file @
060b917d
from
.colo_module
import
ColoModule
from
.colo_module
import
ColoModule
from
colossalai.tensor
import
ComputePattern
,
distspec
from
colossalai.tensor
import
ComputePattern
,
distspec
,
ProcessGroup
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
...
@@ -10,20 +10,18 @@ class ColoEmbedding(ColoModule):
...
@@ -10,20 +10,18 @@ class ColoEmbedding(ColoModule):
super
(
ColoEmbedding
,
self
).
__init__
()
super
(
ColoEmbedding
,
self
).
__init__
()
self
.
_register_shard_params
([
'weight'
])
self
.
_register_shard_params
([
'weight'
])
def
register
(
self
,
compute_pattern
):
def
register
(
self
,
compute_pattern
,
pg
:
ProcessGroup
):
if
not
compute_pattern
in
self
.
_allowed_patterns
:
if
not
compute_pattern
in
self
.
_allowed_patterns
:
if
ComputePattern
.
TP1D
==
compute_pattern
:
if
ComputePattern
.
TP1D
==
compute_pattern
:
self
.
_set_TP1D
()
self
.
_set_TP1D
(
pg
)
def
_set_TP1D
(
self
):
def
_set_TP1D
(
self
,
pg
:
ProcessGroup
):
# TP1D Row Linear
# TP1D Row Linear
_compute_pattern
=
ComputePattern
.
TP1D
_compute_pattern
=
ComputePattern
.
TP1D
self
.
_register_allowed_patterns
(
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
dist_specs
=
{
'weight'
:
'weight'
:
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
},
},
mode
=
'row'
,
mode
=
'row'
,
)
)
...
@@ -32,9 +30,7 @@ class ColoEmbedding(ColoModule):
...
@@ -32,9 +30,7 @@ class ColoEmbedding(ColoModule):
self
.
_register_allowed_patterns
(
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
dist_specs
=
{
'weight'
:
'weight'
:
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
},
},
mode
=
'col'
,
mode
=
'col'
,
)
)
...
...
colossalai/nn/parallel/layers/linear.py
View file @
060b917d
from
.colo_module
import
ColoModule
from
.colo_module
import
ColoModule
from
colossalai.tensor
import
ComputePattern
,
distspec
from
colossalai.tensor
import
ComputePattern
,
distspec
,
ProcessGroup
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
class
ColoLinear
(
ColoModule
):
class
ColoLinear
(
ColoModule
):
...
@@ -10,22 +8,19 @@ class ColoLinear(ColoModule):
...
@@ -10,22 +8,19 @@ class ColoLinear(ColoModule):
super
(
ColoLinear
,
self
).
__init__
()
super
(
ColoLinear
,
self
).
__init__
()
self
.
_register_shard_params
([
'weight'
,
'bias'
])
self
.
_register_shard_params
([
'weight'
,
'bias'
])
def
register
(
self
,
compute_pattern
):
def
register
(
self
,
compute_pattern
,
pg
:
ProcessGroup
):
if
not
compute_pattern
in
self
.
_allowed_patterns
:
if
not
compute_pattern
in
self
.
_allowed_patterns
:
if
ComputePattern
.
TP1D
==
compute_pattern
:
if
ComputePattern
.
TP1D
==
compute_pattern
:
self
.
_set_TP1D
()
self
.
_set_TP1D
(
pg
)
def
_set_TP1D
(
self
):
def
_set_TP1D
(
self
,
pg
):
# TP1D Row Linear
# TP1D Row Linear
_compute_pattern
=
ComputePattern
.
TP1D
_compute_pattern
=
ComputePattern
.
TP1D
self
.
_register_allowed_patterns
(
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
dist_specs
=
{
'weight'
:
'weight'
:
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
'bias'
:
None
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
'bias'
:
None
},
},
mode
=
'row'
,
mode
=
'row'
,
)
)
...
@@ -34,12 +29,8 @@ class ColoLinear(ColoModule):
...
@@ -34,12 +29,8 @@ class ColoLinear(ColoModule):
self
.
_register_allowed_patterns
(
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
dist_specs
=
{
'weight'
:
'weight'
:
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
'bias'
:
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()])
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
'bias'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)])
},
},
mode
=
'col'
,
mode
=
'col'
,
)
)
...
...
colossalai/nn/parallel/layers/module_utils.py
View file @
060b917d
from
typing
import
Dict
from
typing
import
Dict
from
colossalai.tensor
import
ColoParameter
,
ComputeSpec
,
TensorSpec
from
colossalai.tensor
import
ColoParameter
,
ComputeSpec
,
TensorSpec
,
ProcessGroup
from
.
import
ColoModule
from
.
import
ColoModule
import
torch
import
torch
...
@@ -29,7 +29,7 @@ def get_colo_module(module: torch.nn.Module):
...
@@ -29,7 +29,7 @@ def get_colo_module(module: torch.nn.Module):
return
None
return
None
def
check_colo_module
(
module
:
torch
.
nn
.
Module
,
recursive
=
True
):
def
check_colo_module
(
module
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
recursive
=
True
):
if
is_colo_module
(
module
):
if
is_colo_module
(
module
):
colo_module
=
get_colo_module
(
module
)
colo_module
=
get_colo_module
(
module
)
param_names
=
colo_module
.
get_param_names
()
param_names
=
colo_module
.
get_param_names
()
...
@@ -50,7 +50,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
...
@@ -50,7 +50,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
continue
continue
if
compute_pattern
is
not
None
:
if
compute_pattern
is
not
None
:
colo_module
.
register
(
compute_pattern
)
colo_module
.
register
(
compute_pattern
,
pg
)
if
not
colo_module
.
has_compute_pattern
(
compute_pattern
):
if
not
colo_module
.
has_compute_pattern
(
compute_pattern
):
raise
Exception
(
raise
Exception
(
f
'Invalid ColoParameter spec: ComputePattern
{
compute_pattern
}
in
{
module
}
is not allowed.'
)
f
'Invalid ColoParameter spec: ComputePattern
{
compute_pattern
}
in
{
module
}
is not allowed.'
)
...
@@ -76,16 +76,20 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
...
@@ -76,16 +76,20 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
raise
Exception
(
f
'Invalid ColoParameter spec: Params in
{
module
}
are incorrectly sharded.'
)
raise
Exception
(
f
'Invalid ColoParameter spec: Params in
{
module
}
are incorrectly sharded.'
)
if
recursive
==
True
:
if
recursive
==
True
:
for
submodule
in
module
.
children
():
for
submodule
in
module
.
children
():
check_colo_module
(
submodule
,
recursive
=
True
)
check_colo_module
(
submodule
,
pg
=
pg
,
recursive
=
True
)
def
init_colo_module
(
module
:
torch
.
nn
.
Module
,
compute_spec
:
ComputeSpec
,
recursive
=
True
,
mode
=
'default'
):
def
init_colo_module
(
module
:
torch
.
nn
.
Module
,
compute_spec
:
ComputeSpec
,
pg
:
ProcessGroup
,
recursive
=
True
,
mode
=
'default'
):
compute_pattern
=
compute_spec
.
compute_pattern
compute_pattern
=
compute_spec
.
compute_pattern
if
is_colo_module
(
module
):
if
is_colo_module
(
module
):
# for each param
# for each param
# set DistSpec and ComputeSpec
# set DistSpec and ComputeSpec
colo_module
=
get_colo_module
(
module
)
colo_module
=
get_colo_module
(
module
)
colo_module
.
register
(
compute_pattern
)
colo_module
.
register
(
compute_pattern
,
pg
)
if
not
colo_module
.
has_compute_pattern_with_mode
(
compute_pattern
,
mode
=
mode
):
if
not
colo_module
.
has_compute_pattern_with_mode
(
compute_pattern
,
mode
=
mode
):
raise
NotImplementedError
raise
NotImplementedError
# a set for modules which update at least one param in the init process.
# a set for modules which update at least one param in the init process.
...
@@ -101,7 +105,7 @@ def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursi
...
@@ -101,7 +105,7 @@ def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursi
for
mod
in
param
.
shared_param_modules
:
for
mod
in
param
.
shared_param_modules
:
modules_update_param
.
add
(
mod
)
modules_update_param
.
add
(
mod
)
for
mod
in
modules_update_param
:
for
mod
in
modules_update_param
:
check_colo_module
(
mod
,
recursive
=
False
)
check_colo_module
(
mod
,
pg
,
recursive
=
False
)
if
recursive
==
True
:
if
recursive
==
True
:
for
submodule
in
module
.
children
():
for
submodule
in
module
.
children
():
init_colo_module
(
submodule
,
compute_spec
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
submodule
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
colossalai/tensor/colo_tensor.py
View file @
060b917d
...
@@ -78,6 +78,12 @@ class ColoTensor(torch.Tensor):
...
@@ -78,6 +78,12 @@ class ColoTensor(torch.Tensor):
def
is_model_data
(
self
)
->
bool
:
def
is_model_data
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
MODEL
return
self
.
_type
==
TensorType
.
MODEL
def
get_process_group
(
self
)
->
'ProcessGroup'
:
return
self
.
_tensor_spec
.
dist_spec
.
process_group
def
get_tp_world_size
(
self
)
->
int
:
return
self
.
_tensor_spec
.
dist_spec
.
process_group
.
tp_world_size
()
@
classmethod
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
if
kwargs
is
None
:
if
kwargs
is
None
:
...
...
colossalai/tensor/dist_spec_mgr.py
View file @
060b917d
...
@@ -5,6 +5,7 @@ from contextlib import contextmanager
...
@@ -5,6 +5,7 @@ from contextlib import contextmanager
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
packaging
import
version
from
packaging
import
version
from
colossalai.logging
import
get_dist_logger
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
...
@@ -64,7 +65,7 @@ class DistSpecManager:
...
@@ -64,7 +65,7 @@ class DistSpecManager:
DistSpecManager
.
_sanity_check
(
old_dist_spec
,
dist_spec
)
DistSpecManager
.
_sanity_check
(
old_dist_spec
,
dist_spec
)
chunk
=
tensor
chunk
=
tensor
idx
=
dist_spec
.
process_group
.
rank
()
idx
=
dist_spec
.
process_group
.
tp_local_
rank
()
num_parts
=
prod
(
dist_spec
.
num_partitions
)
num_parts
=
prod
(
dist_spec
.
num_partitions
)
for
i
,
dim
in
enumerate
(
dist_spec
.
dims
):
for
i
,
dim
in
enumerate
(
dist_spec
.
dims
):
num_parts
//=
dist_spec
.
num_partitions
[
i
]
num_parts
//=
dist_spec
.
num_partitions
[
i
]
...
@@ -91,8 +92,9 @@ class DistSpecManager:
...
@@ -91,8 +92,9 @@ class DistSpecManager:
saved_dev
=
tensor
.
device
saved_dev
=
tensor
.
device
tensor
.
data
=
tensor
.
data
.
cuda
()
tensor
.
data
=
tensor
.
data
.
cuda
()
buffer
=
[
torch
.
empty_like
(
tensor
)
for
_
in
range
(
old_dist_spec
.
process_group
.
size
())]
buffer
=
[
torch
.
empty_like
(
tensor
)
for
_
in
range
(
old_dist_spec
.
process_group
.
tp_world_size
())]
dist
.
all_gather
(
buffer
,
tensor
,
group
=
old_dist_spec
.
process_group
)
assert
tensor
.
device
.
type
==
'cuda'
dist
.
all_gather
(
buffer
,
tensor
,
group
=
old_dist_spec
.
process_group
.
tp_process_group
())
for
i
in
range
(
len
(
old_dist_spec
.
dims
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
old_dist_spec
.
dims
)
-
1
,
-
1
,
-
1
):
new_buffer
=
[]
new_buffer
=
[]
dim
=
old_dist_spec
.
dims
[
i
]
dim
=
old_dist_spec
.
dims
[
i
]
...
@@ -108,14 +110,14 @@ class DistSpecManager:
...
@@ -108,14 +110,14 @@ class DistSpecManager:
@
staticmethod
@
staticmethod
def
_all_to_all
(
tensor
:
torch
.
Tensor
,
old_dist_spec
:
_DistSpec
,
dist_spec
:
_DistSpec
)
->
torch
.
Tensor
:
def
_all_to_all
(
tensor
:
torch
.
Tensor
,
old_dist_spec
:
_DistSpec
,
dist_spec
:
_DistSpec
)
->
torch
.
Tensor
:
world_size
=
old_dist_spec
.
process_group
.
size
()
world_size
=
old_dist_spec
.
process_group
.
tp_world_
size
()
if
world_size
==
1
:
if
world_size
==
1
:
return
tensor
return
tensor
assert
tensor
.
device
.
type
==
"cuda"
and
dist
.
get_backend
(
old_dist_spec
.
process_group
)
==
"nccl"
,
\
assert
tensor
.
device
.
type
==
"cuda"
and
old_dist_spec
.
process_group
.
backend
==
"nccl"
,
\
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll "
\
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll "
\
f
"collective function, however, we got
{
tensor
.
device
.
type
}
device and "
\
f
"collective function, however, we got
{
tensor
.
device
.
type
}
device and "
\
f
"
{
dist
.
get_backend
(
old_dist_spec
.
process_group
)
}
backend"
f
"
{
old_dist_spec
.
process_group
.
backend
}
backend"
gather_dim
=
old_dist_spec
.
dims
[
0
]
gather_dim
=
old_dist_spec
.
dims
[
0
]
scatter_dim
=
dist_spec
.
dims
[
0
]
scatter_dim
=
dist_spec
.
dims
[
0
]
...
@@ -126,7 +128,7 @@ class DistSpecManager:
...
@@ -126,7 +128,7 @@ class DistSpecManager:
scatter_list
=
[
t
.
contiguous
()
for
t
in
torch
.
tensor_split
(
tensor
,
world_size
,
scatter_dim
)]
scatter_list
=
[
t
.
contiguous
()
for
t
in
torch
.
tensor_split
(
tensor
,
world_size
,
scatter_dim
)]
gather_list
=
[
torch
.
empty
(
*
shapes
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
world_size
)]
gather_list
=
[
torch
.
empty
(
*
shapes
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
world_size
)]
dist
.
all_to_all
(
gather_list
,
scatter_list
,
group
=
old_dist_spec
.
process_group
)
dist
.
all_to_all
(
gather_list
,
scatter_list
,
group
=
old_dist_spec
.
process_group
.
tp_process_group
()
)
output_
=
torch
.
cat
(
gather_list
,
dim
=
gather_dim
).
contiguous
()
output_
=
torch
.
cat
(
gather_list
,
dim
=
gather_dim
).
contiguous
()
assert
output_
.
shape
[
scatter_dim
]
==
scattered_dim_size
and
output_
.
shape
[
gather_dim
]
==
gathered_dim_size
assert
output_
.
shape
[
scatter_dim
]
==
scattered_dim_size
and
output_
.
shape
[
gather_dim
]
==
gathered_dim_size
...
...
colossalai/tensor/distspec.py
View file @
060b917d
from
enum
import
Enum
from
enum
import
Enum
from
torch.distributed
import
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
from
typing
import
Optional
,
List
from
typing
import
Optional
,
List
from
numpy
import
prod
from
numpy
import
prod
...
@@ -51,8 +51,8 @@ def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
...
@@ -51,8 +51,8 @@ def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
def
shard
(
process_group
:
ProcessGroup
,
dims
:
List
[
int
],
num_partitions
:
List
[
int
])
->
_DistSpec
:
def
shard
(
process_group
:
ProcessGroup
,
dims
:
List
[
int
],
num_partitions
:
List
[
int
])
->
_DistSpec
:
assert
process_group
is
not
None
assert
process_group
is
not
None
and
isinstance
(
process_group
,
ProcessGroup
)
assert
isinstance
(
dims
,
list
)
and
isinstance
(
num_partitions
,
list
)
assert
isinstance
(
dims
,
list
)
and
isinstance
(
num_partitions
,
list
)
assert
len
(
dims
)
==
len
(
num_partitions
)
assert
len
(
dims
)
==
len
(
num_partitions
)
assert
prod
(
num_partitions
)
==
process_group
.
size
(),
f
"
{
num_partitions
}
{
process_group
.
size
()
}
"
assert
prod
(
num_partitions
)
==
process_group
.
tp_world_
size
(),
f
"
{
num_partitions
}
{
process_group
.
tp_world_
size
()
}
"
return
_DistSpec
(
DistPlacementPattern
.
SHARD
,
process_group
,
dims
=
tuple
(
dims
),
num_partitions
=
tuple
(
num_partitions
))
return
_DistSpec
(
DistPlacementPattern
.
SHARD
,
process_group
,
dims
=
tuple
(
dims
),
num_partitions
=
tuple
(
num_partitions
))
colossalai/tensor/process_group.py
View file @
060b917d
import
torch
import
torch
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
colossalai.logging
import
get_dist_logger
class
ProcessGroup
:
class
ProcessGroup
:
...
@@ -41,12 +42,12 @@ class ProcessGroup:
...
@@ -41,12 +42,12 @@ class ProcessGroup:
if
dp_degree
and
not
tp_degree
:
if
dp_degree
and
not
tp_degree
:
self
.
_dp_degree
=
dp_degree
self
.
_dp_degree
=
dp_degree
assert
self
.
_world_size
%
self
.
_dp_degree
==
0
,
f
"DP degree
{
dp_degree
}
should be divisible by
{
self
.
_world_size
}
hen DP degree is None"
assert
self
.
_world_size
%
self
.
_dp_degree
==
0
,
f
"DP degree
{
dp_degree
}
should be divisible by
{
self
.
_world_size
}
hen DP degree is None"
self
.
_tp_degree
=
self
.
_world_size
/
dp_degree
self
.
_tp_degree
=
self
.
_world_size
/
/
dp_degree
if
not
dp_degree
and
tp_degree
:
if
not
dp_degree
and
tp_degree
:
self
.
_tp_degree
=
tp_degree
self
.
_tp_degree
=
tp_degree
assert
self
.
_world_size
%
self
.
_tp_degree
==
0
,
f
"TP degree
{
tp_degree
}
should be divisible by
{
self
.
_world_size
}
when DP degree is None"
assert
self
.
_world_size
%
self
.
_tp_degree
==
0
,
f
"TP degree
{
tp_degree
}
should be divisible by
{
self
.
_world_size
}
when DP degree is None"
self
.
_dp_degree
=
self
.
_world_size
/
tp_degree
self
.
_dp_degree
=
self
.
_world_size
/
/
tp_degree
self
.
_tp_rank_list
=
[]
self
.
_tp_rank_list
=
[]
self
.
_dp_rank_list
=
[]
self
.
_dp_rank_list
=
[]
...
@@ -58,12 +59,48 @@ class ProcessGroup:
...
@@ -58,12 +59,48 @@ class ProcessGroup:
if
rank_id
//
self
.
_tp_degree
==
self
.
_rank
//
self
.
_tp_degree
:
if
rank_id
//
self
.
_tp_degree
==
self
.
_rank
//
self
.
_tp_degree
:
self
.
_tp_rank_list
.
append
(
rank_id
)
self
.
_tp_rank_list
.
append
(
rank_id
)
self
.
_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_tp_rank_list
,
backend
=
backend
)
assert
backend
==
'nccl'
self
.
_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_dp_rank_list
,
backend
=
backend
)
self
.
_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_tp_rank_list
)
self
.
_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_dp_rank_list
)
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'
{
self
.
_rank
}
initialize TP group on
{
self
.
_tp_rank_list
}
DP group pn
{
self
.
_dp_rank_list
}
'
)
@
property
def
backend
(
self
):
return
self
.
_backend
def
__eq__
(
self
,
obj
:
'ProcessGroup'
)
->
bool
:
if
not
isinstance
(
obj
,
ProcessGroup
):
return
False
if
self
.
_rank
!=
obj
.
_rank
:
assert
False
if
self
.
_rank_list
!=
obj
.
_rank_list
:
assert
False
if
self
.
_tp_rank_list
!=
obj
.
_tp_rank_list
:
assert
False
if
self
.
_dp_rank_list
!=
obj
.
_dp_rank_list
:
assert
False
if
self
.
_backend
!=
obj
.
_backend
:
assert
False
if
self
.
_tp_degree
!=
obj
.
_tp_degree
:
return
False
if
self
.
_dp_degree
!=
obj
.
_dp_degree
:
return
False
return
True
def
rank
(
self
):
return
self
.
_rank
def
world_size
(
self
):
def
world_size
(
self
):
return
self
.
_world_size
return
self
.
_world_size
def
tp_local_rank
(
self
):
return
self
.
_rank
%
self
.
_tp_degree
def
dp_local_rank
(
self
):
return
self
.
_rank
//
self
.
_tp_degree
def
dp_world_size
(
self
):
def
dp_world_size
(
self
):
return
len
(
self
.
_dp_rank_list
)
return
len
(
self
.
_dp_rank_list
)
...
...
colossalai/tensor/tensor_spec.py
View file @
060b917d
...
@@ -17,11 +17,12 @@ class TensorSpec(object):
...
@@ -17,11 +17,12 @@ class TensorSpec(object):
self
.
compute_spec
=
compute_spec
self
.
compute_spec
=
compute_spec
self
.
dist_spec
=
dist_spec
self
.
dist_spec
=
dist_spec
# TODO(jiaruifang) actually need tp process group
def
get_process_group
(
self
):
def
get_process_group
(
self
):
return
self
.
dist_spec
.
process_group
return
self
.
dist_spec
.
process_group
def
get_process_group_size
(
self
):
def
get_process_group_size
(
self
):
return
dist
.
get_world_size
(
self
.
dist_spec
.
process_group
)
return
dist
.
get_world_size
(
self
.
dist_spec
.
process_group
.
tp_process_group
()
)
def
get_placement
(
self
):
def
get_placement
(
self
):
return
self
.
dist_spec
.
placement
return
self
.
dist_spec
.
placement
...
@@ -30,7 +31,7 @@ class TensorSpec(object):
...
@@ -30,7 +31,7 @@ class TensorSpec(object):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
REPLICATE
\
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
REPLICATE
\
or
(
len
(
self
.
dist_spec
.
num_partitions
)
==
1
or
(
len
(
self
.
dist_spec
.
num_partitions
)
==
1
and
self
.
dist_spec
.
num_partitions
[
0
]
==
1
)
\
and
self
.
dist_spec
.
num_partitions
[
0
]
==
1
)
\
or
(
self
.
dist_spec
.
process_group
.
size
()
==
1
)
or
(
self
.
dist_spec
.
process_group
.
tp_world_
size
()
==
1
)
def
is_shard_1dcol
(
self
):
def
is_shard_1dcol
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
...
...
tests/test_ddp/test_ddp_ignore_params.py
View file @
060b917d
...
@@ -15,6 +15,7 @@ import torch.distributed as dist
...
@@ -15,6 +15,7 @@ import torch.distributed as dist
import
os
import
os
import
random
import
random
import
numpy
as
np
import
numpy
as
np
from
colossalai.tensor
import
ProcessGroup
def
set_seed
(
seed
):
def
set_seed
(
seed
):
...
@@ -27,14 +28,16 @@ def set_seed(seed):
...
@@ -27,14 +28,16 @@ def set_seed(seed):
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
return
ColoDDP
(
module
)
pg
=
ProcessGroup
()
return
ColoDDP
(
module
,
process_group
=
pg
)
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
)
->
ZeroDDP
:
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
)
->
ZeroDDP
:
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
2
)
if
use_chunk
else
None
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
2
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
)
chunk_manager
=
ChunkManager
(
chunk_size
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
pg
=
ProcessGroup
()
return
ZeroDDP
(
module
,
gemini_manager
,
pg
)
class
Net
(
torch
.
nn
.
Module
):
class
Net
(
torch
.
nn
.
Module
):
...
...
tests/test_ddp/test_ddp_state_dict.py
View file @
060b917d
...
@@ -13,6 +13,7 @@ from colossalai.nn.parallel import ZeroDDP, ColoDDP
...
@@ -13,6 +13,7 @@ from colossalai.nn.parallel import ZeroDDP, ColoDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Callable
from
typing
import
Callable
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
colossalai.tensor
import
ProcessGroup
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
...
@@ -22,14 +23,16 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
...
@@ -22,14 +23,16 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
return
ColoDDP
(
module
)
pg
=
ProcessGroup
()
return
ColoDDP
(
module
,
process_group
=
pg
)
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
,
use_zero
:
bool
=
False
)
->
ZeroDDP
:
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
,
use_zero
:
bool
=
False
)
->
ZeroDDP
:
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
4
)
if
use_chunk
else
None
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
4
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
enable_distributed_storage
=
use_zero
)
chunk_manager
=
ChunkManager
(
chunk_size
,
enable_distributed_storage
=
use_zero
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
pg
=
ProcessGroup
()
return
ZeroDDP
(
module
,
gemini_manager
,
process_group
=
pg
)
def
run_state_dict
(
ddp_init_func
:
Callable
[[
torch
.
nn
.
Module
],
ColoDDP
]):
def
run_state_dict
(
ddp_init_func
:
Callable
[[
torch
.
nn
.
Module
],
ColoDDP
]):
...
...
tests/test_tensor/_utils/_util.py
View file @
060b917d
...
@@ -41,7 +41,7 @@ def tensor_equal(A, B):
...
@@ -41,7 +41,7 @@ def tensor_equal(A, B):
return
torch
.
allclose
(
A
,
B
,
rtol
=
1e-3
,
atol
=
1e-1
)
return
torch
.
allclose
(
A
,
B
,
rtol
=
1e-3
,
atol
=
1e-1
)
def
tensor_shard_equal
(
tensor
:
torch
.
Tensor
,
shard
:
torch
.
Tensor
):
def
tensor_shard_equal
(
tensor
:
torch
.
Tensor
,
shard
:
torch
.
Tensor
,
rank
,
world_size
):
assert
tensor
.
ndim
==
shard
.
ndim
assert
tensor
.
ndim
==
shard
.
ndim
if
tensor
.
shape
==
shard
.
shape
:
if
tensor
.
shape
==
shard
.
shape
:
return
tensor_equal
(
tensor
,
shard
)
return
tensor_equal
(
tensor
,
shard
)
...
@@ -50,8 +50,10 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
...
@@ -50,8 +50,10 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
if
dims_not_eq
.
numel
()
==
1
:
if
dims_not_eq
.
numel
()
==
1
:
# 1D shard
# 1D shard
dim
=
dims_not_eq
.
item
()
dim
=
dims_not_eq
.
item
()
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
if
world_size
is
None
:
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
if
rank
is
None
:
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
return
tensor_equal
(
tensor
.
chunk
(
world_size
,
dim
)[
rank
],
shard
)
return
tensor_equal
(
tensor
.
chunk
(
world_size
,
dim
)[
rank
],
shard
)
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
tests/test_tensor/test_addmm_tp.py
View file @
060b917d
...
@@ -3,14 +3,12 @@ import torch
...
@@ -3,14 +3,12 @@ import torch
import
pytest
import
pytest
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
from
colossalai.tensor
import
distspec
from
colossalai.tensor
import
distspec
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
colossalai.context
import
ParallelMode
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
functools
import
partial
from
functools
import
partial
from
colossalai.core
import
global_context
as
gpc
from
_utils
import
tensor_shard_equal
,
tensor_equal
from
_utils
import
tensor_shard_equal
,
tensor_equal
...
@@ -38,18 +36,14 @@ class Conv1D(nn.Module):
...
@@ -38,18 +36,14 @@ class Conv1D(nn.Module):
return
x
return
x
def
init_1d_row
(
weight
,
bias
):
def
init_1d_row
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
def
init_1d_col
(
weight
,
bias
):
def
init_1d_col
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
bias
.
set_tensor_spec
(
spec
)
bias
.
set_tensor_spec
(
spec
)
...
@@ -59,7 +53,9 @@ def run_with_spec(spec_init_func):
...
@@ -59,7 +53,9 @@ def run_with_spec(spec_init_func):
model
=
Conv1D
(
4
,
16
).
cuda
()
model
=
Conv1D
(
4
,
16
).
cuda
()
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()))
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()))
bias
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
bias
.
detach
()))
bias
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
bias
.
detach
()))
spec_init_func
(
weight
,
bias
)
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
spec_init_func
(
weight
,
bias
,
pg
)
x
=
torch
.
rand
(
2
,
16
).
cuda
()
x
=
torch
.
rand
(
2
,
16
).
cuda
()
out
=
model
(
x
)
out
=
model
(
x
)
colo_out
=
torch
.
addmm
(
bias
,
x
,
weight
)
colo_out
=
torch
.
addmm
(
bias
,
x
,
weight
)
...
@@ -68,13 +64,12 @@ def run_with_spec(spec_init_func):
...
@@ -68,13 +64,12 @@ def run_with_spec(spec_init_func):
grad
=
torch
.
rand_like
(
out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
)
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
tensor_shard_equal
(
model
.
bias
.
grad
,
bias
.
grad
)
tensor_shard_equal
(
model
.
bias
.
grad
,
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_spec
(
init_1d_row
)
run_with_spec
(
init_1d_row
)
run_with_spec
(
init_1d_col
)
run_with_spec
(
init_1d_col
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment