Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
73a03812
Unverified
Commit
73a03812
authored
Sep 01, 2021
by
Lysandre Debut
Committed by
GitHub
Sep 01, 2021
Browse files
Torchscript test (#13350)
* Torchscript test * Remove print statement
parent
b9c6a976
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
3 deletions
+26
-3
tests/test_modeling_bert.py
tests/test_modeling_bert.py
+26
-3
No files found.
tests/test_modeling_bert.py
View file @
73a03812
...
...
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
tempfile
import
unittest
from
transformers
import
BertConfig
,
is_torch_available
from
transformers.models.auto
import
get_values
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
require_torch_gpu
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_generation_utils
import
GenerationTesterMixin
...
...
@@ -556,6 +556,29 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
model
=
BertModel
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
model
)
@
slow
@
require_torch_gpu
def
test_torchscript_device_change
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
for
model_class
in
self
.
all_model_classes
:
# BertForMultipleChoice behaves incorrectly in JIT environments.
if
model_class
==
BertForMultipleChoice
:
return
config
.
torchscript
=
True
model
=
model_class
(
config
=
config
)
inputs_dict
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
traced_model
=
torch
.
jit
.
trace
(
model
,
(
inputs_dict
[
"input_ids"
].
to
(
"cpu"
),
inputs_dict
[
"attention_mask"
].
to
(
"cpu"
))
)
with
tempfile
.
TemporaryDirectory
()
as
tmp
:
torch
.
jit
.
save
(
traced_model
,
os
.
path
.
join
(
tmp
,
"bert.pt"
))
loaded
=
torch
.
jit
.
load
(
os
.
path
.
join
(
tmp
,
"bert.pt"
),
map_location
=
torch_device
)
loaded
(
inputs_dict
[
"input_ids"
].
to
(
torch_device
),
inputs_dict
[
"attention_mask"
].
to
(
torch_device
))
@
require_torch
class
BertModelIntegrationTest
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment