Unverified Commit 6c26faa1 authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

Skip warning if tracing with dynamo (#25581)

* Ignore warning if tracing with dynamo

* fix import error

* separate to function

* add test
parent 18ee1fe7
......@@ -81,7 +81,12 @@ from .utils import (
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_sagemaker_mp_enabled,
is_torch_fx_proxy,
is_torchdynamo_compiling,
)
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod
from .utils.versions import require_version_core
......@@ -3799,7 +3804,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
"""
# Skip the check during tracing.
if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing():
if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
return
if (attention_mask is not None) or (self.config.pad_token_id is None):
......
......@@ -463,6 +463,17 @@ def is_torch_compile_available():
return hasattr(torch, "compile")
def is_torchdynamo_compiling():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
return dynamo.is_compiling()
except Exception:
return False
def is_torch_tensorrt_fx_available():
if importlib.util.find_spec("torch_tensorrt") is None:
return False
......
......@@ -55,6 +55,7 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME,
WEIGHTS_NAME,
)
from transformers.utils.import_utils import is_torchdynamo_available
sys.path.append(str(Path(__file__).parent.parent / "utils"))
......@@ -1014,6 +1015,25 @@ class ModelUtilsTest(TestCasePlus):
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
if not is_torchdynamo_available():
return
with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
logger.warning_once.cache_clear()
from torch._dynamo import config, testing
config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
def f(input_ids):
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
compile_counter = testing.CompileCounter()
opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
opt_fn(input_ids)
self.assertEqual(compile_counter.frame_count, 0)
@require_torch_gpu
@slow
def test_pretrained_low_mem_new_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment