Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
060b917d
Unverified
Commit
060b917d
authored
Jul 04, 2022
by
Jiarui Fang
Committed by
GitHub
Jul 04, 2022
Browse files
[refactor] remove gpc dependency in colotensor's _ops (#1189)
parent
abf6a262
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
157 additions
and
234 deletions
+157
-234
tests/test_tensor/test_dist_spec_mgr.py
tests/test_tensor/test_dist_spec_mgr.py
+3
-3
tests/test_tensor/test_embedding_bag_tp.py
tests/test_tensor/test_embedding_bag_tp.py
+7
-10
tests/test_tensor/test_embedding_tp.py
tests/test_tensor/test_embedding_tp.py
+14
-17
tests/test_tensor/test_gpt.py
tests/test_tensor/test_gpt.py
+31
-23
tests/test_tensor/test_hybrid_device.py
tests/test_tensor/test_hybrid_device.py
+0
-88
tests/test_tensor/test_linear_tp.py
tests/test_tensor/test_linear_tp.py
+9
-12
tests/test_tensor/test_model.py
tests/test_tensor/test_model.py
+13
-15
tests/test_tensor/test_module_spec.py
tests/test_tensor/test_module_spec.py
+28
-25
tests/test_tensor/test_op.py
tests/test_tensor/test_op.py
+4
-4
tests/test_tensor/test_tensor.py
tests/test_tensor/test_tensor.py
+13
-8
tests/test_tensor/test_zero_optim.py
tests/test_tensor/test_zero_optim.py
+29
-25
tests/test_zero/test_sharded_optim_state_dict.py
tests/test_zero/test_sharded_optim_state_dict.py
+3
-3
tests/test_zero/test_zero_optim_state_dict.py
tests/test_zero/test_zero_optim_state_dict.py
+3
-1
No files found.
tests/test_tensor/test_dist_spec_mgr.py
View file @
060b917d
...
@@ -7,12 +7,12 @@ import torch.multiprocessing as mp
...
@@ -7,12 +7,12 @@ import torch.multiprocessing as mp
from
torch.distributed.distributed_c10d
import
_get_default_group
from
torch.distributed.distributed_c10d
import
_get_default_group
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
DistSpecManager
,
distspec
from
colossalai.tensor
import
DistSpecManager
,
distspec
,
ProcessGroup
from
functools
import
partial
from
functools
import
partial
def
run
():
def
run
():
group
=
_get_default_group
()
group
=
ProcessGroup
(
tp_degree
=
dist
.
get_world_size
()
)
rank
=
dist
.
get_rank
()
rank
=
dist
.
get_rank
()
size
=
dist
.
get_world_size
()
size
=
dist
.
get_world_size
()
depth
=
int
(
math
.
sqrt
(
size
))
depth
=
int
(
math
.
sqrt
(
size
))
...
@@ -34,7 +34,7 @@ def run():
...
@@ -34,7 +34,7 @@ def run():
def
check_mem
():
def
check_mem
():
group
=
_get_default_group
()
group
=
ProcessGroup
(
tp_degree
=
dist
.
get_world_size
()
)
size
=
dist
.
get_world_size
()
size
=
dist
.
get_world_size
()
assert
torch
.
cuda
.
memory_allocated
()
==
0
assert
torch
.
cuda
.
memory_allocated
()
==
0
x
=
torch
.
rand
(
32
,
32
).
cuda
()
x
=
torch
.
rand
(
32
,
32
).
cuda
()
...
...
tests/test_tensor/test_embedding_bag_tp.py
View file @
060b917d
import
torch
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
distspec
,
ColoParameter
from
colossalai.tensor
import
ColoTensor
,
distspec
,
ColoParameter
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
functools
import
partial
from
functools
import
partial
...
@@ -10,23 +9,21 @@ import torch
...
@@ -10,23 +9,21 @@ import torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
_utils
import
tensor_equal
,
tensor_shard_equal
from
_utils
import
tensor_equal
,
tensor_shard_equal
def
init_1d_col
(
weight
):
def
init_1d_col
(
weight
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
def
run_with_spec
(
spec_init_func
):
def
run_with_spec
(
spec_init_func
):
pg
=
ProcessGroup
(
tp_degree
=
torch
.
distributed
.
get_world_size
())
model
=
torch
.
nn
.
EmbeddingBag
(
10
,
4
).
cuda
()
model
=
torch
.
nn
.
EmbeddingBag
(
10
,
4
).
cuda
()
weight
=
ColoParameter
(
model
.
weight
.
clone
())
weight
=
ColoParameter
(
model
.
weight
.
clone
())
spec_init_func
(
weight
)
spec_init_func
(
weight
,
pg
)
inputs
=
torch
.
tensor
([
1
,
2
,
4
,
5
,
4
,
3
,
2
,
9
]).
cuda
()
inputs
=
torch
.
tensor
([
1
,
2
,
4
,
5
,
4
,
3
,
2
,
9
]).
cuda
()
offsets
=
torch
.
tensor
([
0
,
4
]).
cuda
()
offsets
=
torch
.
tensor
([
0
,
4
]).
cuda
()
out
=
model
(
inputs
,
offsets
=
offsets
)
out
=
model
(
inputs
,
offsets
=
offsets
)
...
@@ -35,7 +32,7 @@ def run_with_spec(spec_init_func):
...
@@ -35,7 +32,7 @@ def run_with_spec(spec_init_func):
grad
=
torch
.
rand_like
(
out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_tensor/test_embedding_tp.py
View file @
060b917d
import
torch
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
ColoTensor
,
distspec
from
colossalai.tensor
import
ColoTensor
,
distspec
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
functools
import
partial
from
functools
import
partial
...
@@ -11,30 +10,26 @@ import torch.multiprocessing as mp
...
@@ -11,30 +10,26 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
from
_utils
import
tensor_equal
,
tensor_shard_equal
from
_utils
import
tensor_equal
,
tensor_shard_equal
def
init_1d_row
(
weight
):
def
init_1d_row
(
weight
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
def
init_1d_col
(
weight
):
def
init_1d_col
(
weight
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
def
run_with_spec
(
spec_init_func
):
def
run_with_spec
(
spec_init_func
,
pg
:
ProcessGroup
):
model
=
torch
.
nn
.
Embedding
(
12
,
32
).
cuda
()
model
=
torch
.
nn
.
Embedding
(
12
,
32
).
cuda
()
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()))
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()))
spec_init_func
(
weight
)
spec_init_func
(
weight
,
pg
)
x
=
torch
.
tensor
((
0
,
3
,
6
,
9
)).
cuda
()
x
=
torch
.
tensor
((
0
,
3
,
6
,
9
)).
cuda
()
out
=
model
(
x
)
out
=
model
(
x
)
colo_out
=
F
.
embedding
(
x
,
weight
)
colo_out
=
F
.
embedding
(
x
,
weight
)
...
@@ -42,14 +37,16 @@ def run_with_spec(spec_init_func):
...
@@ -42,14 +37,16 @@ def run_with_spec(spec_init_func):
grad
=
torch
.
rand_like
(
out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
)
# compare grad inside a TP group
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
world_size
),))
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_with_spec
(
init_1d_row
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
run_with_spec
(
init_1d_col
)
run_with_spec
(
init_1d_row
,
pg
)
run_with_spec
(
init_1d_col
,
pg
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_tensor/test_gpt.py
View file @
060b917d
import
pytest
import
pytest
import
colossalai
import
colossalai
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
,
ProcessGroup
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
functools
import
partial
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context.parallel_mode
import
ParallelMode
def
init_1d_row_spec
(
model
):
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
tensor_spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_tensor_spec
(
spec
)
p
.
set_tensor_spec
(
tensor_
spec
)
def
init_1d_col_spec
(
model
):
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
p
.
set_tensor_spec
(
spec
)
p
.
set_tensor_spec
(
spec
)
def
check_param_equal
(
model
,
torch_model
):
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
assert
tensor_shard_equal
(
torch_p
,
p
)
assert
pg
.
tp_local_rank
()
is
not
None
,
f
"
{
pg
.
rank
()
}
{
pg
.
tp_world_size
()
}
{
pg
.
_tp_degree
}
{
pg
.
tp_local_rank
()
}
1"
assert
pg
.
tp_world_size
()
is
not
None
assert
tensor_shard_equal
(
torch_p
,
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
check_grad_equal
(
model
,
torch_model
):
def
check_grad_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
assert
tensor_shard_equal
(
torch_p
.
grad
,
p
.
grad
)
assert
tensor_shard_equal
(
torch_p
.
grad
,
p
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
def
run_gpt
(
init_spec_func
,
use_ddp
):
def
run_gpt
(
init_spec_func
,
use_ddp
):
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
dp_degree
=
(
2
if
(
use_ddp
and
world_size
>=
2
)
else
1
))
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -54,21 +57,25 @@ def run_gpt(init_spec_func, use_ddp):
...
@@ -54,21 +57,25 @@ def run_gpt(init_spec_func, use_ddp):
model
=
model
.
cuda
()
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
torch_model
=
model_builder
().
cuda
()
if
use_ddp
:
if
use_ddp
:
model
=
ColoDDP
(
model
)
# torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
# torch.distributed.barrier()
torch_model
=
DDP
(
torch_model
,
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
gpc
.
get_global_rank
()],
device_ids
=
[
gpc
.
get_global_rank
()],
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
model
=
ColoDDP
(
model
,
process_group
=
pg
)
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
)
torch_p
.
data
.
copy_
(
p
)
init_spec_func
(
model
)
init_spec_func
(
model
,
pg
)
check_param_equal
(
model
,
torch_model
)
check_param_equal
(
model
,
torch_model
,
pg
)
model
.
train
()
model
.
train
()
torch_model
.
train
()
torch_model
.
train
()
set_seed
(
gpc
.
get_local_rank
(
ParallelMode
.
DATA
))
set_seed
(
pg
.
tp_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
,
attn_mask
)
torch_logits
=
torch_model
(
input_ids
,
attn_mask
)
torch_logits
=
torch_model
(
input_ids
,
attn_mask
)
assert
tensor_equal
(
torch_logits
,
logits
)
assert
tensor_equal
(
torch_logits
,
logits
)
,
f
"
{
torch_logits
-
logits
}
"
loss
=
criterion
(
logits
,
input_ids
)
loss
=
criterion
(
logits
,
input_ids
)
torch_loss
=
criterion
(
torch_logits
,
input_ids
)
torch_loss
=
criterion
(
torch_logits
,
input_ids
)
if
use_ddp
:
if
use_ddp
:
...
@@ -76,7 +83,7 @@ def run_gpt(init_spec_func, use_ddp):
...
@@ -76,7 +83,7 @@ def run_gpt(init_spec_func, use_ddp):
else
:
else
:
loss
.
backward
()
loss
.
backward
()
torch_loss
.
backward
()
torch_loss
.
backward
()
check_grad_equal
(
model
,
torch_model
)
check_grad_equal
(
model
,
torch_model
,
pg
)
if
i
>
0
:
if
i
>
0
:
break
break
...
@@ -87,11 +94,12 @@ def run_dist(rank, world_size, port, use_ddp):
...
@@ -87,11 +94,12 @@ def run_dist(rank, world_size, port, use_ddp):
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_gpt
(
init_1d_row_spec
,
use_ddp
)
#
run_gpt(init_1d_row_spec, use_ddp)
run_gpt
(
init_1d_col_spec
,
use_ddp
)
run_gpt
(
init_1d_col_spec
,
use_ddp
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
(
"under development"
)
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
...
tests/test_tensor/test_hybrid_device.py
deleted
100644 → 0
View file @
abf6a262
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.tensor
import
ComputePattern
,
ComputeSpec
from
functools
import
partial
from
colossalai.core
import
global_context
as
gpc
from
colossalai.context
import
ParallelMode
from
colossalai.nn.parallel.layers
import
init_colo_module
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.nn.optimizer
import
ColoOptimizer
import
colossalai
import
torch
import
torch.multiprocessing
as
mp
import
pytest
class
Net
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
embed
=
torch
.
nn
.
Embedding
(
20
,
4
)
self
.
proj
=
torch
.
nn
.
Linear
(
4
,
8
)
def
forward
(
self
,
x
):
# move input to cpu and restore output
current_dev
=
x
.
device
x
=
x
.
to
(
'cpu'
)
x
=
self
.
embed
(
x
)
x
=
x
.
to
(
current_dev
)
x
=
self
.
proj
(
x
)
return
x
def
run_hybrid_device
(
use_ddp
,
mode
):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
Net
()
real_model
=
model
if
use_ddp
:
model
=
ColoDDP
(
model
)
real_model
=
model
.
module
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| device:
{
real_model
.
embed
.
weight
.
device
}
'
)
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
parallel_action
=
ComputeSpec
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
# use cpu gloo to handle embedding
real_model
.
embed
.
to
(
'cpu'
)
gloo_group_tp
=
gpc
.
get_cpu_group
(
ParallelMode
.
PARALLEL_1D
)
real_model
.
embed
.
weight
.
spec
.
dist_spec
.
process_group
=
gloo_group_tp
print
(
f
'embedding weight size:
{
real_model
.
embed
.
weight
.
size
()
}
| new device:
{
real_model
.
embed
.
weight
.
device
}
'
)
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
optimizer
=
ColoOptimizer
(
dict
(
model
.
named_parameters
()),
torch
.
optim
.
SGD
,
lr
=
0.1
)
data
=
torch
.
randint
(
low
=
0
,
high
=
20
,
size
=
(
16
,),
device
=
get_current_device
())
out
=
model
(
data
)
out
.
sum
().
backward
()
optimizer
.
step
()
def
run_dist
(
rank
,
world_size
,
port
,
use_ddp
,
mode
):
if
use_ddp
and
world_size
==
1
:
return
tp_world_size
=
world_size
//
2
if
use_ddp
else
world_size
config
=
dict
(
parallel
=
dict
(
tensor
=
dict
(
mode
=
"1d"
,
size
=
tp_world_size
),))
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_hybrid_device
(
use_ddp
,
mode
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'use_ddp'
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
'mode'
,
[
'col'
,
'row'
])
@
rerun_if_address_is_in_use
()
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
def
_test_hybrid_device
(
world_size
,
use_ddp
,
mode
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
(),
use_ddp
=
use_ddp
,
mode
=
mode
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
_test_hybrid_device
(
4
,
True
,
'row'
)
tests/test_tensor/test_linear_tp.py
View file @
060b917d
...
@@ -12,32 +12,29 @@ import torch.nn.functional as F
...
@@ -12,32 +12,29 @@ import torch.nn.functional as F
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
ProcessGroup
from
_utils
import
tensor_equal
,
tensor_shard_equal
from
_utils
import
tensor_equal
,
tensor_shard_equal
def
init_1d_row
(
weight
,
bias
):
def
init_1d_row
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
def
init_1d_col
(
weight
,
bias
):
def
init_1d_col
(
weight
,
bias
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
bias
.
set_tensor_spec
(
spec
)
bias
.
set_tensor_spec
(
spec
)
def
run_with_spec
(
spec_init_func
):
def
run_with_spec
(
spec_init_func
):
pg
=
ProcessGroup
(
tp_degree
=
torch
.
distributed
.
get_world_size
())
model
=
torch
.
nn
.
Linear
(
4
,
8
).
cuda
()
model
=
torch
.
nn
.
Linear
(
4
,
8
).
cuda
()
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()))
weight
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
weight
.
detach
()))
bias
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
bias
.
detach
()))
bias
=
ColoTensor
(
torch
.
nn
.
Parameter
(
model
.
bias
.
detach
()))
spec_init_func
(
weight
,
bias
)
spec_init_func
(
weight
,
bias
,
pg
)
x
=
torch
.
rand
(
2
,
4
).
cuda
()
x
=
torch
.
rand
(
2
,
4
).
cuda
()
out
=
model
(
x
)
out
=
model
(
x
)
colo_out
=
F
.
linear
(
x
,
weight
,
bias
)
colo_out
=
F
.
linear
(
x
,
weight
,
bias
)
...
@@ -46,8 +43,8 @@ def run_with_spec(spec_init_func):
...
@@ -46,8 +43,8 @@ def run_with_spec(spec_init_func):
grad
=
torch
.
rand_like
(
out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
bias
.
grad
)
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
...
...
tests/test_tensor/test_model.py
View file @
060b917d
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
tests.components_to_test.registry
import
non_distributed_component_funcs
import
colossalai
import
pytest
import
pytest
from
functools
import
partial
from
_utils
import
tensor_shard_equal
,
set_seed
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.tensor.colo_parameter
import
ColoParameter
import
colossalai
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
...
@@ -12,34 +14,30 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
...
@@ -12,34 +14,30 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
from
colossalai.tensor
import
distspec
,
TensorSpec
,
ComputePattern
,
\
from
colossalai.tensor
import
distspec
,
TensorSpec
,
ComputePattern
,
\
ComputeSpec
,
ColoTensor
,
DistSpecManager
,
ProcessGroup
ComputeSpec
,
ColoTensor
,
DistSpecManager
,
ProcessGroup
from
colossalai.nn.optimizer
import
ColoOptimizer
from
colossalai.nn.optimizer
import
ColoOptimizer
from
functools
import
partial
from
_utils
import
tensor_shard_equal
,
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
init_1d_row_linear
(
weight
,
pg
:
ProcessGroup
):
def
init_1d_row_linear
(
weight
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
distspec
.
shard
(
pg
.
tp_process_group
(),
[
-
1
],
[
pg
.
tp_world_size
()]),
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
def
init_1d_col_linear
(
weight
,
pg
):
def
init_1d_col_linear
(
weight
,
pg
):
spec
=
TensorSpec
(
distspec
.
shard
(
pg
.
tp_process_group
(),
[
0
],
[
pg
.
tp_world_size
()]),
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
def
init_1d_row_embedding
(
weight
,
pg
):
def
init_1d_row_embedding
(
weight
,
pg
):
spec
=
TensorSpec
(
distspec
.
shard
(
pg
.
tp_process_group
(),
[
0
],
[
pg
.
tp_world_size
()]),
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
def
init_1d_col_embedding
(
weight
,
pg
):
def
init_1d_col_embedding
(
weight
,
pg
):
spec
=
TensorSpec
(
distspec
.
shard
(
pg
.
tp_process_group
(),
[
-
1
],
[
pg
.
tp_world_size
()]),
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
weight
.
set_tensor_spec
(
spec
)
weight
.
set_tensor_spec
(
spec
)
...
@@ -142,7 +140,7 @@ def run_1d_hybrid_tp(model_name):
...
@@ -142,7 +140,7 @@ def run_1d_hybrid_tp(model_name):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
# check param
# check param
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
assert
tensor_shard_equal
(
torch_p
,
p
)
assert
tensor_shard_equal
(
torch_p
,
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
if
i
>
5
:
if
i
>
5
:
break
break
...
...
tests/test_tensor/test_module_spec.py
View file @
060b917d
...
@@ -13,12 +13,10 @@ import colossalai
...
@@ -13,12 +13,10 @@ import colossalai
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.tensor
import
distspec
,
ProcessGroup
from
colossalai.tensor
import
distspec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
@@ -26,7 +24,9 @@ from tests.components_to_test.registry import non_distributed_component_funcs
...
@@ -26,7 +24,9 @@ from tests.components_to_test.registry import non_distributed_component_funcs
def
run_model_with_spec
(
mode
,
model_name
):
def
run_model_with_spec
(
mode
,
model_name
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
rank
=
gpc
.
get_local_rank
(
ParallelMode
.
PARALLEL_1D
)
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
rank
=
pg
.
rank
()
set_seed
(
1
)
set_seed
(
1
)
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
...
@@ -40,28 +40,28 @@ def run_model_with_spec(mode, model_name):
...
@@ -40,28 +40,28 @@ def run_model_with_spec(mode, model_name):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_seq
.
parameters
()):
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_seq
.
parameters
()):
p2
.
data
.
copy_
(
p1
.
data
)
p2
.
data
.
copy_
(
p1
.
data
)
parallel_action
=
ComputeSpec
(
ComputePattern
.
TP1D
)
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
# Not all layers in Bert can be mod by 4.
# Not all layers in Bert can be mod by 4.
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
if
'bert'
==
model_name
:
if
'bert'
==
model_name
:
if
'col'
==
mode
:
if
'col'
==
mode
:
init_colo_module
(
model
.
bert
.
embeddings
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
bert
.
embeddings
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
bert
.
encoder
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
bert
.
encoder
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
classifier
,
parallel_action
,
recursive
=
True
,
mode
=
'row'
)
init_colo_module
(
model
.
classifier
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
'row'
)
elif
'row'
==
mode
:
elif
'row'
==
mode
:
init_colo_module
(
model
.
bert
.
embeddings
,
parallel_action
,
recursive
=
True
,
mode
=
'col'
)
init_colo_module
(
model
.
bert
.
embeddings
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
'col'
)
init_colo_module
(
model
.
bert
.
encoder
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
bert
.
encoder
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
classifier
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
.
classifier
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
elif
'simple_net'
==
model_name
:
elif
'simple_net'
==
model_name
:
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
data
=
data
.
to
(
get_current_device
())
data
=
data
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
label
=
label
.
to
(
get_current_device
())
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
torch
.
distributed
.
broadcast
(
data
,
0
,
group
=
pg
.
tp_process_group
(
))
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
))
torch
.
distributed
.
broadcast
(
label
,
0
,
group
=
pg
.
tp_process_group
(
))
if
criterion
:
if
criterion
:
output
=
model
(
data
)
output
=
model
(
data
)
...
@@ -113,9 +113,10 @@ def run_linear_with_spec(mode):
...
@@ -113,9 +113,10 @@ def run_linear_with_spec(mode):
model
=
torch
.
nn
.
Linear
(
4
,
8
)
model
=
torch
.
nn
.
Linear
(
4
,
8
)
model_handy
=
copy
(
model
)
model_handy
=
copy
(
model
)
world_size
=
torch
.
distributed
.
get_world_size
()
parallel_action
=
ComputeSpec
(
ComputePattern
.
TP1D
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
mode
)
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
x
=
torch
.
rand
(
2
,
4
).
cuda
()
x
=
torch
.
rand
(
2
,
4
).
cuda
()
out
=
model
(
x
)
out
=
model
(
x
)
...
@@ -124,8 +125,8 @@ def run_linear_with_spec(mode):
...
@@ -124,8 +125,8 @@ def run_linear_with_spec(mode):
grad
=
torch
.
rand_like
(
out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
model_handy
.
weight
.
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
model_handy
.
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
model_handy
.
bias
.
grad
)
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
model_handy
.
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()
)
def
run_check_shared_param
():
def
run_check_shared_param
():
...
@@ -136,6 +137,10 @@ def run_check_shared_param():
...
@@ -136,6 +137,10 @@ def run_check_shared_param():
num_layer
=
2
num_layer
=
2
vocab_size
=
24
vocab_size
=
24
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
rank
=
pg
.
rank
()
config
=
BertConfig
(
vocab_size
=
vocab_size
,
config
=
BertConfig
(
vocab_size
=
vocab_size
,
hidden_size
=
hidden_dim
,
hidden_size
=
hidden_dim
,
intermediate_size
=
hidden_dim
*
4
,
intermediate_size
=
hidden_dim
*
4
,
...
@@ -148,18 +153,16 @@ def run_check_shared_param():
...
@@ -148,18 +153,16 @@ def run_check_shared_param():
model
=
BertForMaskedLM
(
config
)
model
=
BertForMaskedLM
(
config
)
model
=
model
.
cuda
()
model
=
model
.
cuda
()
parallel_action
=
ComputeSpec
(
ComputePattern
.
TP1D
)
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
assert
len
(
model
.
cls
.
predictions
.
decoder
.
bias
.
shared_param_modules
)
==
2
assert
len
(
model
.
cls
.
predictions
.
decoder
.
bias
.
shared_param_modules
)
==
2
# They are all Linear, so both row is allowed. This should pass check.
# They are all Linear, so both row is allowed. This should pass check.
init_colo_module
(
model
,
parallel_action
,
recursive
=
True
,
mode
=
'row'
)
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
'row'
)
# This should be detected by check because you can not set weight as row while set bias as col.
# This should be detected by check because you can not set weight as row while set bias as col.
col_spec
=
TensorSpec
(
col_spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
model
.
cls
.
predictions
.
bias
.
set_tensor_spec
(
col_spec
)
model
.
cls
.
predictions
.
bias
.
set_tensor_spec
(
col_spec
)
try
:
try
:
check_colo_module
(
model
.
cls
.
predictions
.
decoder
,
recursive
=
False
)
check_colo_module
(
model
.
cls
.
predictions
.
decoder
,
pg
=
pg
,
recursive
=
False
)
except
Exception
as
e
:
except
Exception
as
e
:
assert
'incorrectly sharded'
in
str
(
e
)
assert
'incorrectly sharded'
in
str
(
e
)
...
...
tests/test_tensor/test_op.py
View file @
060b917d
...
@@ -4,10 +4,9 @@ import colossalai
...
@@ -4,10 +4,9 @@ import colossalai
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
functools
import
partial
from
functools
import
partial
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
from
colossalai.tensor
import
ColoTensor
,
ProcessGroup
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.distributed.distributed_c10d
import
_get_default_group
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
distspec
,
TensorSpec
from
colossalai.tensor
import
distspec
,
TensorSpec
...
@@ -43,9 +42,10 @@ def check_spec_eq(tensor, other):
...
@@ -43,9 +42,10 @@ def check_spec_eq(tensor, other):
def
check_element_wise_ops
():
def
check_element_wise_ops
():
pg
=
_get_default_group
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
t
=
torch
.
rand
(
2
,
2
)
t
=
torch
.
rand
(
2
,
2
)
x
=
ColoTensor
(
t
,
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
size
()])))
x
=
ColoTensor
(
t
,
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_
size
()])))
check_spec_eq
(
x
,
x
.
cuda
())
check_spec_eq
(
x
,
x
.
cuda
())
assert
torch
.
equal
(
x
.
cuda
(),
t
.
cuda
())
assert
torch
.
equal
(
x
.
cuda
(),
t
.
cuda
())
check_spec_eq
(
x
,
torch
.
abs
(
x
))
check_spec_eq
(
x
,
torch
.
abs
(
x
))
...
...
tests/test_tensor/test_tensor.py
View file @
060b917d
...
@@ -11,7 +11,6 @@ import torch.multiprocessing as mp
...
@@ -11,7 +11,6 @@ import torch.multiprocessing as mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.tensor
import
distspec
,
TensorSpec
,
ColoTensor
,
ProcessGroup
from
colossalai.tensor
import
distspec
,
TensorSpec
,
ColoTensor
,
ProcessGroup
from
colossalai.context
import
ParallelMode
from
functools
import
partial
from
functools
import
partial
...
@@ -55,11 +54,9 @@ def test_operand():
...
@@ -55,11 +54,9 @@ def test_operand():
def
_run_view
(
world_size
):
def
_run_view
(
world_size
):
t_ref
=
torch
.
randn
(
4
,
5
)
t_ref
=
torch
.
randn
(
4
,
5
)
rank
=
gpc
.
get_global_rank
()
rank
=
gpc
.
get_global_rank
()
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
)))
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
)),
tp_degree
=
world_size
)
assert
pg
.
dp_world_size
()
==
world_size
,
f
"
{
pg
.
dp_world_size
()
}
vs
{
world_size
}
"
t
=
ColoTensor
.
from_torch_tensor
(
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
,
t_ref
,
TensorSpec
(
distspec
.
shard
(
process_group
=
pg
,
dims
=
[
0
],
num_partitions
=
[
pg
.
tp_world_size
()])))
TensorSpec
(
distspec
.
shard
(
process_group
=
pg
.
dp_process_group
(),
dims
=
[
0
],
num_partitions
=
[
pg
.
dp_world_size
()])))
assert
t
.
size_global
()[
0
]
==
4
*
world_size
assert
t
.
size_global
()[
0
]
==
4
*
world_size
assert
t
.
size_global
(
1
)
==
5
assert
t
.
size_global
(
1
)
==
5
...
@@ -77,12 +74,12 @@ def _run_tensor_shard_init(world_size):
...
@@ -77,12 +74,12 @@ def _run_tensor_shard_init(world_size):
t_ref
=
torch
.
randn
(
4
,
5
)
t_ref
=
torch
.
randn
(
4
,
5
)
rank
=
gpc
.
get_global_rank
()
rank
=
gpc
.
get_global_rank
()
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
)))
pg
=
ProcessGroup
(
rank
,
list
(
range
(
world_size
))
,
tp_degree
=
world_size
)
shard_spec
=
distspec
.
shard
(
process_group
=
pg
.
dp_process_group
()
,
dims
=
[
0
],
num_partitions
=
[
pg
.
d
p_world_size
()])
shard_spec
=
distspec
.
shard
(
process_group
=
pg
,
dims
=
[
0
],
num_partitions
=
[
pg
.
t
p_world_size
()])
tensor_spec
=
TensorSpec
(
shard_spec
)
tensor_spec
=
TensorSpec
(
shard_spec
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
tensor_spec
)
t
=
ColoTensor
.
from_torch_tensor
(
t_ref
.
clone
(),
tensor_spec
)
t
.
set_tensor_spec
(
TensorSpec
(
dist_spec
=
distspec
.
replicate
()))
t
.
set_tensor_spec
(
TensorSpec
(
dist_spec
=
distspec
.
replicate
()))
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
))
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
))
,
f
"
{
t
.
shape
}
vs (
{
4
*
world_size
,
5
}
)"
def
_run_tensor_replicated_init
(
world_size
):
def
_run_tensor_replicated_init
(
world_size
):
...
@@ -92,11 +89,19 @@ def _run_tensor_replicated_init(world_size):
...
@@ -92,11 +89,19 @@ def _run_tensor_replicated_init(world_size):
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
)),
f
"
{
t
.
shape
}
"
assert
t
.
shape
==
torch
.
Size
((
4
*
world_size
,
5
)),
f
"
{
t
.
shape
}
"
def
_run_process_group
(
world_size
):
pg1
=
ProcessGroup
()
pg2
=
ProcessGroup
()
assert
pg1
==
pg2
def
run_dist_tests
(
rank
,
world_size
,
port
):
def
run_dist_tests
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
_run_tensor_shard_init
(
world_size
)
_run_tensor_shard_init
(
world_size
)
_run_tensor_replicated_init
(
world_size
)
_run_tensor_replicated_init
(
world_size
)
_run_view
(
world_size
)
_run_view
(
world_size
)
_run_process_group
(
world_size
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
tests/test_tensor/test_zero_optim.py
View file @
060b917d
...
@@ -2,13 +2,11 @@ import pytest
...
@@ -2,13 +2,11 @@ import pytest
import
colossalai
import
colossalai
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.gemini
import
ChunkManager
from
colossalai.gemini
import
ChunkManager
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
functools
import
partial
from
_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
_utils
import
tensor_equal
,
set_seed
,
tensor_shard_equal
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
@@ -19,20 +17,22 @@ from colossalai.zero import ZeroOptimizer
...
@@ -19,20 +17,22 @@ from colossalai.zero import ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.amp
import
convert_to_apex_amp
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
from
colossalai.tensor
import
TensorSpec
,
ComputePattern
,
ComputeSpec
,
DistSpecManager
,
distspec
,
ProcessGroup
def
check_param_equal
(
model
,
torch_model
):
def
check_param_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
if
p
.
storage
().
size
()
>
0
:
if
p
.
storage
().
size
()
>
0
:
assert
p
.
dtype
==
torch
.
half
assert
p
.
dtype
==
torch
.
half
assert
tensor_shard_equal
(
torch_p
.
to
(
dtype
=
p
.
dtype
,
device
=
p
.
device
),
p
),
f
'
{
torch_p
}
vs
{
p
}
'
assert
tensor_shard_equal
(
torch_p
.
to
(
dtype
=
p
.
dtype
,
device
=
p
.
device
),
p
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
()),
f
'
{
torch_p
}
vs
{
p
}
'
def
check_grad_equal
(
model
,
torch_model
):
def
check_grad_equal
(
model
,
torch_model
,
pg
:
ProcessGroup
):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
for
p
,
torch_p
in
zip
(
model
.
parameters
(),
torch_model
.
parameters
()):
if
p
.
grad
is
not
None
:
if
p
.
grad
is
not
None
:
assert
tensor_shard_equal
(
torch_p
.
grad
.
to
(
dtype
=
p
.
grad
.
dtype
,
device
=
p
.
grad
.
device
),
p
.
grad
)
assert
tensor_shard_equal
(
torch_p
.
grad
.
to
(
dtype
=
p
.
grad
.
dtype
,
device
=
p
.
grad
.
device
),
p
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
...
@@ -44,20 +44,16 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
...
@@ -44,20 +44,16 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
return
logits
return
logits
def
init_1d_row_spec
(
model
):
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
0
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
0
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
'weight'
in
n
and
'ln'
not
in
n
:
if
'weight'
in
n
and
'ln'
not
in
n
:
p
.
set_tensor_spec
(
spec
)
p
.
set_tensor_spec
(
spec
)
def
init_1d_col_spec
(
model
):
def
init_1d_col_spec
(
model
,
pg
:
ProcessGroup
):
spec
=
TensorSpec
(
spec
=
TensorSpec
(
distspec
.
shard
(
pg
,
[
-
1
],
[
pg
.
tp_world_size
()]),
ComputeSpec
(
ComputePattern
.
TP1D
))
distspec
.
shard
(
gpc
.
get_group
(
ParallelMode
.
PARALLEL_1D
),
[
-
1
],
[
gpc
.
get_world_size
(
ParallelMode
.
PARALLEL_1D
)]),
ComputeSpec
(
ComputePattern
.
TP1D
))
with
DistSpecManager
.
no_grad
():
with
DistSpecManager
.
no_grad
():
for
n
,
p
in
model
.
named_parameters
():
for
n
,
p
in
model
.
named_parameters
():
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
if
'ln'
not
in
n
and
(
'weight'
in
n
or
'bias'
in
n
):
...
@@ -79,44 +75,51 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
...
@@ -79,44 +75,51 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
)
torch_p
.
data
.
copy_
(
p
)
world_size
=
torch
.
distributed
.
get_world_size
()
# world size, dp = 2, tp =2, construct a hybrid parallelism.
if
world_size
==
4
:
pg
=
ProcessGroup
(
tp_degree
=
2
)
else
:
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
if
tp_init_spec_func
:
if
tp_init_spec_func
:
tp_init_spec_func
(
model
)
tp_init_spec_func
(
model
,
pg
)
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
8192
,
8
)
if
use_chunk
else
None
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
8192
,
8
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
chunk_manager
=
ChunkManager
(
chunk_size
,
enable_distributed_storage
=
use_zero
,
enable_distributed_storage
=
use_zero
,
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
,
pg
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
32
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
32
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
32
)
amp_config
=
dict
(
opt_level
=
'O2'
,
keep_batchnorm_fp32
=
False
,
loss_scale
=
32
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
gpc
.
get_global_
rank
()],
process_group
=
gpc
.
get_group
(
ParallelMode
.
DATA
))
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
(
))
print
(
chunk_manager
)
#
print(chunk_manager)
check_param_equal
(
model
,
torch_model
)
check_param_equal
(
model
,
torch_model
,
pg
)
model
.
train
()
model
.
train
()
torch_model
.
train
()
torch_model
.
train
()
set_seed
(
gpc
.
get
_local_rank
(
ParallelMode
.
DATA
))
set_seed
(
pg
.
dp
_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
logits
=
run_fwd_bwd
(
model
,
criterion
,
optim
,
input_ids
,
attn_mask
)
logits
=
run_fwd_bwd
(
model
,
criterion
,
optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
assert
tensor_equal
(
logits
,
torch_logits
)
assert
tensor_equal
(
logits
,
torch_logits
)
check_grad_equal
(
model
,
torch_model
)
check_grad_equal
(
model
,
torch_model
,
pg
)
optim
.
step
()
optim
.
step
()
torch_optim
.
step
()
torch_optim
.
step
()
check_param_equal
(
model
,
torch_model
)
check_param_equal
(
model
,
torch_model
,
pg
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
config
=
{}
config
=
{}
if
world_size
==
4
:
config
[
'parallel'
]
=
{
'tensor'
:
{
'mode'
:
'1d'
,
'size'
:
2
}}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
if
world_size
==
4
:
if
world_size
==
4
:
run_gpt
(
tp_init_spec_func
=
init_1d_col_spec
)
run_gpt
(
tp_init_spec_func
=
init_1d_col_spec
)
...
@@ -126,6 +129,7 @@ def run_dist(rank, world_size, port):
...
@@ -126,6 +129,7 @@ def run_dist(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
skip
(
"under development"
)
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_gpt
(
world_size
):
def
test_gpt
(
world_size
):
...
...
tests/test_zero/test_sharded_optim_state_dict.py
View file @
060b917d
import
pytest
import
pytest
import
colossalai
import
colossalai
import
torch
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
functools
import
partial
from
functools
import
partial
from
tests.test_tensor._utils
import
set_seed
from
tests.test_tensor._utils
import
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
...
@@ -16,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
...
@@ -16,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.shard_utils
import
TensorShardStrategy
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_model
import
ShardedModelV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.zero.sharded_optim
import
ShardedOptimizerV2
from
colossalai.tensor
import
ProcessGroup
def
init_zero
(
model_builder
,
placement_policy
):
def
init_zero
(
model_builder
,
placement_policy
):
...
@@ -64,7 +63,8 @@ def run_nested_model(placement_policy):
...
@@ -64,7 +63,8 @@ def run_nested_model(placement_policy):
model
.
train
()
model
.
train
()
model_copy
.
train
()
model_copy
.
train
()
set_seed
(
gpc
.
get_local_rank
(
ParallelMode
.
DATA
))
pg
=
ProcessGroup
()
set_seed
(
pg
.
dp_local_rank
())
data_iter
=
iter
(
train_dataloader
)
data_iter
=
iter
(
train_dataloader
)
data
,
label
=
map
(
lambda
x
:
x
.
cuda
(),
next
(
data_iter
))
data
,
label
=
map
(
lambda
x
:
x
.
cuda
(),
next
(
data_iter
))
...
...
tests/test_zero/test_zero_optim_state_dict.py
View file @
060b917d
...
@@ -16,6 +16,7 @@ from colossalai.gemini import ChunkManager, GeminiManager
...
@@ -16,6 +16,7 @@ from colossalai.gemini import ChunkManager, GeminiManager
from
colossalai.testing
import
parameterize
from
colossalai.testing
import
parameterize
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.tensor
import
ProcessGroup
def
init_zero
(
model
,
use_chunk
,
use_zero
,
placement_policy
):
def
init_zero
(
model
,
use_chunk
,
use_zero
,
placement_policy
):
...
@@ -24,7 +25,8 @@ def init_zero(model, use_chunk, use_zero, placement_policy):
...
@@ -24,7 +25,8 @@ def init_zero(model, use_chunk, use_zero, placement_policy):
enable_distributed_storage
=
use_zero
,
enable_distributed_storage
=
use_zero
,
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
return
ZeroDDP
(
model
,
gemini_manager
)
pg
=
ProcessGroup
()
return
ZeroDDP
(
model
,
gemini_manager
,
pg
)
def
run_step
(
model
,
optim
,
criterion
,
data
,
label
):
def
run_step
(
model
,
optim
,
criterion
,
data
,
label
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment