"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a725db4f6cb0c73ea526f73cb0f9767eda671726"
Unverified Commit 6c26faa1 authored by Angela Yi's avatar Angela Yi Committed by GitHub
Browse files

Skip warning if tracing with dynamo (#25581)

* Ignore warning if tracing with dynamo

* fix import error

* separate to function

* add test
parent 18ee1fe7
...@@ -81,7 +81,12 @@ from .utils import ( ...@@ -81,7 +81,12 @@ from .utils import (
strtobool, strtobool,
) )
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_sagemaker_mp_enabled,
is_torch_fx_proxy,
is_torchdynamo_compiling,
)
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod
from .utils.versions import require_version_core from .utils.versions import require_version_core
...@@ -3799,7 +3804,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -3799,7 +3804,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
""" """
# Skip the check during tracing. # Skip the check during tracing.
if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing(): if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
return return
if (attention_mask is not None) or (self.config.pad_token_id is None): if (attention_mask is not None) or (self.config.pad_token_id is None):
......
...@@ -463,6 +463,17 @@ def is_torch_compile_available(): ...@@ -463,6 +463,17 @@ def is_torch_compile_available():
return hasattr(torch, "compile") return hasattr(torch, "compile")
def is_torchdynamo_compiling():
if not is_torch_available():
return False
try:
import torch._dynamo as dynamo # noqa: F401
return dynamo.is_compiling()
except Exception:
return False
def is_torch_tensorrt_fx_available(): def is_torch_tensorrt_fx_available():
if importlib.util.find_spec("torch_tensorrt") is None: if importlib.util.find_spec("torch_tensorrt") is None:
return False return False
......
...@@ -55,6 +55,7 @@ from transformers.utils import ( ...@@ -55,6 +55,7 @@ from transformers.utils import (
WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from transformers.utils.import_utils import is_torchdynamo_available
sys.path.append(str(Path(__file__).parent.parent / "utils")) sys.path.append(str(Path(__file__).parent.parent / "utils"))
...@@ -1014,6 +1015,25 @@ class ModelUtilsTest(TestCasePlus): ...@@ -1014,6 +1015,25 @@ class ModelUtilsTest(TestCasePlus):
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None) model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out) self.assertIn("You may ignore this warning if your `pad_token_id`", cl.out)
if not is_torchdynamo_available():
return
with self.subTest("Ensure that the warning code is skipped when compiling with torchdynamo."):
logger.warning_once.cache_clear()
from torch._dynamo import config, testing
config = PretrainedConfig()
config.pad_token_id = 0
model = ModelWithHead(config)
input_ids = torch.tensor([[0, 345, 232, 328, 740, 140, 1695, 69, 6078, 432, 5232]])
def f(input_ids):
model.warn_if_padding_and_no_attention_mask(input_ids, attention_mask=None)
compile_counter = testing.CompileCounter()
opt_fn = torch.compile(f, dynamic=True, backend=compile_counter)
opt_fn(input_ids)
self.assertEqual(compile_counter.frame_count, 0)
@require_torch_gpu @require_torch_gpu
@slow @slow
def test_pretrained_low_mem_new_config(self): def test_pretrained_low_mem_new_config(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment