Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
bb640ec7
Unverified
Commit
bb640ec7
authored
Jul 26, 2022
by
Boyuan Yao
Committed by
GitHub
Jul 26, 2022
Browse files
[fx] Add colotracer compatibility test on torchrec (#1370)
parent
c415240d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
208 additions
and
0 deletions
+208
-0
requirements/requirements-test.txt
requirements/requirements-test.txt
+1
-0
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
...t_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
+91
-0
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
...est_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
+116
-0
No files found.
requirements/requirements-test.txt
View file @
bb640ec7
...
@@ -3,3 +3,4 @@ torchvision
...
@@ -3,3 +3,4 @@ torchvision
transformers
transformers
timm
timm
titans
titans
torchrec
tests/test_fx/test_tracer/test_torchrec_model/test_deepfm_model.py
0 → 100644
View file @
bb640ec7
from
curses
import
meta
from
math
import
dist
from
xml.dom
import
HierarchyRequestErr
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
import
torch
from
torchrec.sparse.jagged_tensor
import
KeyedTensor
,
KeyedJaggedTensor
from
torchrec.modules.embedding_modules
import
EmbeddingBagCollection
from
torchrec.modules.embedding_configs
import
EmbeddingBagConfig
from
torchrec.models
import
deepfm
,
dlrm
import
colossalai.fx
as
fx
import
pdb
from
torch.fx
import
GraphModule
BATCH
=
2
SHAPE
=
10
def
test_torchrec_deepfm_models
():
MODEL_LIST
=
[
deepfm
.
DenseArch
,
deepfm
.
FMInteractionArch
,
deepfm
.
OverArch
,
deepfm
.
SimpleDeepFMNN
,
deepfm
.
SparseArch
]
# Data Preparation
# EmbeddingBagCollection
eb1_config
=
EmbeddingBagConfig
(
name
=
"t1"
,
embedding_dim
=
SHAPE
,
num_embeddings
=
SHAPE
,
feature_names
=
[
"f1"
])
eb2_config
=
EmbeddingBagConfig
(
name
=
"t2"
,
embedding_dim
=
SHAPE
,
num_embeddings
=
SHAPE
,
feature_names
=
[
"f2"
])
ebc
=
EmbeddingBagCollection
(
tables
=
[
eb1_config
,
eb2_config
])
keys
=
[
"f1"
,
"f2"
]
# KeyedTensor
KT
=
KeyedTensor
(
keys
=
keys
,
length_per_key
=
[
SHAPE
,
SHAPE
],
values
=
torch
.
rand
((
BATCH
,
2
*
SHAPE
)))
# KeyedJaggedTensor
KJT
=
KeyedJaggedTensor
.
from_offsets_sync
(
keys
=
keys
,
values
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]),
offsets
=
torch
.
tensor
([
0
,
2
,
4
,
6
,
8
]))
# Dense Features
features
=
torch
.
rand
((
BATCH
,
SHAPE
))
# Tracer
tracer
=
ColoTracer
()
for
model_cls
in
MODEL_LIST
:
# Initializing model
if
model_cls
==
deepfm
.
DenseArch
:
model
=
model_cls
(
SHAPE
,
SHAPE
,
SHAPE
)
elif
model_cls
==
deepfm
.
FMInteractionArch
:
model
=
model_cls
(
SHAPE
*
3
,
keys
,
SHAPE
)
elif
model_cls
==
deepfm
.
OverArch
:
model
=
model_cls
(
SHAPE
)
elif
model_cls
==
deepfm
.
SimpleDeepFMNN
:
model
=
model_cls
(
SHAPE
,
ebc
,
SHAPE
,
SHAPE
)
elif
model_cls
==
deepfm
.
SparseArch
:
model
=
model_cls
(
ebc
)
# Setup GraphModule
graph
=
tracer
.
trace
(
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
model
.
eval
()
gm
.
eval
()
# Aligned Test
with
torch
.
no_grad
():
if
model_cls
==
deepfm
.
DenseArch
or
model_cls
==
deepfm
.
OverArch
:
fx_out
=
gm
(
features
)
non_fx_out
=
model
(
features
)
elif
model_cls
==
deepfm
.
FMInteractionArch
:
fx_out
=
gm
(
features
,
KT
)
non_fx_out
=
model
(
features
,
KT
)
elif
model_cls
==
deepfm
.
SimpleDeepFMNN
:
fx_out
=
gm
(
features
,
KJT
)
non_fx_out
=
model
(
features
,
KJT
)
elif
model_cls
==
deepfm
.
SparseArch
:
fx_out
=
gm
(
KJT
)
non_fx_out
=
model
(
KJT
)
if
torch
.
is_tensor
(
fx_out
):
assert
torch
.
allclose
(
fx_out
,
non_fx_out
),
f
'
{
model
.
__class__
.
__name__
}
has inconsistent outputs,
{
fx_out
}
vs
{
non_fx_out
}
'
else
:
assert
torch
.
allclose
(
fx_out
.
values
(),
non_fx_out
.
values
()),
f
'
{
model
.
__class__
.
__name__
}
has inconsistent outputs,
{
fx_out
}
vs
{
non_fx_out
}
'
if
__name__
==
"__main__"
:
test_torchrec_deepfm_models
()
tests/test_fx/test_tracer/test_torchrec_model/test_dlrm_model.py
0 → 100644
View file @
bb640ec7
from
curses
import
meta
from
math
import
dist
from
xml.dom
import
HierarchyRequestErr
from
colossalai.fx.tracer
import
meta_patch
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_function
import
python_ops
import
torch
from
torchrec.sparse.jagged_tensor
import
KeyedTensor
,
KeyedJaggedTensor
from
torchrec.modules.embedding_modules
import
EmbeddingBagCollection
from
torchrec.modules.embedding_configs
import
EmbeddingBagConfig
from
torchrec.models
import
deepfm
,
dlrm
import
colossalai.fx
as
fx
import
pdb
from
torch.fx
import
GraphModule
BATCH
=
2
SHAPE
=
10
def
test_torchrec_dlrm_models
():
MODEL_LIST
=
[
dlrm
.
DLRM
,
dlrm
.
DenseArch
,
dlrm
.
InteractionArch
,
dlrm
.
InteractionV2Arch
,
dlrm
.
OverArch
,
dlrm
.
SparseArch
,
# dlrm.DLRMV2
]
# Data Preparation
# EmbeddingBagCollection
eb1_config
=
EmbeddingBagConfig
(
name
=
"t1"
,
embedding_dim
=
SHAPE
,
num_embeddings
=
SHAPE
,
feature_names
=
[
"f1"
])
eb2_config
=
EmbeddingBagConfig
(
name
=
"t2"
,
embedding_dim
=
SHAPE
,
num_embeddings
=
SHAPE
,
feature_names
=
[
"f2"
])
ebc
=
EmbeddingBagCollection
(
tables
=
[
eb1_config
,
eb2_config
])
keys
=
[
"f1"
,
"f2"
]
# KeyedTensor
KT
=
KeyedTensor
(
keys
=
keys
,
length_per_key
=
[
SHAPE
,
SHAPE
],
values
=
torch
.
rand
((
BATCH
,
2
*
SHAPE
)))
# KeyedJaggedTensor
KJT
=
KeyedJaggedTensor
.
from_offsets_sync
(
keys
=
keys
,
values
=
torch
.
tensor
([
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
]),
offsets
=
torch
.
tensor
([
0
,
2
,
4
,
6
,
8
]))
# Dense Features
dense_features
=
torch
.
rand
((
BATCH
,
SHAPE
))
# Sparse Features
sparse_features
=
torch
.
rand
((
BATCH
,
len
(
keys
),
SHAPE
))
# Tracer
tracer
=
ColoTracer
()
for
model_cls
in
MODEL_LIST
:
# Initializing model
if
model_cls
==
dlrm
.
DLRM
:
model
=
model_cls
(
ebc
,
SHAPE
,
[
SHAPE
,
SHAPE
],
[
5
,
1
])
elif
model_cls
==
dlrm
.
DenseArch
:
model
=
model_cls
(
SHAPE
,
[
SHAPE
,
SHAPE
])
elif
model_cls
==
dlrm
.
InteractionArch
:
model
=
model_cls
(
len
(
keys
))
elif
model_cls
==
dlrm
.
InteractionV2Arch
:
I1
=
dlrm
.
DenseArch
(
3
*
SHAPE
,
[
3
*
SHAPE
,
3
*
SHAPE
])
I2
=
dlrm
.
DenseArch
(
3
*
SHAPE
,
[
3
*
SHAPE
,
3
*
SHAPE
])
model
=
model_cls
(
len
(
keys
),
I1
,
I2
)
elif
model_cls
==
dlrm
.
OverArch
:
model
=
model_cls
(
SHAPE
,
[
5
,
1
])
elif
model_cls
==
dlrm
.
SparseArch
:
model
=
model_cls
(
ebc
)
elif
model_cls
==
dlrm
.
DLRMV2
:
# Currently DLRMV2 cannot be traced
model
=
model_cls
(
ebc
,
SHAPE
,
[
SHAPE
,
SHAPE
],
[
5
,
1
],
[
4
*
SHAPE
,
4
*
SHAPE
],
[
4
*
SHAPE
,
4
*
SHAPE
])
# Setup GraphModule
if
model_cls
==
dlrm
.
InteractionV2Arch
:
concrete_args
=
{
"dense_features"
:
dense_features
,
"sparse_features"
:
sparse_features
}
graph
=
tracer
.
trace
(
model
,
concrete_args
=
concrete_args
)
else
:
graph
=
tracer
.
trace
(
model
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
model
.
eval
()
gm
.
eval
()
# Aligned Test
with
torch
.
no_grad
():
if
model_cls
==
dlrm
.
DLRM
or
model_cls
==
dlrm
.
DLRMV2
:
fx_out
=
gm
(
dense_features
,
KJT
)
non_fx_out
=
model
(
dense_features
,
KJT
)
elif
model_cls
==
dlrm
.
DenseArch
:
fx_out
=
gm
(
dense_features
)
non_fx_out
=
model
(
dense_features
)
elif
model_cls
==
dlrm
.
InteractionArch
or
model_cls
==
dlrm
.
InteractionV2Arch
:
fx_out
=
gm
(
dense_features
,
sparse_features
)
non_fx_out
=
model
(
dense_features
,
sparse_features
)
elif
model_cls
==
dlrm
.
OverArch
:
fx_out
=
gm
(
dense_features
)
non_fx_out
=
model
(
dense_features
)
elif
model_cls
==
dlrm
.
SparseArch
:
fx_out
=
gm
(
KJT
)
non_fx_out
=
model
(
KJT
)
if
torch
.
is_tensor
(
fx_out
):
assert
torch
.
allclose
(
fx_out
,
non_fx_out
),
f
'
{
model
.
__class__
.
__name__
}
has inconsistent outputs,
{
fx_out
}
vs
{
non_fx_out
}
'
else
:
assert
torch
.
allclose
(
fx_out
.
values
(),
non_fx_out
.
values
()),
f
'
{
model
.
__class__
.
__name__
}
has inconsistent outputs,
{
fx_out
}
vs
{
non_fx_out
}
'
if
__name__
==
"__main__"
:
test_torchrec_dlrm_models
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment