Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
18546378
Unverified
Commit
18546378
authored
Apr 10, 2024
by
Fanli Lin
Committed by
GitHub
Apr 10, 2024
Browse files
[tests] make 2 tests device-agnostic (#30008)
add torch device
parent
bb76f81e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
3 additions
and
3 deletions
+3
-3
tests/models/blip_2/test_modeling_blip_2.py
tests/models/blip_2/test_modeling_blip_2.py
+2
-2
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+1
-1
No files found.
tests/models/blip_2/test_modeling_blip_2.py
View file @
18546378
...
@@ -992,7 +992,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -992,7 +992,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
# prepare image
# prepare image
image
=
prepare_img
()
image
=
prepare_img
()
inputs
=
processor
(
images
=
image
,
return_tensors
=
"pt"
).
to
(
0
,
dtype
=
torch
.
float16
)
inputs
=
processor
(
images
=
image
,
return_tensors
=
"pt"
).
to
(
f
"
{
torch_device
}
:0"
,
dtype
=
torch
.
float16
)
predictions
=
model
.
generate
(
**
inputs
)
predictions
=
model
.
generate
(
**
inputs
)
generated_text
=
processor
.
batch_decode
(
predictions
,
skip_special_tokens
=
True
)[
0
].
strip
()
generated_text
=
processor
.
batch_decode
(
predictions
,
skip_special_tokens
=
True
)[
0
].
strip
()
...
@@ -1003,7 +1003,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
...
@@ -1003,7 +1003,7 @@ class Blip2ModelIntegrationTest(unittest.TestCase):
# image and context
# image and context
prompt
=
"Question: which city is this? Answer:"
prompt
=
"Question: which city is this? Answer:"
inputs
=
processor
(
images
=
image
,
text
=
prompt
,
return_tensors
=
"pt"
).
to
(
0
,
dtype
=
torch
.
float16
)
inputs
=
processor
(
images
=
image
,
text
=
prompt
,
return_tensors
=
"pt"
).
to
(
f
"
{
torch_device
}
:0"
,
dtype
=
torch
.
float16
)
predictions
=
model
.
generate
(
**
inputs
)
predictions
=
model
.
generate
(
**
inputs
)
generated_text
=
processor
.
batch_decode
(
predictions
,
skip_special_tokens
=
True
)[
0
].
strip
()
generated_text
=
processor
.
batch_decode
(
predictions
,
skip_special_tokens
=
True
)[
0
].
strip
()
...
...
tests/test_modeling_utils.py
View file @
18546378
...
@@ -776,7 +776,7 @@ class ModelUtilsTest(TestCasePlus):
...
@@ -776,7 +776,7 @@ class ModelUtilsTest(TestCasePlus):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"openai-community/gpt2"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"openai-community/gpt2"
)
inputs
=
tokenizer
(
"Hello, my name is"
,
return_tensors
=
"pt"
)
inputs
=
tokenizer
(
"Hello, my name is"
,
return_tensors
=
"pt"
)
output
=
model
.
generate
(
inputs
[
"input_ids"
].
to
(
0
))
output
=
model
.
generate
(
inputs
[
"input_ids"
].
to
(
f
"
{
torch_device
}
:0"
))
text_output
=
tokenizer
.
decode
(
output
[
0
].
tolist
())
text_output
=
tokenizer
.
decode
(
output
[
0
].
tolist
())
self
.
assertEqual
(
text_output
,
"Hello, my name is John. I'm a writer, and I'm a writer. I'm"
)
self
.
assertEqual
(
text_output
,
"Hello, my name is John. I'm a writer, and I'm a writer. I'm"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment