Unverified Commit 42b8940b authored by Michael Benayoun's avatar Michael Benayoun Committed by GitHub
Browse files

[FX] _generate_dummy_input supports audio-classification models for labels (#18580)

* Support audio classification architectures for labels generation, as well as provides a flag to print warnings or not

* Use ENV_VARS_TRUE_VALUES
parent d53dffec
......@@ -19,6 +19,7 @@ import functools
import inspect
import math
import operator
import os
import random
import warnings
from typing import Any, Callable, Dict, List, Optional, Type, Union
......@@ -48,11 +49,12 @@ from ..models.auto.modeling_auto import (
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES,
MODEL_MAPPING_NAMES,
)
from ..utils import TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils import ENV_VARS_TRUE_VALUES, TORCH_FX_REQUIRED_VERSION, is_torch_fx_available
from ..utils.versions import importlib_metadata
logger = logging.get_logger(__name__)
_IS_IN_DEBUG_MODE = os.environ.get("FX_DEBUG_MODE", "").upper() in ENV_VARS_TRUE_VALUES
def _generate_supported_model_class_names(
......@@ -678,7 +680,12 @@ class HFTracer(Tracer):
if input_name in ["labels", "start_positions", "end_positions"]:
batch_size = shape[0]
if model_class_name in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES):
if model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
*get_values(MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES),
......@@ -710,11 +717,6 @@ class HFTracer(Tracer):
)
inputs_dict["labels"] = torch.zeros(*labels_shape, dtype=labels_dtype, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING_NAMES),
*get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES),
]:
inputs_dict["labels"] = torch.zeros(batch_size, dtype=torch.long, device=device)
elif model_class_name in [
*get_values(MODEL_FOR_PRETRAINING_MAPPING_NAMES),
*get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES),
......@@ -725,7 +727,9 @@ class HFTracer(Tracer):
]:
inputs_dict["labels"] = torch.zeros(shape, dtype=torch.long, device=device)
else:
raise NotImplementedError(f"{model_class_name} not supported yet.")
raise NotImplementedError(
f"Generating the dummy input named {input_name} for {model_class_name} is not supported yet."
)
elif "pixel_values" in input_name:
batch_size = shape[0]
image_size = getattr(model.config, "image_size", None)
......@@ -846,6 +850,7 @@ class HFTracer(Tracer):
raise ValueError("Don't support composite output yet")
rv.install_metadata(meta_out)
except Exception as e:
if _IS_IN_DEBUG_MODE:
warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}")
return rv
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment