"src/webui/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "8531d8097f1b4cd4225c52b1aa01a3f85554bc11"
Unverified Commit 2e12b907 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

TF generate refactor - Greedy Search (#15562)



* TF generate start refactor

* Add tf tests for sample generate

* re-organize

* boom boom

* Apply suggestions from code review

* re-add

* add all code

* make random greedy pass

* make encoder-decoder random work

* further improvements

* delete bogus file

* make gpt2 and t5 tests work

* finish logits tests

* correct logits processors

* correct past / encoder_outputs drama

* refactor some methods

* another fix

* refactor shape_list

* fix more shape list

* import shape
_list

* finish docs

* fix imports

* make style

* correct tf utils

* Fix TFRag as well

* Apply Lysandre's and Sylvais suggestions

* Update tests/test_generation_tf_logits_process.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* Update src/transformers/tf_utils.py
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>

* remove cpu according to gante

* correct logit processor
Co-authored-by: default avatarMatt <Rocketknight1@users.noreply.github.com>
parent a3dbbc34
...@@ -148,6 +148,24 @@ generation. ...@@ -148,6 +148,24 @@ generation.
[[autodoc]] InfNanRemoveLogitsProcessor [[autodoc]] InfNanRemoveLogitsProcessor
- __call__ - __call__
[[autodoc]] TFLogitsProcessor
- __call__
[[autodoc]] TFLogitsProcessorList
- __call__
[[autodoc]] TFMinLengthLogitsProcessor
- __call__
[[autodoc]] TFNoBadWordsLogitsProcessor
- __call__
[[autodoc]] TFNoRepeatNGramLogitsProcessor
- __call__
[[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] FlaxLogitsProcessor [[autodoc]] FlaxLogitsProcessor
- __call__ - __call__
......
...@@ -1592,6 +1592,14 @@ if is_tf_available(): ...@@ -1592,6 +1592,14 @@ if is_tf_available():
_import_structure["activations_tf"] = [] _import_structure["activations_tf"] = []
_import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"] _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"]
_import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"] _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"]
_import_structure["generation_tf_logits_process"] = [
"TFLogitsProcessor",
"TFLogitsProcessorList",
"TFMinLengthLogitsProcessor",
"TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor",
]
_import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"] _import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"]
_import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"] _import_structure["keras_callbacks"] = ["KerasMetricCallback", "PushToHubCallback"]
_import_structure["modeling_tf_outputs"] = [] _import_structure["modeling_tf_outputs"] = []
...@@ -2046,6 +2054,7 @@ if is_tf_available(): ...@@ -2046,6 +2054,7 @@ if is_tf_available():
] ]
) )
_import_structure["optimization_tf"] = ["AdamWeightDecay", "GradientAccumulator", "WarmUp", "create_optimizer"] _import_structure["optimization_tf"] = ["AdamWeightDecay", "GradientAccumulator", "WarmUp", "create_optimizer"]
_import_structure["tf_utils"] = []
_import_structure["trainer_tf"] = ["TFTrainer"] _import_structure["trainer_tf"] = ["TFTrainer"]
else: else:
...@@ -3572,6 +3581,14 @@ if TYPE_CHECKING: ...@@ -3572,6 +3581,14 @@ if TYPE_CHECKING:
# Benchmarks # Benchmarks
from .benchmark.benchmark_tf import TensorFlowBenchmark from .benchmark.benchmark_tf import TensorFlowBenchmark
from .generation_tf_logits_process import (
TFLogitsProcessor,
TFLogitsProcessorList,
TFMinLengthLogitsProcessor,
TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor,
)
from .generation_tf_utils import tf_top_k_top_p_filtering from .generation_tf_utils import tf_top_k_top_p_filtering
from .keras_callbacks import KerasMetricCallback, PushToHubCallback from .keras_callbacks import KerasMetricCallback, PushToHubCallback
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from abc import ABC
import jax import jax
import jax.lax as lax import jax.lax as lax
...@@ -48,7 +47,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" ...@@ -48,7 +47,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
""" """
class FlaxLogitsProcessor(ABC): class FlaxLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation.""" """Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
...@@ -59,7 +58,7 @@ class FlaxLogitsProcessor(ABC): ...@@ -59,7 +58,7 @@ class FlaxLogitsProcessor(ABC):
) )
class FlaxLogitsWarper(ABC): class FlaxLogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
......
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
import inspect import inspect
import math import math
from abc import ABC
from typing import Callable, Iterable, List, Optional from typing import Callable, Iterable, List, Optional
import numpy as np import numpy as np
...@@ -49,7 +48,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r""" ...@@ -49,7 +48,7 @@ LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
""" """
class LogitsProcessor(ABC): class LogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation.""" """Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
...@@ -60,7 +59,7 @@ class LogitsProcessor(ABC): ...@@ -60,7 +59,7 @@ class LogitsProcessor(ABC):
) )
class LogitsWarper(ABC): class LogitsWarper:
"""Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
@add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING) @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List
import numpy as np
import tensorflow as tf
from .file_utils import add_start_docstrings
from .tf_utils import set_tensor_by_indices_to_value
from .utils.logging import get_logger
logger = get_logger(__name__)
TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
Args:
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
search or log softmax for each vocabulary token when using beam search
kwargs:
Additional logits processor specific kwargs.
Return:
`tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
"""
class TFLogitsProcessor:
"""Abstract base class for all logit processors that can be applied during generation."""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
"""TF method for processing logits."""
raise NotImplementedError(
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
)
class TFLogitsProcessorList(list):
"""
This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
This class inherits from list and adds a specific *__call__* method to apply each [`TFLogitsProcessor`] to the
inputs.
"""
@add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, **kwargs) -> tf.Tensor:
for processor in self:
function_args = inspect.signature(processor.__call__).parameters
if len(function_args) > 2:
if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
raise ValueError(
f"Make sure that all the required parameters: {list(function_args.keys())} for "
f"{processor.__class__} are passed to the logits processor."
)
scores = processor(input_ids, scores, **kwargs)
else:
scores = processor(input_ids, scores)
return scores
class TFMinLengthLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
Args:
min_length (`int`):
The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""
def __init__(self, min_length: int, eos_token_id: int):
if not isinstance(min_length, int) or min_length < 0:
raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
if not isinstance(eos_token_id, int) or eos_token_id < 0:
raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
self.min_length = min_length
self.eos_token_id = eos_token_id
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
# create boolean flag to decide if min length penalty should be applied
cur_len = input_ids.shape[-1]
apply_penalty = 1 - tf.clip_by_value(cur_len - self.min_length, 0, 1)
# TODO(Matt) - this if statement has to be rewritten for XLA. Leaving it now though since
# generate is not XLA - compileable anyways
if apply_penalty:
eos_token_id_mask = tf.broadcast_to(tf.range(scores.shape[-1]) == self.eos_token_id, scores.shape)
scores = set_tensor_by_indices_to_value(scores, eos_token_id_mask, float("-inf"))
return scores
class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences.
Args:
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
"""
def __init__(self, penalty: float):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
self.penalty = penalty
def _create_score_penalties(self, input_ids, logits):
# create logit penalties for already seen input_ids
token_penalties = np.ones(logits.shape)
prev_input_ids = [np.unique(input_id) for input_id in input_ids.numpy()]
for i, prev_input_id in enumerate(prev_input_ids):
logit_penalized = logits[i].numpy()[prev_input_id]
logit_penalties = np.zeros(logit_penalized.shape)
# if previous logit score is < 0 then multiply repetition penalty else divide
logit_penalties[logit_penalized < 0] = self.penalty
logit_penalties[logit_penalized > 0] = 1 / self.penalty
np.put(token_penalties[i], prev_input_id, logit_penalties)
return tf.convert_to_tensor(token_penalties, dtype=tf.float32)
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
score_penalties = self._create_score_penalties(input_ids, scores)
scores = tf.math.multiply(scores, score_penalties)
return scores
class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
"""
[`TFLogitsProcessor`] that enforces that specified sequences will never be sampled.
Args:
bad_words_ids (`List[List[int]]`):
List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
that should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True).input_ids`.
eos_token_id (`int`):
The id of the *end-of-sequence* token.
"""
def __init__(self, bad_words_ids: List[List[int]], eos_token_id: int):
if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0:
raise ValueError(f"`bad_words_ids` has to be a non-emtpy list, but is {bad_words_ids}.")
if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
if any(
any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
for bad_word_ids in bad_words_ids
):
raise ValueError(
f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
)
self.bad_words_ids = bad_words_ids
def calc_banned_bad_words_ids(self, prev_input_ids):
banned_tokens = []
def _tokens_match(prev_tokens, tokens):
if len(tokens) == 0:
# if bad word tokens is just one token always ban it
return True
if len(tokens) > len(prev_tokens):
# if bad word tokens are longer than prev tokens they can't be equal
return False
if prev_tokens[-len(tokens) :] == tokens:
# if tokens match
return True
else:
return False
for prev_input_ids_slice in prev_input_ids:
banned_tokens_slice = []
for banned_token_seq in self.bad_words_ids:
assert (
len(banned_token_seq) > 0
), f"Banned words token sequences {self.bad_words_ids} cannot have an empty list"
if _tokens_match(prev_input_ids_slice.numpy().tolist(), banned_token_seq[:-1]) is False:
# if tokens do not match continue
continue
banned_tokens_slice.append(banned_token_seq[-1])
banned_tokens.append(banned_tokens_slice)
return banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
vocab_size = scores.shape[-1]
# calculate a list of banned tokens according to bad words
banned_tokens = self.calc_banned_bad_words_ids(input_ids)
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
return scores
class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
r"""
[`TFLogitsProcessor`] that enforces no repetition of n-grams. See
[Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
Args:
ngram_size (`int`):
All ngrams of size `ngram_size` can only occur once.
"""
def __init__(self, ngram_size: int):
if not isinstance(ngram_size, int) or ngram_size <= 0:
raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
self.ngram_size = ngram_size
def calc_banned_ngram_tokens(self, prev_input_ids, num_hypos, cur_len):
# Copied from fairseq for no_repeat_ngram in beam_search
if cur_len + 1 < self.ngram_size:
# return no banned tokens if we haven't generated ngram_size tokens yet
return [[] for _ in range(num_hypos)]
generated_ngrams = [{} for _ in range(num_hypos)]
for idx in range(num_hypos):
gen_tokens = prev_input_ids[idx].numpy().tolist()
generated_ngram = generated_ngrams[idx]
for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
prev_ngram_tuple = tuple(ngram[:-1])
generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
def _get_generated_ngrams(hypo_idx):
# Before decoding the next token, prevent decoding of ngrams that have already appeared
start_idx = cur_len + 1 - self.ngram_size
ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
return generated_ngrams[hypo_idx].get(ngram_idx, [])
banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
return banned_tokens
def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor) -> tf.Tensor:
batch_size, vocab_size = scores.shape
cur_len = input_ids.shape[-1]
banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
# create banned_tokens boolean mask
banned_tokens_indices_mask = []
for banned_tokens_slice in banned_tokens:
banned_tokens_indices_mask.append(
[True if token in banned_tokens_slice else False for token in range(vocab_size)]
)
scores = set_tensor_by_indices_to_value(
scores, tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf")
)
return scores
This diff is collapsed.
...@@ -54,6 +54,7 @@ from .file_utils import ( ...@@ -54,6 +54,7 @@ from .file_utils import (
) )
from .generation_tf_utils import TFGenerationMixin from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_outputs import TFSeq2SeqLMOutput from .modeling_tf_outputs import TFSeq2SeqLMOutput
from .tf_utils import shape_list
from .tokenization_utils_base import BatchEncoding from .tokenization_utils_base import BatchEncoding
from .utils import logging from .utils import logging
...@@ -2041,29 +2042,6 @@ class TFSequenceSummary(tf.keras.layers.Layer): ...@@ -2041,29 +2042,6 @@ class TFSequenceSummary(tf.keras.layers.Layer):
cls._auto_class = auto_class cls._auto_class = auto_class
def shape_list(tensor: Union[tf.Tensor, np.ndarray]) -> List[int]:
"""
Deal with dynamic shape in tensorflow cleanly.
Args:
tensor (`tf.Tensor` or `np.ndarray`): The tensor we want the shape of.
Returns:
`List[int]`: The shape of the tensor as a list.
"""
if isinstance(tensor, np.ndarray):
return list(tensor.shape)
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal: def get_initializer(initializer_range: float = 0.02) -> tf.initializers.TruncatedNormal:
""" """
Creates a `tf.initializers.TruncatedNormal` with the given range. Creates a `tf.initializers.TruncatedNormal` with the given range.
......
...@@ -51,8 +51,8 @@ from ...modeling_tf_utils import ( ...@@ -51,8 +51,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_albert import AlbertConfig from .configuration_albert import AlbertConfig
......
...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( ...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_bart import BartConfig from .configuration_bart import BartConfig
......
...@@ -57,8 +57,8 @@ from ...modeling_tf_utils import ( ...@@ -57,8 +57,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_bert import BertConfig from .configuration_bert import BertConfig
......
...@@ -46,8 +46,8 @@ from ...modeling_tf_utils import ( ...@@ -46,8 +46,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_blenderbot import BlenderbotConfig from .configuration_blenderbot import BlenderbotConfig
......
...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import ( ...@@ -44,8 +44,8 @@ from ...modeling_tf_utils import (
TFWrappedEmbeddings, TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_blenderbot_small import BlenderbotSmallConfig from .configuration_blenderbot_small import BlenderbotSmallConfig
......
...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( ...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
......
...@@ -43,8 +43,8 @@ from ...modeling_tf_utils import ( ...@@ -43,8 +43,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_convbert import ConvBertConfig from .configuration_convbert import ConvBertConfig
......
...@@ -30,8 +30,8 @@ from ...modeling_tf_utils import ( ...@@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_ctrl import CTRLConfig from .configuration_ctrl import CTRLConfig
......
...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import ( ...@@ -39,8 +39,8 @@ from ...modeling_tf_utils import (
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing, input_processing,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_deberta import DebertaConfig from .configuration_deberta import DebertaConfig
......
...@@ -38,8 +38,8 @@ from ...modeling_tf_utils import ( ...@@ -38,8 +38,8 @@ from ...modeling_tf_utils import (
TFTokenClassificationLoss, TFTokenClassificationLoss,
get_initializer, get_initializer,
input_processing, input_processing,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_deberta_v2 import DebertaV2Config from .configuration_deberta_v2 import DebertaV2Config
......
...@@ -45,8 +45,8 @@ from ...modeling_tf_utils import ( ...@@ -45,8 +45,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_distilbert import DistilBertConfig from .configuration_distilbert import DistilBertConfig
......
...@@ -50,8 +50,8 @@ from ...modeling_tf_utils import ( ...@@ -50,8 +50,8 @@ from ...modeling_tf_utils import (
get_initializer, get_initializer,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list,
) )
from ...tf_utils import shape_list
from ...utils import logging from ...utils import logging
from .configuration_electra import ElectraConfig from .configuration_electra import ElectraConfig
......
...@@ -30,13 +30,8 @@ from ...file_utils import ( ...@@ -30,13 +30,8 @@ from ...file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput from ...modeling_tf_outputs import TFBaseModelOutput, TFSeq2SeqLMOutput
from ...modeling_tf_utils import ( from ...modeling_tf_utils import TFCausalLanguageModelingLoss, TFPreTrainedModel, get_initializer, input_processing
TFCausalLanguageModelingLoss, from ...tf_utils import shape_list
TFPreTrainedModel,
get_initializer,
input_processing,
shape_list,
)
from ...utils import logging from ...utils import logging
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM from ..auto.modeling_tf_auto import TFAutoModel, TFAutoModelForCausalLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment