Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9bcd2fd4
Unverified
Commit
9bcd2fd4
authored
Jul 11, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 11, 2022
Browse files
[tensor] a shorter shard and replicate spec (#1245)
parent
2699dfbb
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
77 additions
and
84 deletions
+77
-84
colossalai/fx/passes/shard_1d_pass.py
colossalai/fx/passes/shard_1d_pass.py
+2
-3
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+6
-5
colossalai/nn/_ops/embedding.py
colossalai/nn/_ops/embedding.py
+7
-5
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+3
-4
colossalai/nn/_ops/layernorm.py
colossalai/nn/_ops/layernorm.py
+2
-2
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+7
-6
colossalai/nn/parallel/layers/colo_module.py
colossalai/nn/parallel/layers/colo_module.py
+0
-10
colossalai/nn/parallel/layers/embedding.py
colossalai/nn/parallel/layers/embedding.py
+3
-5
colossalai/nn/parallel/layers/linear.py
colossalai/nn/parallel/layers/linear.py
+4
-4
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+4
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+10
-9
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+1
-1
tests/test_tensor/test_addmm_tp.py
tests/test_tensor/test_addmm_tp.py
+3
-3
tests/test_tensor/test_dist_spec_mgr.py
tests/test_tensor/test_dist_spec_mgr.py
+7
-8
tests/test_tensor/test_embedding_bag_tp.py
tests/test_tensor/test_embedding_bag_tp.py
+2
-2
tests/test_tensor/test_embedding_tp.py
tests/test_tensor/test_embedding_tp.py
+3
-3
tests/test_tensor/test_gpt.py
tests/test_tensor/test_gpt.py
+3
-3
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+3
-3
tests/test_tensor/test_loss_func.py
tests/test_tensor/test_loss_func.py
+2
-2
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+5
-5
No files found.
colossalai/fx/passes/shard_1d_pass.py
View file @
9bcd2fd4
import
torch
import
torch
from
torch.fx.node
import
map_arg
from
colossalai.tensor
import
ColoTensorSpec
,
distspec
,
ProcessGroup
,
ComputeSpec
,
ComputePattern
,
ShardSpec
from
colossalai.tensor
import
ColoTensorSpec
,
distspec
,
ProcessGroup
,
ComputeSpec
,
ComputePattern
def
weight_split
(
weight
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
nn
.
parameter
.
Parameter
:
def
weight_split
(
weight
:
torch
.
Tensor
,
dim
:
int
)
->
torch
.
nn
.
parameter
.
Parameter
:
...
@@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
...
@@ -25,7 +24,7 @@ def weight_split(weight: torch.Tensor, dim: int) -> torch.nn.parameter.Parameter
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
spec
=
ColoTensorSpec
(
pg
,
distspec
.
shard
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
ColoTensorSpec
(
pg
,
ShardSpec
([
dim
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
# As you has constructed a Spec, why not directly convert the tensor to ColoTensor.
setattr
(
weight
,
"fx_attr"
,
spec
)
setattr
(
weight
,
"fx_attr"
,
spec
)
return
weight
return
weight
...
...
colossalai/nn/_ops/addmm.py
View file @
9bcd2fd4
import
torch
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ComputePattern
,
ComputePattern
,
ComputeSpec
,
ColoTensor
from
colossalai.tensor
import
ComputePattern
,
ComputePattern
,
ComputeSpec
,
ColoTensor
from
colossalai.tensor
import
distspec
,
ColoTensorSpec
from
colossalai.tensor
import
distspec
,
ColoTensorSpec
,
ShardSpec
,
ReplicaSpec
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
from
._utils
import
reduce_input
,
reduce_grad
from
._utils
import
reduce_input
,
reduce_grad
...
@@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
...
@@ -11,7 +11,8 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
# beta * input + alpha * All-Reduce(Output) = res
mat1
=
mat1
.
redistribute
(
distspec
.
shard
([
-
1
],
[
mat2
.
get_tp_world_size
()]))
mat1
=
mat1
.
redistribute
(
ShardSpec
([
-
1
],
[
mat2
.
get_tp_world_size
()]))
# Output:P
# Output:P
partial_output
=
torch
.
mm
(
mat1
,
mat2
)
partial_output
=
torch
.
mm
(
mat1
,
mat2
)
...
@@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
...
@@ -20,7 +21,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# input
# input
assert
not
input_tensor
.
has_compute_spec
(),
'Invalid input spec for 1Drow addmm op'
assert
not
input_tensor
.
has_compute_spec
(),
'Invalid input spec for 1Drow addmm op'
output
=
beta
*
input_tensor
+
alpha
*
output
output
=
beta
*
input_tensor
+
alpha
*
output
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
distspec
.
r
eplica
te
()))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
R
eplica
Spec
()))
return
output
return
output
...
@@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
...
@@ -28,11 +29,11 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha
:
Number
)
->
ColoTensor
:
alpha
:
Number
)
->
ColoTensor
:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec
=
mat2
.
compute_spec
compute_spec
=
mat2
.
compute_spec
mat1
=
mat1
.
redistribute
(
distspec
.
r
eplica
te
())
mat1
=
mat1
.
redistribute
(
R
eplica
Spec
())
mat1
=
reduce_grad
(
mat1
,
mat1
.
get_process_group
())
mat1
=
reduce_grad
(
mat1
,
mat1
.
get_process_group
())
output_parallel
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
)
output_parallel
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
)
output_spec
=
ColoTensorSpec
(
input_tensor
.
get_process_group
(),
distspec
.
shard
([
-
1
],
[
mat2
.
get_tp_world_size
()]),
output_spec
=
ColoTensorSpec
(
input_tensor
.
get_process_group
(),
ShardSpec
([
-
1
],
[
mat2
.
get_tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
...
...
colossalai/nn/_ops/embedding.py
View file @
9bcd2fd4
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
typing
import
Optional
from
typing
import
Optional
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ComputePattern
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
dists
pec
from
colossalai.tensor
import
ComputePattern
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
ShardSpec
,
ReplicaS
pec
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
,
reduce_input
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
,
reduce_input
...
@@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
...
@@ -14,7 +14,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse
:
bool
=
False
)
->
ColoTensor
:
sparse
:
bool
=
False
)
->
ColoTensor
:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
# Gather splitted lookup table
input_tensor
=
input_tensor
.
redistribute
(
distspec
.
replicate
())
input_tensor
=
input_tensor
.
redistribute
(
ReplicaSpec
())
output_parallel
=
F
.
embedding
(
input_tensor
,
output_parallel
=
F
.
embedding
(
input_tensor
,
weight
,
weight
,
...
@@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
...
@@ -23,7 +24,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
norm_type
=
norm_type
,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
)
sparse
=
sparse
)
output_spec
=
ColoTensorSpec
(
weight
.
get_process_group
(),
distspec
.
shard
([
-
1
],
[
weight
.
get_tp_world_size
()]),
output_spec
=
ColoTensorSpec
(
weight
.
get_process_group
(),
ShardSpec
([
-
1
],
[
weight
.
get_tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
...
@@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
...
@@ -46,7 +47,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here
# Find index in this shard and mask those not here
# Reduce all
# Reduce all
pg
=
weight
.
get_process_group
()
pg
=
weight
.
get_process_group
()
input_tensor
=
input_tensor
.
redistribute
(
distspec
.
replicate
())
input_tensor
=
input_tensor
.
redistribute
(
ReplicaSpec
())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank
=
weight
.
get_process_group
().
tp_local_rank
()
tensor_parallel_rank
=
weight
.
get_process_group
().
tp_local_rank
()
...
@@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
...
@@ -74,7 +76,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output
[
input_mask
,
:]
=
0.
partial_output
[
input_mask
,
:]
=
0.
# Reduce across all the model parallel GPUs.
# Reduce across all the model parallel GPUs.
output
=
reduce_input
(
partial_output
,
weight
.
get_process_group
())
output
=
reduce_input
(
partial_output
,
weight
.
get_process_group
())
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
weight
.
get_process_group
(),
distspec
.
r
eplica
te
()))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
weight
.
get_process_group
(),
R
eplica
Spec
()))
return
output
return
output
...
...
colossalai/nn/_ops/embedding_bag.py
View file @
9bcd2fd4
...
@@ -2,7 +2,7 @@ import torch.nn.functional as F
...
@@ -2,7 +2,7 @@ import torch.nn.functional as F
from
typing
import
Optional
from
typing
import
Optional
from
torch
import
Tensor
from
torch
import
Tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ComputePattern
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
,
ColoTensorSpec
from
colossalai.tensor
import
ComputePattern
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
,
ColoTensorSpec
,
ShardSpec
,
ReplicaSpec
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
...
@@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
...
@@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
# Gather splitted lookup table
pg
=
weight
.
get_process_group
()
pg
=
weight
.
get_process_group
()
input_tensor
=
input_tensor
.
redistribute
(
distspec
.
r
eplica
te
())
input_tensor
=
input_tensor
.
redistribute
(
R
eplica
Spec
())
output_parallel
=
F
.
embedding_bag
(
input_tensor
,
output_parallel
=
F
.
embedding_bag
(
input_tensor
,
weight
,
weight
,
...
@@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
...
@@ -33,8 +33,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
per_sample_weights
=
per_sample_weights
,
per_sample_weights
=
per_sample_weights
,
include_last_offset
=
include_last_offset
,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
)
padding_idx
=
padding_idx
)
output_spec
=
ColoTensorSpec
(
pg
,
distspec
.
shard
([
-
1
],
[
weight
.
get_tp_world_size
()]),
output_spec
=
ColoTensorSpec
(
pg
,
ShardSpec
([
-
1
],
[
weight
.
get_tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
if
weight
.
compute_spec
.
output_replicate
:
if
weight
.
compute_spec
.
output_replicate
:
...
...
colossalai/nn/_ops/layernorm.py
View file @
9bcd2fd4
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ColoTensor
,
distspec
,
ColoTensorSpec
from
colossalai.tensor
import
ColoTensor
,
distspec
,
ColoTensorSpec
,
ReplicaSpec
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
...
@@ -16,7 +16,7 @@ def colo_layernorm(
...
@@ -16,7 +16,7 @@ def colo_layernorm(
assert
isinstance
(
weight
,
ColoTensor
)
assert
isinstance
(
weight
,
ColoTensor
)
input_tensor
=
convert_to_colo_tensor
(
input_tensor
,
weight
.
get_process_group
())
input_tensor
=
convert_to_colo_tensor
(
input_tensor
,
weight
.
get_process_group
())
bias
=
convert_to_colo_tensor
(
bias
,
weight
.
get_process_group
())
bias
=
convert_to_colo_tensor
(
bias
,
weight
.
get_process_group
())
input_tensor
=
input_tensor
.
redistribute
(
distspec
.
r
eplica
te
())
input_tensor
=
input_tensor
.
redistribute
(
R
eplica
Spec
())
output
=
F
.
layer_norm
(
input_tensor
,
normalized_shape
,
weight
=
weight
,
bias
=
bias
,
eps
=
eps
)
output
=
F
.
layer_norm
(
input_tensor
,
normalized_shape
,
weight
=
weight
,
bias
=
bias
,
eps
=
eps
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
input_tensor
.
get_process_group
()))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
ColoTensorSpec
(
input_tensor
.
get_process_group
()))
...
...
colossalai/nn/_ops/linear.py
View file @
9bcd2fd4
...
@@ -3,8 +3,7 @@ from typing import Optional
...
@@ -3,8 +3,7 @@ from typing import Optional
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
._utils
import
reduce_input
,
reduce_grad
from
._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
,
ColoTensorSpec
from
colossalai.tensor
import
ComputePattern
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
ShardSpec
,
ReplicaSpec
,
ColoTensorSpec
from
colossalai.nn.graph
import
register_colo_graph
,
GraphOpNode
,
GraphGlobalEnv
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
def
colo_linear_1Drow
(
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
...
@@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
...
@@ -12,7 +11,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res
# All-Reduce(Output) + bias = res
# Input:S[1]
# Input:S[1]
pg
=
weight
.
get_process_group
()
pg
=
weight
.
get_process_group
()
input_tensor
=
input_tensor
.
redistribute
(
distspec
.
shard
([
-
1
],
[
weight
.
get_tp_world_size
()]))
input_tensor
=
input_tensor
.
redistribute
(
ShardSpec
([
-
1
],
[
weight
.
get_tp_world_size
()]))
# Output:P
# Output:P
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
...
@@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
...
@@ -24,7 +23,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
assert
not
bias
.
has_compute_spec
(),
'Invalid bias spec for 1Drow Linear op'
assert
not
bias
.
has_compute_spec
(),
'Invalid bias spec for 1Drow Linear op'
output
=
output
+
bias
output
=
output
+
bias
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
pg
,
distspec
.
r
eplica
te
()))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
ColoTensorSpec
(
pg
,
R
eplica
Spec
()))
return
output
return
output
...
@@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
...
@@ -33,13 +32,15 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output)
# All-Gather(Output)
# Input:B
# Input:B
compute_spec
=
weight
.
compute_spec
compute_spec
=
weight
.
compute_spec
input_tensor
=
input_tensor
.
redistribute
(
distspec
.
replicate
())
input_tensor
=
input_tensor
.
redistribute
(
ReplicaSpec
())
input_parallel
=
reduce_grad
(
input_tensor
,
weight
.
get_process_group
())
input_parallel
=
reduce_grad
(
input_tensor
,
weight
.
get_process_group
())
output_parallel
=
F
.
linear
(
input_parallel
,
weight
,
bias
)
output_parallel
=
F
.
linear
(
input_parallel
,
weight
,
bias
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
ColoTensorSpec
(
weight
.
get_process_group
(),
spec
=
ColoTensorSpec
(
weight
.
get_process_group
(),
distspec
.
shard
([
-
1
],
[
weight
.
get_tp_world_size
()]),
ShardSpec
([
-
1
],
[
weight
.
get_tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
)))
ComputeSpec
(
ComputePattern
.
TP1D
)))
if
compute_spec
.
output_replicate
:
if
compute_spec
.
output_replicate
:
return
output
.
to_replicate
()
return
output
.
to_replicate
()
...
...
colossalai/nn/parallel/layers/colo_module.py
View file @
9bcd2fd4
...
@@ -7,16 +7,6 @@ class ColoModule(object):
...
@@ -7,16 +7,6 @@ class ColoModule(object):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_shard_params
:
List
[
str
]
=
[]
self
.
_shard_params
:
List
[
str
]
=
[]
# Example:
# {ComputePattern.TP1D:
# 'default':
# 'weight':
# distspec.shard(xxxxx)
# 'bias':
# distspec.shard(xxxxx)
# 'row': ...
# 'col': ...
# }
self
.
_allowed_patterns
:
Dict
[
ComputePattern
,
Dict
[
str
,
Dict
[
str
,
_DistSpec
]]]
=
{}
self
.
_allowed_patterns
:
Dict
[
ComputePattern
,
Dict
[
str
,
Dict
[
str
,
_DistSpec
]]]
=
{}
def
_register_shard_params
(
self
,
params
:
List
[
str
]):
def
_register_shard_params
(
self
,
params
:
List
[
str
]):
...
...
colossalai/nn/parallel/layers/embedding.py
View file @
9bcd2fd4
from
.colo_module
import
ColoModule
from
.colo_module
import
ColoModule
from
colossalai.tensor
import
ComputePattern
,
distspec
,
ProcessGroup
from
colossalai.tensor
import
ComputePattern
,
distspec
,
ProcessGroup
,
ShardSpec
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
class
ColoEmbedding
(
ColoModule
):
class
ColoEmbedding
(
ColoModule
):
...
@@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
...
@@ -21,7 +19,7 @@ class ColoEmbedding(ColoModule):
self
.
_register_allowed_patterns
(
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
dist_specs
=
{
'weight'
:
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
'weight'
:
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
},
},
mode
=
'row'
,
mode
=
'row'
,
)
)
...
@@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
...
@@ -30,7 +28,7 @@ class ColoEmbedding(ColoModule):
self
.
_register_allowed_patterns
(
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
dist_specs
=
{
'weight'
:
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
'weight'
:
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
},
},
mode
=
'col'
,
mode
=
'col'
,
)
)
...
...
colossalai/nn/parallel/layers/linear.py
View file @
9bcd2fd4
from
.colo_module
import
ColoModule
from
.colo_module
import
ColoModule
from
colossalai.tensor
import
ComputePattern
,
distspec
,
ProcessGroup
from
colossalai.tensor
import
ComputePattern
,
distspec
,
ProcessGroup
,
ShardSpec
class
ColoLinear
(
ColoModule
):
class
ColoLinear
(
ColoModule
):
...
@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
...
@@ -19,7 +19,7 @@ class ColoLinear(ColoModule):
self
.
_register_allowed_patterns
(
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
dist_specs
=
{
'weight'
:
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
'weight'
:
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
'bias'
:
None
'bias'
:
None
},
},
mode
=
'row'
,
mode
=
'row'
,
...
@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
...
@@ -29,8 +29,8 @@ class ColoLinear(ColoModule):
self
.
_register_allowed_patterns
(
self
.
_register_allowed_patterns
(
compute_pattern
=
_compute_pattern
,
compute_pattern
=
_compute_pattern
,
dist_specs
=
{
dist_specs
=
{
'weight'
:
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
'weight'
:
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
'bias'
:
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()])
'bias'
:
ShardSpec
([
0
],
[
pg
.
tp_world_size
()])
},
},
mode
=
'col'
,
mode
=
'col'
,
)
)
...
...
colossalai/tensor/__init__.py
View file @
9bcd2fd4
from
.process_group
import
ProcessGroup
from
.process_group
import
ProcessGroup
from
.tensor_spec
import
ColoTensorSpec
from
.tensor_spec
import
ColoTensorSpec
from
.distspec
import
shard
as
ShardSpec
from
.distspec
import
replicate
as
ReplicaSpec
from
.compute_spec
import
ComputeSpec
,
ComputePattern
from
.compute_spec
import
ComputeSpec
,
ComputePattern
from
.colo_tensor
import
ColoTensor
from
.colo_tensor
import
ColoTensor
from
.colo_parameter
import
ColoParameter
from
.colo_parameter
import
ColoParameter
...
@@ -11,5 +14,5 @@ from . import distspec
...
@@ -11,5 +14,5 @@ from . import distspec
__all__
=
[
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'ChunkManager'
,
'TensorState'
,
'ProcessGroup'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'ChunkManager'
,
'TensorState'
,
'ProcessGroup'
,
'ColoTensorSpec'
,
'TensorSpec'
'ColoTensorSpec'
,
'TensorSpec'
,
'ShardSpec'
,
'ReplicaSpec'
]
]
colossalai/tensor/colo_tensor.py
View file @
9bcd2fd4
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
functools
import
lru_cache
from
functools
import
lru_cache
from
colossalai.tensor
import
ColoTensorSpec
from
colossalai.tensor
import
ColoTensorSpec
from
colossalai.tensor
import
distspec
,
ProcessGroup
from
colossalai.tensor
import
ProcessGroup
,
ReplicaSpec
from
colossalai.tensor.dist_spec_mgr
import
DistSpecManager
from
colossalai.tensor.dist_spec_mgr
import
DistSpecManager
from
colossalai.tensor.distspec
import
_DistSpec
,
DistPlacementPattern
from
colossalai.tensor.distspec
import
_DistSpec
,
DistPlacementPattern
from
typing
import
Optional
,
Set
,
Callable
from
typing
import
Optional
,
Set
,
Callable
...
@@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
...
@@ -51,21 +51,21 @@ class ColoTensor(torch.Tensor):
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
""" Data Structure for Tensor in Colossal-AI. It is a subclass of torch.Tensor.
Args:
Args:
data (torch.Tensor): a torch tensor used as the payload the colotensor.
data (torch.Tensor): a torch tensor used as the payload the colotensor.
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(
distspec.r
eplica
te
()).
spec (ColoTensorSpec, optional): the tensor spec of initialization. Defaults to ColoTensorSpec(
R
eplica
Spec
()).
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The signature of the function has to be consistent with the __new__ except for the 1st arg.
The class should be initialized with a torch tensor in the following ways.
The class should be initialized with a torch tensor in the following ways.
1. directly init.
1. directly init.
>>> pg = ProcessGroup()
>>> pg = ProcessGroup()
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg,
distspec.r
eplica
te
())
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = ColoTensorSpec(pg,
R
eplica
Spec
())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec =
distspec.shard
(process_group=ProcessGroup(tp=world_size),
>>> shard_spec =
ShardSpec
(process_group=ProcessGroup(tp=world_size),
>>> dims=[0],
>>> dims=[0],
>>> num_partitions=[world_size])
>>> num_partitions=[world_size])
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> tensor_spec = ColoTensorSpec(pg, shard_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
>>> colo_t2 = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
2. use static method from_torch_tensor
2. use static method from_torch_tensor
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg,
distspec.r
eplica
te
())
>>> colo_t = ColoTensor.from_torch_tensor(torch.randn(2,3), spec = ColoTensorSpec(pg,
R
eplica
Spec
())
"""
"""
def
__new__
(
cls
,
data
:
torch
.
Tensor
,
spec
:
ColoTensorSpec
)
->
'ColoTensor'
:
def
__new__
(
cls
,
data
:
torch
.
Tensor
,
spec
:
ColoTensorSpec
)
->
'ColoTensor'
:
...
@@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
...
@@ -85,7 +85,7 @@ class ColoTensor(torch.Tensor):
# If not set spec, use a DP process group and replicate dist spec
# If not set spec, use a DP process group and replicate dist spec
if
spec
is
None
:
if
spec
is
None
:
self
.
has_initialized
=
False
self
.
has_initialized
=
False
self
.
dist_spec
=
distspec
.
r
eplica
te
()
self
.
dist_spec
=
R
eplica
Spec
()
self
.
compute_spec
=
None
self
.
compute_spec
=
None
self
.
process_group
=
ProcessGroup
()
self
.
process_group
=
ProcessGroup
()
else
:
else
:
...
@@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
...
@@ -194,13 +194,14 @@ class ColoTensor(torch.Tensor):
"""to_replicate_
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
an inline member function, converting dist spec of the tensor to REPLICATE
"""
"""
self
.
_redistribute
(
dist_spec
=
distspec
.
r
eplica
te
())
self
.
_redistribute
(
dist_spec
=
R
eplica
Spec
())
def
to_replicate
(
self
)
->
'ColoTensor'
:
def
to_replicate
(
self
)
->
'ColoTensor'
:
"""to_replicate
"""to_replicate
converting dist spec of the tensor to REPLICATE
converting dist spec of the tensor to REPLICATE
"""
"""
return
self
.
redistribute
(
distspec
.
replicate
())
return
self
.
redistribute
(
ReplicaSpec
())
@
staticmethod
@
staticmethod
def
from_torch_tensor
(
tensor
:
torch
.
Tensor
,
spec
:
Optional
[
ColoTensorSpec
]
=
None
)
->
'ColoTensor'
:
def
from_torch_tensor
(
tensor
:
torch
.
Tensor
,
spec
:
Optional
[
ColoTensorSpec
]
=
None
)
->
'ColoTensor'
:
...
@@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
...
@@ -234,7 +235,7 @@ class ColoTensor(torch.Tensor):
"""
"""
if
self
.
is_replicate
():
if
self
.
is_replicate
():
return
super
().
view
(
*
args
)
return
super
().
view
(
*
args
)
replicated_t
=
self
.
redistribute
(
dist_spec
=
distspec
.
r
eplica
te
())
replicated_t
=
self
.
redistribute
(
dist_spec
=
R
eplica
Spec
())
return
replicated_t
.
view
(
*
args
)
return
replicated_t
.
view
(
*
args
)
def
size_global
(
self
,
args
:
Optional
[
int
]
=
None
):
def
size_global
(
self
,
args
:
Optional
[
int
]
=
None
):
...
...
colossalai/utils/model/colo_init_context.py
View file @
9bcd2fd4
from
.utils
import
InsertPostInitMethodToModuleSubClasses
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
import
torch
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
distspec
,
ProcessGroup
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
distspec
,
ProcessGroup
,
ReplicaSpec
from
colossalai.nn.parallel.layers
import
register_colo_module
,
\
from
colossalai.nn.parallel.layers
import
register_colo_module
,
\
ColoLinear
,
ColoEmbedding
ColoLinear
,
ColoEmbedding
...
...
tests/test_tensor/test_addmm_tp.py
View file @
9bcd2fd4
...
@@ -4,7 +4,7 @@ import pytest
...
@@ -4,7 +4,7 @@ import pytest
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
from
colossalai.tensor
import
dists
pec
from
colossalai.tensor
import
ShardS
pec
from
colossalai.tensor
import
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
colossalai.tensor
import
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
...
@@ -37,13 +37,13 @@ class Conv1D(nn.Module):
...
@@ -37,13 +37,13 @@ class Conv1D(nn.Module):
def
init_1d_row
(
weight
,
bias
,
pg
:
ProcessGroup
):
def
init_1d_row
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col
(
weight
,
bias
,
pg
:
ProcessGroup
):
def
init_1d_col
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
bias
.
set_tensor_spec
(
*
spec
)
bias
.
set_tensor_spec
(
*
spec
)
...
...
tests/test_tensor/test_dist_spec_mgr.py
View file @
9bcd2fd4
...
@@ -4,10 +4,9 @@ import torch.distributed as dist
...
@@ -4,10 +4,9 @@ import torch.distributed as dist
import
pytest
import
pytest
import
colossalai
import
colossalai
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.distributed.distributed_c10d
import
_get_default_group
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
DistSpecManager
,
distspec
,
ProcessGroup
from
colossalai.tensor
import
DistSpecManager
,
ProcessGroup
,
ShardSpec
,
ReplicaSpec
from
functools
import
partial
from
functools
import
partial
...
@@ -18,10 +17,10 @@ def run():
...
@@ -18,10 +17,10 @@ def run():
depth
=
int
(
math
.
sqrt
(
size
))
depth
=
int
(
math
.
sqrt
(
size
))
assert
depth
==
math
.
sqrt
(
size
)
assert
depth
==
math
.
sqrt
(
size
)
x
=
torch
.
rand
(
8
,
8
).
cuda
()
x
=
torch
.
rand
(
8
,
8
).
cuda
()
old_dist_spec
=
distspec
.
r
eplica
te
()
old_dist_spec
=
R
eplica
Spec
()
row_spec
=
distspec
.
shard
([
0
],
[
size
])
row_spec
=
ShardSpec
([
0
],
[
size
])
col_spec
=
distspec
.
shard
([
-
1
],
[
size
])
col_spec
=
ShardSpec
([
-
1
],
[
size
])
mat_spec
=
distspec
.
shard
([
0
,
1
],
[
depth
,
depth
])
mat_spec
=
ShardSpec
([
0
,
1
],
[
depth
,
depth
])
row_shard
=
DistSpecManager
.
_shard_as
(
x
,
old_dist_spec
,
row_spec
,
group
)
row_shard
=
DistSpecManager
.
_shard_as
(
x
,
old_dist_spec
,
row_spec
,
group
)
assert
torch
.
equal
(
x
.
chunk
(
size
,
0
)[
rank
],
row_shard
)
assert
torch
.
equal
(
x
.
chunk
(
size
,
0
)[
rank
],
row_shard
)
assert
torch
.
equal
(
x
,
DistSpecManager
.
_gather
(
row_shard
,
row_spec
,
group
))
assert
torch
.
equal
(
x
,
DistSpecManager
.
_gather
(
row_shard
,
row_spec
,
group
))
...
@@ -40,8 +39,8 @@ def check_mem():
...
@@ -40,8 +39,8 @@ def check_mem():
x
=
torch
.
rand
(
32
,
32
).
cuda
()
x
=
torch
.
rand
(
32
,
32
).
cuda
()
orig_mem
=
x
.
numel
()
*
x
.
element_size
()
orig_mem
=
x
.
numel
()
*
x
.
element_size
()
assert
torch
.
cuda
.
memory_allocated
()
==
orig_mem
assert
torch
.
cuda
.
memory_allocated
()
==
orig_mem
old_dist_spec
=
distspec
.
r
eplica
te
()
old_dist_spec
=
R
eplica
Spec
()
row_spec
=
distspec
.
shard
([
0
],
[
size
])
row_spec
=
ShardSpec
([
0
],
[
size
])
x
.
data
=
DistSpecManager
.
_shard_as
(
x
,
old_dist_spec
,
row_spec
,
pg
)
x
.
data
=
DistSpecManager
.
_shard_as
(
x
,
old_dist_spec
,
row_spec
,
pg
)
assert
x
.
size
(
0
)
==
32
//
size
and
x
.
size
(
1
)
==
32
assert
x
.
size
(
0
)
==
32
//
size
and
x
.
size
(
1
)
==
32
assert
torch
.
cuda
.
memory_allocated
()
==
orig_mem
//
size
assert
torch
.
cuda
.
memory_allocated
()
==
orig_mem
//
size
...
...
tests/test_tensor/test_embedding_bag_tp.py
View file @
9bcd2fd4
import
torch
import
torch
from
colossalai.tensor
import
dists
pec
,
ColoParameter
from
colossalai.tensor
import
ShardS
pec
,
ColoParameter
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
functools
import
partial
from
functools
import
partial
...
@@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal
...
@@ -14,7 +14,7 @@ from _utils import tensor_equal, tensor_shard_equal
def
init_1d_col
(
weight
,
pg
:
ProcessGroup
):
def
init_1d_col
(
weight
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
...
...
tests/test_tensor/test_embedding_tp.py
View file @
9bcd2fd4
import
torch
import
torch
from
colossalai.tensor
import
ColoTensor
,
dists
pec
from
colossalai.tensor
import
ColoTensor
,
ShardS
pec
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
functools
import
partial
from
functools
import
partial
...
@@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal
...
@@ -14,13 +14,13 @@ from _utils import tensor_equal, tensor_shard_equal
def
init_1d_row
(
weight
,
pg
:
ProcessGroup
):
def
init_1d_row
(
weight
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col
(
weight
,
pg
:
ProcessGroup
):
def
init_1d_col
(
weight
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
...
...
tests/test_tensor/test_gpt.py
View file @
9bcd2fd4
...
@@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -12,7 +12,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ColoTensor
Spec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
,
ProcessGroup
from
colossalai.tensor
import
Shard
Spec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.parallel_mode
import
ParallelMode
...
@@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
...
@@ -20,7 +20,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
tensor_spec
=
(
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
tensor_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
if
'weight'
in
n
and
'ln'
not
in
n
:
...
@@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
...
@@ -28,7 +28,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
...
...
tests/test_tensor/test_linear_tp.py
View file @
9bcd2fd4
import
torch
import
torch
from
colossalai.tensor
import
ColoTensor
,
dists
pec
from
colossalai.tensor
import
ColoTensor
,
ShardS
pec
from
functools
import
partial
from
functools
import
partial
...
@@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal
...
@@ -15,13 +15,13 @@ from _utils import tensor_equal, tensor_shard_equal
def
init_1d_row
(
weight
,
bias
,
pg
:
ProcessGroup
):
def
init_1d_row
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col
(
weight
,
bias
,
pg
:
ProcessGroup
):
def
init_1d_col
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
bias
.
set_tensor_spec
(
*
spec
)
bias
.
set_tensor_spec
(
*
spec
)
...
...
tests/test_tensor/test_loss_func.py
View file @
9bcd2fd4
...
@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
...
@@ -8,7 +8,7 @@ from colossalai.tensor import ColoTensor, ProcessGroup, ColoTensorSpec
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
dists
pec
,
ComputeSpec
,
ComputePattern
from
colossalai.tensor
import
ShardS
pec
,
ComputeSpec
,
ComputePattern
def
check_cross_entropy
():
def
check_cross_entropy
():
...
@@ -22,7 +22,7 @@ def check_cross_entropy():
...
@@ -22,7 +22,7 @@ def check_cross_entropy():
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
input_t_colo
=
ColoTensor
.
from_torch_tensor
(
tensor
=
input_ct
,
spec
=
ColoTensorSpec
(
pg
))
input_t_colo
=
ColoTensor
.
from_torch_tensor
(
tensor
=
input_ct
,
spec
=
ColoTensorSpec
(
pg
))
input_shard
=
input_t_colo
.
redistribute
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]))
input_shard
=
input_t_colo
.
redistribute
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]))
input_shard
.
set_tensor_spec
(
dist_spec
=
None
,
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
))
input_shard
.
set_tensor_spec
(
dist_spec
=
None
,
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
F
.
cross_entropy
(
input_t
,
target
)
output
=
F
.
cross_entropy
(
input_t
,
target
)
...
...
tests/test_tensor/test_model.py
View file @
9bcd2fd4
...
@@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
...
@@ -11,7 +11,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
dists
pec
,
ColoTensorSpec
,
ComputePattern
,
\
from
colossalai.tensor
import
ShardS
pec
,
ColoTensorSpec
,
ComputePattern
,
\
ComputeSpec
,
ColoTensor
,
DistSpecManager
,
ProcessGroup
ComputeSpec
,
ColoTensor
,
DistSpecManager
,
ProcessGroup
from
colossalai.nn.optimizer
import
ColoOptimizer
from
colossalai.nn.optimizer
import
ColoOptimizer
...
@@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs
...
@@ -19,28 +19,28 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def
init_1d_row_linear
(
weight
:
ColoTensor
,
pg
:
ProcessGroup
):
def
init_1d_row_linear
(
weight
:
ColoTensor
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_process_group
(
pg
)
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col_linear
(
weight
,
pg
):
def
init_1d_col_linear
(
weight
,
pg
):
spec
=
(
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_process_group
(
pg
)
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_row_embedding
(
weight
,
pg
):
def
init_1d_row_embedding
(
weight
,
pg
):
spec
=
(
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_process_group
(
pg
)
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
def
init_1d_col_embedding
(
weight
,
pg
):
def
init_1d_col_embedding
(
weight
,
pg
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_process_group
(
pg
)
weight
.
set_process_group
(
pg
)
weight
.
set_tensor_spec
(
*
spec
)
weight
.
set_tensor_spec
(
*
spec
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment