Unverified Commit ad00c482 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FIX [`Gemma` / `CI`] Make sure our runners have access to the model (#29242)



* pu hf token in gemma tests

* update suggestion

* add to flax

* revert

* fix

* fixup

* forward contrib credits from discussion

---------
Co-authored-by: default avatarArthurZucker <ArthurZucker@users.noreply.github.com>
parent bd5b9863
......@@ -31,12 +31,14 @@ import time
import unittest
from collections import defaultdict
from collections.abc import Mapping
from functools import wraps
from io import StringIO
from pathlib import Path
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
from unittest import mock
from unittest.mock import patch
import huggingface_hub
import urllib3
from transformers import logging as transformers_logging
......@@ -460,6 +462,20 @@ def require_torch_sdpa(test_case):
return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
def require_read_token(fn):
"""
A decorator that loads the HF token for tests that require to load gated models.
"""
token = os.getenv("HF_HUB_READ_TOKEN", None)
@wraps(fn)
def _inner(*args, **kwargs):
with patch(huggingface_hub.utils._headers, "get_token", return_value=token):
return fn(*args, **kwargs)
return _inner
def require_peft(test_case):
"""
Decorator marking a test that requires PEFT.
......
......@@ -11,14 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import AutoTokenizer, GemmaConfig, is_flax_available
from transformers.testing_utils import require_flax, slow
from transformers.testing_utils import require_flax, require_read_token, slow
from ...generation.test_flax_utils import FlaxGenerationTesterMixin
from ...test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
......@@ -205,6 +203,7 @@ class FlaxGemmaModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unitte
@slow
@require_flax
@require_read_token
class FlaxGemmaIntegrationTest(unittest.TestCase):
input_text = ["The capital of France is", "To play the perfect cover drive"]
model_id = "google/gemma-2b"
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Gemma model. """
import tempfile
import unittest
......@@ -24,6 +23,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_to
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
require_read_token,
require_torch,
require_torch_gpu,
require_torch_sdpa,
......@@ -529,6 +529,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@require_torch_gpu
@slow
@require_read_token
class GemmaIntegrationTest(unittest.TestCase):
input_text = ["Hello I am doing", "Hi today"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment