Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f53041a7
Unverified
Commit
f53041a7
authored
Oct 31, 2023
by
Hz, Ji
Committed by
GitHub
Oct 31, 2023
Browse files
device agnostic pipelines testing (#27129)
* device agnostic pipelines testing * pass torch_device
parent
08fadc80
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
64 additions
and
58 deletions
+64
-58
tests/pipelines/test_pipelines_automatic_speech_recognition.py
.../pipelines/test_pipelines_automatic_speech_recognition.py
+8
-9
tests/pipelines/test_pipelines_common.py
tests/pipelines/test_pipelines_common.py
+11
-9
tests/pipelines/test_pipelines_conversational.py
tests/pipelines/test_pipelines_conversational.py
+5
-9
tests/pipelines/test_pipelines_fill_mask.py
tests/pipelines/test_pipelines_fill_mask.py
+11
-6
tests/pipelines/test_pipelines_summarization.py
tests/pipelines/test_pipelines_summarization.py
+1
-4
tests/pipelines/test_pipelines_text_classification.py
tests/pipelines/test_pipelines_text_classification.py
+2
-4
tests/pipelines/test_pipelines_text_generation.py
tests/pipelines/test_pipelines_text_generation.py
+9
-3
tests/pipelines/test_pipelines_text_to_audio.py
tests/pipelines/test_pipelines_text_to_audio.py
+4
-3
tests/pipelines/test_pipelines_token_classification.py
tests/pipelines/test_pipelines_token_classification.py
+5
-4
tests/pipelines/test_pipelines_visual_question_answering.py
tests/pipelines/test_pipelines_visual_question_answering.py
+8
-7
No files found.
tests/pipelines/test_pipelines_automatic_speech_recognition.py
View file @
f53041a7
...
@@ -39,9 +39,10 @@ from transformers.testing_utils import (
...
@@ -39,9 +39,10 @@ from transformers.testing_utils import (
require_pyctcdecode
,
require_pyctcdecode
,
require_tf
,
require_tf
,
require_torch
,
require_torch
,
require_torch_
gpu
,
require_torch_
accelerator
,
require_torchaudio
,
require_torchaudio
,
slow
,
slow
,
torch_device
,
)
)
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
...
@@ -166,13 +167,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
...
@@ -166,13 +167,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
_
=
speech_recognizer
(
waveform
,
return_timestamps
=
"char"
)
_
=
speech_recognizer
(
waveform
,
return_timestamps
=
"char"
)
@
slow
@
slow
@
require_torch
@
require_torch
_accelerator
def
test_whisper_fp16
(
self
):
def
test_whisper_fp16
(
self
):
if
not
torch
.
cuda
.
is_available
():
self
.
skipTest
(
"Cuda is necessary for this test"
)
speech_recognizer
=
pipeline
(
speech_recognizer
=
pipeline
(
model
=
"openai/whisper-base"
,
model
=
"openai/whisper-base"
,
device
=
0
,
device
=
torch_device
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
)
)
waveform
=
np
.
tile
(
np
.
arange
(
1000
,
dtype
=
np
.
float32
),
34
)
waveform
=
np
.
tile
(
np
.
arange
(
1000
,
dtype
=
np
.
float32
),
34
)
...
@@ -904,12 +903,12 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
...
@@ -904,12 +903,12 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self
.
assertEqual
(
output
,
{
"text"
:
"a man said to the universe sir i exist"
})
self
.
assertEqual
(
output
,
{
"text"
:
"a man said to the universe sir i exist"
})
@
slow
@
slow
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_wav2vec2_conformer_float16
(
self
):
def
test_wav2vec2_conformer_float16
(
self
):
speech_recognizer
=
pipeline
(
speech_recognizer
=
pipeline
(
task
=
"automatic-speech-recognition"
,
task
=
"automatic-speech-recognition"
,
model
=
"facebook/wav2vec2-conformer-rope-large-960h-ft"
,
model
=
"facebook/wav2vec2-conformer-rope-large-960h-ft"
,
device
=
"cuda:0"
,
device
=
torch_device
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
torch
.
float16
,
framework
=
"pt"
,
framework
=
"pt"
,
)
)
...
@@ -1304,14 +1303,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
...
@@ -1304,14 +1303,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self
.
assertEqual
(
output
,
{
"text"
:
"XB"
})
self
.
assertEqual
(
output
,
{
"text"
:
"XB"
})
@
slow
@
slow
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_slow_unfinished_sequence
(
self
):
def
test_slow_unfinished_sequence
(
self
):
from
transformers
import
GenerationConfig
from
transformers
import
GenerationConfig
pipe
=
pipeline
(
pipe
=
pipeline
(
"automatic-speech-recognition"
,
"automatic-speech-recognition"
,
model
=
"vasista22/whisper-hindi-large-v2"
,
model
=
"vasista22/whisper-hindi-large-v2"
,
device
=
"cuda:0"
,
device
=
torch_device
,
)
)
# Original model wasn't trained with timestamps and has incorrect generation config
# Original model wasn't trained with timestamps and has incorrect generation config
pipe
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
"openai/whisper-large-v2"
)
pipe
.
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
"openai/whisper-large-v2"
)
...
...
tests/pipelines/test_pipelines_common.py
View file @
f53041a7
...
@@ -40,15 +40,17 @@ from transformers.testing_utils import (
...
@@ -40,15 +40,17 @@ from transformers.testing_utils import (
USER
,
USER
,
CaptureLogger
,
CaptureLogger
,
RequestCounter
,
RequestCounter
,
backend_empty_cache
,
is_pipeline_test
,
is_pipeline_test
,
is_staging_test
,
is_staging_test
,
nested_simplify
,
nested_simplify
,
require_tensorflow_probability
,
require_tensorflow_probability
,
require_tf
,
require_tf
,
require_torch
,
require_torch
,
require_torch_
gpu
,
require_torch_
accelerator
,
require_torch_or_tf
,
require_torch_or_tf
,
slow
,
slow
,
torch_device
,
)
)
from
transformers.utils
import
direct_transformers_import
,
is_tf_available
,
is_torch_available
from
transformers.utils
import
direct_transformers_import
,
is_tf_available
,
is_torch_available
from
transformers.utils
import
logging
as
transformers_logging
from
transformers.utils
import
logging
as
transformers_logging
...
@@ -511,7 +513,7 @@ class PipelineUtilsTest(unittest.TestCase):
...
@@ -511,7 +513,7 @@ class PipelineUtilsTest(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch
# clean-up as much as possible GPU memory occupied by PyTorch
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
backend_
empty_cache
(
torch_device
)
@
slow
@
slow
@
require_tf
@
require_tf
...
@@ -541,20 +543,20 @@ class PipelineUtilsTest(unittest.TestCase):
...
@@ -541,20 +543,20 @@ class PipelineUtilsTest(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch
# clean-up as much as possible GPU memory occupied by PyTorch
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
backend_
empty_cache
(
torch_device
)
@
slow
@
slow
@
require_torch
@
require_torch
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_pipeline_
cuda
(
self
):
def
test_pipeline_
accelerator
(
self
):
pipe
=
pipeline
(
"text-generation"
,
device
=
"cuda"
)
pipe
=
pipeline
(
"text-generation"
,
device
=
torch_device
)
_
=
pipe
(
"Hello"
)
_
=
pipe
(
"Hello"
)
@
slow
@
slow
@
require_torch
@
require_torch
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_pipeline_
cuda
_indexed
(
self
):
def
test_pipeline_
accelerator
_indexed
(
self
):
pipe
=
pipeline
(
"text-generation"
,
device
=
"cuda:0"
)
pipe
=
pipeline
(
"text-generation"
,
device
=
torch_device
)
_
=
pipe
(
"Hello"
)
_
=
pipe
(
"Hello"
)
@
slow
@
slow
...
...
tests/pipelines/test_pipelines_conversational.py
View file @
f53041a7
...
@@ -31,6 +31,7 @@ from transformers import (
...
@@ -31,6 +31,7 @@ from transformers import (
pipeline
,
pipeline
,
)
)
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
backend_empty_cache
,
is_pipeline_test
,
is_pipeline_test
,
is_torch_available
,
is_torch_available
,
require_tf
,
require_tf
,
...
@@ -42,9 +43,6 @@ from transformers.testing_utils import (
...
@@ -42,9 +43,6 @@ from transformers.testing_utils import (
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
DEFAULT_DEVICE_NUM
=
-
1
if
torch_device
==
"cpu"
else
0
@
is_pipeline_test
@
is_pipeline_test
class
ConversationalPipelineTests
(
unittest
.
TestCase
):
class
ConversationalPipelineTests
(
unittest
.
TestCase
):
def
tearDown
(
self
):
def
tearDown
(
self
):
...
@@ -52,9 +50,7 @@ class ConversationalPipelineTests(unittest.TestCase):
...
@@ -52,9 +50,7 @@ class ConversationalPipelineTests(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch
# clean-up as much as possible GPU memory occupied by PyTorch
gc
.
collect
()
gc
.
collect
()
if
is_torch_available
():
if
is_torch_available
():
import
torch
backend_empty_cache
(
torch_device
)
torch
.
cuda
.
empty_cache
()
model_mapping
=
dict
(
model_mapping
=
dict
(
list
(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
.
items
())
list
(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
.
items
())
...
@@ -136,7 +132,7 @@ class ConversationalPipelineTests(unittest.TestCase):
...
@@ -136,7 +132,7 @@ class ConversationalPipelineTests(unittest.TestCase):
@
slow
@
slow
def
test_integration_torch_conversation
(
self
):
def
test_integration_torch_conversation
(
self
):
# When
# When
conversation_agent
=
pipeline
(
task
=
"conversational"
,
device
=
DEFAULT_DEVICE_NUM
)
conversation_agent
=
pipeline
(
task
=
"conversational"
,
device
=
torch_device
)
conversation_1
=
Conversation
(
"Going to the movies tonight - any suggestions?"
)
conversation_1
=
Conversation
(
"Going to the movies tonight - any suggestions?"
)
conversation_2
=
Conversation
(
"What's the last book you have read?"
)
conversation_2
=
Conversation
(
"What's the last book you have read?"
)
# Then
# Then
...
@@ -168,7 +164,7 @@ class ConversationalPipelineTests(unittest.TestCase):
...
@@ -168,7 +164,7 @@ class ConversationalPipelineTests(unittest.TestCase):
@
slow
@
slow
def
test_integration_torch_conversation_truncated_history
(
self
):
def
test_integration_torch_conversation_truncated_history
(
self
):
# When
# When
conversation_agent
=
pipeline
(
task
=
"conversational"
,
min_length_for_response
=
24
,
device
=
DEFAULT_DEVICE_NUM
)
conversation_agent
=
pipeline
(
task
=
"conversational"
,
min_length_for_response
=
24
,
device
=
torch_device
)
conversation_1
=
Conversation
(
"Going to the movies tonight - any suggestions?"
)
conversation_1
=
Conversation
(
"Going to the movies tonight - any suggestions?"
)
# Then
# Then
self
.
assertEqual
(
len
(
conversation_1
.
past_user_inputs
),
1
)
self
.
assertEqual
(
len
(
conversation_1
.
past_user_inputs
),
1
)
...
@@ -374,7 +370,7 @@ These are just a few of the many attractions that Paris has to offer. With so mu
...
@@ -374,7 +370,7 @@ These are just a few of the many attractions that Paris has to offer. With so mu
# When
# When
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/blenderbot_small-90M"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/blenderbot_small-90M"
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"facebook/blenderbot_small-90M"
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
"facebook/blenderbot_small-90M"
)
conversation_agent
=
ConversationalPipeline
(
model
=
model
,
tokenizer
=
tokenizer
,
device
=
DEFAULT_DEVICE_NUM
)
conversation_agent
=
ConversationalPipeline
(
model
=
model
,
tokenizer
=
tokenizer
,
device
=
torch_device
)
conversation_1
=
Conversation
(
"My name is Sarah and I live in London"
)
conversation_1
=
Conversation
(
"My name is Sarah and I live in London"
)
conversation_2
=
Conversation
(
"Going to the movies tonight, What movie would you recommend? "
)
conversation_2
=
Conversation
(
"Going to the movies tonight, What movie would you recommend? "
)
...
...
tests/pipelines/test_pipelines_fill_mask.py
View file @
f53041a7
...
@@ -18,13 +18,15 @@ import unittest
...
@@ -18,13 +18,15 @@ import unittest
from
transformers
import
MODEL_FOR_MASKED_LM_MAPPING
,
TF_MODEL_FOR_MASKED_LM_MAPPING
,
FillMaskPipeline
,
pipeline
from
transformers
import
MODEL_FOR_MASKED_LM_MAPPING
,
TF_MODEL_FOR_MASKED_LM_MAPPING
,
FillMaskPipeline
,
pipeline
from
transformers.pipelines
import
PipelineException
from
transformers.pipelines
import
PipelineException
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
backend_empty_cache
,
is_pipeline_test
,
is_pipeline_test
,
is_torch_available
,
is_torch_available
,
nested_simplify
,
nested_simplify
,
require_tf
,
require_tf
,
require_torch
,
require_torch
,
require_torch_
gpu
,
require_torch_
accelerator
,
slow
,
slow
,
torch_device
,
)
)
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
...
@@ -40,9 +42,7 @@ class FillMaskPipelineTests(unittest.TestCase):
...
@@ -40,9 +42,7 @@ class FillMaskPipelineTests(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch
# clean-up as much as possible GPU memory occupied by PyTorch
gc
.
collect
()
gc
.
collect
()
if
is_torch_available
():
if
is_torch_available
():
import
torch
backend_empty_cache
(
torch_device
)
torch
.
cuda
.
empty_cache
()
@
require_tf
@
require_tf
def
test_small_model_tf
(
self
):
def
test_small_model_tf
(
self
):
...
@@ -148,9 +148,14 @@ class FillMaskPipelineTests(unittest.TestCase):
...
@@ -148,9 +148,14 @@ class FillMaskPipelineTests(unittest.TestCase):
],
],
)
)
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_fp16_casting
(
self
):
def
test_fp16_casting
(
self
):
pipe
=
pipeline
(
"fill-mask"
,
model
=
"hf-internal-testing/tiny-random-distilbert"
,
device
=
0
,
framework
=
"pt"
)
pipe
=
pipeline
(
"fill-mask"
,
model
=
"hf-internal-testing/tiny-random-distilbert"
,
device
=
torch_device
,
framework
=
"pt"
,
)
# convert model to fp16
# convert model to fp16
pipe
.
model
.
half
()
pipe
.
model
.
half
()
...
...
tests/pipelines/test_pipelines_summarization.py
View file @
f53041a7
...
@@ -27,9 +27,6 @@ from transformers.tokenization_utils import TruncationStrategy
...
@@ -27,9 +27,6 @@ from transformers.tokenization_utils import TruncationStrategy
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
DEFAULT_DEVICE_NUM
=
-
1
if
torch_device
==
"cpu"
else
0
@
is_pipeline_test
@
is_pipeline_test
class
SummarizationPipelineTests
(
unittest
.
TestCase
):
class
SummarizationPipelineTests
(
unittest
.
TestCase
):
model_mapping
=
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
model_mapping
=
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
...
@@ -106,7 +103,7 @@ class SummarizationPipelineTests(unittest.TestCase):
...
@@ -106,7 +103,7 @@ class SummarizationPipelineTests(unittest.TestCase):
@
require_torch
@
require_torch
@
slow
@
slow
def
test_integration_torch_summarization
(
self
):
def
test_integration_torch_summarization
(
self
):
summarizer
=
pipeline
(
task
=
"summarization"
,
device
=
DEFAULT_DEVICE_NUM
)
summarizer
=
pipeline
(
task
=
"summarization"
,
device
=
torch_device
)
cnn_article
=
(
cnn_article
=
(
" (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
...
...
tests/pipelines/test_pipelines_text_classification.py
View file @
f53041a7
...
@@ -20,7 +20,7 @@ from transformers import (
...
@@ -20,7 +20,7 @@ from transformers import (
TextClassificationPipeline
,
TextClassificationPipeline
,
pipeline
,
pipeline
,
)
)
from
transformers.testing_utils
import
is_pipeline_test
,
nested_simplify
,
require_tf
,
require_torch
,
slow
from
transformers.testing_utils
import
is_pipeline_test
,
nested_simplify
,
require_tf
,
require_torch
,
slow
,
torch_device
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
...
@@ -96,13 +96,11 @@ class TextClassificationPipelineTests(unittest.TestCase):
...
@@ -96,13 +96,11 @@ class TextClassificationPipelineTests(unittest.TestCase):
@
require_torch
@
require_torch
def
test_accepts_torch_device
(
self
):
def
test_accepts_torch_device
(
self
):
import
torch
text_classifier
=
pipeline
(
text_classifier
=
pipeline
(
task
=
"text-classification"
,
task
=
"text-classification"
,
model
=
"hf-internal-testing/tiny-random-distilbert"
,
model
=
"hf-internal-testing/tiny-random-distilbert"
,
framework
=
"pt"
,
framework
=
"pt"
,
device
=
torch
.
device
(
"cpu"
)
,
device
=
torch
_
device
,
)
)
outputs
=
text_classifier
(
"This is great !"
)
outputs
=
text_classifier
(
"This is great !"
)
...
...
tests/pipelines/test_pipelines_text_generation.py
View file @
f53041a7
...
@@ -27,8 +27,10 @@ from transformers.testing_utils import (
...
@@ -27,8 +27,10 @@ from transformers.testing_utils import (
require_accelerate
,
require_accelerate
,
require_tf
,
require_tf
,
require_torch
,
require_torch
,
require_torch_accelerator
,
require_torch_gpu
,
require_torch_gpu
,
require_torch_or_tf
,
require_torch_or_tf
,
torch_device
,
)
)
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
...
@@ -319,16 +321,20 @@ class TextGenerationPipelineTests(unittest.TestCase):
...
@@ -319,16 +321,20 @@ class TextGenerationPipelineTests(unittest.TestCase):
)
)
@
require_torch
@
require_torch
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_small_model_fp16
(
self
):
def
test_small_model_fp16
(
self
):
import
torch
import
torch
pipe
=
pipeline
(
model
=
"hf-internal-testing/tiny-random-bloom"
,
device
=
0
,
torch_dtype
=
torch
.
float16
)
pipe
=
pipeline
(
model
=
"hf-internal-testing/tiny-random-bloom"
,
device
=
torch_device
,
torch_dtype
=
torch
.
float16
,
)
pipe
(
"This is a test"
)
pipe
(
"This is a test"
)
@
require_torch
@
require_torch
@
require_accelerate
@
require_accelerate
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_pipeline_accelerate_top_p
(
self
):
def
test_pipeline_accelerate_top_p
(
self
):
import
torch
import
torch
...
...
tests/pipelines/test_pipelines_text_to_audio.py
View file @
f53041a7
...
@@ -25,9 +25,10 @@ from transformers import (
...
@@ -25,9 +25,10 @@ from transformers import (
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
is_pipeline_test
,
is_pipeline_test
,
require_torch
,
require_torch
,
require_torch_
gpu
,
require_torch_
accelerator
,
require_torch_or_tf
,
require_torch_or_tf
,
slow
,
slow
,
torch_device
,
)
)
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
...
@@ -115,9 +116,9 @@ class TextToAudioPipelineTests(unittest.TestCase):
...
@@ -115,9 +116,9 @@ class TextToAudioPipelineTests(unittest.TestCase):
self
.
assertEqual
([
ANY
(
np
.
ndarray
),
ANY
(
np
.
ndarray
)],
audio
)
self
.
assertEqual
([
ANY
(
np
.
ndarray
),
ANY
(
np
.
ndarray
)],
audio
)
@
slow
@
slow
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_conversion_additional_tensor
(
self
):
def
test_conversion_additional_tensor
(
self
):
speech_generator
=
pipeline
(
task
=
"text-to-audio"
,
model
=
"suno/bark-small"
,
framework
=
"pt"
,
device
=
0
)
speech_generator
=
pipeline
(
task
=
"text-to-audio"
,
model
=
"suno/bark-small"
,
framework
=
"pt"
,
device
=
torch_device
)
processor
=
AutoProcessor
.
from_pretrained
(
"suno/bark-small"
)
processor
=
AutoProcessor
.
from_pretrained
(
"suno/bark-small"
)
forward_params
=
{
forward_params
=
{
...
...
tests/pipelines/test_pipelines_token_classification.py
View file @
f53041a7
...
@@ -30,8 +30,9 @@ from transformers.testing_utils import (
...
@@ -30,8 +30,9 @@ from transformers.testing_utils import (
nested_simplify
,
nested_simplify
,
require_tf
,
require_tf
,
require_torch
,
require_torch
,
require_torch_
gpu
,
require_torch_
accelerator
,
slow
,
slow
,
torch_device
,
)
)
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
...
@@ -391,13 +392,13 @@ class TokenClassificationPipelineTests(unittest.TestCase):
...
@@ -391,13 +392,13 @@ class TokenClassificationPipelineTests(unittest.TestCase):
],
],
)
)
@
require_torch_
gpu
@
require_torch_
accelerator
@
slow
@
slow
def
test_
gpu
(
self
):
def
test_
accelerator
(
self
):
sentence
=
"This is dummy sentence"
sentence
=
"This is dummy sentence"
ner
=
pipeline
(
ner
=
pipeline
(
"token-classification"
,
"token-classification"
,
device
=
0
,
device
=
torch_device
,
aggregation_strategy
=
AggregationStrategy
.
SIMPLE
,
aggregation_strategy
=
AggregationStrategy
.
SIMPLE
,
)
)
...
...
tests/pipelines/test_pipelines_visual_question_answering.py
View file @
f53041a7
...
@@ -22,9 +22,10 @@ from transformers.testing_utils import (
...
@@ -22,9 +22,10 @@ from transformers.testing_utils import (
nested_simplify
,
nested_simplify
,
require_tf
,
require_tf
,
require_torch
,
require_torch
,
require_torch_
gpu
,
require_torch_
accelerator
,
require_vision
,
require_vision
,
slow
,
slow
,
torch_device
,
)
)
from
.test_pipelines_common
import
ANY
from
.test_pipelines_common
import
ANY
...
@@ -91,7 +92,7 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
...
@@ -91,7 +92,7 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
)
)
@
require_torch
@
require_torch
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_small_model_pt_blip2
(
self
):
def
test_small_model_pt_blip2
(
self
):
vqa_pipeline
=
pipeline
(
vqa_pipeline
=
pipeline
(
"visual-question-answering"
,
model
=
"hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
"visual-question-answering"
,
model
=
"hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
...
@@ -112,9 +113,9 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
...
@@ -112,9 +113,9 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
"visual-question-answering"
,
"visual-question-answering"
,
model
=
"hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
,
model
=
"hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
,
model_kwargs
=
{
"torch_dtype"
:
torch
.
float16
},
model_kwargs
=
{
"torch_dtype"
:
torch
.
float16
},
device
=
0
,
device
=
torch_device
,
)
)
self
.
assertEqual
(
vqa_pipeline
.
model
.
device
,
torch
.
device
(
0
))
self
.
assertEqual
(
vqa_pipeline
.
model
.
device
,
torch
.
device
(
"{}:0"
.
format
(
torch_device
)
))
self
.
assertEqual
(
vqa_pipeline
.
model
.
language_model
.
dtype
,
torch
.
float16
)
self
.
assertEqual
(
vqa_pipeline
.
model
.
language_model
.
dtype
,
torch
.
float16
)
self
.
assertEqual
(
vqa_pipeline
.
model
.
vision_model
.
dtype
,
torch
.
float16
)
self
.
assertEqual
(
vqa_pipeline
.
model
.
vision_model
.
dtype
,
torch
.
float16
)
...
@@ -148,15 +149,15 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
...
@@ -148,15 +149,15 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
@
slow
@
slow
@
require_torch
@
require_torch
@
require_torch_
gpu
@
require_torch_
accelerator
def
test_large_model_pt_blip2
(
self
):
def
test_large_model_pt_blip2
(
self
):
vqa_pipeline
=
pipeline
(
vqa_pipeline
=
pipeline
(
"visual-question-answering"
,
"visual-question-answering"
,
model
=
"Salesforce/blip2-opt-2.7b"
,
model
=
"Salesforce/blip2-opt-2.7b"
,
model_kwargs
=
{
"torch_dtype"
:
torch
.
float16
},
model_kwargs
=
{
"torch_dtype"
:
torch
.
float16
},
device
=
0
,
device
=
torch_device
,
)
)
self
.
assertEqual
(
vqa_pipeline
.
model
.
device
,
torch
.
device
(
0
))
self
.
assertEqual
(
vqa_pipeline
.
model
.
device
,
torch
.
device
(
"{}:0"
.
format
(
torch_device
)
))
self
.
assertEqual
(
vqa_pipeline
.
model
.
language_model
.
dtype
,
torch
.
float16
)
self
.
assertEqual
(
vqa_pipeline
.
model
.
language_model
.
dtype
,
torch
.
float16
)
image
=
"./tests/fixtures/tests_samples/COCO/000000039769.png"
image
=
"./tests/fixtures/tests_samples/COCO/000000039769.png"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment