Unverified Commit 42b8940b authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

[FX] _generate_dummy_input supports audio-classification models for labels (#18580)

* Support audio classification architectures for labels generation, as well as provides a flag to print warnings or not

* Use ENV_VARS_TRUE_VALUES
parent d53dffec
...@@ -19,6 +19,7 @@ import functools ...@@ -19,6 +19,7 @@ import functools
import inspect import inspect
import math import math
import operator import operator
import os
import random import random
import warnings import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union from typing import Any, Callable, Dict, List, Optional, Type, Union
...@@ -48,11 +49,12 @@ from ..models.auto.modeling_auto import ( ...@@ -48,11 +49,12 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES, MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES, MODEL_MAPPING_NAMES,
) )
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata from ..utils.versions import importlib_metadata
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
def _generate_supported_model_class_names( def _generate_supported_model_class_names(
...@@ -678,7 +680,12 @@ class HFTracer(Tracer): ...@@ -678,7 +680,12 @@ class HFTracer(Tracer):
if input_name in ["labels", "start_positions", "end_positions"]: if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0] batch_size = shape[0]
if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES): if model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [ elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES), *get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
...@@ -710,11 +717,6 @@ class HFTracer(Tracer): ...@@ -710,11 +717,6 @@ class HFTracer(Tracer):
) )
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device) inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [ elif model_class_name in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES), *get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
...@@ -725,7 +727,9 @@ class HFTracer(Tracer): ...@@ -725,7 +727,9 @@ class HFTracer(Tracer):
]: ]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device) inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else: else:
raise NotImplementedError(f"{model_class_name} not supported yet.") raise NotImplementedError(
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
)
elif "pixel_values" in input_name: elif "pixel_values" in input_name:
batch_size = shape[0] batch_size = shape[0]
image_size = getattr(model.config, "image_size", None) image_size = getattr(model.config, "image_size", None)
...@@ -846,7 +850,8 @@ class HFTracer(Tracer): ...@@ -846,7 +850,8 @@ class HFTracer(Tracer):
raise ValueError("Don't support composite output yet") raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out) rv.install_metadata(meta_out)
except Exception as e: except Exception as e:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") if _IS_IN_DEBUG_MODE:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
return rv return rv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment