"docs/source/vscode:/vscode.git/clone" did not exist on "011b15c1c75a575fcaee5a50de02ff316881816a"
testing_utils.py 51.2 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

NielsRogge's avatar
NielsRogge committed
15
import collections
16
import contextlib
17
import inspect
18
import logging
19
import os
20
import re
21
import shlex
22
import shutil
Zachary Mueller's avatar
Zachary Mueller committed
23
import subprocess
24
import sys
25
import tempfile
Aymeric Augustin's avatar
Aymeric Augustin committed
26
import unittest
27
from collections.abc import Mapping
28
from distutils.util import strtobool
29
from io import StringIO
30
from pathlib import Path
Zachary Mueller's avatar
Zachary Mueller committed
31
from typing import Iterator, List, Union
32
from unittest import mock
33

34
35
from transformers import logging as transformers_logging

36
from .deepspeed import is_deepspeed_available
37
38
39
40
41
42
43
from .integrations import (
    is_fairscale_available,
    is_optuna_available,
    is_ray_available,
    is_sigopt_available,
    is_wandb_available,
)
44
from .utils import (
45
    is_accelerate_available,
46
47
    is_apex_available,
    is_bitsandbytes_available,
48
    is_detectron2_available,
49
50
    is_faiss_available,
    is_flax_available,
51
    is_ftfy_available,
52
    is_ipex_available,
53
    is_librosa_available,
54
    is_onnx_available,
55
    is_pandas_available,
56
    is_phonemizer_available,
57
    is_pyctcdecode_available,
58
    is_pytesseract_available,
59
    is_pytorch_quantization_available,
yujun's avatar
yujun committed
60
    is_rjieba_available,
61
    is_scatter_available,
62
    is_scipy_available,
63
    is_sentencepiece_available,
Patrick von Platen's avatar
Patrick von Platen committed
64
    is_soundfile_availble,
65
    is_spacy_available,
Kamal Raj's avatar
Kamal Raj committed
66
    is_tensorflow_probability_available,
67
    is_tensorflow_text_available,
68
    is_tf2onnx_available,
69
    is_tf_available,
NielsRogge's avatar
NielsRogge committed
70
    is_timm_available,
71
72
    is_tokenizers_available,
    is_torch_available,
73
74
    is_torch_bf16_cpu_available,
    is_torch_bf16_gpu_available,
75
    is_torch_tensorrt_fx_available,
76
    is_torch_tf32_available,
77
    is_torch_tpu_available,
Suraj Patil's avatar
Suraj Patil committed
78
    is_torchaudio_available,
79
    is_torchdynamo_available,
80
    is_vision_available,
81
)
82
83


Julien Chaumond's avatar
Julien Chaumond committed
84
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
85
DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
86
DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
Julien Chaumond's avatar
Julien Chaumond committed
87
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
Julien Chaumond's avatar
Julien Chaumond committed
88

Sylvain Gugger's avatar
Sylvain Gugger committed
89
90
# Used to test the hub
USER = "__DUMMY_TRANSFORMERS_USER__"
91
92
93
94
ENDPOINT_STAGING = "https://hub-ci.huggingface.co"

# Not critical, only usable on the sandboxed CI instance.
TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
Sylvain Gugger's avatar
Sylvain Gugger committed
95

Julien Chaumond's avatar
Julien Chaumond committed
96

97
def parse_flag_from_env(key, default=False):
98
    try:
99
100
101
102
103
104
105
106
107
108
        value = os.environ[key]
    except KeyError:
        # KEY isn't set, default to `default`.
        _value = default
    else:
        # KEY is set, convert it to True or False.
        try:
            _value = strtobool(value)
        except ValueError:
            # More values are supported, but let's keep the message simple.
109
            raise ValueError(f"If set, {key} must be yes or no.")
110
111
    return _value

112

Julien Chaumond's avatar
Julien Chaumond committed
113
114
115
116
117
118
119
120
121
def parse_int_from_env(key, default=None):
    try:
        value = os.environ[key]
    except KeyError:
        _value = default
    else:
        try:
            _value = int(value)
        except ValueError:
122
            raise ValueError(f"If set, {key} must be a int.")
Julien Chaumond's avatar
Julien Chaumond committed
123
124
125
    return _value


126
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
127
_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=False)
128
_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=False)
129
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
Sylvain Gugger's avatar
Sylvain Gugger committed
130
_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
131
_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=False)
132
_run_git_lfs_tests = parse_flag_from_env("RUN_GIT_LFS_TESTS", default=False)
Julien Chaumond's avatar
Julien Chaumond committed
133
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
134
135


136
137
138
139
140
141
142
143
def is_pt_tf_cross_test(test_case):
    """
    Decorator marking a test as a test that control interactions between PyTorch and TensorFlow.

    PT+TF tests are skipped by default and we can run only them by setting RUN_PT_TF_CROSS_TESTS environment variable
    to a truthy value and selecting the is_pt_tf_cross_test pytest mark.

    """
144
    if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
145
146
147
148
149
150
151
152
153
154
        return unittest.skip("test is PT+TF test")(test_case)
    else:
        try:
            import pytest  # We don't need a hard dependency on pytest in the main library
        except ImportError:
            return test_case
        else:
            return pytest.mark.is_pt_tf_cross_test()(test_case)


155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def is_pt_flax_cross_test(test_case):
    """
    Decorator marking a test as a test that control interactions between PyTorch and Flax

    PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
    variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.

    """
    if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
        return unittest.skip("test is PT+FLAX test")(test_case)
    else:
        try:
            import pytest  # We don't need a hard dependency on pytest in the main library
        except ImportError:
            return test_case
        else:
            return pytest.mark.is_pt_flax_cross_test()(test_case)


174
175
176
177
def is_pipeline_test(test_case):
    """
    Decorator marking a test as a pipeline test.

178
179
    Pipeline tests are skipped by default and we can run only them by setting RUN_PIPELINE_TESTS environment variable
    to a truthy value and selecting the is_pipeline_test pytest mark.
180
181
182
183
184
185
186
187
188
189
190
191
192

    """
    if not _run_pipeline_tests:
        return unittest.skip("test is pipeline test")(test_case)
    else:
        try:
            import pytest  # We don't need a hard dependency on pytest in the main library
        except ImportError:
            return test_case
        else:
            return pytest.mark.is_pipeline_test()(test_case)


Sylvain Gugger's avatar
Sylvain Gugger committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
def is_staging_test(test_case):
    """
    Decorator marking a test as a staging test.

    Those tests will run using the staging environment of huggingface.co instead of the real model hub.
    """
    if not _run_staging:
        return unittest.skip("test is staging test")(test_case)
    else:
        try:
            import pytest  # We don't need a hard dependency on pytest in the main library
        except ImportError:
            return test_case
        else:
            return pytest.mark.is_staging_test()(test_case)


210
211
212
213
def slow(test_case):
    """
    Decorator marking a test as slow.

Sylvain Gugger's avatar
Sylvain Gugger committed
214
    Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
215
216

    """
217
    return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
218
219


Lysandre Debut's avatar
Lysandre Debut committed
220
221
222
223
224
225
226
227
228
229
230
def tooslow(test_case):
    """
    Decorator marking a test as too slow.

    Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as
    these will not be tested by the CI.

    """
    return unittest.skip("test is too slow")(test_case)


231
232
233
234
def custom_tokenizers(test_case):
    """
    Decorator marking a test for a custom tokenizer.

Sylvain Gugger's avatar
Sylvain Gugger committed
235
236
    Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
    environment variable to a truthy value to run them.
237
    """
238
    return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
239
240


241
242
243
244
245
246
247
def require_git_lfs(test_case):
    """
    Decorator marking a test that requires git-lfs.

    git-lfs requires additional dependencies, and tests are skipped by default. Set the RUN_GIT_LFS_TESTS environment
    variable to a truthy value to run them.
    """
248
    return unittest.skipUnless(_run_git_lfs_tests, "test of git lfs workflow")(test_case)
249
250


251
252
253
254
255
256
257
def require_accelerate(test_case):
    """
    Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
    """
    return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)


yujun's avatar
yujun committed
258
259
260
261
def require_rjieba(test_case):
    """
    Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
    """
262
    return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
yujun's avatar
yujun committed
263
264


265
def require_tf2onnx(test_case):
266
    return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
267
268


269
def require_onnx(test_case):
270
    return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
271
272


NielsRogge's avatar
NielsRogge committed
273
274
275
276
277
278
279
def require_timm(test_case):
    """
    Decorator marking a test that requires Timm.

    These tests are skipped when Timm isn't installed.

    """
280
    return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
NielsRogge's avatar
NielsRogge committed
281
282


283
284
285
286
287
288
289
def require_torch(test_case):
    """
    Decorator marking a test that requires PyTorch.

    These tests are skipped when PyTorch isn't installed.

    """
290
    return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
291
292


293
294
295
296
def require_intel_extension_for_pytorch(test_case):
    """
    Decorator marking a test that requires Intel Extension for PyTorch.

297
298
    These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
    version.
299
300

    """
301
302
303
304
305
    return unittest.skipUnless(
        is_ipex_available(),
        "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
        " https://github.com/intel/intel-extension-for-pytorch",
    )(test_case)
306
307


308
309
310
311
312
313
314
def require_torch_scatter(test_case):
    """
    Decorator marking a test that requires PyTorch scatter.

    These tests are skipped when PyTorch scatter isn't installed.

    """
315
    return unittest.skipUnless(is_scatter_available(), "test requires PyTorch scatter")(test_case)
Suraj Patil's avatar
Suraj Patil committed
316
317


Kamal Raj's avatar
Kamal Raj committed
318
319
320
321
322
323
324
def require_tensorflow_probability(test_case):
    """
    Decorator marking a test that requires TensorFlow probability.

    These tests are skipped when TensorFlow probability isn't installed.

    """
325
326
327
    return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
        test_case
    )
Kamal Raj's avatar
Kamal Raj committed
328
329


Suraj Patil's avatar
Suraj Patil committed
330
331
def require_torchaudio(test_case):
    """
332
    Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
Suraj Patil's avatar
Suraj Patil committed
333
    """
334
    return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
335
336


337
338
def require_tf(test_case):
    """
339
    Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
340
    """
341
    return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)
342
343


344
345
def require_flax(test_case):
    """
346
    Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
347
    """
348
    return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
349
350


351
352
def require_sentencepiece(test_case):
    """
353
    Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
354
    """
355
    return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
356
357


358
359
360
361
def require_scipy(test_case):
    """
    Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
    """
362
    return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
363
364


365
366
def require_tokenizers(test_case):
    """
367
    Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
368
    """
369
    return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
370
371


372
373
374
375
376
377
378
379
def require_tensorflow_text(test_case):
    """
    Decorator marking a test that requires tensorflow_text. These tests are skipped when tensroflow_text isn't
    installed.
    """
    return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case)


NielsRogge's avatar
NielsRogge committed
380
381
382
383
def require_pandas(test_case):
    """
    Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
    """
384
    return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
NielsRogge's avatar
NielsRogge committed
385
386


387
388
389
390
def require_pytesseract(test_case):
    """
    Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
    """
391
    return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
392
393


NielsRogge's avatar
NielsRogge committed
394
395
396
397
398
def require_scatter(test_case):
    """
    Decorator marking a test that requires PyTorch Scatter. These tests are skipped when PyTorch Scatter isn't
    installed.
    """
399
    return unittest.skipUnless(is_scatter_available(), "test requires PyTorch Scatter")(test_case)
NielsRogge's avatar
NielsRogge committed
400
401


402
403
404
405
406
def require_pytorch_quantization(test_case):
    """
    Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
    Quantization Toolkit isn't installed.
    """
407
408
409
    return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
        test_case
    )
410
411


412
def require_vision(test_case):
413
    """
414
415
416
    Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
    installed.
    """
417
    return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
418

419

420
421
422
423
def require_ftfy(test_case):
    """
    Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
    """
424
    return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
425
426
427
428
429
430


def require_spacy(test_case):
    """
    Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
    """
431
    return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
432
433


434
435
436
437
def require_torch_multi_gpu(test_case):
    """
    Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
    multiple GPUs.
438

439
    To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
440
    """
441
    if not is_torch_available():
442
443
444
445
        return unittest.skip("test requires PyTorch")(test_case)

    import torch

446
    return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
447
448


449
def require_torch_non_multi_gpu(test_case):
450
451
452
    """
    Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
    """
453
    if not is_torch_available():
454
455
456
457
        return unittest.skip("test requires PyTorch")(test_case)

    import torch

458
    return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
459
460


461
462
463
464
465
466
467
468
469
def require_torch_up_to_2_gpus(test_case):
    """
    Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
    """
    if not is_torch_available():
        return unittest.skip("test requires PyTorch")(test_case)

    import torch

470
    return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
471
472


Lysandre Debut's avatar
Lysandre Debut committed
473
474
475
476
def require_torch_tpu(test_case):
    """
    Decorator marking a test that requires a TPU (in PyTorch).
    """
477
    return unittest.skipUnless(is_torch_tpu_available(check_device=False), "test requires PyTorch TPU")(test_case)
Lysandre Debut's avatar
Lysandre Debut committed
478
479


480
if is_torch_available():
Stas Bekman's avatar
Stas Bekman committed
481
482
483
484
    # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
    import torch

    torch_device = "cuda" if torch.cuda.is_available() else "cpu"
485
486
else:
    torch_device = None
487

488
489
490
if is_tf_available():
    import tensorflow as tf

491
492
493
494
495
496
497
if is_flax_available():
    import jax

    jax_device = jax.default_backend()
else:
    jax_device = None

498

499
500
501
502
503
def require_torchdynamo(test_case):
    """Decorator marking a test that requires TorchDynamo"""
    return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)


504
505
506
507
508
def require_torch_tensorrt_fx(test_case):
    """Decorator marking a test that requires Torch-TensorRT FX"""
    return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)


509
def require_torch_gpu(test_case):
Patrick von Platen's avatar
Patrick von Platen committed
510
    """Decorator marking a test that requires CUDA and PyTorch."""
511
    return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
512
513


514
515
def require_torch_bf16_gpu(test_case):
    """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
516
    return unittest.skipUnless(
517
518
519
520
521
522
523
524
525
526
        is_torch_bf16_gpu_available(),
        "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
    )(test_case)


def require_torch_bf16_cpu(test_case):
    """Decorator marking a test that requires torch>=1.10, using CPU."""
    return unittest.skipUnless(
        is_torch_bf16_cpu_available(),
        "test requires torch>=1.10, using CPU",
527
    )(test_case)
528
529
530
531


def require_torch_tf32(test_case):
    """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
532
533
534
    return unittest.skipUnless(
        is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
    )(test_case)
535
536


537
538
def require_detectron2(test_case):
    """Decorator marking a test that requires detectron2."""
539
    return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
540
541


Ola Piktus's avatar
Ola Piktus committed
542
543
def require_faiss(test_case):
    """Decorator marking a test that requires faiss."""
544
    return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
Ola Piktus's avatar
Ola Piktus committed
545
546


547
548
549
550
551
552
553
def require_optuna(test_case):
    """
    Decorator marking a test that requires optuna.

    These tests are skipped when optuna isn't installed.

    """
554
    return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
555
556
557
558
559
560
561
562
563


def require_ray(test_case):
    """
    Decorator marking a test that requires Ray/tune.

    These tests are skipped when Ray/tune isn't installed.

    """
564
    return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
565
566


567
568
569
570
571
572
573
def require_sigopt(test_case):
    """
    Decorator marking a test that requires SigOpt.

    These tests are skipped when SigOpt isn't installed.

    """
574
    return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
575
576


577
578
579
580
581
582
583
def require_wandb(test_case):
    """
    Decorator marking a test that requires wandb.

    These tests are skipped when wandb isn't installed.

    """
584
    return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
585
586


Patrick von Platen's avatar
Patrick von Platen committed
587
588
589
590
591
592
593
def require_soundfile(test_case):
    """
    Decorator marking a test that requires soundfile

    These tests are skipped when soundfile isn't installed.

    """
594
    return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case)
Patrick von Platen's avatar
Patrick von Platen committed
595
596


597
598
599
600
def require_deepspeed(test_case):
    """
    Decorator marking a test that requires deepspeed
    """
601
    return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
602
603


604
605
606
607
def require_fairscale(test_case):
    """
    Decorator marking a test that requires fairscale
    """
608
    return unittest.skipUnless(is_fairscale_available(), "test requires fairscale")(test_case)
609
610
611
612
613
614


def require_apex(test_case):
    """
    Decorator marking a test that requires apex
    """
615
    return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
616
617
618
619
620
621


def require_bitsandbytes(test_case):
    """
    Decorator for bits and bytes (bnb) dependency
    """
622
    return unittest.skipUnless(is_bitsandbytes_available(), "test requires bnb")(test_case)
623
624


625
626
627
628
def require_phonemizer(test_case):
    """
    Decorator marking a test that requires phonemizer
    """
629
    return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
630
631


632
633
634
635
def require_pyctcdecode(test_case):
    """
    Decorator marking a test that requires pyctcdecode
    """
636
    return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
637
638
639
640
641
642


def require_librosa(test_case):
    """
    Decorator marking a test that requires librosa
    """
643
    return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
644
645


646
647
648
649
650
651
652
653
def cmd_exists(cmd):
    return shutil.which(cmd) is not None


def require_usr_bin_time(test_case):
    """
    Decorator marking a test that requires `/usr/bin/time`
    """
654
    return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
655
656


657
658
def get_gpu_count():
    """
Suraj Patil's avatar
Suraj Patil committed
659
    Return the number of available gpus (regardless of whether torch, tf or jax is used)
660
    """
661
    if is_torch_available():
662
663
664
        import torch

        return torch.cuda.device_count()
665
    elif is_tf_available():
666
667
668
        import tensorflow as tf

        return len(tf.config.list_physical_devices("GPU"))
Suraj Patil's avatar
Suraj Patil committed
669
670
671
672
    elif is_flax_available():
        import jax

        return jax.device_count()
673
674
675
676
    else:
        return 0


677
def get_tests_dir(append_path=None):
678
    """
679
680
681
682
    Args:
        append_path: optional path to append to the tests dir path

    Return:
Sylvain Gugger's avatar
Sylvain Gugger committed
683
684
        The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
        joined after the `tests` dir the former is provided.
685

686
687
688
    """
    # this function caller's __file__
    caller__file__ = inspect.stack()[1][1]
689
    tests_dir = os.path.abspath(os.path.dirname(caller__file__))
690
691
692
693

    while not tests_dir.endswith("tests"):
        tests_dir = os.path.dirname(tests_dir)

694
695
696
697
    if append_path:
        return os.path.join(tests_dir, append_path)
    else:
        return tests_dir
698
699


700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
#
# Helper functions for dealing with testing text outputs
# The original code came from:
# https://github.com/fastai/fastai/blob/master/tests/utils/text.py

# When any function contains print() calls that get overwritten, like progress bars,
# a special care needs to be applied, since under pytest -s captured output (capsys
# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
# \r's. This helper function ensures that the buffer will contain the same output
# with and without -s in pytest, by turning:
# foo bar\r tar mar\r final message
# into:
# final message
# it can handle a single string or a multiline buffer
def apply_print_resets(buf):
    return re.sub(r"^.*\r", "", buf, 0, re.M)


def assert_screenout(out, what):
    out_pr = apply_print_resets(out).lower()
    match_str = out_pr.find(what.lower())
    assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"


class CaptureStd:
Sylvain Gugger's avatar
Sylvain Gugger committed
725
726
    """
    Context manager to capture:
727

728
729
        - stdout: replay it, clean it up and make it available via `obj.out`
        - stderr: replay it and make it available via `obj.err`
730

731
732
733
734
735
    Args:
        out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not.
        err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not.
        replay (`bool`, *optional*, defaults to `True`): Whether to replay or not.
            By default each captured stream gets replayed back on context's exit, so that one can see what the test was
Sylvain Gugger's avatar
Sylvain Gugger committed
736
737
            doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to
            disable this feature.
738
739
740
741
742
743
744
745
746
747
748

    Examples:

    ```python
    # to capture stdout only with auto-replay
    with CaptureStdout() as cs:
        print("Secret message")
    assert "message" in cs.out

    # to capture stderr only with auto-replay
    import sys
Sylvain Gugger's avatar
Sylvain Gugger committed
749

750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
    with CaptureStderr() as cs:
        print("Warning: ", file=sys.stderr)
    assert "Warning" in cs.err

    # to capture both streams with auto-replay
    with CaptureStd() as cs:
        print("Secret message")
        print("Warning: ", file=sys.stderr)
    assert "message" in cs.out
    assert "Warning" in cs.err

    # to capture just one of the streams, and not the other, with auto-replay
    with CaptureStd(err=False) as cs:
        print("Secret message")
    assert "message" in cs.out
    # but best use the stream-specific subclasses

    # to capture without auto-replay
    with CaptureStd(replay=False) as cs:
        print("Secret message")
    assert "message" in cs.out
    ```"""
772

773
774
775
776
    def __init__(self, out=True, err=True, replay=True):

        self.replay = replay

777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
        if out:
            self.out_buf = StringIO()
            self.out = "error: CaptureStd context is unfinished yet, called too early"
        else:
            self.out_buf = None
            self.out = "not capturing stdout"

        if err:
            self.err_buf = StringIO()
            self.err = "error: CaptureStd context is unfinished yet, called too early"
        else:
            self.err_buf = None
            self.err = "not capturing stderr"

    def __enter__(self):
        if self.out_buf:
            self.out_old = sys.stdout
            sys.stdout = self.out_buf

        if self.err_buf:
            self.err_old = sys.stderr
            sys.stderr = self.err_buf

        return self

    def __exit__(self, *exc):
        if self.out_buf:
            sys.stdout = self.out_old
805
806
807
808
            captured = self.out_buf.getvalue()
            if self.replay:
                sys.stdout.write(captured)
            self.out = apply_print_resets(captured)
809
810
811

        if self.err_buf:
            sys.stderr = self.err_old
812
813
814
815
            captured = self.err_buf.getvalue()
            if self.replay:
                sys.stderr.write(captured)
            self.err = captured
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832

    def __repr__(self):
        msg = ""
        if self.out_buf:
            msg += f"stdout: {self.out}\n"
        if self.err_buf:
            msg += f"stderr: {self.err}\n"
        return msg


# in tests it's the best to capture only the stream that's wanted, otherwise
# it's easy to miss things, so unless you need to capture both streams, use the
# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
# disable the stream you don't need to test.


class CaptureStdout(CaptureStd):
Patrick von Platen's avatar
Patrick von Platen committed
833
    """Same as CaptureStd but captures only stdout"""
834

835
836
    def __init__(self, replay=True):
        super().__init__(err=False, replay=replay)
837
838
839


class CaptureStderr(CaptureStd):
Patrick von Platen's avatar
Patrick von Platen committed
840
    """Same as CaptureStd but captures only stderr"""
841

842
843
    def __init__(self, replay=True):
        super().__init__(out=False, replay=replay)
844
845


846
class CaptureLogger:
Sylvain Gugger's avatar
Sylvain Gugger committed
847
848
    """
    Context manager to capture `logging` streams
849
850

    Args:
851
        logger: 'logging` logger object
852

853
    Returns:
854
855
        The captured output is available via `self.out`

856
    Example:
857

858
859
860
    ```python
    >>> from transformers import logging
    >>> from transformers.testing_utils import CaptureLogger
861

862
863
864
865
866
    >>> msg = "Testing 1, 2, 3"
    >>> logging.set_verbosity_info()
    >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
    >>> with CaptureLogger(logger) as cl:
    ...     logger.info(msg)
Sylvain Gugger's avatar
Sylvain Gugger committed
867
    >>> assert cl.out, msg + "\n"
868
    ```
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
    """

    def __init__(self, logger):
        self.logger = logger
        self.io = StringIO()
        self.sh = logging.StreamHandler(self.io)
        self.out = ""

    def __enter__(self):
        self.logger.addHandler(self.sh)
        return self

    def __exit__(self, *exc):
        self.logger.removeHandler(self.sh)
        self.out = self.io.getvalue()

    def __repr__(self):
        return f"captured: {self.out}\n"


889
890
891
892
893
894
@contextlib.contextmanager
def LoggingLevel(level):
    """
    This is a context manager to temporarily change transformers modules logging level to the desired value and have it
    restored to the original setting at the end of the scope.

895
    Example:
896

897
898
    ```python
    with LoggingLevel(logging.INFO):
Sylvain Gugger's avatar
Sylvain Gugger committed
899
        AutoModel.from_pretrained("gpt2")  # calls logger.info() several times
900
    ```
901
902
903
904
905
906
907
908
909
    """
    orig_level = transformers_logging.get_verbosity()
    try:
        transformers_logging.set_verbosity(level)
        yield
    finally:
        transformers_logging.set_verbosity(orig_level)


910
911
912
913
914
915
@contextlib.contextmanager
# adapted from https://stackoverflow.com/a/64789046/9201239
def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
    """
    Temporary add given path to `sys.path`.

916
    Usage :
917

918
    ```python
Sylvain Gugger's avatar
Sylvain Gugger committed
919
920
    with ExtendSysPath("/path/to/dir"):
        mymodule = importlib.import_module("mymodule")
921
    ```
922
923
924
925
926
927
928
929
930
931
    """

    path = os.fspath(path)
    try:
        sys.path.insert(0, path)
        yield
    finally:
        sys.path.remove(path)


932
class TestCasePlus(unittest.TestCase):
Sylvain Gugger's avatar
Sylvain Gugger committed
933
    """
934
    This class extends *unittest.TestCase* with additional features.
935

936
937
938
939
940
941
    Feature 1: A set of fully resolved important file and dir path accessors.

    In tests often we need to know where things are relative to the current test file, and it's not trivial since the
    test could be invoked from more than one directory or could reside in sub-directories with different depths. This
    class solves this problem by sorting out all the basic paths and provides easy accessors to them:

942
    - `pathlib` objects (all fully resolved):
943

944
945
946
947
948
949
       - `test_file_path` - the current test file path (=`__file__`)
       - `test_file_dir` - the directory containing the current test file
       - `tests_dir` - the directory of the `tests` test suite
       - `examples_dir` - the directory of the `examples` test suite
       - `repo_root_dir` - the directory of the repository
       - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides)
950

951
    - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects:
952

953
954
955
956
957
958
       - `test_file_path_str`
       - `test_file_dir_str`
       - `tests_dir_str`
       - `examples_dir_str`
       - `repo_root_dir_str`
       - `src_dir_str`
959

960
    Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
961

962
    1. Create a unique temporary dir:
963

964
965
966
967
    ```python
    def test_whatever(self):
        tmp_dir = self.get_auto_remove_tmp_dir()
    ```
968

969
    `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the
970
971
972
973
974
    test.


    2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
    empty it after the test.
975

976
977
978
979
    ```python
    def test_whatever(self):
        tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
    ```
980

981
982
    This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
    didn't leave any data in there.
983

984
985
    3. You can override the first two options by directly overriding the `before` and `after` args, leading to the
        following behavior:
986

987
    `before=True`: the temporary dir will always be cleared at the beginning of the test.
988

989
    `before=False`: if the temporary dir already existed, any existing files will remain there.
990

991
    `after=True`: the temporary dir will always be deleted at the end of the test.
992

993
    `after=False`: the temporary dir will always be left intact at the end of the test.
994

995
    Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are
Sylvain Gugger's avatar
Sylvain Gugger committed
996
997
    allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem
    will get nuked. i.e. please always pass paths that start with `./`
998

999
1000
    Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
    otherwise.
1001

Sylvain Gugger's avatar
Sylvain Gugger committed
1002
1003
    Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This
    is useful for invoking external programs from the test suite - e.g. distributed training.
1004
1005


1006
1007
1008
1009
    ```python
    def test_whatever(self):
        env = self.get_env()
    ```"""
1010
1011

    def setUp(self):
1012
        # get_auto_remove_tmp_dir feature:
1013
1014
        self.teardown_tmp_dirs = []

1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
        # figure out the resolved paths for repo_root, tests, examples, etc.
        self._test_file_path = inspect.getfile(self.__class__)
        path = Path(self._test_file_path).resolve()
        self._test_file_dir = path.parents[0]
        for up in [1, 2, 3]:
            tmp_dir = path.parents[up]
            if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir():
                break
        if tmp_dir:
            self._repo_root_dir = tmp_dir
        else:
            raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
        self._tests_dir = self._repo_root_dir / "tests"
        self._examples_dir = self._repo_root_dir / "examples"
        self._src_dir = self._repo_root_dir / "src"

    @property
    def test_file_path(self):
        return self._test_file_path

    @property
    def test_file_path_str(self):
        return str(self._test_file_path)

    @property
    def test_file_dir(self):
        return self._test_file_dir

    @property
    def test_file_dir_str(self):
        return str(self._test_file_dir)

    @property
    def tests_dir(self):
        return self._tests_dir

    @property
    def tests_dir_str(self):
        return str(self._tests_dir)

    @property
    def examples_dir(self):
        return self._examples_dir

    @property
    def examples_dir_str(self):
        return str(self._examples_dir)

    @property
    def repo_root_dir(self):
        return self._repo_root_dir

    @property
    def repo_root_dir_str(self):
        return str(self._repo_root_dir)

    @property
    def src_dir(self):
        return self._src_dir

    @property
    def src_dir_str(self):
        return str(self._src_dir)

    def get_env(self):
        """
Sylvain Gugger's avatar
Sylvain Gugger committed
1081
1082
        Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's
        invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training.
1083

Sylvain Gugger's avatar
Sylvain Gugger committed
1084
1085
        It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally
        the preset `PYTHONPATH` if any (all full resolved paths).
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098

        """
        env = os.environ.copy()
        paths = [self.src_dir_str]
        if "/examples" in self.test_file_dir_str:
            paths.append(self.examples_dir_str)
        else:
            paths.append(self.tests_dir_str)
        paths.append(env.get("PYTHONPATH", ""))

        env["PYTHONPATH"] = ":".join(paths)
        return env

1099
    def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
1100
1101
        """
        Args:
1102
1103
            tmp_dir (`string`, *optional*):
                if `None`:
1104
1105

                   - a unique temporary path will be created
1106
1107
                   - sets `before=True` if `before` is `None`
                   - sets `after=True` if `after` is `None`
1108
1109
                else:

1110
1111
1112
1113
                   - `tmp_dir` will be created
                   - sets `before=True` if `before` is `None`
                   - sets `after=False` if `after` is `None`
            before (`bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1114
1115
                If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the
                `tmp_dir` already exists, any existing files will remain there.
1116
            after (`bool`, *optional*):
Sylvain Gugger's avatar
Sylvain Gugger committed
1117
1118
                If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
                intact at the end of the test.
1119
1120

        Returns:
Sylvain Gugger's avatar
Sylvain Gugger committed
1121
            tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
1122
1123
        """
        if tmp_dir is not None:
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133

            # defining the most likely desired behavior for when a custom path is provided.
            # this most likely indicates the debug mode where we want an easily locatable dir that:
            # 1. gets cleared out before the test (if it already exists)
            # 2. is left intact after the test
            if before is None:
                before = True
            if after is None:
                after = False

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
            # using provided path
            path = Path(tmp_dir).resolve()

            # to avoid nuking parts of the filesystem, only relative paths are allowed
            if not tmp_dir.startswith("./"):
                raise ValueError(
                    f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
                )

            # ensure the dir is empty to start with
            if before is True and path.exists():
                shutil.rmtree(tmp_dir, ignore_errors=True)

            path.mkdir(parents=True, exist_ok=True)

        else:
1150
1151
1152
1153
1154
1155
1156
1157
1158
            # defining the most likely desired behavior for when a unique tmp path is auto generated
            # (not a debug mode), here we require a unique tmp dir that:
            # 1. is empty before the test (it will be empty in this situation anyway)
            # 2. gets fully removed after the test
            if before is None:
                before = True
            if after is None:
                after = True

1159
1160
1161
1162
1163
1164
1165
1166
1167
            # using unique tmp dir (always empty, regardless of `before`)
            tmp_dir = tempfile.mkdtemp()

        if after is True:
            # register for deletion
            self.teardown_tmp_dirs.append(tmp_dir)

        return tmp_dir

1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
    def python_one_liner_max_rss(self, one_liner_str):
        """
        Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
        program.

        Args:
            one_liner_str (`string`):
                a python one liner code that gets passed to `python -c`

        Returns:
            max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.

        Requirements:
            this helper needs `/usr/bin/time` to be installed (`apt install time`)

        Example:

        ```
        one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("t5-large")'
        max_rss = self.python_one_liner_max_rss(one_liner_str)
        ```
        """

        if not cmd_exists("/usr/bin/time"):
            raise ValueError("/usr/bin/time is required, install with `apt install time`")

        cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
        with CaptureStd() as cs:
            execute_subprocess_async(cmd, env=self.get_env())
        # returned data is in KB so convert to bytes
        max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
        return max_rss

1201
    def tearDown(self):
1202
1203

        # get_auto_remove_tmp_dir feature: remove registered temp dirs
1204
1205
1206
        for path in self.teardown_tmp_dirs:
            shutil.rmtree(path, ignore_errors=True)
        self.teardown_tmp_dirs = []
1207
1208
1209


def mockenv(**kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
1210
    """
1211
1212
    this is a convenience wrapper, that allows this ::

Sylvain Gugger's avatar
Sylvain Gugger committed
1213
1214
    @mockenv(RUN_SLOW=True, USE_TF=False) def test_something():
        run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False)
1215
1216

    """
1217
    return mock.patch.dict(os.environ, kwargs)
1218
1219


1220
1221
1222
1223
# from https://stackoverflow.com/a/34333710/9201239
@contextlib.contextmanager
def mockenv_context(*remove, **update):
    """
1224
    Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv
1225

1226
    The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations.
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251

    Args:
      remove: Environment variables to remove.
      update: Dictionary of environment variables and values to add/update.
    """
    env = os.environ
    update = update or {}
    remove = remove or []

    # List of environment variables being updated or removed.
    stomped = (set(update.keys()) | set(remove)) & set(env.keys())
    # Environment variables and values to restore on exit.
    update_after = {k: env[k] for k in stomped}
    # Environment variables and values to remove on exit.
    remove_after = frozenset(k for k in update if k not in env)

    try:
        env.update(update)
        [env.pop(k, None) for k in remove]
        yield
    finally:
        env.update(update_after)
        [env.pop(k) for k in remove_after]


1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
# --- pytest conf functions --- #

# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
pytest_opt_registered = {}


def pytest_addoption_shared(parser):
    """
    This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.

    It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
    option.

    """
    option = "--make-reports"
    if option not in pytest_opt_registered:
        parser.addoption(
            option,
            action="store",
            default=False,
            help="generate report files. The value of this option is used as a prefix to report names",
        )
        pytest_opt_registered[option] = 1


1277
1278
def pytest_terminal_summary_main(tr, id):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1279
1280
    Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
    directory. The report files are prefixed with the test suite name.
1281
1282
1283

    This function emulates --duration and -rA pytest arguments.

Sylvain Gugger's avatar
Sylvain Gugger committed
1284
1285
    This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
    there.
1286
1287

    Args:
1288

1289
    - tr: `terminalreporter` passed from `conftest.py`
1290
1291
    - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
      needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
1292

Sylvain Gugger's avatar
Sylvain Gugger committed
1293
1294
1295
    NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
    changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
    plugins and interfere.
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307

    """
    from _pytest.config import create_terminal_writer

    if not len(id):
        id = "tests"

    config = tr.config
    orig_writer = config.get_terminal_writer()
    orig_tbstyle = config.option.tbstyle
    orig_reportchars = tr.reportchars

1308
    dir = f"reports/{id}"
1309
    Path(dir).mkdir(parents=True, exist_ok=True)
Stas Bekman's avatar
Stas Bekman committed
1310
    report_files = {
1311
        k: f"{dir}/{k}.txt"
Stas Bekman's avatar
Stas Bekman committed
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
        for k in [
            "durations",
            "errors",
            "failures_long",
            "failures_short",
            "failures_line",
            "passes",
            "stats",
            "summary_short",
            "warnings",
        ]
    }
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343

    # custom durations report
    # note: there is no need to call pytest --durations=XX to get this separate report
    # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
    dlist = []
    for replist in tr.stats.values():
        for rep in replist:
            if hasattr(rep, "duration"):
                dlist.append(rep)
    if dlist:
        dlist.sort(key=lambda x: x.duration, reverse=True)
        with open(report_files["durations"], "w") as f:
            durations_min = 0.05  # sec
            f.write("slowest durations\n")
            for i, rep in enumerate(dlist):
                if rep.duration < durations_min:
                    f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
                    break
                f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")

Stas Bekman's avatar
Stas Bekman committed
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
    def summary_failures_short(tr):
        # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
        reports = tr.getreports("failed")
        if not reports:
            return
        tr.write_sep("=", "FAILURES SHORT STACK")
        for rep in reports:
            msg = tr._getfailureheadline(rep)
            tr.write_sep("_", msg, red=True, bold=True)
            # chop off the optional leading extra frames, leaving only the last one
            longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
            tr._tw.line(longrepr)
            # note: not printing out any rep.sections to keep the report short

1358
1359
1360
1361
    # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
    # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
    # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
    # pytest-instafail does that)
Stas Bekman's avatar
Stas Bekman committed
1362
1363
1364
1365

    # report failures with line/short/long styles
    config.option.tbstyle = "auto"  # full tb
    with open(report_files["failures_long"], "w") as f:
1366
1367
1368
        tr._tw = create_terminal_writer(config, f)
        tr.summary_failures()

Stas Bekman's avatar
Stas Bekman committed
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
    # config.option.tbstyle = "short" # short tb
    with open(report_files["failures_short"], "w") as f:
        tr._tw = create_terminal_writer(config, f)
        summary_failures_short(tr)

    config.option.tbstyle = "line"  # one line per error
    with open(report_files["failures_line"], "w") as f:
        tr._tw = create_terminal_writer(config, f)
        tr.summary_failures()

    with open(report_files["errors"], "w") as f:
1380
1381
1382
        tr._tw = create_terminal_writer(config, f)
        tr.summary_errors()

Stas Bekman's avatar
Stas Bekman committed
1383
    with open(report_files["warnings"], "w") as f:
1384
1385
1386
1387
        tr._tw = create_terminal_writer(config, f)
        tr.summary_warnings()  # normal warnings
        tr.summary_warnings()  # final warnings

Stas Bekman's avatar
Stas Bekman committed
1388
1389
    tr.reportchars = "wPpsxXEf"  # emulate -rA (used in summary_passes() and short_test_summary())
    with open(report_files["passes"], "w") as f:
1390
1391
1392
        tr._tw = create_terminal_writer(config, f)
        tr.summary_passes()

Stas Bekman's avatar
Stas Bekman committed
1393
    with open(report_files["summary_short"], "w") as f:
1394
1395
1396
        tr._tw = create_terminal_writer(config, f)
        tr.short_test_summary()

Stas Bekman's avatar
Stas Bekman committed
1397
    with open(report_files["stats"], "w") as f:
1398
1399
1400
1401
1402
1403
1404
        tr._tw = create_terminal_writer(config, f)
        tr.summary_stats()

    # restore:
    tr._tw = orig_writer
    tr.reportchars = orig_reportchars
    config.option.tbstyle = orig_tbstyle
1405
1406


1407
# --- distributed testing functions --- #
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461

# adapted from https://stackoverflow.com/a/59041913/9201239
import asyncio  # noqa


class _RunOutput:
    def __init__(self, returncode, stdout, stderr):
        self.returncode = returncode
        self.stdout = stdout
        self.stderr = stderr


async def _read_stream(stream, callback):
    while True:
        line = await stream.readline()
        if line:
            callback(line)
        else:
            break


async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
    if echo:
        print("\nRunning: ", " ".join(cmd))

    p = await asyncio.create_subprocess_exec(
        cmd[0],
        *cmd[1:],
        stdin=stdin,
        stdout=asyncio.subprocess.PIPE,
        stderr=asyncio.subprocess.PIPE,
        env=env,
    )

    # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
    # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
    #
    # If it starts hanging, will need to switch to the following code. The problem is that no data
    # will be seen until it's done and if it hangs for example there will be no debug info.
    # out, err = await p.communicate()
    # return _RunOutput(p.returncode, out, err)

    out = []
    err = []

    def tee(line, sink, pipe, label=""):
        line = line.decode("utf-8").rstrip()
        sink.append(line)
        if not quiet:
            print(label, line, file=pipe)

    # XXX: the timeout doesn't seem to make any difference here
    await asyncio.wait(
        [
Stas Bekman's avatar
Stas Bekman committed
1462
            _read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")),
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
            _read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
        ],
        timeout=timeout,
    )
    return _RunOutput(await p.wait(), out, err)


def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:

    loop = asyncio.get_event_loop()
    result = loop.run_until_complete(
        _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
    )

    cmd_str = " ".join(cmd)
    if result.returncode > 0:
1479
        stderr = "\n".join(result.stderr)
1480
        raise RuntimeError(
1481
1482
            f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
            f"The combined stderr from workers follows:\n{stderr}"
1483
        )
Stas Bekman's avatar
Stas Bekman committed
1484
1485
1486
1487

    # check that the subprocess actually did run and produced some output, should the test rely on
    # the remote side to do the testing
    if not result.stdout and not result.stderr:
1488
1489
1490
        raise RuntimeError(f"'{cmd_str}' produced no output.")

    return result
1491
1492


1493
1494
def pytest_xdist_worker_id():
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
1495
1496
    Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
    if `-n 1` or `pytest-xdist` isn't being used.
1497
1498
1499
1500
1501
1502
1503
1504
    """
    worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
    worker = re.sub(r"^gw", "", worker, 0, re.M)
    return int(worker)


def get_torch_dist_unique_port():
    """
1505
    Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
1506

Sylvain Gugger's avatar
Sylvain Gugger committed
1507
1508
    Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
    port at once.
1509
1510
1511
1512
1513
1514
    """
    port = 29500
    uniq_delta = pytest_xdist_worker_id()
    return port + uniq_delta


1515
1516
1517
1518
1519
def nested_simplify(obj, decimals=3):
    """
    Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
    within tests.
    """
1520
1521
    import numpy as np

1522
1523
    if isinstance(obj, list):
        return [nested_simplify(item, decimals) for item in obj]
1524
1525
    elif isinstance(obj, np.ndarray):
        return nested_simplify(obj.tolist())
1526
    elif isinstance(obj, Mapping):
1527
        return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
1528
    elif isinstance(obj, (str, int, np.int64)):
1529
        return obj
1530
1531
    elif obj is None:
        return obj
1532
    elif is_torch_available() and isinstance(obj, torch.Tensor):
1533
        return nested_simplify(obj.tolist(), decimals)
1534
1535
1536
1537
    elif is_tf_available() and tf.is_tensor(obj):
        return nested_simplify(obj.numpy().tolist())
    elif isinstance(obj, float):
        return round(obj, decimals)
1538
    elif isinstance(obj, (np.int32, np.float32)):
1539
        return nested_simplify(obj.item(), decimals)
1540
1541
    else:
        raise Exception(f"Not supported: {type(obj)}")
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558


def check_json_file_has_correct_format(file_path):
    with open(file_path, "r") as f:
        lines = f.readlines()
        if len(lines) == 1:
            # length can only be 1 if dict is empty
            assert lines[0] == "{}"
        else:
            # otherwise make sure json has correct format (at least 3 lines)
            assert len(lines) >= 3
            # each key one line, ident should be 2, min length is 3
            assert lines[0].strip() == "{"
            for line in lines[1:-1]:
                left_indent = len(lines[1]) - len(lines[1].lstrip())
                assert left_indent == 2
            assert lines[-1].strip() == "}"
NielsRogge's avatar
NielsRogge committed
1559
1560
1561
1562
1563
1564


def to_2tuple(x):
    if isinstance(x, collections.abc.Iterable):
        return x
    return (x, x)
Zachary Mueller's avatar
Zachary Mueller committed
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586


# These utils relate to ensuring the right error message is received when running scripts
class SubprocessCallException(Exception):
    pass


def run_command(command: List[str], return_stdout=False):
    """
    Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
    if an error occured while running `command`
    """
    try:
        output = subprocess.check_output(command, stderr=subprocess.STDOUT)
        if return_stdout:
            if hasattr(output, "decode"):
                output = output.decode("utf-8")
            return output
    except subprocess.CalledProcessError as e:
        raise SubprocessCallException(
            f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
        ) from e