Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9bcd2fd4
Unverified
Commit
9bcd2fd4
authored
Jul 11, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 11, 2022
Browse files
[tensor] a shorter shard and replicate spec (#1245)
parent
2699dfbb
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
14 deletions
+14
-14
tests/test_tensor/test_module_spec.py
tests/test_tensor/test_module_spec.py
+3
-3
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+2
-2
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+4
-4
tests/test_tensor/test_zero_optim.py
tests/test_tensor/test_zero_optim.py
+3
-3
tests/test_utils/test_colo_checkpoint.py
tests/test_utils/test_colo_checkpoint.py
+2
-2
No files found.
tests/test_tensor/test_module_spec.py
View file @
9bcd2fd4
...
...
@@ -5,7 +5,7 @@ from functools import partial
import
torch
import
torch.multiprocessing
as
mp
from
colossalai.tensor
import
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
from
colossalai.tensor
import
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ShardSpec
,
ReplicaSpec
from
colossalai.nn.parallel.layers
import
init_colo_module
,
check_colo_module
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
...
...
@@ -13,7 +13,7 @@ import colossalai
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
distspec
,
ProcessGroup
from
colossalai.tensor
import
distspec
,
ProcessGroup
,
ReplicaSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
...
...
@@ -159,7 +159,7 @@ def run_check_shared_param():
# They are all Linear, so both row is allowed. This should pass check.
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
'row'
)
# This should be detected by check because you can not set weight as row while set bias as col.
col_spec
=
(
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
col_spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
# TODO(jiaruifang) optimize this line
if
not
model
.
cls
.
predictions
.
bias
.
has_initialized
:
...
...
tests/test_tensor/test_op.py
View file @
9bcd2fd4
...
...
@@ -4,7 +4,7 @@ import colossalai
import
torch.nn.functional
as
F
import
torch.multiprocessing
as
mp
from
functools
import
partial
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
,
ColoTensorSpec
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
,
ColoTensorSpec
,
ShardSpec
from
colossalai.utils
import
get_current_device
from
torch.nn
import
Parameter
from
colossalai.testing
import
rerun_if_address_is_in_use
...
...
@@ -47,7 +47,7 @@ def check_element_wise_ops():
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
t
=
torch
.
rand
(
2
,
2
)
x
=
ColoTensor
(
t
,
spec
=
ColoTensorSpec
(
pg
,
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()])))
x
=
ColoTensor
(
t
,
spec
=
ColoTensorSpec
(
pg
,
ShardSpec
([
0
],
[
pg
.
tp_world_size
()])))
check_spec_eq
(
x
,
x
.
cuda
())
assert
torch
.
equal
(
x
.
cuda
(),
t
.
cuda
())
...
...
tests/test_tensor/test_tensor.py
View file @
9bcd2fd4
...
...
@@ -10,7 +10,7 @@ from colossalai.core import global_context as gpc
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
distspec
,
ColoTensor
,
ProcessGroup
from
colossalai.tensor
import
distspec
,
ColoTensor
,
ProcessGroup
,
ShardSpec
,
ReplicaSpec
from
functools
import
partial
...
...
@@ -55,7 +55,7 @@ def _run_operand(world_size):
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
ColoTensorSpec
(
pg
))
t
.
set_dist_spec
(
distspec
.
shard
([
0
],
[
world_size
]))
t
.
set_dist_spec
(
ShardSpec
([
0
],
[
world_size
]))
t_new
=
torch
.
zeros_like
(
t
)
assert
isinstance
(
t_new
,
ColoTensor
)
assert
t_new
.
is_sharded
()
...
...
@@ -69,7 +69,7 @@ def _run_view(world_size):
rank
=
gpc
.
get_global_rank
()
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
)),
tp_degree
=
world_size
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
,
ColoTensorSpec
(
pg
,
dist_attr
=
distspec
.
shard
(
dims
=
[
0
],
num_partitions
=
[
pg
.
tp_world_size
()])))
t_ref
,
ColoTensorSpec
(
pg
,
dist_attr
=
ShardSpec
(
dims
=
[
0
],
num_partitions
=
[
pg
.
tp_world_size
()])))
assert
t
.
size_global
()[
0
]
==
4
*
world_size
assert
t
.
size_global
(
1
)
==
5
...
...
@@ -82,7 +82,7 @@ def _run_view(world_size):
def
_run_tensor_shard_init
(
world_size
):
t_ref
=
torch
.
randn
(
4
,
5
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
shard_attr
=
distspec
.
shard
(
dims
=
[
0
],
num_partitions
=
[
pg
.
tp_world_size
()])
shard_attr
=
ShardSpec
(
dims
=
[
0
],
num_partitions
=
[
pg
.
tp_world_size
()])
tensor_spec
=
ColoTensorSpec
(
pg
,
dist_attr
=
shard_attr
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
tensor_spec
)
t
.
set_dist_spec
(
distspec
.
replicate
())
...
...
tests/test_tensor/test_zero_optim.py
View file @
9bcd2fd4
...
...
@@ -17,7 +17,7 @@ from colossalai.zero import ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
,
ProcessGroup
from
colossalai.tensor
import
ColoTensorSpec
,
ShardSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
...
...
@@ -45,7 +45,7 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
...
...
@@ -53,7 +53,7 @@ def init_1d_row_spec(model, pg: ProcessGroup):
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
...
...
tests/test_utils/test_colo_checkpoint.py
View file @
9bcd2fd4
...
...
@@ -16,7 +16,7 @@ from colossalai.testing import rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
dists
pec
,
ProcessGroup
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ShardS
pec
,
ProcessGroup
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.utils.checkpoint
import
save_checkpoint
,
load_checkpoint
from
colossalai.nn.lr_scheduler
import
CosineAnnealingWarmupLR
...
...
@@ -81,7 +81,7 @@ class MLP(nn.Module):
def
init_1d_row_for_linear_weight_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
(
distspec
.
shard
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
spec
=
(
ShardSpec
([
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
:
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment