Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f4ef2243
Unverified
Commit
f4ef2243
authored
Jun 23, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 23, 2022
Browse files
[Tensor] remove ParallelAction, use ComputeSpec instread (#1166)
parent
177c3744
Changes
20
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
87 additions
and
77 deletions
+87
-77
colossalai/nn/_ops/addmm.py
colossalai/nn/_ops/addmm.py
+3
-3
colossalai/nn/_ops/embedding.py
colossalai/nn/_ops/embedding.py
+2
-2
colossalai/nn/_ops/embedding_bag.py
colossalai/nn/_ops/embedding_bag.py
+2
-2
colossalai/nn/_ops/linear.py
colossalai/nn/_ops/linear.py
+3
-3
colossalai/nn/parallel/layers/module_utils.py
colossalai/nn/parallel/layers/module_utils.py
+4
-4
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+4
-4
colossalai/tensor/colo_parameter.py
colossalai/tensor/colo_parameter.py
+5
-3
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+1
-1
colossalai/tensor/compute_spec.py
colossalai/tensor/compute_spec.py
+23
-0
colossalai/tensor/tensor_spec.py
colossalai/tensor/tensor_spec.py
+7
-24
docker/Dockerfile
docker/Dockerfile
+1
-1
tests/test_tensor/test_addmm_tp.py
tests/test_tensor/test_addmm_tp.py
+3
-3
tests/test_tensor/test_embedding_bag_tp.py
tests/test_tensor/test_embedding_bag_tp.py
+2
-2
tests/test_tensor/test_embedding_tp.py
tests/test_tensor/test_embedding_tp.py
+3
-3
tests/test_tensor/test_gpt.py
tests/test_tensor/test_gpt.py
+3
-3
tests/test_tensor/test_hybrid_device.py
tests/test_tensor/test_hybrid_device.py
+5
-3
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+3
-3
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+5
-5
tests/test_tensor/test_module_spec.py
tests/test_tensor/test_module_spec.py
+5
-5
tests/test_tensor/test_zero_optim.py
tests/test_tensor/test_zero_optim.py
+3
-3
No files found.
colossalai/nn/_ops/addmm.py
View file @
f4ef2243
import
torch
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
from
colossalai.tensor
import
distspec
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
Number
,
convert_to_colo_tensor
...
...
@@ -29,13 +29,13 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
def
colo_addmm_1Dcol
(
input_tensor
:
ColoTensor
,
mat1
:
ColoTensor
,
mat2
:
ColoTensor
,
beta
:
Number
,
alpha
:
Number
)
->
ColoTensor
:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action
=
mat2
.
spec
.
parallel_action
parallel_action
=
mat2
.
spec
.
compute_spec
mat1
=
mat1
.
convert_to_dist_spec
(
distspec
.
replicate
(
mat2
.
spec
.
get_process_group
()))
mat1
=
reduce_grad
(
mat1
,
ParallelMode
.
PARALLEL_1D
)
output_parallel
=
torch
.
addmm
(
input_tensor
,
mat1
,
mat2
,
beta
=
beta
,
alpha
=
alpha
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
mat2
.
spec
.
get_process_group
(),
[
-
1
],
[
mat2
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
# TODO(jiaruifang) addam is special case
...
...
colossalai/nn/_ops/embedding.py
View file @
f4ef2243
...
...
@@ -3,7 +3,7 @@ from typing import Optional
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
distspec
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
colossalai.context
import
ParallelMode
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
...
...
@@ -28,7 +28,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse
=
sparse
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
return
output
.
to_replicate
()
...
...
colossalai/nn/_ops/embedding_bag.py
View file @
f4ef2243
...
...
@@ -2,7 +2,7 @@ import torch.nn.functional as F
from
typing
import
Optional
from
torch
import
Tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
distspec
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
...
...
@@ -34,7 +34,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
padding_idx
=
padding_idx
)
output_spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
output
=
ColoTensor
.
from_torch_tensor
(
output_parallel
,
spec
=
output_spec
)
return
output
.
to_replicate
()
...
...
colossalai/nn/_ops/linear.py
View file @
f4ef2243
...
...
@@ -3,7 +3,7 @@ from typing import Optional
from
._utils
import
GeneralTensor
,
convert_to_colo_tensor
from
colossalai.tensor.op_wrapper
import
colo_op_impl
from
colossalai.nn.layer.parallel_1d._utils
import
reduce_input
,
reduce_grad
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ParallelAction
,
ColoTensor
,
distspec
from
colossalai.tensor
import
ComputePattern
,
TensorSpec
,
ComputePattern
,
ComputeSpec
,
ColoTensor
,
distspec
from
colossalai.context
import
ParallelMode
from
colossalai.nn.graph
import
register_colo_graph
,
GraphOpNode
,
GraphGlobalEnv
...
...
@@ -32,7 +32,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action
=
weight
.
spec
.
parallel_action
parallel_action
=
weight
.
spec
.
compute_spec
input_tensor
=
input_tensor
.
convert_to_dist_spec
(
distspec
.
replicate
(
weight
.
spec
.
get_process_group
()))
input_parallel
=
reduce_grad
(
input_tensor
,
ParallelMode
.
PARALLEL_1D
)
...
...
@@ -41,7 +41,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
spec
=
TensorSpec
(
distspec
.
shard
(
weight
.
spec
.
get_process_group
(),
[
-
1
],
[
weight
.
spec
.
get_process_group_size
()]),
ParallelAction
(
ComputePattern
.
TP1D
)))
ComputeSpec
(
ComputePattern
.
TP1D
)))
return
output
.
to_replicate
()
...
...
colossalai/nn/parallel/layers/module_utils.py
View file @
f4ef2243
from
typing
import
Dict
from
colossalai.tensor
import
ColoParameter
,
ParallelAction
,
TensorSpec
from
colossalai.tensor
import
ColoParameter
,
ComputeSpec
,
TensorSpec
from
.
import
ColoModule
import
torch
...
...
@@ -39,7 +39,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
if
not
isinstance
(
param
,
ColoParameter
):
raise
Exception
(
f
'Invalid ColoParameter spec:
{
param
}
in
{
module
}
is not a ColoParameter.'
)
if
param
.
has_spec
():
cur_compute_pattern
=
param
.
spec
.
parallel_action
.
compute_pattern
cur_compute_pattern
=
param
.
spec
.
compute_spec
.
compute_pattern
if
compute_pattern
is
None
:
compute_pattern
=
cur_compute_pattern
else
:
...
...
@@ -79,11 +79,11 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
check_colo_module
(
submodule
,
recursive
=
True
)
def
init_colo_module
(
module
:
torch
.
nn
.
Module
,
parallel_action
:
ParallelAction
,
recursive
=
True
,
mode
=
'default'
):
def
init_colo_module
(
module
:
torch
.
nn
.
Module
,
parallel_action
:
ComputeSpec
,
recursive
=
True
,
mode
=
'default'
):
compute_pattern
=
parallel_action
.
compute_pattern
if
is_colo_module
(
module
):
# for each param
# set DistSpec and
ParallelAction
# set DistSpec and
ComputeSpec
colo_module
=
get_colo_module
(
module
)
colo_module
.
register
(
compute_pattern
)
if
not
colo_module
.
has_compute_pattern_with_mode
(
compute_pattern
,
mode
=
mode
):
...
...
colossalai/tensor/__init__.py
View file @
f4ef2243
from
.spec
import
ComputePattern
,
ParallelAction
,
TensorSpec
from
.
tensor_
spec
import
TensorSpec
from
.compute_spec
import
ComputeSpec
,
ComputePattern
from
.colo_tensor
import
ColoTensor
from
.colo_parameter
import
ColoParameter
from
.utils
import
convert_parameter
,
named_params_with_colotensor
from
.
import
distspec
from
.dist_spec_mgr
import
DistSpecManager
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.chunk
import
ChunkManager
,
TensorState
from
.
import
distspec
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'TensorSpec'
,
'
ParallelAction
'
,
'named_params_with_colotensor'
,
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'TensorSpec'
,
'
ComputeSpec
'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'ChunkManager'
,
'TensorState'
]
colossalai/tensor/colo_parameter.py
View file @
f4ef2243
import
torch
from
typing
import
Optional
from
copy
import
copy
from
colossalai.tensor.colo_tensor
import
ColoTensor
from
colossalai.tensor.const
import
TensorType
import
torch
from
colossalai.tensor
import
TensorSpec
,
distspec
from
copy
import
copy
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
typing
import
Optional
def
filter_args
(
func
,
*
args
):
...
...
colossalai/tensor/colo_tensor.py
View file @
f4ef2243
...
...
@@ -66,7 +66,7 @@ class ColoTensor(torch.Tensor):
self
.
_tensor_spec
=
spec
def
has_spec
(
self
)
->
bool
:
return
self
.
_tensor_spec
.
parallel_action
is
not
None
return
self
.
_tensor_spec
.
compute_spec
is
not
None
def
is_model_data
(
self
)
->
bool
:
return
self
.
_type
==
TensorType
.
MODEL
...
...
colossalai/tensor/compute_spec.py
0 → 100644
View file @
f4ef2243
from
enum
import
Enum
class
ComputePattern
(
Enum
):
TP1D
=
0
TP2D
=
1
TP2P5D
=
2
TP3D
=
3
class
ComputeSpec
(
object
):
"""ComputeSpec
The Specification for compuattion pattern
Args:
compute_pattern (ComputePattern): an Enum instance for compute pattern.
"""
def
__init__
(
self
,
compute_pattern
:
ComputePattern
)
->
None
:
assert
isinstance
(
compute_pattern
,
ComputePattern
)
self
.
compute_pattern
=
compute_pattern
def
__repr__
(
self
):
return
f
'compute pattern:
{
self
.
compute_pattern
}
'
colossalai/tensor/spec.py
→
colossalai/tensor/
tensor_
spec.py
View file @
f4ef2243
import
torch.distributed
as
dist
from
enum
import
Enum
from
typing
import
List
,
Optional
from
typing
import
Optional
from
colossalai.tensor.distspec
import
_DistSpec
,
DistPlacementPattern
class
ComputePattern
(
Enum
):
TP1D
=
0
TP2D
=
1
TP2P5D
=
2
TP3D
=
3
class
ParallelAction
(
object
):
def
__init__
(
self
,
compute_pattern
:
ComputePattern
)
->
None
:
assert
isinstance
(
compute_pattern
,
ComputePattern
)
self
.
compute_pattern
=
compute_pattern
def
__repr__
(
self
):
return
f
'compute pattern:
{
self
.
compute_pattern
}
'
from
.compute_spec
import
ComputeSpec
,
ComputePattern
class
TensorSpec
(
object
):
...
...
@@ -26,12 +9,12 @@ class TensorSpec(object):
The specification of the ColoTensor.
Args:
dist_spec (_DistSpec): descriping the layout among processes.
parallel_action (Optional[
ParallelAction
], optional): actions conducted on the tensor after initialization if it's a model data tensor.
parallel_action (Optional[
ComputeSpec
], optional): actions conducted on the tensor after initialization if it's a model data tensor.
Defaults to None.
"""
def
__init__
(
self
,
dist_spec
:
_DistSpec
,
parallel_action
:
Optional
[
ParallelAction
]
=
None
):
self
.
parallel_action
=
parallel_action
def
__init__
(
self
,
dist_spec
:
_DistSpec
,
compute_spec
:
Optional
[
ComputeSpec
]
=
None
):
self
.
compute_spec
=
compute_spec
self
.
dist_spec
=
dist_spec
def
get_process_group
(
self
):
...
...
@@ -58,7 +41,7 @@ class TensorSpec(object):
and
len
(
self
.
dist_spec
.
dims
)
==
1
and
self
.
dist_spec
.
dims
[
0
]
==
0
def
has_compute_pattern
(
self
,
compute_pattern
:
ComputePattern
):
return
self
.
parallel_action
.
compute_pattern
==
compute_pattern
return
self
.
compute_spec
.
compute_pattern
==
compute_pattern
def
__repr__
(
self
):
return
f
'parallel action:
{
self
.
parallel_action
}
, dist_spec:
{
self
.
dist_spec
}
'
return
f
'parallel action:
{
self
.
compute_spec
}
, dist_spec:
{
self
.
dist_spec
}
'
docker/Dockerfile
View file @
f4ef2243
...
...
@@ -14,4 +14,4 @@ RUN git clone https://github.com/hpcaitech/ColossalAI.git \
&&
pip
install
-v
--no-cache-dir
.
# install titans
RUN
pip
install
-no-cache-dir
titans
RUN
pip
install
-
-no-cache-dir
titans
tests/test_tensor/test_addmm_tp.py
View file @
f4ef2243
...
...
@@ -5,7 +5,7 @@ import torch.nn as nn
import
torch.multiprocessing
as
mp
from
colossalai.tensor
import
ColoTensor
from
colossalai.tensor
import
distspec
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
colossalai.context
import
ParallelMode
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
...
...
@@ -41,7 +41,7 @@ class Conv1D(nn.Module):
def
init_1d_row
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -49,7 +49,7 @@ def init_1d_row(weight, bias):
def
init_1d_col
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
bias
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_embedding_bag_tp.py
View file @
f4ef2243
...
...
@@ -11,14 +11,14 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
_utils
import
tensor_equal
,
tensor_shard_equal
def
init_1d_col
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_embedding_tp.py
View file @
f4ef2243
...
...
@@ -11,14 +11,14 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
_utils
import
tensor_equal
,
tensor_shard_equal
def
init_1d_row
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -26,7 +26,7 @@ def init_1d_row(weight):
def
init_1d_col
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_gpt.py
View file @
f4ef2243
...
...
@@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
,
distspec
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
...
...
@@ -18,7 +18,7 @@ from colossalai.nn.parallel.data_parallel import ColoDDP
def
init_1d_row_spec
(
model
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
...
...
@@ -28,7 +28,7 @@ def init_1d_row_spec(model):
def
init_1d_col_spec
(
model
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
...
...
tests/test_tensor/test_hybrid_device.py
View file @
f4ef2243
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
ComputePattern
,
ParallelAction
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
from
functools
import
partial
from
colossalai.core
import
global_context
as
gpc
...
...
@@ -46,7 +46,7 @@ def run_hybrid_device(use_ddp, mode):
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| device:
{
real_model
.
embed
.
weight
.
device
}
'
)
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
parallel_action
=
ComputeSpec
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
# use cpu gloo to handle embedding
...
...
@@ -63,6 +63,7 @@ def run_hybrid_device(use_ddp, mode):
out
.
sum
().
backward
()
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
mode
):
if
use_ddp
and
world_size
==
1
:
return
...
...
@@ -71,6 +72,7 @@ def run_dist(rank, world_size, port, use_ddp, mode):
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_hybrid_device
(
use_ddp
,
mode
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
...
...
@@ -78,7 +80,7 @@ def run_dist(rank, world_size, port, use_ddp, mode):
@
rerun_if_address_is_in_use
()
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
def
_test_hybrid_device
(
world_size
,
use_ddp
,
mode
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
,
mode
=
mode
)
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
,
mode
=
mode
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
...
...
tests/test_tensor/test_linear_tp.py
View file @
f4ef2243
...
...
@@ -12,14 +12,14 @@ import torch.nn.functional as F
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
_utils
import
tensor_equal
,
tensor_shard_equal
def
init_1d_row
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -27,7 +27,7 @@ def init_1d_row(weight, bias):
def
init_1d_col
(
weight
,
bias
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
bias
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_model.py
View file @
f4ef2243
...
...
@@ -10,7 +10,7 @@ from colossalai.utils.cuda import get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
distspec
,
TensorSpec
,
ComputePattern
,
\
ParallelAction
,
ColoTensor
,
DistSpecManager
ComputeSpec
,
ColoTensor
,
DistSpecManager
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.optimizer
import
ColoOptimizer
...
...
@@ -21,7 +21,7 @@ from _utils import tensor_equal, tensor_shard_equal, set_seed
def
init_1d_row_linear
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -29,7 +29,7 @@ def init_1d_row_linear(weight):
def
init_1d_col_linear
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -37,7 +37,7 @@ def init_1d_col_linear(weight):
def
init_1d_row_embedding
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
@@ -45,7 +45,7 @@ def init_1d_row_embedding(weight):
def
init_1d_col_embedding
(
weight
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
weight
.
set_spec
(
spec
)
...
...
tests/test_tensor/test_module_spec.py
View file @
f4ef2243
...
...
@@ -5,7 +5,7 @@ from functools import partial
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
from
colossalai.nn.parallel.layers
import
init_colo_module
,
check_colo_module
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
...
...
@@ -40,7 +40,7 @@ def run_model_with_spec(mode, model_name):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_seq
.
parameters
()):
p2
.
data
.
copy_
(
p1
.
data
)
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
parallel_action
=
ComputeSpec
(
ComputePattern
.
TP1D
)
# Not all layers in Bert can be mod by 4.
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
if
'bert'
==
model_name
:
...
...
@@ -114,7 +114,7 @@ def run_linear_with_spec(mode):
model_handy
=
copy
(
model
)
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
parallel_action
=
ComputeSpec
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
x
=
torch
.
rand
(
2
,
4
).
cuda
()
...
...
@@ -148,7 +148,7 @@ def run_check_shared_param():
model
=
BertForMaskedLM
(
config
)
model
=
model
.
cuda
()
parallel_action
=
ParallelAction
(
ComputePattern
.
TP1D
)
parallel_action
=
ComputeSpec
(
ComputePattern
.
TP1D
)
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
assert
len
(
model
.
cls
.
predictions
.
decoder
.
bias
.
shared_param_modules
)
==
2
# They are all Linear, so both row is allowed. This should pass check.
...
...
@@ -156,7 +156,7 @@ def run_check_shared_param():
# This should be detected by check because you can not set weight as row while set bias as col.
col_spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
model
.
cls
.
predictions
.
bias
.
set_spec
(
col_spec
)
try
:
check_colo_module
(
model
.
cls
.
predictions
.
decoder
,
recursive
=
False
)
...
...
tests/test_tensor/test_zero_optim.py
View file @
f4ef2243
...
...
@@ -19,7 +19,7 @@ from colossalai.zero import ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ParallelAction
,
DistSpecManager
,
distspec
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
def
check_param_equal
(
model
,
torch_model
):
...
...
@@ -47,7 +47,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def
init_1d_row_spec
(
model
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
...
...
@@ -57,7 +57,7 @@ def init_1d_row_spec(model):
def
init_1d_col_spec
(
model
):
spec
=
TensorSpec
(
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ParallelAction
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment