Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4b9bba81
Unverified
Commit
4b9bba81
authored
Jun 24, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 24, 2022
Browse files
[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)
parent
f4ef2243
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
110 additions
and
99 deletions
+110
-99
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+20
-16
colossalai/nn/_ops/element_wise.py
colossalai/nn/_ops/element_wise.py
+1
-1
colossalai/nn/_ops/embedding.py
colossalai/nn/_ops/embedding.py
+16
-10
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+10
-7
colossalai/nn/_ops/layernorm.py
colossalai/nn/_ops/layernorm.py
+2
-2
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+19
-14
colossalai/nn/_ops/loss.py
colossalai/nn/_ops/loss.py
+3
-3
colossalai/nn/parallel/layers/module_utils.py
colossalai/nn/parallel/layers/module_utils.py
+5
-5
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+1
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+9
-19
colossalai/tensor/compute_spec.py
colossalai/tensor/compute_spec.py
+2
-0
colossalai/tensor/param_op_hook.py
colossalai/tensor/param_op_hook.py
+1
-1
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+4
-4
tests/test_tensor/test_addmm_tp.py
tests/test_tensor/test_addmm_tp.py
+4
-3
tests/test_tensor/test_embedding_bag_tp.py
tests/test_tensor/test_embedding_bag_tp.py
+1
-1
tests/test_tensor/test_embedding_tp.py
tests/test_tensor/test_embedding_tp.py
+2
-2
tests/test_tensor/test_gpt.py
tests/test_tensor/test_gpt.py
+2
-2
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+3
-3
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+4
-4
tests/test_tensor/test_module_spec.py
tests/test_tensor/test_module_spec.py
+1
-1
No files found.
colossalai/nn/_ops/addmm.py
View file @
4b9bba81
...
...
@@ -13,34 +13,37 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# beta * input + alpha * All-Reduce(Output) = res
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
shard
(
mat2
.
spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
spec
.
get_process_group_size
()]))
distspec
.
shard
(
mat2
.
tensor_
spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
tensor_
spec
.
get_process_group_size
()]))
# Output:P
partial_output
=
torch
.
mm
(
mat1
,
mat2
)
# Reduce(Output)
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
# input
assert
not
input_tensor
.
has_spec
(),
'Invalid input spec for 1Drow addmm op'
assert
not
input_tensor
.
has_
compute_
spec
(),
'Invalid input spec for 1Drow addmm op'
output
=
beta
*
input_tensor
+
alpha
*
output
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
mat2
.
spec
.
get_process_group
())))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
mat2
.
tensor_spec
.
get_process_group
())))
return
output
def
colo_addmm_1Dcol
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
alpha
:
Number
)
->
ColoTensor
:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action
=
mat2
.
spec
.
compute_spec
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
spec
.
get_process_group
()))
compute_spec
=
mat2
.
tensor_
spec
.
compute_spec
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
tensor_
spec
.
get_process_group
()))
mat1
=
reduce_grad
(
mat1
,
ParallelMode
.
PARALLEL_1D
)
output_parallel
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
spec
.
get_process_group_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
tensor_spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
tensor_spec
.
get_process_group_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
# TODO(jiaruifang) addam is special case
# since gpt call view after the Op.
return
output
.
to_replicate
()
if
compute_spec
.
output_replicate
:
return
output
.
to_replicate
()
else
:
return
output
def
colo_addmm_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
...
...
@@ -64,14 +67,15 @@ def colo_addmm(input_tensor: GeneralTensor,
# Add communication logic before and after linear call.
ret_tensor
=
None
if
not
mat2
.
has_spec
():
# No Model Parallel Applied
assert
mat2
.
spec
.
is_gathered
(),
'Invalid mat2 spec for native addmm op'
assert
input_tensor
.
spec
.
is_gathered
(),
'Invalid input spec for native addmm op'
if
not
mat2
.
has_
compute_
spec
():
# No Model Parallel Applied
assert
mat2
.
tensor_
spec
.
is_gathered
(),
'Invalid mat2 spec for native addmm op'
assert
input_tensor
.
tensor_
spec
.
is_gathered
(),
'Invalid input spec for native addmm op'
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
))
elif
mat2
.
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
mat2
.
spec
.
is_1D_row
()
and
input_tensor
.
spec
.
is_gathered
():
elif
mat2
.
tensor_
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
mat2
.
tensor_
spec
.
is_1D_row
()
and
input_tensor
.
tensor_
spec
.
is_gathered
():
mode
=
'row'
elif
mat2
.
spec
.
is_1D_col
()
and
(
input_tensor
.
spec
.
is_1D_col
()
or
input_tensor
.
spec
.
is_1D_row
()):
elif
mat2
.
tensor_spec
.
is_1D_col
()
and
(
input_tensor
.
tensor_spec
.
is_1D_col
()
or
input_tensor
.
tensor_spec
.
is_1D_row
()):
mode
=
'col'
else
:
raise
NotImplementedError
...
...
colossalai/nn/_ops/element_wise.py
View file @
4b9bba81
...
...
@@ -18,7 +18,7 @@ def register_elementwise_op(op):
"""
output
=
op
(
input_tensor
,
*
args
,
**
kwargs
)
if
isinstance
(
input_tensor
,
ColoTensor
):
spec
=
copy
(
input_tensor
.
spec
)
spec
=
copy
(
input_tensor
.
tensor_
spec
)
return
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
spec
)
return
ColoTensor
.
from_torch_tensor
(
output
)
...
...
colossalai/nn/_ops/embedding.py
View file @
4b9bba81
...
...
@@ -17,7 +17,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse
:
bool
=
False
)
->
ColoTensor
:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_
spec
.
get_process_group
()))
output_parallel
=
F
.
embedding
(
input_tensor
,
weight
,
...
...
@@ -27,10 +27,15 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
distspec
.
shard
(
weight
.
tensor_
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_
spec
.
get_process_group_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
return
output
.
to_replicate
()
compute_spec
=
weight
.
tensor_spec
.
compute_spec
if
compute_spec
.
output_replicate
:
return
output
.
to_replicate
()
else
:
return
output
def
colo_embedding_1Drow
(
input_tensor
:
ColoTensor
,
...
...
@@ -43,7 +48,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_
spec
.
get_process_group
()))
tensor_parallel_rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
num_embeddings_per_partition
=
weight
.
size
(
0
)
...
...
@@ -70,7 +75,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output
[
input_mask
,
:]
=
0.
# Reduce across all the model parallel GPUs.
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
())))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
())))
return
output
...
...
@@ -108,8 +114,8 @@ def colo_embedding(input_tensor: GeneralTensor,
# Handle differen parallel actions.
if
not
weight
.
has_spec
():
# No Model Parallel Applied
assert
weight
.
spec
.
is_gathered
(),
'Invalid weight spec for native embedding op'
if
not
weight
.
has_
compute_
spec
():
# No Model Parallel Applied
assert
weight
.
tensor_
spec
.
is_gathered
(),
'Invalid weight spec for native embedding op'
return
ColoTensor
.
from_torch_tensor
(
F
.
embedding
(
input_tensor
,
weight
,
...
...
@@ -118,10 +124,10 @@ def colo_embedding(input_tensor: GeneralTensor,
norm_type
=
norm_type
,
scale_grad_by_freq
=
scale_grad_by_freq
,
sparse
=
sparse
))
elif
weight
.
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
spec
.
is_1D_row
():
elif
weight
.
tensor_
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
tensor_
spec
.
is_1D_row
():
mode
=
'row'
elif
weight
.
spec
.
is_1D_col
():
elif
weight
.
tensor_
spec
.
is_1D_col
():
mode
=
'col'
else
:
raise
NotImplementedError
...
...
colossalai/nn/_ops/embedding_bag.py
View file @
4b9bba81
...
...
@@ -19,7 +19,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
padding_idx
:
Optional
[
int
]
=
None
)
->
ColoTensor
:
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_
spec
.
get_process_group
()))
output_parallel
=
F
.
embedding_bag
(
input_tensor
,
weight
,
...
...
@@ -33,11 +33,14 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
distspec
.
shard
(
weight
.
tensor_
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_
spec
.
get_process_group_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
return
output
.
to_replicate
()
if
weight
.
tensor_spec
.
compute_spec
.
output_replicate
:
return
output
.
to_replicate
()
else
:
return
output
def
colo_embedding_bag_1d
(
tp_mode
:
str
,
...
...
@@ -86,8 +89,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
# Handle differen parallel actions.
if
not
weight
.
has_spec
():
# No Model Parallel Applied
assert
weight
.
spec
.
is_gathered
(),
'Invalid weight spec for native embedding op'
if
not
weight
.
has_
compute_
spec
():
# No Model Parallel Applied
assert
weight
.
tensor_
spec
.
is_gathered
(),
'Invalid weight spec for native embedding op'
return
ColoTensor
.
from_torch_tensor
(
F
.
embedding_bag
(
input_tensor
,
weight
,
...
...
@@ -100,8 +103,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
per_sample_weights
=
per_sample_weights
,
include_last_offset
=
include_last_offset
,
padding_idx
=
padding_idx
))
elif
weight
.
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
spec
.
is_1D_col
():
elif
weight
.
tensor_
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
tensor_
spec
.
is_1D_col
():
tp_mode
=
'col'
else
:
raise
NotImplementedError
...
...
colossalai/nn/_ops/layernorm.py
View file @
4b9bba81
...
...
@@ -17,8 +17,8 @@ def colo_layernorm(
input_tensor
,
weight
,
bias
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
weight
,
bias
)))
# TODO (ver217): check dist spec
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
input_tensor
.
spec
.
get_process_group
()))
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
input_tensor
.
tensor_
spec
.
get_process_group
()))
output
=
F
.
layer_norm
(
input_tensor
,
normalized_shape
,
weight
=
weight
,
bias
=
bias
,
eps
=
eps
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
input_tensor
.
spec
)
output
=
ColoTensor
.
from_torch_tensor
(
output
,
input_tensor
.
tensor_
spec
)
return
output
colossalai/nn/_ops/linear.py
View file @
4b9bba81
...
...
@@ -13,7 +13,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res
# Input:S[1]
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]))
distspec
.
shard
(
weight
.
tensor_
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_
spec
.
get_process_group_size
()]))
# Output:P
partial_output
=
F
.
linear
(
input_tensor
,
weight
)
...
...
@@ -21,10 +21,11 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
output
=
reduce_input
(
partial_output
,
ParallelMode
.
PARALLEL_1D
)
# Bias
if
bias
is
not
None
:
assert
not
bias
.
has_spec
(),
'Invalid bias spec for 1Drow Linear op'
assert
not
bias
.
has_
compute_
spec
(),
'Invalid bias spec for 1Drow Linear op'
output
=
output
+
bias
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
())))
output
=
ColoTensor
.
from_torch_tensor
(
output
,
spec
=
TensorSpec
(
distspec
.
replicate
(
weight
.
tensor_spec
.
get_process_group
())))
return
output
...
...
@@ -32,17 +33,20 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action
=
weight
.
spec
.
compute_spec
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
compute_spec
=
weight
.
tensor_
spec
.
compute_spec
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
tensor_
spec
.
get_process_group
()))
input_parallel
=
reduce_grad
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
)
output_parallel
=
F
.
linear
(
input_parallel
,
weight
,
bias
)
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
distspec
.
shard
(
weight
.
tensor_
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
tensor_
spec
.
get_process_group_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
)))
return
output
.
to_replicate
()
if
compute_spec
.
output_replicate
:
return
output
.
to_replicate
()
else
:
return
output
def
colo_linear_1d
(
mode
:
str
,
input_tensor
:
ColoTensor
,
weight
:
ColoTensor
,
bias
:
Optional
[
ColoTensor
])
->
'ColoTensor'
:
...
...
@@ -62,14 +66,15 @@ def colo_linear_imp(input_tensor: GeneralTensor,
# Add communication logic before and after linear call.
ret_tensor
=
None
if
not
weight
.
has_spec
():
# No Model Parallel Applied
assert
weight
.
spec
.
is_gathered
(),
'Invalid weight spec for native Linear op'
assert
bias
is
None
or
bias
.
spec
.
is_gathered
(),
'Invalid bias spec for native Linear op'
if
not
weight
.
has_
compute_
spec
():
# No Model Parallel Applied
assert
weight
.
tensor_
spec
.
is_gathered
(),
'Invalid weight spec for native Linear op'
assert
bias
is
None
or
bias
.
tensor_
spec
.
is_gathered
(),
'Invalid bias spec for native Linear op'
ret_tensor
=
ColoTensor
.
from_torch_tensor
(
F
.
linear
(
input_tensor
,
weight
,
bias
))
elif
weight
.
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
spec
.
is_1D_col
()
and
(
bias
is
None
or
bias
.
spec
.
is_gathered
()):
elif
weight
.
tensor_
spec
.
has_compute_pattern
(
ComputePattern
.
TP1D
):
# Single Model Parallel Applied
if
weight
.
tensor_
spec
.
is_1D_col
()
and
(
bias
is
None
or
bias
.
tensor_
spec
.
is_gathered
()):
mode
=
'row'
elif
weight
.
spec
.
is_1D_row
()
and
(
bias
is
None
or
bias
.
spec
.
is_1D_row
()
or
bias
.
spec
.
is_1D_col
()):
elif
weight
.
tensor_spec
.
is_1D_row
()
and
(
bias
is
None
or
bias
.
tensor_spec
.
is_1D_row
()
or
bias
.
tensor_spec
.
is_1D_col
()):
mode
=
'col'
else
:
raise
NotImplementedError
...
...
colossalai/nn/_ops/loss.py
View file @
4b9bba81
...
...
@@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
label_smoothing
:
float
=
0.0
):
input_tensor
,
target
,
weight
=
tuple
(
map
(
convert_to_colo_tensor
,
(
input_tensor
,
target
,
weight
)))
if
input_tensor
.
spec
.
is_gathered
():
# Input is gathered
if
input_tensor
.
tensor_
spec
.
is_gathered
():
# Input is gathered
output
=
F
.
cross_entropy
(
input_tensor
,
target
,
weight
=
weight
,
...
...
@@ -28,8 +28,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduction
=
reduction
,
label_smoothing
=
label_smoothing
)
return
ColoTensor
.
from_torch_tensor
(
output
)
elif
input_tensor
.
has_spec
():
# Single Model Parallel Applied
if
input_tensor
.
spec
.
is_1D_col
():
elif
input_tensor
.
has_
compute_
spec
():
# Single Model Parallel Applied
if
input_tensor
.
tensor_
spec
.
is_1D_col
():
output
=
VocabParallelCrossEntropyLoss1D
()(
input_tensor
,
target
)
return
ColoTensor
.
from_torch_tensor
(
output
)
else
:
...
...
colossalai/nn/parallel/layers/module_utils.py
View file @
4b9bba81
...
...
@@ -38,8 +38,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
param
=
module
.
get_parameter
(
param_name
)
if
not
isinstance
(
param
,
ColoParameter
):
raise
Exception
(
f
'Invalid ColoParameter spec:
{
param
}
in
{
module
}
is not a ColoParameter.'
)
if
param
.
has_spec
():
cur_compute_pattern
=
param
.
spec
.
compute_spec
.
compute_pattern
if
param
.
has_
compute_
spec
():
cur_compute_pattern
=
param
.
tensor_
spec
.
compute_spec
.
compute_pattern
if
compute_pattern
is
None
:
compute_pattern
=
cur_compute_pattern
else
:
...
...
@@ -61,8 +61,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
cur_match
=
True
for
param_name
,
dist_spec
in
param_specs
.
items
():
param
=
module
.
get_parameter
(
param_name
)
if
param
.
has_spec
():
if
dist_spec
!=
param
.
spec
.
dist_spec
:
if
param
.
has_
compute_
spec
():
if
dist_spec
!=
param
.
tensor_
spec
.
dist_spec
:
cur_match
=
False
break
else
:
...
...
@@ -97,7 +97,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
param
=
module
.
get_parameter
(
param_name
)
if
isinstance
(
param
,
ColoParameter
):
spec
=
TensorSpec
(
dist_spec
,
parallel_action
)
param
.
set_spec
(
spec
)
param
.
set_
tensor_
spec
(
spec
)
for
mod
in
param
.
shared_param_modules
:
modules_update_param
.
add
(
mod
)
for
mod
in
modules_update_param
:
...
...
colossalai/tensor/colo_parameter.py
View file @
4b9bba81
...
...
@@ -82,7 +82,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
else
:
with
torch
.
_C
.
DisableTorchFunction
():
data
=
self
.
data
.
clone
()
tensor
=
ColoParameter
(
data
,
self
.
requires_grad
,
spec
=
copy
(
self
.
spec
))
tensor
=
ColoParameter
(
data
,
self
.
requires_grad
,
spec
=
copy
(
self
.
tensor_
spec
))
memo
[
id
(
self
)]
=
tensor
return
tensor
...
...
colossalai/tensor/colo_tensor.py
View file @
4b9bba81
...
...
@@ -57,15 +57,15 @@ class ColoTensor(torch.Tensor):
self
.
_graph_node
=
None
@
property
def
spec
(
self
)
->
TensorSpec
:
def
tensor_
spec
(
self
)
->
TensorSpec
:
return
self
.
_tensor_spec
def
set_spec
(
self
,
spec
:
TensorSpec
)
->
None
:
def
set_
tensor_
spec
(
self
,
spec
:
TensorSpec
)
->
None
:
spec
=
copy
(
spec
)
self
.
_convert_to_dist_spec
(
spec
.
dist_spec
)
self
.
_tensor_spec
=
spec
def
has_spec
(
self
)
->
bool
:
def
has_
compute_
spec
(
self
)
->
bool
:
return
self
.
_tensor_spec
.
compute_spec
is
not
None
def
is_model_data
(
self
)
->
bool
:
...
...
@@ -100,27 +100,27 @@ class ColoTensor(torch.Tensor):
dist_spec (_DistSpec): the target dist. spec.
"""
with
DistSpecManager
.
no_grad
():
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
dist_spec
)
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
tensor_
spec
.
dist_spec
,
dist_spec
)
self
.
_tensor_spec
.
dist_spec
=
dist_spec
def
convert_to_dist_spec
(
self
,
dist_spec
:
_DistSpec
)
->
'ColoTensor'
:
tensor_spec
=
copy
(
self
.
_tensor_spec
)
tensor_spec
.
dist_spec
=
dist_spec
ret
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
dist_spec
)
ret
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
tensor_
spec
.
dist_spec
,
dist_spec
)
return
ColoTensor
.
from_torch_tensor
(
ret
,
tensor_spec
)
def
to_replicate_
(
self
):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
spec
.
dist_spec
,
distspec
.
replicate
())
self
.
data
=
DistSpecManager
.
handle_trans_spec
(
self
,
self
.
tensor_
spec
.
dist_spec
,
distspec
.
replicate
())
self
.
_tensor_spec
.
dist_spec
=
distspec
.
replicate
()
def
to_replicate
(
self
)
->
'ColoTensor'
:
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return
self
.
convert_to_dist_spec
(
distspec
.
replicate
(
self
.
spec
.
get_process_group
()))
return
self
.
convert_to_dist_spec
(
distspec
.
replicate
(
self
.
tensor_
spec
.
get_process_group
()))
@
staticmethod
def
from_torch_tensor
(
tensor
:
torch
.
Tensor
,
spec
:
TensorSpec
=
TensorSpec
(
distspec
.
replicate
()))
->
'ColoTensor'
:
...
...
@@ -134,16 +134,6 @@ class ColoTensor(torch.Tensor):
else
:
with
torch
.
_C
.
DisableTorchFunction
():
data
=
self
.
data
.
clone
()
tensor
=
ColoTensor
(
data
,
spec
=
copy
(
self
.
spec
))
tensor
=
ColoTensor
(
data
,
spec
=
copy
(
self
.
tensor_
spec
))
memo
[
id
(
self
)]
=
tensor
return
tensor
# TODO(jiaruifang) a patch for gpt test.
# We need to override the member function must operate on a replicated tensor
# def view(self, *args, **kwargs):
# self.data = DistSpecManager.handle_trans_spec(self,
# self.spec.dist_spec,
# distspec.replicate(self.spec.get_process_group()))
# # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group())
# self.data.view(*args, **kwargs)
# return ColoTensor.from_torch_tensor(self.data)
return
tensor
\ No newline at end of file
colossalai/tensor/compute_spec.py
View file @
4b9bba81
...
...
@@ -18,6 +18,8 @@ class ComputeSpec(object):
def
__init__
(
self
,
compute_pattern
:
ComputePattern
)
->
None
:
assert
isinstance
(
compute_pattern
,
ComputePattern
)
self
.
compute_pattern
=
compute_pattern
# Make sure output tensors are replicate
self
.
output_replicate
=
True
def
__repr__
(
self
):
return
f
'compute pattern:
{
self
.
compute_pattern
}
'
colossalai/tensor/param_op_hook.py
View file @
4b9bba81
...
...
@@ -129,7 +129,7 @@ def _get_colo_tensors_info(*args) -> list:
info
=
[]
for
arg
in
args
:
if
isinstance
(
arg
,
ColoTensor
):
info
.
append
((
arg
.
__class__
,
arg
.
spec
))
info
.
append
((
arg
.
__class__
,
arg
.
tensor_
spec
))
else
:
info
.
append
(
None
)
return
info
...
...
colossalai/utils/model/colo_init_context.py
View file @
4b9bba81
...
...
@@ -42,10 +42,10 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
has_dist_parameter
=
False
with
torch
.
no_grad
():
for
param
in
self
.
parameters
():
if
isinstance
(
param
,
ColoParameter
)
and
param
.
has_spec
():
if
isinstance
(
param
,
ColoParameter
)
and
param
.
has_
compute_
spec
():
has_dist_parameter
=
True
mapping
[
id
(
param
)]
=
copy
(
param
.
spec
)
param
.
set_spec
(
TensorSpec
(
distspec
.
replicate
()))
mapping
[
id
(
param
)]
=
copy
(
param
.
tensor_
spec
)
param
.
set_
tensor_
spec
(
TensorSpec
(
distspec
.
replicate
()))
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
...
...
@@ -62,7 +62,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
param_id
=
id
(
param
)
if
param_id
in
mapping
:
spec
=
mapping
[
id
(
param
)]
param
.
set_spec
(
spec
)
param
.
set_
tensor_
spec
(
spec
)
return
ret
...
...
tests/test_tensor/test_addmm_tp.py
View file @
4b9bba81
...
...
@@ -43,7 +43,7 @@ def init_1d_row(weight, bias):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
init_1d_col
(
weight
,
bias
):
...
...
@@ -51,8 +51,8 @@ def init_1d_col(weight, bias):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
bias
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
bias
.
set_
tensor_
spec
(
spec
)
def
run_with_spec
(
spec_init_func
):
...
...
@@ -63,6 +63,7 @@ def run_with_spec(spec_init_func):
x
=
torch
.
rand
(
2
,
16
).
cuda
()
out
=
model
(
x
)
colo_out
=
torch
.
addmm
(
bias
,
x
,
weight
)
colo_out
=
colo_out
.
to_replicate
()
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
...
...
tests/test_tensor/test_embedding_bag_tp.py
View file @
4b9bba81
...
...
@@ -20,7 +20,7 @@ def init_1d_col(weight):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
run_with_spec
(
spec_init_func
):
...
...
tests/test_tensor/test_embedding_tp.py
View file @
4b9bba81
...
...
@@ -20,7 +20,7 @@ def init_1d_row(weight):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
init_1d_col
(
weight
):
...
...
@@ -28,7 +28,7 @@ def init_1d_col(weight):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
run_with_spec
(
spec_init_func
):
...
...
tests/test_tensor/test_gpt.py
View file @
4b9bba81
...
...
@@ -22,7 +22,7 @@ def init_1d_row_spec(model):
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_spec
(
spec
)
p
.
set_
tensor_
spec
(
spec
)
def
init_1d_col_spec
(
model
):
...
...
@@ -32,7 +32,7 @@ def init_1d_col_spec(model):
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
p
.
set_spec
(
spec
)
p
.
set_
tensor_
spec
(
spec
)
def
check_param_equal
(
model
,
torch_model
):
...
...
tests/test_tensor/test_linear_tp.py
View file @
4b9bba81
...
...
@@ -21,7 +21,7 @@ def init_1d_row(weight, bias):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
init_1d_col
(
weight
,
bias
):
...
...
@@ -29,8 +29,8 @@ def init_1d_col(weight, bias):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
bias
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
bias
.
set_
tensor_
spec
(
spec
)
def
run_with_spec
(
spec_init_func
):
...
...
tests/test_tensor/test_model.py
View file @
4b9bba81
...
...
@@ -23,7 +23,7 @@ def init_1d_row_linear(weight):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
init_1d_col_linear
(
weight
):
...
...
@@ -31,7 +31,7 @@ def init_1d_col_linear(weight):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
init_1d_row_embedding
(
weight
):
...
...
@@ -39,7 +39,7 @@ def init_1d_row_embedding(weight):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
init_1d_col_embedding
(
weight
):
...
...
@@ -47,7 +47,7 @@ def init_1d_col_embedding(weight):
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
weight
.
set_
tensor_
spec
(
spec
)
def
run_1d_hybrid_tp
(
model_name
):
...
...
tests/test_tensor/test_module_spec.py
View file @
4b9bba81
...
...
@@ -157,7 +157,7 @@ def run_check_shared_param():
col_spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
model
.
cls
.
predictions
.
bias
.
set_spec
(
col_spec
)
model
.
cls
.
predictions
.
bias
.
set_
tensor_
spec
(
col_spec
)
try
:
check_colo_module
(
model
.
cls
.
predictions
.
decoder
,
recursive
=
False
)
except
Exception
as
e
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment