Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
060b917d
Unverified
Commit
060b917d
authored
Jul 04, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 04, 2022
Browse files
[refactor] remove gpc dependency in colotensor's _ops (#1189)
parent
abf6a262
Changes
33
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
342 additions
and
123 deletions
+342
-123
colossalai/nn/_ops/_utils.py
colossalai/nn/_ops/_utils.py
+185
-0
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+8
-12
colossalai/nn/_ops/embedding.py
colossalai/nn/_ops/embedding.py
+9
-13
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+2
-3
colossalai/nn/_ops/layernorm.py
colossalai/nn/_ops/layernorm.py
+1
-1
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+7
-9
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+15
-10
colossalai/nn/parallel/layers/colo_module.py
colossalai/nn/parallel/layers/colo_module.py
+1
-1
colossalai/nn/parallel/layers/embedding.py
colossalai/nn/parallel/layers/embedding.py
+6
-10
colossalai/nn/parallel/layers/linear.py
colossalai/nn/parallel/layers/linear.py
+8
-17
colossalai/nn/parallel/layers/module_utils.py
colossalai/nn/parallel/layers/module_utils.py
+12
-8
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+6
-0
colossalai/tensor/dist_spec_mgr.py
colossalai/tensor/dist_spec_mgr.py
+9
-7
colossalai/tensor/distspec.py
colossalai/tensor/distspec.py
+3
-3
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+41
-4
colossalai/tensor/tensor_spec.py
colossalai/tensor/tensor_spec.py
+3
-2
tests/test_ddp/test_ddp_ignore_params.py
tests/test_ddp/test_ddp_ignore_params.py
+5
-2
tests/test_ddp/test_ddp_state_dict.py
tests/test_ddp/test_ddp_state_dict.py
+5
-2
tests/test_tensor/_utils/_util.py
tests/test_tensor/_utils/_util.py
+5
-3
tests/test_tensor/test_addmm_tp.py
tests/test_tensor/test_addmm_tp.py
+11
-16
No files found.
colossalai/nn/_ops/_utils.py
View file @
060b917d
import
torch
from
typing
import
Union
,
Optional
from
colossalai.tensor
import
ColoTensor
import
torch
import
torch.distributed
as
dist
from
colossalai.global_variables
import
tensor_parallel_env
as
env
from
colossalai.nn.layer.utils
import
divide
from
colossalai.tensor
import
ProcessGroup
GeneralTensor
=
Union
[
ColoTensor
,
torch
.
Tensor
]
Number
=
Union
[
int
,
float
]
...
...
@@ -10,3 +16,182 @@ def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTens
if
tensor
is
not
None
and
not
isinstance
(
tensor
,
ColoTensor
):
tensor
=
ColoTensor
.
from_torch_tensor
(
tensor
)
return
tensor
def
set_parallel_input
(
input_parallel
:
bool
):
env
.
parallel_input_1d
=
input_parallel
def
get_parallel_input
():
return
env
.
parallel_input_1d
def
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
):
index_f
=
rank
*
per_partition_vocab_size
index_l
=
index_f
+
per_partition_vocab_size
return
index_f
,
index_l
def
vocab_range_from_global_vocab_size
(
global_vocab_size
,
rank
,
world_size
):
per_partition_vocab_size
=
divide
(
global_vocab_size
,
world_size
)
return
vocab_range_from_per_partition_vocab_size
(
per_partition_vocab_size
,
rank
)
def
_reduce
(
input_
,
pg
:
ProcessGroup
):
# skip if only one rank involved
if
pg
.
tp_world_size
()
==
1
:
return
input_
assert
input_
.
device
.
type
==
'cuda'
group
=
pg
.
tp_process_group
()
dist
.
all_reduce
(
input_
,
group
=
group
)
return
input_
def
_split
(
input_
,
pg
:
ProcessGroup
,
dim
=-
1
):
# skip if only one rank involved
world_size
=
pg
.
tp_world_size
()
if
world_size
==
1
:
return
input_
# Split along last dimension.
dim_size
=
input_
.
size
(
dim
)
assert
dim_size
%
world_size
==
0
,
\
f
'The dimension to split (
{
dim_size
}
) is not a multiple of world size (
{
world_size
}
), '
\
f
'cannot split tensor evenly'
tensor_list
=
torch
.
split
(
input_
,
dim_size
//
world_size
,
dim
=
dim
)
rank
=
pg
.
tp_local_rank
()
output
=
tensor_list
[
rank
].
contiguous
()
return
output
def
_gather
(
input_
,
pg
:
ProcessGroup
,
dim
=-
1
):
# skip if only one rank involved
world_size
=
pg
.
tp_world_size
()
if
world_size
==
1
:
return
input_
# all gather
rank
=
pg
.
tp_local_rank
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
]
=
input_
assert
input_
.
device
.
type
==
'cuda'
group
=
pg
.
tp_process_group
()
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
# concat
output
=
torch
.
cat
(
tensor_list
,
dim
=
dim
).
contiguous
()
return
output
class
_ReduceGrad
(
torch
.
autograd
.
Function
):
"""
Pass the input to the model parallel region.
Args:
input_: input matrix.
process_group: parallel mode.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
input_
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
):
ctx
.
mode
=
process_group
return
input_
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_reduce
(
grad_output
,
ctx
.
mode
),
None
class
_ReduceInput
(
torch
.
autograd
.
Function
):
"""
All-reduce the input from the model parallel region.
Args:
input_: input matrix.
process_group: parallel mode.
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
):
return
_reduce
(
input_
,
process_group
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
grad_output
,
None
class
_SplitForwardGatherBackward
(
torch
.
autograd
.
Function
):
"""
Split the input and keep only the corresponding chuck to the rank.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
,
dim
):
ctx
.
mode
=
process_group
ctx
.
dim
=
dim
return
_split
(
input_
,
process_group
,
dim
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
,
ctx
.
mode
,
ctx
.
dim
),
None
,
None
class
_GatherForwardSplitBackward
(
torch
.
autograd
.
Function
):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
process_group: parallel mode.
dim: dimension
"""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
process_group
,
dim
):
ctx
.
mode
=
process_group
ctx
.
dim
=
dim
return
_gather
(
input_
,
process_group
,
dim
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
,
ctx
.
mode
,
ctx
.
dim
),
None
,
None
def
reduce_grad
(
input_
,
process_group
):
return
_ReduceGrad
.
apply
(
input_
,
process_group
)
def
reduce_input
(
input_
,
process_group
):
return
_ReduceInput
.
apply
(
input_
,
process_group
)
def
split_forward_gather_backward
(
input_
,
process_group
,
dim
):
return
_SplitForwardGatherBackward
.
apply
(
input_
,
process_group
,
dim
)
def
gather_forward_split_backward
(
input_
,
process_group
,
dim
):
return
_GatherForwardSplitBackward
.
apply
(
input_
,
process_group
,
dim
)
colossalai/nn/_ops/addmm.py
View file @
060b917d
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
from
colossalai.tensor
import
distspec
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
from
._utils
import
reduce_input
,
reduce_grad
def
colo_addmm_1Drow
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
...
...
@@ -12,18 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
shard
(
mat2
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
tensor_spec
.
get_process_group_size
()]))
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
shard
(
mat2
.
get_process_group
(),
[
-
1
],
[
mat2
.
get_tp_world_size
()]))
# Output:P
partial_output
=
torch
.
mm
(
mat1
,
mat2
)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
partial_output
,
mat1
.
get_process_group
()
)
# input
assert
not
input_tensor
.
has_compute_spec
(),
'Invalid input spec for 1Drow addmm op'
output
=
beta
*
input_tensor
+
alpha
*
output
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
mat2
.
tensor_spec
.
get_process_group
())))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
mat2
.
get_process_group
())))
return
output
...
...
@@ -31,12 +28,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha
:
Number
)
->
ColoTensor
:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec
=
mat2
.
tensor_spec
.
compute_spec
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
tensor_spec
.
get_process_group
()))
mat1
=
reduce_grad
(
mat1
,
ParallelMode
.
PARALLEL_1D
)
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
get_process_group
()))
mat1
=
reduce_grad
(
mat1
,
mat1
.
get_process_group
()
)
output_parallel
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
tensor_spec
.
get_process_group_size
()]),
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
get_process_group
(),
[
-
1
],
[
mat2
.
get_tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
...
...
colossalai/nn/_ops/embedding.py
View file @
060b917d
import
torch.nn.functional
as
F
from
typing
import
Optional
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
,
reduce_input
def
colo_embedding_1Dcol
(
input_tensor
:
ColoTensor
,
...
...
@@ -17,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse
:
bool
=
False
)
->
ColoTensor
:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
get_process_group
()))
output_parallel
=
F
.
embedding
(
input_tensor
,
weight
,
...
...
@@ -26,8 +23,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_spec
.
get_process_group_size
()]),
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
get_process_group
(),
[
-
1
],
[
weight
.
get_tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
...
...
@@ -49,9 +45,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
get_process_group
()))
tensor_parallel_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank
=
weight
.
tensor_spec
.
dist_spec
.
process_group
.
tp_local_rank
()
num_embeddings_per_partition
=
weight
.
size_local
(
0
)
vocab_start_index
=
tensor_parallel_rank
*
num_embeddings_per_partition
vocab_end_index
=
vocab_start_index
+
num_embeddings_per_partition
...
...
@@ -75,9 +72,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Mask the output embedding.
partial_output
[
input_mask
,
:]
=
0.
# Reduce across all the model parallel GPUs.
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
())))
output
=
reduce_input
(
partial_output
,
weight
.
get_process_group
())
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
get_process_group
())))
return
output
...
...
colossalai/nn/_ops/embedding_bag.py
View file @
060b917d
...
...
@@ -32,8 +32,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights
=
per_sample_weights
,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_spec
.
get_process_group_size
()]),
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
get_process_group
(),
[
-
1
],
[
weight
.
get_tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
...
...
colossalai/nn/_ops/layernorm.py
View file @
060b917d
...
...
@@ -17,7 +17,7 @@ def colo_layernorm(
input_tensor
,
weight
,
bias
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
weight
,
bias
)))
# TODO (ver217): check dist spec
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
input_tensor
.
tensor_spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
input_tensor
.
get_process_group
()))
output
=
F
.
layer_norm
(
input_tensor
,
normalized_shape
,
weight
=
weight
,
bias
=
bias
,
eps
=
eps
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
input_tensor
.
tensor_spec
)
...
...
colossalai/nn/_ops/linear.py
View file @
060b917d
...
...
@@ -2,9 +2,8 @@ import torch.nn.functional as F
from
typing
import
Optional
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d
._utils
import
reduce_input
,
reduce_grad
from
._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
colossalai.context
import
ParallelMode
from
colossalai.nn.graph
import
register_colo_graph
,
GraphOpNode
,
GraphGlobalEnv
...
...
@@ -13,19 +12,18 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res
# Input:S[1]
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
shard
(
weight
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_spec
.
get_process_group
_size
()]))
distspec
.
shard
(
weight
.
get_process_group
(),
[
-
1
],
[
weight
.
get_tp_world
_size
()]))
# Output:P
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
reduce_input
(
partial_output
,
weight
.
get_process_group
()
)
# Bias
if
bias
is
not
None
:
assert
not
bias
.
has_compute_spec
(),
'Invalid bias spec for 1Drow Linear op'
output
=
output
+
bias
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
())))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
get_process_group
())))
return
output
...
...
@@ -35,13 +33,13 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B
compute_spec
=
weight
.
tensor_spec
.
compute_spec
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
()))
input_parallel
=
reduce_grad
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
)
input_parallel
=
reduce_grad
(
input_tensor
,
weight
.
tensor_spec
.
dist_spec
.
process_group
)
output_parallel
=
F
.
linear
(
input_parallel
,
weight
,
bias
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_spec
.
get_process_group
_size
()]),
distspec
.
shard
(
weight
.
get_process_group
(),
[
-
1
],
[
weight
.
get_tp_world
_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
)))
if
compute_spec
.
output_replicate
:
return
output
.
to_replicate
()
...
...
colossalai/nn/parallel/data_parallel.py
View file @
060b917d
import
torch
import
itertools
import
torch.distributed
as
dist
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
functools
import
partial
from
colossalai.zero.utils.zero_hook_v2
import
ZeROHookV2
from
colossalai.gemini.chunk
import
TensorState
,
Chunk
...
...
@@ -12,6 +10,7 @@ from typing import Dict, Iterable, List, Optional
from
colossalai.logging
import
get_dist_logger
from
collections
import
OrderedDict
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
.reducer
import
Reducer
try
:
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
,
_IncompatibleKeys
...
...
@@ -45,8 +44,8 @@ class ColoDDP(torch.nn.Module):
>>> from colossalai.core import global_context as gpc
>>> from colossalai.context import ParallelMode
>>> model = torch.nn.Linear(20, 1)
>>>
model = ColoDDP(model
)
>>>
//
model = ColoDDP(model, p
rocess_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA)
)
>>>
pg = ProcessGroup(tp_degree = world_size//2
)
>>> model = ColoDDP(model, p
g
)
>>> logits = model(x)
>>> loss = criterion(logits, labels)
>>> model.backward(loss)
...
...
@@ -55,13 +54,13 @@ class ColoDDP(torch.nn.Module):
module (torch.nn.Module): Module to apply DDP.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
If it's None, the default data parallel group will be used. Defaults to None.
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
cpu_
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
If it's None, the default CPU data parallel group will be used. Defaults to None.
"""
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
process_group
:
Colo
ProcessGroup
,
cpu_process_group
:
Optional
[
dist
.
ProcessGroup
]
=
None
,
bucket_cap_mb
:
int
=
25
,
rebuild_bucket
:
bool
=
True
)
->
None
:
...
...
@@ -69,8 +68,9 @@ class ColoDDP(torch.nn.Module):
super
().
__init__
()
self
.
module
=
module
self
.
comm_stream
:
torch
.
cuda
.
Stream
=
torch
.
cuda
.
Stream
()
self
.
process_group
=
process_group
or
gpc
.
get_group
(
ParallelMode
.
DATA
)
self
.
cpu_process_group
=
cpu_process_group
or
gpc
.
get_cpu_group
(
ParallelMode
.
DATA
)
assert
process_group
self
.
process_group
=
process_group
.
dp_process_group
()
self
.
dp_world_size
=
self
.
process_group
.
size
()
self
.
reducer
=
Reducer
(
bucket_cap_mb
)
self
.
rebuild_bucket
=
rebuild_bucket
...
...
@@ -120,6 +120,8 @@ class ColoDDP(torch.nn.Module):
return
empty_grad
else
:
#TODO(jiaruifang) fixme
raise
NotImplementedError
dist
.
all_reduce
(
grad
,
group
=
self
.
cpu_process_group
)
return
grad
...
...
@@ -191,8 +193,11 @@ class ZeroDDP(ColoDDP):
For more details, see the API reference of ``GeminiManager``.
"""
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
gemini_manager
:
GeminiManager
)
->
None
:
super
().
__init__
(
module
.
half
())
def
__init__
(
self
,
module
:
torch
.
nn
.
Module
,
gemini_manager
:
GeminiManager
,
process_group
:
Optional
[
ColoProcessGroup
]
=
None
)
->
None
:
super
().
__init__
(
module
.
half
(),
process_group
=
process_group
)
self
.
gemini_manager
=
gemini_manager
self
.
chunk_manager
=
gemini_manager
.
chunk_manager
self
.
param_op_hook
=
ZeROHookV2
(
gemini_manager
)
...
...
colossalai/nn/parallel/layers/colo_module.py
View file @
060b917d
...
...
@@ -52,5 +52,5 @@ class ColoModule(object):
def
get_param_names
(
self
):
return
self
.
_shard_params
def
register
(
self
,
compute_pattern
):
def
register
(
self
,
compute_pattern
,
pg
):
raise
NotImplementedError
colossalai/nn/parallel/layers/embedding.py
View file @
060b917d
from
.colo_module
import
ColoModule
from
colossalai.tensor
import
ComputePattern
,
distspec
from
colossalai.tensor
import
ComputePattern
,
distspec
,
ProcessGroup
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
...
...
@@ -10,20 +10,18 @@ class ColoEmbedding(ColoModule):
super
(
ColoEmbedding
,
self
).
__init__
()
self
.
_register_shard_params
([
'weight'
])
def
register
(
self
,
compute_pattern
):
def
register
(
self
,
compute_pattern
,
pg
:
ProcessGroup
):
if
not
compute_pattern
in
self
.
_allowed_patterns
:
if
ComputePattern
.
TP1D
==
compute_pattern
:
self
.
_set_TP1D
()
self
.
_set_TP1D
(
pg
)
def
_set_TP1D
(
self
):
def
_set_TP1D
(
self
,
pg
:
ProcessGroup
):
# TP1D Row Linear
_compute_pattern
=
ComputePattern
.
TP1D
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
'weight'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
'weight'
:
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
},
mode
=
'row'
,
)
...
...
@@ -32,9 +30,7 @@ class ColoEmbedding(ColoModule):
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
'weight'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
'weight'
:
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
},
mode
=
'col'
,
)
...
...
colossalai/nn/parallel/layers/linear.py
View file @
060b917d
from
.colo_module
import
ColoModule
from
colossalai.tensor
import
ComputePattern
,
distspec
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
ComputePattern
,
distspec
,
ProcessGroup
class
ColoLinear
(
ColoModule
):
...
...
@@ -10,22 +8,19 @@ class ColoLinear(ColoModule):
super
(
ColoLinear
,
self
).
__init__
()
self
.
_register_shard_params
([
'weight'
,
'bias'
])
def
register
(
self
,
compute_pattern
):
def
register
(
self
,
compute_pattern
,
pg
:
ProcessGroup
):
if
not
compute_pattern
in
self
.
_allowed_patterns
:
if
ComputePattern
.
TP1D
==
compute_pattern
:
self
.
_set_TP1D
()
self
.
_set_TP1D
(
pg
)
def
_set_TP1D
(
self
):
def
_set_TP1D
(
self
,
pg
):
# TP1D Row Linear
_compute_pattern
=
ComputePattern
.
TP1D
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
'weight'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
'bias'
:
None
'weight'
:
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
'bias'
:
None
},
mode
=
'row'
,
)
...
...
@@ -34,12 +29,8 @@ class ColoLinear(ColoModule):
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
'weight'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
'bias'
:
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)])
'weight'
:
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
'bias'
:
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()])
},
mode
=
'col'
,
)
...
...
colossalai/nn/parallel/layers/module_utils.py
View file @
060b917d
from
typing
import
Dict
from
colossalai.tensor
import
ColoParameter
,
ComputeSpec
,
TensorSpec
from
colossalai.tensor
import
ColoParameter
,
ComputeSpec
,
TensorSpec
,
ProcessGroup
from
.
import
ColoModule
import
torch
...
...
@@ -29,7 +29,7 @@ def get_colo_module(module: torch.nn.Module):
return
None
def
check_colo_module
(
module
:
torch
.
nn
.
Module
,
recursive
=
True
):
def
check_colo_module
(
module
:
torch
.
nn
.
Module
,
pg
:
ProcessGroup
,
recursive
=
True
):
if
is_colo_module
(
module
):
colo_module
=
get_colo_module
(
module
)
param_names
=
colo_module
.
get_param_names
()
...
...
@@ -50,7 +50,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
continue
if
compute_pattern
is
not
None
:
colo_module
.
register
(
compute_pattern
)
colo_module
.
register
(
compute_pattern
,
pg
)
if
not
colo_module
.
has_compute_pattern
(
compute_pattern
):
raise
Exception
(
f
'Invalid ColoParameter spec: ComputePattern
{
compute_pattern
}
in
{
module
}
is not allowed.'
)
...
...
@@ -76,16 +76,20 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
raise
Exception
(
f
'Invalid ColoParameter spec: Params in
{
module
}
are incorrectly sharded.'
)
if
recursive
==
True
:
for
submodule
in
module
.
children
():
check_colo_module
(
submodule
,
recursive
=
True
)
check_colo_module
(
submodule
,
pg
=
pg
,
recursive
=
True
)
def
init_colo_module
(
module
:
torch
.
nn
.
Module
,
compute_spec
:
ComputeSpec
,
recursive
=
True
,
mode
=
'default'
):
def
init_colo_module
(
module
:
torch
.
nn
.
Module
,
compute_spec
:
ComputeSpec
,
pg
:
ProcessGroup
,
recursive
=
True
,
mode
=
'default'
):
compute_pattern
=
compute_spec
.
compute_pattern
if
is_colo_module
(
module
):
# for each param
# set DistSpec and ComputeSpec
colo_module
=
get_colo_module
(
module
)
colo_module
.
register
(
compute_pattern
)
colo_module
.
register
(
compute_pattern
,
pg
)
if
not
colo_module
.
has_compute_pattern_with_mode
(
compute_pattern
,
mode
=
mode
):
raise
NotImplementedError
# a set for modules which update at least one param in the init process.
...
...
@@ -101,7 +105,7 @@ def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursi
for
mod
in
param
.
shared_param_modules
:
modules_update_param
.
add
(
mod
)
for
mod
in
modules_update_param
:
check_colo_module
(
mod
,
recursive
=
False
)
check_colo_module
(
mod
,
pg
,
recursive
=
False
)
if
recursive
==
True
:
for
submodule
in
module
.
children
():
init_colo_module
(
submodule
,
compute_spec
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
submodule
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
colossalai/tensor/colo_tensor.py
View file @
060b917d
...
...
@@ -78,6 +78,12 @@ class ColoTensor(torch.Tensor):
def
is_model_data
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
MODEL
def
get_process_group
(
self
)
->
'ProcessGroup'
:
return
self
.
_tensor_spec
.
dist_spec
.
process_group
def
get_tp_world_size
(
self
)
->
int
:
return
self
.
_tensor_spec
.
dist_spec
.
process_group
.
tp_world_size
()
@
classmethod
def
__torch_function__
(
cls
,
func
,
types
,
args
=
(),
kwargs
=
None
):
if
kwargs
is
None
:
...
...
colossalai/tensor/dist_spec_mgr.py
View file @
060b917d
...
...
@@ -5,6 +5,7 @@ from contextlib import contextmanager
import
torch
import
torch.distributed
as
dist
from
packaging
import
version
from
colossalai.logging
import
get_dist_logger
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
...
...
@@ -64,7 +65,7 @@ class DistSpecManager:
DistSpecManager
.
_sanity_check
(
old_dist_spec
,
dist_spec
)
chunk
=
tensor
idx
=
dist_spec
.
process_group
.
rank
()
idx
=
dist_spec
.
process_group
.
tp_local_
rank
()
num_parts
=
prod
(
dist_spec
.
num_partitions
)
for
i
,
dim
in
enumerate
(
dist_spec
.
dims
):
num_parts
//=
dist_spec
.
num_partitions
[
i
]
...
...
@@ -91,8 +92,9 @@ class DistSpecManager:
saved_dev
=
tensor
.
device
tensor
.
data
=
tensor
.
data
.
cuda
()
buffer
=
[
torch
.
empty_like
(
tensor
)
for
_
in
range
(
old_dist_spec
.
process_group
.
size
())]
dist
.
all_gather
(
buffer
,
tensor
,
group
=
old_dist_spec
.
process_group
)
buffer
=
[
torch
.
empty_like
(
tensor
)
for
_
in
range
(
old_dist_spec
.
process_group
.
tp_world_size
())]
assert
tensor
.
device
.
type
==
'cuda'
dist
.
all_gather
(
buffer
,
tensor
,
group
=
old_dist_spec
.
process_group
.
tp_process_group
())
for
i
in
range
(
len
(
old_dist_spec
.
dims
)
-
1
,
-
1
,
-
1
):
new_buffer
=
[]
dim
=
old_dist_spec
.
dims
[
i
]
...
...
@@ -108,14 +110,14 @@ class DistSpecManager:
@
staticmethod
def
_all_to_all
(
tensor
:
torch
.
Tensor
,
old_dist_spec
:
_DistSpec
,
dist_spec
:
_DistSpec
)
->
torch
.
Tensor
:
world_size
=
old_dist_spec
.
process_group
.
size
()
world_size
=
old_dist_spec
.
process_group
.
tp_world_
size
()
if
world_size
==
1
:
return
tensor
assert
tensor
.
device
.
type
==
"cuda"
and
dist
.
get_backend
(
old_dist_spec
.
process_group
)
==
"nccl"
,
\
assert
tensor
.
device
.
type
==
"cuda"
and
old_dist_spec
.
process_group
.
backend
==
"nccl"
,
\
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll "
\
f
"collective function, however, we got
{
tensor
.
device
.
type
}
device and "
\
f
"
{
dist
.
get_backend
(
old_dist_spec
.
process_group
)
}
backend"
f
"
{
old_dist_spec
.
process_group
.
backend
}
backend"
gather_dim
=
old_dist_spec
.
dims
[
0
]
scatter_dim
=
dist_spec
.
dims
[
0
]
...
...
@@ -126,7 +128,7 @@ class DistSpecManager:
scatter_list
=
[
t
.
contiguous
()
for
t
in
torch
.
tensor_split
(
tensor
,
world_size
,
scatter_dim
)]
gather_list
=
[
torch
.
empty
(
*
shapes
,
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
)
for
_
in
range
(
world_size
)]
dist
.
all_to_all
(
gather_list
,
scatter_list
,
group
=
old_dist_spec
.
process_group
)
dist
.
all_to_all
(
gather_list
,
scatter_list
,
group
=
old_dist_spec
.
process_group
.
tp_process_group
()
)
output_
=
torch
.
cat
(
gather_list
,
dim
=
gather_dim
).
contiguous
()
assert
output_
.
shape
[
scatter_dim
]
==
scattered_dim_size
and
output_
.
shape
[
gather_dim
]
==
gathered_dim_size
...
...
colossalai/tensor/distspec.py
View file @
060b917d
from
enum
import
Enum
from
torch.distributed
import
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
from
typing
import
Optional
,
List
from
numpy
import
prod
...
...
@@ -51,8 +51,8 @@ def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
def
shard
(
process_group
:
ProcessGroup
,
dims
:
List
[
int
],
num_partitions
:
List
[
int
])
->
_DistSpec
:
assert
process_group
is
not
None
assert
process_group
is
not
None
and
isinstance
(
process_group
,
ProcessGroup
)
assert
isinstance
(
dims
,
list
)
and
isinstance
(
num_partitions
,
list
)
assert
len
(
dims
)
==
len
(
num_partitions
)
assert
prod
(
num_partitions
)
==
process_group
.
size
(),
f
"
{
num_partitions
}
{
process_group
.
size
()
}
"
assert
prod
(
num_partitions
)
==
process_group
.
tp_world_
size
(),
f
"
{
num_partitions
}
{
process_group
.
tp_world_
size
()
}
"
return
_DistSpec
(
DistPlacementPattern
.
SHARD
,
process_group
,
dims
=
tuple
(
dims
),
num_partitions
=
tuple
(
num_partitions
))
colossalai/tensor/process_group.py
View file @
060b917d
import
torch
from
typing
import
List
,
Optional
from
colossalai.logging
import
get_dist_logger
class
ProcessGroup
:
...
...
@@ -41,12 +42,12 @@ class ProcessGroup:
if
dp_degree
and
not
tp_degree
:
self
.
_dp_degree
=
dp_degree
assert
self
.
_world_size
%
self
.
_dp_degree
==
0
,
f
"DP degree
{
dp_degree
}
should be divisible by
{
self
.
_world_size
}
hen DP degree is None"
self
.
_tp_degree
=
self
.
_world_size
/
dp_degree
self
.
_tp_degree
=
self
.
_world_size
/
/
dp_degree
if
not
dp_degree
and
tp_degree
:
self
.
_tp_degree
=
tp_degree
assert
self
.
_world_size
%
self
.
_tp_degree
==
0
,
f
"TP degree
{
tp_degree
}
should be divisible by
{
self
.
_world_size
}
when DP degree is None"
self
.
_dp_degree
=
self
.
_world_size
/
tp_degree
self
.
_dp_degree
=
self
.
_world_size
/
/
tp_degree
self
.
_tp_rank_list
=
[]
self
.
_dp_rank_list
=
[]
...
...
@@ -58,12 +59,48 @@ class ProcessGroup:
if
rank_id
//
self
.
_tp_degree
==
self
.
_rank
//
self
.
_tp_degree
:
self
.
_tp_rank_list
.
append
(
rank_id
)
self
.
_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_tp_rank_list
,
backend
=
backend
)
self
.
_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_dp_rank_list
,
backend
=
backend
)
assert
backend
==
'nccl'
self
.
_tp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_tp_rank_list
)
self
.
_dp_process_group
=
torch
.
distributed
.
new_group
(
ranks
=
self
.
_dp_rank_list
)
self
.
logger
=
get_dist_logger
(
'ProcessGroup'
)
self
.
logger
.
info
(
f
'
{
self
.
_rank
}
initialize TP group on
{
self
.
_tp_rank_list
}
DP group pn
{
self
.
_dp_rank_list
}
'
)
@
property
def
backend
(
self
):
return
self
.
_backend
def
__eq__
(
self
,
obj
:
'ProcessGroup'
)
->
bool
:
if
not
isinstance
(
obj
,
ProcessGroup
):
return
False
if
self
.
_rank
!=
obj
.
_rank
:
assert
False
if
self
.
_rank_list
!=
obj
.
_rank_list
:
assert
False
if
self
.
_tp_rank_list
!=
obj
.
_tp_rank_list
:
assert
False
if
self
.
_dp_rank_list
!=
obj
.
_dp_rank_list
:
assert
False
if
self
.
_backend
!=
obj
.
_backend
:
assert
False
if
self
.
_tp_degree
!=
obj
.
_tp_degree
:
return
False
if
self
.
_dp_degree
!=
obj
.
_dp_degree
:
return
False
return
True
def
rank
(
self
):
return
self
.
_rank
def
world_size
(
self
):
return
self
.
_world_size
def
tp_local_rank
(
self
):
return
self
.
_rank
%
self
.
_tp_degree
def
dp_local_rank
(
self
):
return
self
.
_rank
//
self
.
_tp_degree
def
dp_world_size
(
self
):
return
len
(
self
.
_dp_rank_list
)
...
...
colossalai/tensor/tensor_spec.py
View file @
060b917d
...
...
@@ -17,11 +17,12 @@ class TensorSpec(object):
self
.
compute_spec
=
compute_spec
self
.
dist_spec
=
dist_spec
# TODO(jiaruifang) actually need tp process group
def
get_process_group
(
self
):
return
self
.
dist_spec
.
process_group
def
get_process_group_size
(
self
):
return
dist
.
get_world_size
(
self
.
dist_spec
.
process_group
)
return
dist
.
get_world_size
(
self
.
dist_spec
.
process_group
.
tp_process_group
()
)
def
get_placement
(
self
):
return
self
.
dist_spec
.
placement
...
...
@@ -30,7 +31,7 @@ class TensorSpec(object):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
REPLICATE
\
or
(
len
(
self
.
dist_spec
.
num_partitions
)
==
1
and
self
.
dist_spec
.
num_partitions
[
0
]
==
1
)
\
or
(
self
.
dist_spec
.
process_group
.
size
()
==
1
)
or
(
self
.
dist_spec
.
process_group
.
tp_world_
size
()
==
1
)
def
is_shard_1dcol
(
self
):
return
self
.
dist_spec
.
placement
==
DistPlacementPattern
.
SHARD
\
...
...
tests/test_ddp/test_ddp_ignore_params.py
View file @
060b917d
...
...
@@ -15,6 +15,7 @@ import torch.distributed as dist
import
os
import
random
import
numpy
as
np
from
colossalai.tensor
import
ProcessGroup
def
set_seed
(
seed
):
...
...
@@ -27,14 +28,16 @@ def set_seed(seed):
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
return
ColoDDP
(
module
)
pg
=
ProcessGroup
()
return
ColoDDP
(
module
,
process_group
=
pg
)
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
)
->
ZeroDDP
:
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
2
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
pg
=
ProcessGroup
()
return
ZeroDDP
(
module
,
gemini_manager
,
pg
)
class
Net
(
torch
.
nn
.
Module
):
...
...
tests/test_ddp/test_ddp_state_dict.py
View file @
060b917d
...
...
@@ -13,6 +13,7 @@ from colossalai.nn.parallel import ZeroDDP, ColoDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Callable
from
collections
import
OrderedDict
from
colossalai.tensor
import
ProcessGroup
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
...
...
@@ -22,14 +23,16 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
return
ColoDDP
(
module
)
pg
=
ProcessGroup
()
return
ColoDDP
(
module
,
process_group
=
pg
)
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
,
use_zero
:
bool
=
False
)
->
ZeroDDP
:
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
4
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
enable_distributed_storage
=
use_zero
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
pg
=
ProcessGroup
()
return
ZeroDDP
(
module
,
gemini_manager
,
process_group
=
pg
)
def
run_state_dict
(
ddp_init_func
:
Callable
[[
torch
.
nn
.
Module
],
ColoDDP
]):
...
...
tests/test_tensor/_utils/_util.py
View file @
060b917d
...
...
@@ -41,7 +41,7 @@ def tensor_equal(A, B):
return
torch
.
allclose
(
A
,
B
,
rtol
=
1e-3
,
atol
=
1e-1
)
def
tensor_shard_equal
(
tensor
:
torch
.
Tensor
,
shard
:
torch
.
Tensor
):
def
tensor_shard_equal
(
tensor
:
torch
.
Tensor
,
shard
:
torch
.
Tensor
,
rank
,
world_size
):
assert
tensor
.
ndim
==
shard
.
ndim
if
tensor
.
shape
==
shard
.
shape
:
return
tensor_equal
(
tensor
,
shard
)
...
...
@@ -50,7 +50,9 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
if
dims_not_eq
.
numel
()
==
1
:
# 1D shard
dim
=
dims_not_eq
.
item
()
if
world_size
is
None
:
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)
if
rank
is
None
:
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
return
tensor_equal
(
tensor
.
chunk
(
world_size
,
dim
)[
rank
],
shard
)
else
:
...
...
tests/test_tensor/test_addmm_tp.py
View file @
060b917d
...
...
@@ -3,14 +3,12 @@ import torch
import
pytest
import
torch.nn
as
nn
import
torch.multiprocessing
as
mp
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
from
colossalai.tensor
import
distspec
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
colossalai.context
import
ParallelMode
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
functools
import
partial
from
colossalai.core
import
global_context
as
gpc
from
_utils
import
tensor_shard_equal
,
tensor_equal
...
...
@@ -38,18 +36,14 @@ class Conv1D(nn.Module):
return
x
def
init_1d_row
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
def
init_1d_row
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
def
init_1d_col
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
def
init_1d_col
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
bias
.
set_tensor_spec
(
spec
)
...
...
@@ -59,7 +53,9 @@ def run_with_spec(spec_init_func):
model
=
Conv1D
(
4
,
16
).
cuda
()
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()))
bias
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
bias
.
detach
()))
spec_init_func
(
weight
,
bias
)
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
spec_init_func
(
weight
,
bias
,
pg
)
x
=
torch
.
rand
(
2
,
16
).
cuda
()
out
=
model
(
x
)
colo_out
=
torch
.
addmm
(
bias
,
x
,
weight
)
...
...
@@ -68,13 +64,12 @@ def run_with_spec(spec_init_func):
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
)
tensor_shard_equal
(
model
.
bias
.
grad
,
bias
.
grad
)
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
tensor_shard_equal
(
model
.
bias
.
grad
,
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_spec
(
init_1d_row
)
run_with_spec
(
init_1d_col
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment