Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ca257a06
"vscode:/vscode.git/clone" did not exist on "68fa1e855bef7d77e227686543787d8e2c4463fc"
Unverified
Commit
ca257a06
authored
Sep 22, 2021
by
Lysandre Debut
Committed by
GitHub
Sep 22, 2021
Browse files
Fix torchscript tests (#13701)
parent
5b570754
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
4 deletions
+9
-4
tests/test_modeling_convbert.py
tests/test_modeling_convbert.py
+1
-1
tests/test_modeling_distilbert.py
tests/test_modeling_distilbert.py
+1
-1
tests/test_modeling_flaubert.py
tests/test_modeling_flaubert.py
+7
-2
No files found.
tests/test_modeling_convbert.py
View file @
ca257a06
...
@@ -436,7 +436,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -436,7 +436,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmp
:
with
tempfile
.
TemporaryDirectory
()
as
tmp
:
torch
.
jit
.
save
(
traced_model
,
os
.
path
.
join
(
tmp
,
"traced_model.pt"
))
torch
.
jit
.
save
(
traced_model
,
os
.
path
.
join
(
tmp
,
"traced_model.pt"
))
loaded
=
torch
.
jit
.
load
(
os
.
path
.
join
(
tmp
,
"
bert
.pt"
),
map_location
=
torch_device
)
loaded
=
torch
.
jit
.
load
(
os
.
path
.
join
(
tmp
,
"
traced_model
.pt"
),
map_location
=
torch_device
)
loaded
(
inputs_dict
[
"input_ids"
].
to
(
torch_device
),
inputs_dict
[
"attention_mask"
].
to
(
torch_device
))
loaded
(
inputs_dict
[
"input_ids"
].
to
(
torch_device
),
inputs_dict
[
"attention_mask"
].
to
(
torch_device
))
...
...
tests/test_modeling_distilbert.py
View file @
ca257a06
...
@@ -273,7 +273,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -273,7 +273,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmp
:
with
tempfile
.
TemporaryDirectory
()
as
tmp
:
torch
.
jit
.
save
(
traced_model
,
os
.
path
.
join
(
tmp
,
"traced_model.pt"
))
torch
.
jit
.
save
(
traced_model
,
os
.
path
.
join
(
tmp
,
"traced_model.pt"
))
loaded
=
torch
.
jit
.
load
(
os
.
path
.
join
(
tmp
,
"
bert
.pt"
),
map_location
=
torch_device
)
loaded
=
torch
.
jit
.
load
(
os
.
path
.
join
(
tmp
,
"
traced_model
.pt"
),
map_location
=
torch_device
)
loaded
(
inputs_dict
[
"input_ids"
].
to
(
torch_device
),
inputs_dict
[
"attention_mask"
].
to
(
torch_device
))
loaded
(
inputs_dict
[
"input_ids"
].
to
(
torch_device
),
inputs_dict
[
"attention_mask"
].
to
(
torch_device
))
...
...
tests/test_modeling_flaubert.py
View file @
ca257a06
...
@@ -325,7 +325,12 @@ class FlaubertModelTester(object):
...
@@ -325,7 +325,12 @@ class FlaubertModelTester(object):
choice_labels
,
choice_labels
,
input_mask
,
input_mask
,
)
=
config_and_inputs
)
=
config_and_inputs
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"lengths"
:
input_lengths
}
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"lengths"
:
input_lengths
,
"attention_mask"
:
input_mask
,
}
return
config
,
inputs_dict
return
config
,
inputs_dict
...
@@ -422,7 +427,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -422,7 +427,7 @@ class FlaubertModelTest(ModelTesterMixin, unittest.TestCase):
with
tempfile
.
TemporaryDirectory
()
as
tmp
:
with
tempfile
.
TemporaryDirectory
()
as
tmp
:
torch
.
jit
.
save
(
traced_model
,
os
.
path
.
join
(
tmp
,
"traced_model.pt"
))
torch
.
jit
.
save
(
traced_model
,
os
.
path
.
join
(
tmp
,
"traced_model.pt"
))
loaded
=
torch
.
jit
.
load
(
os
.
path
.
join
(
tmp
,
"
bert
.pt"
),
map_location
=
torch_device
)
loaded
=
torch
.
jit
.
load
(
os
.
path
.
join
(
tmp
,
"
traced_model
.pt"
),
map_location
=
torch_device
)
loaded
(
inputs_dict
[
"input_ids"
].
to
(
torch_device
),
inputs_dict
[
"attention_mask"
].
to
(
torch_device
))
loaded
(
inputs_dict
[
"input_ids"
].
to
(
torch_device
),
inputs_dict
[
"attention_mask"
].
to
(
torch_device
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment