Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c463f8ad
Unverified
Commit
c463f8ad
authored
Jun 29, 2022
by
Jiarui Fang
Committed by
GitHub
Jun 29, 2022
Browse files
[tensor] remove gpc in tensor tests (#1186)
parent
372f7914
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
26 additions
and
20 deletions
+26
-20
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+1
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+1
-1
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+14
-6
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+10
-12
No files found.
colossalai/tensor/__init__.py
View file @
c463f8ad
from
.process_group
import
ProcessGroup
from
.tensor_spec
import
TensorSpec
from
.tensor_spec
import
TensorSpec
from
.compute_spec
import
ComputeSpec
,
ComputePattern
from
.compute_spec
import
ComputeSpec
,
ComputePattern
from
.colo_tensor
import
ColoTensor
from
.colo_tensor
import
ColoTensor
...
@@ -6,7 +7,6 @@ from .utils import convert_parameter, named_params_with_colotensor
...
@@ -6,7 +7,6 @@ from .utils import convert_parameter, named_params_with_colotensor
from
.dist_spec_mgr
import
DistSpecManager
from
.dist_spec_mgr
import
DistSpecManager
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.param_op_hook
import
ParamOpHook
,
ParamOpHookManager
from
.
import
distspec
from
.
import
distspec
from
.process_group
import
ProcessGroup
__all__
=
[
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'TensorSpec'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'TensorSpec'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
...
...
colossalai/tensor/colo_tensor.py
View file @
c463f8ad
...
@@ -30,7 +30,7 @@ class ColoTensor(torch.Tensor):
...
@@ -30,7 +30,7 @@ class ColoTensor(torch.Tensor):
1. directly init.
1. directly init.
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> colo_t1 = ColoTensor(torch.randn(2,3), spec = TensorSpec(distspec.replicate())
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> # If initializaed in a shard model, the tensor passed in is one shard of the global tensor.
>>> shard_spec = distspec.shard(process_group=
gpc.get_group(ParallelMode.DATA
),
>>> shard_spec = distspec.shard(process_group=
ProcessGroup(tp=world_size
),
>>> dims=[0],
>>> dims=[0],
>>> num_partitions=[world_size])
>>> num_partitions=[world_size])
>>> tensor_spec = TensorSpec(shard_spec)
>>> tensor_spec = TensorSpec(shard_spec)
...
...
colossalai/tensor/process_group.py
View file @
c463f8ad
...
@@ -5,7 +5,7 @@ from typing import List, Optional
...
@@ -5,7 +5,7 @@ from typing import List, Optional
class
ProcessGroup
:
class
ProcessGroup
:
"""
"""
Process Group contains group partition for Tensor Parallel and Data Parallel.
Process Group contains group partition for Tensor Parallel and Data Parallel.
WARNING
, the ProcessGroup must be used after torch.distributed.initialize()
NOTE
, the ProcessGroup must be used after torch.distributed.initialize()
args:
args:
rank: the global rank of the current process.
rank: the global rank of the current process.
ranks: List[int], a list of rank id belongings to this process group.
ranks: List[int], a list of rank id belongings to this process group.
...
@@ -15,16 +15,24 @@ class ProcessGroup:
...
@@ -15,16 +15,24 @@ class ProcessGroup:
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
rank
:
int
,
rank
:
Optional
[
int
]
=
None
,
ranks
:
List
[
int
],
ranks
:
Optional
[
List
[
int
]
]
=
None
,
backend
:
str
=
'nccl'
,
backend
:
str
=
'nccl'
,
tp_degree
:
Optional
[
int
]
=
None
,
tp_degree
:
Optional
[
int
]
=
None
,
dp_degree
:
Optional
[
int
]
=
None
)
->
None
:
dp_degree
:
Optional
[
int
]
=
None
)
->
None
:
self
.
_rank
=
rank
assert
torch
.
distributed
.
is_initialized
(),
f
"ProcessGroup must be used after distributed initialized"
self
.
_rank_list
=
ranks
if
rank
is
None
:
self
.
_rank
=
torch
.
distributed
.
get_rank
()
else
:
self
.
_rank
=
rank
if
ranks
is
None
:
self
.
_rank_list
=
list
(
range
(
torch
.
distributed
.
get_world_size
()))
else
:
self
.
_rank_list
=
ranks
self
.
_backend
=
backend
self
.
_backend
=
backend
self
.
_world_size
=
len
(
self
.
_rank_list
)
self
.
_world_size
=
len
(
self
.
_rank_list
)
assert
torch
.
distributed
.
is_initialized
(),
f
"ProcessGroup must be used after distributed initialized"
if
dp_degree
is
None
and
tp_degree
is
None
:
if
dp_degree
is
None
and
tp_degree
is
None
:
self
.
_dp_degree
=
self
.
_world_size
self
.
_dp_degree
=
self
.
_world_size
...
...
tests/test_tensor/test_model.py
View file @
c463f8ad
...
@@ -11,11 +11,9 @@ from colossalai.utils import free_port
...
@@ -11,11 +11,9 @@ from colossalai.utils import free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
distspec
,
TensorSpec
,
ComputePattern
,
\
from
colossalai.tensor
import
distspec
,
TensorSpec
,
ComputePattern
,
\
ComputeSpec
,
ColoTensor
,
DistSpecManager
,
ProcessGroup
ComputeSpec
,
ColoTensor
,
DistSpecManager
,
ProcessGroup
from
colossalai.context
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
colossalai.nn.optimizer
import
ColoOptimizer
from
colossalai.nn.optimizer
import
ColoOptimizer
from
functools
import
partial
from
functools
import
partial
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
from
_utils
import
tensor_shard_equal
,
set_seed
def
init_1d_row_linear
(
weight
,
pg
:
ProcessGroup
):
def
init_1d_row_linear
(
weight
,
pg
:
ProcessGroup
):
...
@@ -50,7 +48,7 @@ def run_1d_hybrid_tp(model_name):
...
@@ -50,7 +48,7 @@ def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear
# A simple net with two stacked nn.Linear
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
rank
=
torch
.
distributed
.
get_rank
(
)
set_seed
(
1
)
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
...
@@ -65,9 +63,9 @@ def run_1d_hybrid_tp(model_name):
...
@@ -65,9 +63,9 @@ def run_1d_hybrid_tp(model_name):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
p2
.
data
.
copy_
(
p1
.
data
)
p2
.
data
.
copy_
(
p1
.
data
)
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
GLOBAL
)
rank
=
torch
.
distributed
.
get_rank
(
)
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
GLOBAL
)
world_size
=
torch
.
distributed
.
get_world_size
(
)
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
)),
tp_degree
=
world_size
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
if
'bert'
==
model_name
:
if
'bert'
==
model_name
:
for
name
,
p
in
model
.
named_parameters
():
for
name
,
p
in
model
.
named_parameters
():
if
not
isinstance
(
p
,
ColoTensor
):
if
not
isinstance
(
p
,
ColoTensor
):
...
@@ -214,14 +212,14 @@ def run_1d_row_tp(model_name: str):
...
@@ -214,14 +212,14 @@ def run_1d_row_tp(model_name: str):
# A simple net with two stacked nn.Linear
# A simple net with two stacked nn.Linear
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
rank
=
torch
.
distributed
.
get_rank
(
)
set_seed
(
1
)
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
True
)
model
=
model_builder
(
checkpoint
=
True
)
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
GLOBAL
)
rank
=
torch
.
distributed
.
get_rank
(
)
world_size
=
gpc
.
get_world_size
(
ParallelMode
.
GLOBAL
)
world_size
=
torch
.
distributed
.
get_world_size
(
)
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
)),
tp_degree
=
world_size
)
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
)),
tp_degree
=
world_size
)
set_seed
(
1
)
set_seed
(
1
)
...
@@ -243,8 +241,8 @@ def run_1d_row_tp(model_name: str):
...
@@ -243,8 +241,8 @@ def run_1d_row_tp(model_name: str):
data
=
data
.
to
(
get_current_device
())
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
pg
.
tp_process_group
(
))
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
pg
.
tp_process_group
(
))
# Bcast rank0 data to all processes
# Bcast rank0 data to all processes
if
criterion
:
if
criterion
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment