Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1e4bf85c
Commit
1e4bf85c
authored
Mar 11, 2022
by
Frank Lee
Browse files
fixed bug in activation checkpointing test (#387)
parent
3af13a2c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
25 additions
and
14 deletions
+25
-14
colossalai/context/random/__init__.py
colossalai/context/random/__init__.py
+4
-6
colossalai/context/random/_helper.py
colossalai/context/random/_helper.py
+6
-1
colossalai/context/random/seed_manager.py
colossalai/context/random/seed_manager.py
+6
-2
tests/test_utils/test_activation_checkpointing.py
tests/test_utils/test_activation_checkpointing.py
+9
-5
No files found.
colossalai/context/random/__init__.py
View file @
1e4bf85c
from
._helper
import
(
seed
,
set_mode
,
with_seed
,
add_seed
,
get_seeds
,
get_states
,
get_current_mode
,
set_seed_states
,
sync_states
,
moe_set_seed
)
from
._helper
import
(
seed
,
set_mode
,
with_seed
,
add_seed
,
get_seeds
,
get_states
,
get_current_mode
,
set_seed_states
,
sync_states
,
moe_set_seed
,
reset_seeds
)
__all__
=
[
'seed'
,
'set_mode'
,
'with_seed'
,
'add_seed'
,
'get_seeds'
,
'get_states'
,
'get_current_mode'
,
'set_seed_states'
,
'sync_states'
,
'moe_set_seed'
'seed'
,
'set_mode'
,
'with_seed'
,
'add_seed'
,
'get_seeds'
,
'get_states'
,
'get_current_mode'
,
'set_seed_states'
,
'sync_states'
,
'moe_set_seed'
,
'reset_seeds'
]
colossalai/context/random/_helper.py
View file @
1e4bf85c
...
...
@@ -154,4 +154,9 @@ def moe_set_seed(seed):
global_rank
=
gpc
.
get_global_rank
()
add_seed
(
ParallelMode
.
TENSOR
,
global_rank
,
True
)
print
(
f
"moe seed condition:
{
global_rank
}
with moe seed
{
moe_mp_seed
}
, "
,
f
"tensor seed
{
global_rank
}
"
,
flush
=
True
)
f
"tensor seed
{
global_rank
}
"
,
flush
=
True
)
def
reset_seeds
():
_SEED_MANAGER
.
reset
()
colossalai/context/random/seed_manager.py
View file @
1e4bf85c
...
...
@@ -66,8 +66,7 @@ class SeedManager:
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added
"""
assert
isinstance
(
parallel_mode
,
ParallelMode
),
'A valid ParallelMode must be provided'
assert
isinstance
(
parallel_mode
,
ParallelMode
),
'A valid ParallelMode must be provided'
if
overwrtie
is
False
:
assert
parallel_mode
not
in
self
.
_seed_states
,
f
'The seed for
{
parallel_mode
}
has been added'
elif
parallel_mode
in
self
.
_seed_states
:
...
...
@@ -78,3 +77,8 @@ class SeedManager:
self
.
_seed_states
[
parallel_mode
]
=
torch
.
cuda
.
get_rng_state
()
self
.
_seeds
[
parallel_mode
]
=
seed
torch
.
cuda
.
set_rng_state
(
current_state
)
def
reset
(
self
):
self
.
_current_mode
=
None
self
.
_seeds
=
dict
()
self
.
_seed_states
=
dict
()
tests/test_utils/test_activation_checkpointing.py
View file @
1e4bf85c
...
...
@@ -7,7 +7,7 @@ import torch.nn.functional as F
from
torch.utils.checkpoint
import
checkpoint
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.context.random
import
add_seed
,
seed
,
set_mode
from
colossalai.context.random
import
add_seed
,
seed
,
set_mode
,
reset_seeds
from
colossalai.utils
import
checkpoint
...
...
@@ -17,12 +17,12 @@ def forward(x, weight):
out_
=
F
.
dropout
(
out
,
p
=
0.4
,
training
=
True
)
return
out_
@
pytest
.
mark
.
gpu
@
pytest
.
mark
.
parametrize
(
"cpu_offload"
,
[
True
,
False
])
def
test_activation_checkpointing
(
cpu_offload
):
if
cpu_offload
:
add_seed
(
ParallelMode
.
GLOBAL
,
1024
)
add_seed
(
ParallelMode
.
DATA
,
1026
)
add_seed
(
ParallelMode
.
GLOBAL
,
1024
)
add_seed
(
ParallelMode
.
DATA
,
1026
)
set_mode
(
ParallelMode
.
GLOBAL
)
global_cuda_rng_state
=
torch
.
cuda
.
get_rng_state
()
set_mode
(
ParallelMode
.
DATA
)
...
...
@@ -56,4 +56,8 @@ def test_activation_checkpointing(cpu_offload):
assert
torch
.
all
(
data
.
grad
==
data_
.
grad
),
'Gradient of the input does not match'
torch
.
cuda
.
empty_cache
()
# as seed manager is singleton
# if we don't reset seeds here,
# other tests will fail if running together with this test
# as other tests can't overwrite the seed set by this test
reset_seeds
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment