Unverified Commit c796b6de authored by Mohit Sharma's avatar Mohit Sharma Committed by GitHub
Browse files

Added onnx config whisper (#19525)

* Added onnx config whisper

* added whisper support onnx

* add audio input data

* added whisper support onnx

* fixed the seqlength value

* Updated the whisper onnx ocnfig

* restore files to old version

* removed attention mask from inputs

* Updated get_dummy_input_onnxruntime docstring

* Updated relative imports and token generation

* update docstring
parent c3a93d8d
......@@ -100,6 +100,7 @@ Ready-made configurations include the following architectures:
- Table Transformer
- Vision Encoder decoder
- ViT
- Whisper
- XLM
- XLM-RoBERTa
- XLM-RoBERTa-XL
......
......@@ -21,7 +21,7 @@ from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_tf_availabl
_import_structure = {
"configuration_whisper": ["WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP", "WhisperConfig"],
"configuration_whisper": ["WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP", "WhisperConfig", "WhisperOnnxConfig"],
"feature_extraction_whisper": ["WhisperFeatureExtractor"],
"processing_whisper": ["WhisperProcessor"],
"tokenization_whisper": ["WhisperTokenizer"],
......@@ -55,7 +55,7 @@ else:
]
if TYPE_CHECKING:
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig
from .configuration_whisper import WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP, WhisperConfig, WhisperOnnxConfig
from .feature_extraction_whisper import WhisperFeatureExtractor
from .processing_whisper import WhisperProcessor
from .tokenization_whisper import WhisperTokenizer
......
......@@ -14,10 +14,19 @@
# limitations under the License.
""" Whisper model configuration"""
from collections import OrderedDict
from typing import TYPE_CHECKING, Any, Mapping, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig, OnnxSeq2SeqConfigWithPast
from ...utils import logging
if TYPE_CHECKING:
from ...feature_extraction_utils import FeatureExtractionMixin
from ...tokenization_utils_base import PreTrainedTokenizerBase
from ...utils import TensorType
logger = logging.get_logger(__name__)
WHISPER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
......@@ -214,3 +223,59 @@ class WhisperConfig(PretrainedConfig):
begin_suppress_tokens=begin_suppress_tokens,
**kwargs,
)
class WhisperOnnxConfig(OnnxSeq2SeqConfigWithPast):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
common_inputs = OrderedDict(
[
("input_features", {0: "batch", 1: "feature_size", 2: "encoder_sequence"}),
]
)
if self.use_past:
common_inputs["decoder_input_ids"] = {0: "batch"}
else:
common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
if self.use_past:
self.fill_with_past_key_values_(common_inputs, direction="inputs")
return common_inputs
def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1,
seq_length: int = -1,
is_pair: bool = False,
framework: Optional["TensorType"] = None,
sampling_rate: int = 22050,
time_duration: float = 5.0,
frequency: int = 220,
) -> Mapping[str, Any]:
dummy_inputs = OrderedDict()
encoder_inputs = OnnxConfig.generate_dummy_inputs(
self,
preprocessor=preprocessor.feature_extractor,
batch_size=batch_size,
framework=framework,
sampling_rate=sampling_rate,
time_duration=time_duration,
frequency=frequency,
)
decoder_inputs = super().generate_dummy_inputs(
preprocessor.tokenizer, batch_size, seq_length, is_pair, framework
)
dummy_inputs["input_features"] = encoder_inputs.pop("input_features")
dummy_inputs["decoder_input_ids"] = decoder_inputs.pop("decoder_input_ids")
if "past_key_values" in decoder_inputs:
dummy_inputs["past_key_values"] = decoder_inputs.pop("past_key_values")
return dummy_inputs
@property
def atol_for_validation(self) -> float:
return 1e-3
......@@ -104,6 +104,7 @@ class OnnxConfig(ABC):
"sequence-classification": OrderedDict({"logits": {0: "batch"}}),
"token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
}
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: List[PatchingSpec] = None):
......@@ -262,6 +263,19 @@ class OnnxConfig(ABC):
images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
return images
def _generate_dummy_audio(
self, batch_size: int = 2, sampling_rate: int = 22050, time_duration: float = 5.0, frequency: int = 220
):
audio_data = []
for _ in range(batch_size):
# time variable
t = np.linspace(0, time_duration, int(time_duration * sampling_rate), endpoint=False)
# generate pure sine wave at `frequency` Hz
audio_data.append(0.5 * np.sin(2 * np.pi * frequency * t))
return audio_data
def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
......@@ -273,6 +287,9 @@ class OnnxConfig(ABC):
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,
sampling_rate: int = 22050,
time_duration: float = 5.0,
frequency: int = 220,
tokenizer: "PreTrainedTokenizerBase" = None,
) -> Mapping[str, Any]:
"""
......@@ -297,6 +314,12 @@ class OnnxConfig(ABC):
The width of the generated images.
image_height (`int`, *optional*, defaults to 40):
The height of the generated images.
sampling_rate (`int`, *optional* defaults to 22050)
The sampling rate for audio data generation.
time_duration (`float`, *optional* defaults to 5.0)
Total seconds of sampling for audio data generation.
frequency (`int`, *optional* defaults to 220)
The desired natural frequency of generated audio.
Returns:
Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
......@@ -325,7 +348,12 @@ class OnnxConfig(ABC):
seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
)
# Generate dummy inputs according to compute batch and sequence
dummy_input = [" ".join([preprocessor.unk_token]) * seq_length] * batch_size
input_token = (
preprocessor.unk_token
if (preprocessor.unk_token is not None and len(preprocessor.unk_token) > 0)
else "0"
)
dummy_input = [" ".join([input_token]) * seq_length] * batch_size
if self.task == "multiple-choice":
# If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations
# made by ONNX
......@@ -345,11 +373,32 @@ class OnnxConfig(ABC):
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
return dict(preprocessor(images=dummy_input, return_tensors=framework))
elif (
isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "input_features"
):
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
dummy_input = self._generate_dummy_audio(batch_size, sampling_rate, time_duration, frequency)
return dict(preprocessor(dummy_input, return_tensors=framework))
else:
raise ValueError(
"Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
)
def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
"""
Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq
models which have the encoder and decoder exported as separate ONNX files.
Args:
reference_model_inputs ([`Mapping[str, Tensor]`):
Reference inputs for the model.
Returns:
`Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function
"""
return reference_model_inputs
def patch_ops(self):
for spec in self._patching_specs:
custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
......
......@@ -145,7 +145,21 @@ def export_pytorch(
device = torch.device(device)
if device.type == "cuda" and torch.cuda.is_available():
model.to(device)
model_inputs = dict((k, v.to(device)) for k, v in model_inputs.items())
model_inputs_device = dict()
for k, v in model_inputs.items():
if isinstance(v, Tuple):
model_inputs_device[k] = tuple(
x.to(device) if isinstance(x, torch.Tensor) else None for x in v
)
elif isinstance(v, List):
model_inputs_device[k] = [
tuple(x.to(device) if isinstance(x, torch.Tensor) else None for x in t) for t in v
]
else:
model_inputs_device[k] = v.to(device)
model_inputs = model_inputs_device
inputs_match, matched_inputs = ensure_model_and_config_inputs_match(model, model_inputs.keys())
onnx_outputs = list(config.outputs.keys())
......@@ -404,9 +418,12 @@ def validate_model_outputs(
else:
ref_outputs_dict[name] = value
# Create onnxruntime inputs from the reference model inputs
reference_model_inputs_onnxruntime = config.generate_dummy_inputs_onnxruntime(reference_model_inputs)
# We flatten potential collection of inputs (i.e. past_keys)
onnx_inputs = {}
for name, value in reference_model_inputs.items():
for name, value in reference_model_inputs_onnxruntime.items():
if isinstance(value, (list, tuple)):
value = config.flatten_output_collection_property(name, value)
onnx_inputs.update({tensor_name: pt_tensor.numpy() for tensor_name, pt_tensor in value.items()})
......
......@@ -29,6 +29,7 @@ if is_torch_available():
AutoModelForSemanticSegmentation,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForSpeechSeq2Seq,
AutoModelForTokenClassification,
AutoModelForVision2Seq,
)
......@@ -100,6 +101,7 @@ class FeaturesManager:
"masked-im": AutoModelForMaskedImageModeling,
"semantic-segmentation": AutoModelForSemanticSegmentation,
"vision2seq-lm": AutoModelForVision2Seq,
"speech2seq-lm": AutoModelForSpeechSeq2Seq,
}
if is_tf_available():
_TASKS_TO_TF_AUTOMODELS = {
......@@ -492,6 +494,13 @@ class FeaturesManager:
"vit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.vit.ViTOnnxConfig"
),
"whisper": supported_features_mapping(
"default",
"default-with-past",
"speech2seq-lm",
"speech2seq-lm-with-past",
onnx_config_cls="models.whisper.WhisperOnnxConfig",
),
"xlm": supported_features_mapping(
"default",
"masked-lm",
......
......@@ -218,6 +218,7 @@ PYTORCH_EXPORT_MODELS = {
("yolos", "hustvl/yolos-tiny"),
("segformer", "nvidia/segformer-b0-finetuned-ade-512-512"),
("swin", "microsoft/swin-tiny-patch4-window7-224"),
("whisper", "openai/whisper-tiny.en"),
}
PYTORCH_EXPORT_ENCODER_DECODER_MODELS = {
......@@ -398,7 +399,7 @@ class OnnxExportTestCaseV2(TestCase):
preprocessor = AutoTokenizer.from_pretrained(model_name)
with NamedTemporaryFile("w") as decoder_output:
onnx_inputs, onnx_outputs = export(
_, onnx_outputs = export(
preprocessor,
decoder_model,
decoder_onnx_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment