Unverified Commit f53041a7 authored by Hz, Ji's avatar Hz, Ji Committed by GitHub
Browse files

device agnostic pipelines testing (#27129)

* device agnostic pipelines testing

* pass torch_device
parent 08fadc80
...@@ -39,9 +39,10 @@ from transformers.testing_utils import ( ...@@ -39,9 +39,10 @@ from transformers.testing_utils import (
require_pyctcdecode, require_pyctcdecode,
require_tf, require_tf,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_torchaudio, require_torchaudio,
slow, slow,
torch_device,
) )
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
...@@ -166,13 +167,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -166,13 +167,11 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
_ = speech_recognizer(waveform, return_timestamps="char") _ = speech_recognizer(waveform, return_timestamps="char")
@slow @slow
@require_torch @require_torch_accelerator
def test_whisper_fp16(self): def test_whisper_fp16(self):
if not torch.cuda.is_available():
self.skipTest("Cuda is necessary for this test")
speech_recognizer = pipeline( speech_recognizer = pipeline(
model="openai/whisper-base", model="openai/whisper-base",
device=0, device=torch_device,
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
waveform = np.tile(np.arange(1000, dtype=np.float32), 34) waveform = np.tile(np.arange(1000, dtype=np.float32), 34)
...@@ -904,12 +903,12 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -904,12 +903,12 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, {"text": "a man said to the universe sir i exist"}) self.assertEqual(output, {"text": "a man said to the universe sir i exist"})
@slow @slow
@require_torch_gpu @require_torch_accelerator
def test_wav2vec2_conformer_float16(self): def test_wav2vec2_conformer_float16(self):
speech_recognizer = pipeline( speech_recognizer = pipeline(
task="automatic-speech-recognition", task="automatic-speech-recognition",
model="facebook/wav2vec2-conformer-rope-large-960h-ft", model="facebook/wav2vec2-conformer-rope-large-960h-ft",
device="cuda:0", device=torch_device,
torch_dtype=torch.float16, torch_dtype=torch.float16,
framework="pt", framework="pt",
) )
...@@ -1304,14 +1303,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): ...@@ -1304,14 +1303,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(output, {"text": "XB"}) self.assertEqual(output, {"text": "XB"})
@slow @slow
@require_torch_gpu @require_torch_accelerator
def test_slow_unfinished_sequence(self): def test_slow_unfinished_sequence(self):
from transformers import GenerationConfig from transformers import GenerationConfig
pipe = pipeline( pipe = pipeline(
"automatic-speech-recognition", "automatic-speech-recognition",
model="vasista22/whisper-hindi-large-v2", model="vasista22/whisper-hindi-large-v2",
device="cuda:0", device=torch_device,
) )
# Original model wasn't trained with timestamps and has incorrect generation config # Original model wasn't trained with timestamps and has incorrect generation config
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2") pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
......
...@@ -40,15 +40,17 @@ from transformers.testing_utils import ( ...@@ -40,15 +40,17 @@ from transformers.testing_utils import (
USER, USER,
CaptureLogger, CaptureLogger,
RequestCounter, RequestCounter,
backend_empty_cache,
is_pipeline_test, is_pipeline_test,
is_staging_test, is_staging_test,
nested_simplify, nested_simplify,
require_tensorflow_probability, require_tensorflow_probability,
require_tf, require_tf,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_torch_or_tf, require_torch_or_tf,
slow, slow,
torch_device,
) )
from transformers.utils import direct_transformers_import, is_tf_available, is_torch_available from transformers.utils import direct_transformers_import, is_tf_available, is_torch_available
from transformers.utils import logging as transformers_logging from transformers.utils import logging as transformers_logging
...@@ -511,7 +513,7 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -511,7 +513,7 @@ class PipelineUtilsTest(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
@slow @slow
@require_tf @require_tf
...@@ -541,20 +543,20 @@ class PipelineUtilsTest(unittest.TestCase): ...@@ -541,20 +543,20 @@ class PipelineUtilsTest(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
@slow @slow
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
def test_pipeline_cuda(self): def test_pipeline_accelerator(self):
pipe = pipeline("text-generation", device="cuda") pipe = pipeline("text-generation", device=torch_device)
_ = pipe("Hello") _ = pipe("Hello")
@slow @slow
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
def test_pipeline_cuda_indexed(self): def test_pipeline_accelerator_indexed(self):
pipe = pipeline("text-generation", device="cuda:0") pipe = pipeline("text-generation", device=torch_device)
_ = pipe("Hello") _ = pipe("Hello")
@slow @slow
......
...@@ -31,6 +31,7 @@ from transformers import ( ...@@ -31,6 +31,7 @@ from transformers import (
pipeline, pipeline,
) )
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache,
is_pipeline_test, is_pipeline_test,
is_torch_available, is_torch_available,
require_tf, require_tf,
...@@ -42,9 +43,6 @@ from transformers.testing_utils import ( ...@@ -42,9 +43,6 @@ from transformers.testing_utils import (
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
@is_pipeline_test @is_pipeline_test
class ConversationalPipelineTests(unittest.TestCase): class ConversationalPipelineTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
...@@ -52,9 +50,7 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -52,9 +50,7 @@ class ConversationalPipelineTests(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() gc.collect()
if is_torch_available(): if is_torch_available():
import torch backend_empty_cache(torch_device)
torch.cuda.empty_cache()
model_mapping = dict( model_mapping = dict(
list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items()) list(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items())
...@@ -136,7 +132,7 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -136,7 +132,7 @@ class ConversationalPipelineTests(unittest.TestCase):
@slow @slow
def test_integration_torch_conversation(self): def test_integration_torch_conversation(self):
# When # When
conversation_agent = pipeline(task="conversational", device=DEFAULT_DEVICE_NUM) conversation_agent = pipeline(task="conversational", device=torch_device)
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
conversation_2 = Conversation("What's the last book you have read?") conversation_2 = Conversation("What's the last book you have read?")
# Then # Then
...@@ -168,7 +164,7 @@ class ConversationalPipelineTests(unittest.TestCase): ...@@ -168,7 +164,7 @@ class ConversationalPipelineTests(unittest.TestCase):
@slow @slow
def test_integration_torch_conversation_truncated_history(self): def test_integration_torch_conversation_truncated_history(self):
# When # When
conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=DEFAULT_DEVICE_NUM) conversation_agent = pipeline(task="conversational", min_length_for_response=24, device=torch_device)
conversation_1 = Conversation("Going to the movies tonight - any suggestions?") conversation_1 = Conversation("Going to the movies tonight - any suggestions?")
# Then # Then
self.assertEqual(len(conversation_1.past_user_inputs), 1) self.assertEqual(len(conversation_1.past_user_inputs), 1)
...@@ -374,7 +370,7 @@ These are just a few of the many attractions that Paris has to offer. With so mu ...@@ -374,7 +370,7 @@ These are just a few of the many attractions that Paris has to offer. With so mu
# When # When
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M") tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot_small-90M")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot_small-90M") model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot_small-90M")
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer, device=DEFAULT_DEVICE_NUM) conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer, device=torch_device)
conversation_1 = Conversation("My name is Sarah and I live in London") conversation_1 = Conversation("My name is Sarah and I live in London")
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ") conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
......
...@@ -18,13 +18,15 @@ import unittest ...@@ -18,13 +18,15 @@ import unittest
from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline from transformers import MODEL_FOR_MASKED_LM_MAPPING, TF_MODEL_FOR_MASKED_LM_MAPPING, FillMaskPipeline, pipeline
from transformers.pipelines import PipelineException from transformers.pipelines import PipelineException
from transformers.testing_utils import ( from transformers.testing_utils import (
backend_empty_cache,
is_pipeline_test, is_pipeline_test,
is_torch_available, is_torch_available,
nested_simplify, nested_simplify,
require_tf, require_tf,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device,
) )
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
...@@ -40,9 +42,7 @@ class FillMaskPipelineTests(unittest.TestCase): ...@@ -40,9 +42,7 @@ class FillMaskPipelineTests(unittest.TestCase):
# clean-up as much as possible GPU memory occupied by PyTorch # clean-up as much as possible GPU memory occupied by PyTorch
gc.collect() gc.collect()
if is_torch_available(): if is_torch_available():
import torch backend_empty_cache(torch_device)
torch.cuda.empty_cache()
@require_tf @require_tf
def test_small_model_tf(self): def test_small_model_tf(self):
...@@ -148,9 +148,14 @@ class FillMaskPipelineTests(unittest.TestCase): ...@@ -148,9 +148,14 @@ class FillMaskPipelineTests(unittest.TestCase):
], ],
) )
@require_torch_gpu @require_torch_accelerator
def test_fp16_casting(self): def test_fp16_casting(self):
pipe = pipeline("fill-mask", model="hf-internal-testing/tiny-random-distilbert", device=0, framework="pt") pipe = pipeline(
"fill-mask",
model="hf-internal-testing/tiny-random-distilbert",
device=torch_device,
framework="pt",
)
# convert model to fp16 # convert model to fp16
pipe.model.half() pipe.model.half()
......
...@@ -27,9 +27,6 @@ from transformers.tokenization_utils import TruncationStrategy ...@@ -27,9 +27,6 @@ from transformers.tokenization_utils import TruncationStrategy
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0
@is_pipeline_test @is_pipeline_test
class SummarizationPipelineTests(unittest.TestCase): class SummarizationPipelineTests(unittest.TestCase):
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
...@@ -106,7 +103,7 @@ class SummarizationPipelineTests(unittest.TestCase): ...@@ -106,7 +103,7 @@ class SummarizationPipelineTests(unittest.TestCase):
@require_torch @require_torch
@slow @slow
def test_integration_torch_summarization(self): def test_integration_torch_summarization(self):
summarizer = pipeline(task="summarization", device=DEFAULT_DEVICE_NUM) summarizer = pipeline(task="summarization", device=torch_device)
cnn_article = ( cnn_article = (
" (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on" " (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on"
" Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The" " Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The"
......
...@@ -20,7 +20,7 @@ from transformers import ( ...@@ -20,7 +20,7 @@ from transformers import (
TextClassificationPipeline, TextClassificationPipeline,
pipeline, pipeline,
) )
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow, torch_device
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
...@@ -96,13 +96,11 @@ class TextClassificationPipelineTests(unittest.TestCase): ...@@ -96,13 +96,11 @@ class TextClassificationPipelineTests(unittest.TestCase):
@require_torch @require_torch
def test_accepts_torch_device(self): def test_accepts_torch_device(self):
import torch
text_classifier = pipeline( text_classifier = pipeline(
task="text-classification", task="text-classification",
model="hf-internal-testing/tiny-random-distilbert", model="hf-internal-testing/tiny-random-distilbert",
framework="pt", framework="pt",
device=torch.device("cpu"), device=torch_device,
) )
outputs = text_classifier("This is great !") outputs = text_classifier("This is great !")
......
...@@ -27,8 +27,10 @@ from transformers.testing_utils import ( ...@@ -27,8 +27,10 @@ from transformers.testing_utils import (
require_accelerate, require_accelerate,
require_tf, require_tf,
require_torch, require_torch,
require_torch_accelerator,
require_torch_gpu, require_torch_gpu,
require_torch_or_tf, require_torch_or_tf,
torch_device,
) )
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
...@@ -319,16 +321,20 @@ class TextGenerationPipelineTests(unittest.TestCase): ...@@ -319,16 +321,20 @@ class TextGenerationPipelineTests(unittest.TestCase):
) )
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
def test_small_model_fp16(self): def test_small_model_fp16(self):
import torch import torch
pipe = pipeline(model="hf-internal-testing/tiny-random-bloom", device=0, torch_dtype=torch.float16) pipe = pipeline(
model="hf-internal-testing/tiny-random-bloom",
device=torch_device,
torch_dtype=torch.float16,
)
pipe("This is a test") pipe("This is a test")
@require_torch @require_torch
@require_accelerate @require_accelerate
@require_torch_gpu @require_torch_accelerator
def test_pipeline_accelerate_top_p(self): def test_pipeline_accelerate_top_p(self):
import torch import torch
......
...@@ -25,9 +25,10 @@ from transformers import ( ...@@ -25,9 +25,10 @@ from transformers import (
from transformers.testing_utils import ( from transformers.testing_utils import (
is_pipeline_test, is_pipeline_test,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_torch_or_tf, require_torch_or_tf,
slow, slow,
torch_device,
) )
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
...@@ -115,9 +116,9 @@ class TextToAudioPipelineTests(unittest.TestCase): ...@@ -115,9 +116,9 @@ class TextToAudioPipelineTests(unittest.TestCase):
self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio) self.assertEqual([ANY(np.ndarray), ANY(np.ndarray)], audio)
@slow @slow
@require_torch_gpu @require_torch_accelerator
def test_conversion_additional_tensor(self): def test_conversion_additional_tensor(self):
speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt", device=0) speech_generator = pipeline(task="text-to-audio", model="suno/bark-small", framework="pt", device=torch_device)
processor = AutoProcessor.from_pretrained("suno/bark-small") processor = AutoProcessor.from_pretrained("suno/bark-small")
forward_params = { forward_params = {
......
...@@ -30,8 +30,9 @@ from transformers.testing_utils import ( ...@@ -30,8 +30,9 @@ from transformers.testing_utils import (
nested_simplify, nested_simplify,
require_tf, require_tf,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device,
) )
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
...@@ -391,13 +392,13 @@ class TokenClassificationPipelineTests(unittest.TestCase): ...@@ -391,13 +392,13 @@ class TokenClassificationPipelineTests(unittest.TestCase):
], ],
) )
@require_torch_gpu @require_torch_accelerator
@slow @slow
def test_gpu(self): def test_accelerator(self):
sentence = "This is dummy sentence" sentence = "This is dummy sentence"
ner = pipeline( ner = pipeline(
"token-classification", "token-classification",
device=0, device=torch_device,
aggregation_strategy=AggregationStrategy.SIMPLE, aggregation_strategy=AggregationStrategy.SIMPLE,
) )
......
...@@ -22,9 +22,10 @@ from transformers.testing_utils import ( ...@@ -22,9 +22,10 @@ from transformers.testing_utils import (
nested_simplify, nested_simplify,
require_tf, require_tf,
require_torch, require_torch,
require_torch_gpu, require_torch_accelerator,
require_vision, require_vision,
slow, slow,
torch_device,
) )
from .test_pipelines_common import ANY from .test_pipelines_common import ANY
...@@ -91,7 +92,7 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase): ...@@ -91,7 +92,7 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
) )
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
def test_small_model_pt_blip2(self): def test_small_model_pt_blip2(self):
vqa_pipeline = pipeline( vqa_pipeline = pipeline(
"visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration" "visual-question-answering", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration"
...@@ -112,9 +113,9 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase): ...@@ -112,9 +113,9 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
"visual-question-answering", "visual-question-answering",
model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration", model="hf-internal-testing/tiny-random-Blip2ForConditionalGeneration",
model_kwargs={"torch_dtype": torch.float16}, model_kwargs={"torch_dtype": torch.float16},
device=0, device=torch_device,
) )
self.assertEqual(vqa_pipeline.model.device, torch.device(0)) self.assertEqual(vqa_pipeline.model.device, torch.device("{}:0".format(torch_device)))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16) self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16) self.assertEqual(vqa_pipeline.model.vision_model.dtype, torch.float16)
...@@ -148,15 +149,15 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase): ...@@ -148,15 +149,15 @@ class VisualQuestionAnsweringPipelineTests(unittest.TestCase):
@slow @slow
@require_torch @require_torch
@require_torch_gpu @require_torch_accelerator
def test_large_model_pt_blip2(self): def test_large_model_pt_blip2(self):
vqa_pipeline = pipeline( vqa_pipeline = pipeline(
"visual-question-answering", "visual-question-answering",
model="Salesforce/blip2-opt-2.7b", model="Salesforce/blip2-opt-2.7b",
model_kwargs={"torch_dtype": torch.float16}, model_kwargs={"torch_dtype": torch.float16},
device=0, device=torch_device,
) )
self.assertEqual(vqa_pipeline.model.device, torch.device(0)) self.assertEqual(vqa_pipeline.model.device, torch.device("{}:0".format(torch_device)))
self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16) self.assertEqual(vqa_pipeline.model.language_model.dtype, torch.float16)
image = "./tests/fixtures/tests_samples/COCO/000000039769.png" image = "./tests/fixtures/tests_samples/COCO/000000039769.png"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment