Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
169954f8
Unverified
Commit
169954f8
authored
Jul 15, 2022
by
Frank Lee
Committed by
GitHub
Jul 15, 2022
Browse files
[test] removed outdated unit test for meta context (#1329)
parent
7a053671
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
55 deletions
+0
-55
tests/test_utils/test_materialize_arbitary_lazy_module.py
tests/test_utils/test_materialize_arbitary_lazy_module.py
+0
-55
No files found.
tests/test_utils/test_materialize_arbitary_lazy_module.py
deleted
100644 → 0
View file @
7a053671
import
torch
from
colossalai.utils.model.lazy_init_context
import
LazyInitContext
from
torchvision.models
import
resnet34
import
random
import
numpy
as
np
MANUAL_SEED
=
0
random
.
seed
(
MANUAL_SEED
)
np
.
random
.
seed
(
MANUAL_SEED
)
torch
.
manual_seed
(
MANUAL_SEED
)
class
MLP
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
=
4
):
super
().
__init__
()
intermediate_dim
=
dim
*
4
self
.
dense_1
=
torch
.
nn
.
Linear
(
dim
,
intermediate_dim
)
self
.
activation
=
torch
.
nn
.
GELU
()
self
.
dense_2
=
torch
.
nn
.
Linear
(
intermediate_dim
,
dim
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
0.1
)
def
forward
(
self
,
x
):
x
=
self
.
dense_1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
dense_2
(
x
)
x
=
self
.
dropout
(
x
)
return
x
def
test_lazy_init
():
cpu_rng_state
=
torch
.
get_rng_state
()
origin_model
=
MLP
()
origin_param_dict
=
dict
(
origin_model
.
named_parameters
())
torch
.
set_rng_state
(
cpu_rng_state
)
ctx
=
LazyInitContext
()
with
ctx
:
model
=
MLP
()
for
param
in
model
.
parameters
():
assert
param
.
is_meta
for
buffer
in
model
.
buffers
():
assert
buffer
.
is_meta
for
module
in
model
.
children
():
ctx
.
lazy_init_parameters
(
module
)
for
param
in
module
.
parameters
():
assert
not
param
.
is_meta
for
buffer
in
module
.
buffers
():
assert
not
buffer
.
is_meta
param_dict
=
dict
(
model
.
named_parameters
())
for
key
in
origin_param_dict
.
keys
():
assert
origin_param_dict
[
key
].
data
.
equal
(
param_dict
[
key
].
data
)
if
__name__
==
'__main__'
:
test_lazy_init
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment