Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1b416864
Unverified
Commit
1b416864
authored
Jul 15, 2022
by
HELSON
Committed by
GitHub
Jul 15, 2022
Browse files
[hotfix] fix unit test test_module_spec (#1321)
parent
9e4c6449
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
29 additions
and
22 deletions
+29
-22
colossalai/nn/parallel/layers/module_utils.py
colossalai/nn/parallel/layers/module_utils.py
+2
-1
colossalai/tensor/colo_tensor.py
colossalai/tensor/colo_tensor.py
+13
-11
tests/test_tensor/test_module_spec.py
tests/test_tensor/test_module_spec.py
+14
-10
No files found.
colossalai/nn/parallel/layers/module_utils.py
View file @
1b416864
...
@@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module,
...
@@ -88,7 +88,7 @@ def init_colo_module(module: torch.nn.Module,
compute_pattern
=
compute_spec
.
compute_pattern
compute_pattern
=
compute_spec
.
compute_pattern
if
is_colo_module
(
module
):
if
is_colo_module
(
module
):
# for each param
# for each param
# set
D
ist
S
pec and
C
ompute
S
pec
# set
its process_group, d
ist
_s
pec and
c
ompute
_s
pec
colo_module
=
get_colo_module
(
module
)
colo_module
=
get_colo_module
(
module
)
colo_module
.
register
(
compute_pattern
,
pg
)
colo_module
.
register
(
compute_pattern
,
pg
)
if
not
colo_module
.
has_compute_pattern_with_mode
(
compute_pattern
,
mode
=
mode
):
if
not
colo_module
.
has_compute_pattern_with_mode
(
compute_pattern
,
mode
=
mode
):
...
@@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module,
...
@@ -101,6 +101,7 @@ def init_colo_module(module: torch.nn.Module,
continue
continue
param
=
module
.
get_parameter
(
param_name
)
param
=
module
.
get_parameter
(
param_name
)
if
isinstance
(
param
,
ColoParameter
):
if
isinstance
(
param
,
ColoParameter
):
param
.
set_process_group
(
pg
)
param
.
set_dist_spec
(
dist_spec
)
param
.
set_dist_spec
(
dist_spec
)
param
.
compute_spec
=
compute_spec
param
.
compute_spec
=
compute_spec
for
mod
in
param
.
shared_param_modules
:
for
mod
in
param
.
shared_param_modules
:
...
...
colossalai/tensor/colo_tensor.py
View file @
1b416864
...
@@ -121,11 +121,13 @@ class ColoTensor(torch.Tensor):
...
@@ -121,11 +121,13 @@ class ColoTensor(torch.Tensor):
RuntimeError:
RuntimeError:
"""
"""
assert
isinstance
(
pg
,
ProcessGroup
),
f
"pg as type
{
type
(
pg
)
}
is invalid"
assert
isinstance
(
pg
,
ProcessGroup
),
f
"pg as type
{
type
(
pg
)
}
is invalid"
if
self
.
process_group
.
tp_world_size
()
!=
1
:
# if the new pg is the same as the old pg, just returns
raise
RuntimeError
(
"can not set_process_group on a ColoTensor whose process_group has tp world group"
)
if
self
.
process_group
==
pg
:
return
if
self
.
dist_spec
.
placement
.
value
!=
'r'
:
assert
self
.
process_group
.
tp_world_size
()
==
1
,
\
raise
RuntimeError
(
"can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
)
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
assert
self
.
dist_spec
.
placement
.
value
==
'r'
,
\
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
self
.
process_group
=
pg
self
.
process_group
=
pg
...
...
tests/test_tensor/test_module_spec.py
View file @
1b416864
from
copy
import
copy
from
copy
import
deep
copy
import
pytest
import
pytest
from
functools
import
partial
from
functools
import
partial
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
colossalai.tensor
import
ColoTensor
Spec
,
ComputePattern
,
ComputeSpec
,
ShardSpec
,
Replica
Spec
from
colossalai.tensor
import
ColoTensor
,
ComputePattern
,
ComputeSpec
,
ShardSpec
,
ColoTensor
Spec
from
colossalai.nn.parallel.layers
import
init_colo_module
,
check_colo_module
from
colossalai.nn.parallel.layers
import
init_colo_module
,
check_colo_module
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
from
_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
...
@@ -112,21 +112,25 @@ def run_linear_with_spec(mode):
...
@@ -112,21 +112,25 @@ def run_linear_with_spec(mode):
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
torch
.
nn
.
Linear
(
4
,
8
)
model
=
torch
.
nn
.
Linear
(
4
,
8
)
model_handy
=
copy
(
model
)
model_handy
=
deep
copy
(
model
)
world_size
=
torch
.
distributed
.
get_world_size
()
world_size
=
torch
.
distributed
.
get_world_size
()
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
pg
=
ProcessGroup
(
tp_degree
=
world_size
)
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
compute_spec
=
ComputeSpec
(
ComputePattern
.
TP1D
)
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
init_colo_module
(
model
,
compute_spec
,
pg
=
pg
,
recursive
=
True
,
mode
=
mode
)
x
=
torch
.
rand
(
2
,
4
).
cuda
()
x
=
torch
.
rand
(
2
,
4
).
cuda
()
colo_x
=
ColoTensor
.
from_torch_tensor
(
x
,
ColoTensorSpec
(
pg
))
out
=
model
(
x
)
out
=
model
(
x
)
colo_out
=
model_handy
(
x
)
colo_out
=
model_handy
(
colo_
x
)
assert
tensor_equal
(
out
,
colo_out
)
assert
tensor_equal
(
out
,
colo_out
)
grad
=
torch
.
rand_like
(
out
)
grad
=
torch
.
rand_like
(
out
)
out
.
backward
(
grad
)
out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
colo_out
.
backward
(
grad
)
assert
tensor_shard_equal
(
model
.
weight
.
grad
,
model_handy
.
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
assert
tensor_shard_equal
(
model
.
bias
.
grad
,
model_handy
.
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
assert
tensor_shard_equal
(
model_handy
.
weight
.
grad
,
model
.
weight
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
assert
tensor_shard_equal
(
model_handy
.
bias
.
grad
,
model
.
bias
.
grad
,
pg
.
tp_local_rank
(),
pg
.
tp_world_size
())
def
run_check_shared_param
():
def
run_check_shared_param
():
...
@@ -196,7 +200,7 @@ def run_dist_check(rank, world_size, port):
...
@@ -196,7 +200,7 @@ def run_dist_check(rank, world_size, port):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
skip
(
"
under development lazy init ColoParameter in Context
"
)
@
pytest
.
mark
.
skip
(
"
for higher testing speed
"
)
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_module_linear_1d
(
world_size
):
def
test_module_linear_1d
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
...
@@ -205,7 +209,7 @@ def test_module_linear_1d(world_size):
...
@@ -205,7 +209,7 @@ def test_module_linear_1d(world_size):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
4
])
@
pytest
.
mark
.
skip
(
"
under development lazy init ColoParameter in Context
"
)
@
pytest
.
mark
.
skip
(
"
for higher testing speed
"
)
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_module_model
(
world_size
):
def
test_module_model
(
world_size
):
run_func
=
partial
(
run_dist_model
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist_model
,
world_size
=
world_size
,
port
=
free_port
())
...
@@ -214,7 +218,7 @@ def test_module_model(world_size):
...
@@ -214,7 +218,7 @@ def test_module_model(world_size):
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
pytest
.
mark
.
skip
(
"
under development lazy init ColoParameter in Context
"
)
@
pytest
.
mark
.
skip
(
"
for higher testing speed
"
)
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
def
test_module_check
(
world_size
):
def
test_module_check
(
world_size
):
run_func
=
partial
(
run_dist_check
,
world_size
=
world_size
,
port
=
free_port
())
run_func
=
partial
(
run_dist_check
,
world_size
=
world_size
,
port
=
free_port
())
...
@@ -222,4 +226,4 @@ def test_module_check(world_size):
...
@@ -222,4 +226,4 @@ def test_module_check(world_size):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_module_
check
(
2
)
test_module_
linear_1d
(
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment