Unverified Commit 85cf90a1 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: add missing logits processors docs (#25653)

parent cb8e3ee2
...@@ -75,39 +75,104 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`. ...@@ -75,39 +75,104 @@ values. Here, for instance, it has two keys that are `sequences` and `scores`.
We document here all output types. We document here all output types.
### GreedySearchOutput ### PyTorch
[[autodoc]] generation.GreedySearchDecoderOnlyOutput
[[autodoc]] generation.GreedySearchEncoderDecoderOutput [[autodoc]] generation.GreedySearchEncoderDecoderOutput
[[autodoc]] generation.FlaxGreedySearchOutput [[autodoc]] generation.GreedySearchDecoderOnlyOutput
### SampleOutput [[autodoc]] generation.SampleEncoderDecoderOutput
[[autodoc]] generation.SampleDecoderOnlyOutput [[autodoc]] generation.SampleDecoderOnlyOutput
[[autodoc]] generation.SampleEncoderDecoderOutput [[autodoc]] generation.BeamSearchEncoderDecoderOutput
[[autodoc]] generation.FlaxSampleOutput [[autodoc]] generation.BeamSearchDecoderOnlyOutput
### BeamSearchOutput [[autodoc]] generation.BeamSampleEncoderDecoderOutput
[[autodoc]] generation.BeamSearchDecoderOnlyOutput [[autodoc]] generation.BeamSampleDecoderOnlyOutput
[[autodoc]] generation.BeamSearchEncoderDecoderOutput [[autodoc]] generation.ContrastiveSearchEncoderDecoderOutput
### BeamSampleOutput [[autodoc]] generation.ContrastiveSearchDecoderOnlyOutput
[[autodoc]] generation.BeamSampleDecoderOnlyOutput ### TensorFlow
[[autodoc]] generation.BeamSampleEncoderDecoderOutput [[autodoc]] generation.TFGreedySearchEncoderDecoderOutput
[[autodoc]] generation.TFGreedySearchDecoderOnlyOutput
[[autodoc]] generation.TFSampleEncoderDecoderOutput
[[autodoc]] generation.TFSampleDecoderOnlyOutput
[[autodoc]] generation.TFBeamSearchEncoderDecoderOutput
[[autodoc]] generation.TFBeamSearchDecoderOnlyOutput
[[autodoc]] generation.TFBeamSampleEncoderDecoderOutput
[[autodoc]] generation.TFBeamSampleDecoderOnlyOutput
[[autodoc]] generation.TFContrastiveSearchEncoderDecoderOutput
[[autodoc]] generation.TFContrastiveSearchDecoderOnlyOutput
### FLAX
[[autodoc]] generation.FlaxSampleOutput
[[autodoc]] generation.FlaxGreedySearchOutput
[[autodoc]] generation.FlaxBeamSearchOutput
## LogitsProcessor ## LogitsProcessor
A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for A [`LogitsProcessor`] can be used to modify the prediction scores of a language model head for
generation. generation.
### PyTorch
[[autodoc]] AlternatingCodebooksLogitsProcessor
- __call__
[[autodoc]] ClassifierFreeGuidanceLogitsProcessor
- __call__
[[autodoc]] EncoderNoRepeatNGramLogitsProcessor
- __call__
[[autodoc]] EncoderRepetitionPenaltyLogitsProcessor
- __call__
[[autodoc]] EpsilonLogitsWarper
- __call__
[[autodoc]] EtaLogitsWarper
- __call__
[[autodoc]] ExponentialDecayLengthPenalty
- __call__
[[autodoc]] ForcedBOSTokenLogitsProcessor
- __call__
[[autodoc]] ForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] ForceTokensLogitsProcessor
- __call__
[[autodoc]] HammingDiversityLogitsProcessor
- __call__
[[autodoc]] InfNanRemoveLogitsProcessor
- __call__
[[autodoc]] LogitNormalization
- __call__
[[autodoc]] LogitsProcessor [[autodoc]] LogitsProcessor
- __call__ - __call__
...@@ -123,61 +188,63 @@ generation. ...@@ -123,61 +188,63 @@ generation.
[[autodoc]] MinNewTokensLengthLogitsProcessor [[autodoc]] MinNewTokensLengthLogitsProcessor
- __call__ - __call__
[[autodoc]] TemperatureLogitsWarper [[autodoc]] NoBadWordsLogitsProcessor
- __call__ - __call__
[[autodoc]] RepetitionPenaltyLogitsProcessor [[autodoc]] NoRepeatNGramLogitsProcessor
- __call__ - __call__
[[autodoc]] TopPLogitsWarper [[autodoc]] PrefixConstrainedLogitsProcessor
- __call__ - __call__
[[autodoc]] TopKLogitsWarper [[autodoc]] RepetitionPenaltyLogitsProcessor
- __call__ - __call__
[[autodoc]] TypicalLogitsWarper [[autodoc]] SequenceBiasLogitsProcessor
- __call__ - __call__
[[autodoc]] NoRepeatNGramLogitsProcessor [[autodoc]] SuppressTokensAtBeginLogitsProcessor
- __call__ - __call__
[[autodoc]] SequenceBiasLogitsProcessor [[autodoc]] SuppressTokensLogitsProcessor
- __call__ - __call__
[[autodoc]] NoBadWordsLogitsProcessor [[autodoc]] TemperatureLogitsWarper
- __call__ - __call__
[[autodoc]] PrefixConstrainedLogitsProcessor [[autodoc]] TopKLogitsWarper
- __call__ - __call__
[[autodoc]] HammingDiversityLogitsProcessor [[autodoc]] TopPLogitsWarper
- __call__ - __call__
[[autodoc]] ForcedBOSTokenLogitsProcessor [[autodoc]] TypicalLogitsWarper
- __call__ - __call__
[[autodoc]] ForcedEOSTokenLogitsProcessor [[autodoc]] UnbatchedClassifierFreeGuidanceLogitsProcessor
- __call__ - __call__
[[autodoc]] InfNanRemoveLogitsProcessor [[autodoc]] WhisperTimeStampLogitsProcessor
- __call__ - __call__
[[autodoc]] TFLogitsProcessor ### TensorFlow
[[autodoc]] TFForcedBOSTokenLogitsProcessor
- __call__ - __call__
[[autodoc]] TFLogitsProcessorList [[autodoc]] TFForcedEOSTokenLogitsProcessor
- __call__ - __call__
[[autodoc]] TFLogitsWarper [[autodoc]] TFForceTokensLogitsProcessor
- __call__ - __call__
[[autodoc]] TFTemperatureLogitsWarper [[autodoc]] TFLogitsProcessor
- __call__ - __call__
[[autodoc]] TFTopPLogitsWarper [[autodoc]] TFLogitsProcessorList
- __call__ - __call__
[[autodoc]] TFTopKLogitsWarper [[autodoc]] TFLogitsWarper
- __call__ - __call__
[[autodoc]] TFMinLengthLogitsProcessor [[autodoc]] TFMinLengthLogitsProcessor
...@@ -192,10 +259,30 @@ generation. ...@@ -192,10 +259,30 @@ generation.
[[autodoc]] TFRepetitionPenaltyLogitsProcessor [[autodoc]] TFRepetitionPenaltyLogitsProcessor
- __call__ - __call__
[[autodoc]] TFForcedBOSTokenLogitsProcessor [[autodoc]] TFSuppressTokensAtBeginLogitsProcessor
- __call__ - __call__
[[autodoc]] TFForcedEOSTokenLogitsProcessor [[autodoc]] TFSuppressTokensLogitsProcessor
- __call__
[[autodoc]] TFTemperatureLogitsWarper
- __call__
[[autodoc]] TFTopKLogitsWarper
- __call__
[[autodoc]] TFTopPLogitsWarper
- __call__
### FLAX
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor
- __call__
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor
- __call__
[[autodoc]] FlaxForceTokensLogitsProcessor
- __call__ - __call__
[[autodoc]] FlaxLogitsProcessor [[autodoc]] FlaxLogitsProcessor
...@@ -207,27 +294,30 @@ generation. ...@@ -207,27 +294,30 @@ generation.
[[autodoc]] FlaxLogitsWarper [[autodoc]] FlaxLogitsWarper
- __call__ - __call__
[[autodoc]] FlaxTemperatureLogitsWarper [[autodoc]] FlaxMinLengthLogitsProcessor
- __call__ - __call__
[[autodoc]] FlaxTopPLogitsWarper [[autodoc]] FlaxSuppressTokensAtBeginLogitsProcessor
- __call__ - __call__
[[autodoc]] FlaxTopKLogitsWarper [[autodoc]] FlaxSuppressTokensLogitsProcessor
- __call__ - __call__
[[autodoc]] FlaxForcedBOSTokenLogitsProcessor [[autodoc]] FlaxTemperatureLogitsWarper
- __call__ - __call__
[[autodoc]] FlaxForcedEOSTokenLogitsProcessor [[autodoc]] FlaxTopKLogitsWarper
- __call__ - __call__
[[autodoc]] FlaxMinLengthLogitsProcessor [[autodoc]] FlaxTopPLogitsWarper
- __call__
[[autodoc]] FlaxWhisperTimeStampLogitsProcessor
- __call__ - __call__
## StoppingCriteria ## StoppingCriteria
A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). A [`StoppingCriteria`] can be used to change when to stop generation (other than EOS token). Please note that this is exclusivelly available to our PyTorch implementations.
[[autodoc]] StoppingCriteria [[autodoc]] StoppingCriteria
- __call__ - __call__
...@@ -243,7 +333,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than ...@@ -243,7 +333,7 @@ A [`StoppingCriteria`] can be used to change when to stop generation (other than
## Constraints ## Constraints
A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. A [`Constraint`] can be used to force the generation to include specific tokens or sequences in the output. Please note that this is exclusivelly available to our PyTorch implementations.
[[autodoc]] Constraint [[autodoc]] Constraint
......
...@@ -1005,17 +1005,26 @@ else: ...@@ -1005,17 +1005,26 @@ else:
_import_structure["deepspeed"] = [] _import_structure["deepspeed"] = []
_import_structure["generation"].extend( _import_structure["generation"].extend(
[ [
"AlternatingCodebooksLogitsProcessor",
"BeamScorer", "BeamScorer",
"BeamSearchScorer", "BeamSearchScorer",
"ClassifierFreeGuidanceLogitsProcessor",
"ConstrainedBeamSearchScorer", "ConstrainedBeamSearchScorer",
"Constraint", "Constraint",
"ConstraintListState", "ConstraintListState",
"DisjunctiveConstraint", "DisjunctiveConstraint",
"EncoderNoRepeatNGramLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"EpsilonLogitsWarper",
"EtaLogitsWarper",
"ExponentialDecayLengthPenalty",
"ForcedBOSTokenLogitsProcessor", "ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor",
"ForceTokensLogitsProcessor",
"GenerationMixin", "GenerationMixin",
"HammingDiversityLogitsProcessor", "HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor", "InfNanRemoveLogitsProcessor",
"LogitNormalization",
"LogitsProcessor", "LogitsProcessor",
"LogitsProcessorList", "LogitsProcessorList",
"LogitsWarper", "LogitsWarper",
...@@ -1031,10 +1040,14 @@ else: ...@@ -1031,10 +1040,14 @@ else:
"SequenceBiasLogitsProcessor", "SequenceBiasLogitsProcessor",
"StoppingCriteria", "StoppingCriteria",
"StoppingCriteriaList", "StoppingCriteriaList",
"SuppressTokensAtBeginLogitsProcessor",
"SuppressTokensLogitsProcessor",
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
"TopKLogitsWarper", "TopKLogitsWarper",
"TopPLogitsWarper", "TopPLogitsWarper",
"TypicalLogitsWarper", "TypicalLogitsWarper",
"UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
"top_k_top_p_filtering", "top_k_top_p_filtering",
] ]
) )
...@@ -3115,6 +3128,7 @@ else: ...@@ -3115,6 +3128,7 @@ else:
[ [
"TFForcedBOSTokenLogitsProcessor", "TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor",
"TFForceTokensLogitsProcessor",
"TFGenerationMixin", "TFGenerationMixin",
"TFLogitsProcessor", "TFLogitsProcessor",
"TFLogitsProcessorList", "TFLogitsProcessorList",
...@@ -3123,6 +3137,8 @@ else: ...@@ -3123,6 +3137,8 @@ else:
"TFNoBadWordsLogitsProcessor", "TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor", "TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
"TFTemperatureLogitsWarper", "TFTemperatureLogitsWarper",
"TFTopKLogitsWarper", "TFTopKLogitsWarper",
"TFTopPLogitsWarper", "TFTopPLogitsWarper",
...@@ -3836,14 +3852,18 @@ else: ...@@ -3836,14 +3852,18 @@ else:
[ [
"FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor",
"FlaxForceTokensLogitsProcessor",
"FlaxGenerationMixin", "FlaxGenerationMixin",
"FlaxLogitsProcessor", "FlaxLogitsProcessor",
"FlaxLogitsProcessorList", "FlaxLogitsProcessorList",
"FlaxLogitsWarper", "FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor", "FlaxMinLengthLogitsProcessor",
"FlaxTemperatureLogitsWarper", "FlaxTemperatureLogitsWarper",
"FlaxSuppressTokensAtBeginLogitsProcessor",
"FlaxSuppressTokensLogitsProcessor",
"FlaxTopKLogitsWarper", "FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper", "FlaxTopPLogitsWarper",
"FlaxWhisperTimeStampLogitsProcessor",
] ]
) )
_import_structure["generation_flax_utils"] = [] _import_structure["generation_flax_utils"] = []
...@@ -4983,17 +5003,26 @@ if TYPE_CHECKING: ...@@ -4983,17 +5003,26 @@ if TYPE_CHECKING:
TextDatasetForNextSentencePrediction, TextDatasetForNextSentencePrediction,
) )
from .generation import ( from .generation import (
AlternatingCodebooksLogitsProcessor,
BeamScorer, BeamScorer,
BeamSearchScorer, BeamSearchScorer,
ClassifierFreeGuidanceLogitsProcessor,
ConstrainedBeamSearchScorer, ConstrainedBeamSearchScorer,
Constraint, Constraint,
ConstraintListState, ConstraintListState,
DisjunctiveConstraint, DisjunctiveConstraint,
EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper,
EtaLogitsWarper,
ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
GenerationMixin, GenerationMixin,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor, InfNanRemoveLogitsProcessor,
LogitNormalization,
LogitsProcessor, LogitsProcessor,
LogitsProcessorList, LogitsProcessorList,
LogitsWarper, LogitsWarper,
...@@ -5009,10 +5038,14 @@ if TYPE_CHECKING: ...@@ -5009,10 +5038,14 @@ if TYPE_CHECKING:
SequenceBiasLogitsProcessor, SequenceBiasLogitsProcessor,
StoppingCriteria, StoppingCriteria,
StoppingCriteriaList, StoppingCriteriaList,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper, TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor,
WhisperTimeStampLogitsProcessor,
top_k_top_p_filtering, top_k_top_p_filtering,
) )
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
...@@ -6712,6 +6745,7 @@ if TYPE_CHECKING: ...@@ -6712,6 +6745,7 @@ if TYPE_CHECKING:
from .generation import ( from .generation import (
TFForcedBOSTokenLogitsProcessor, TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor, TFForcedEOSTokenLogitsProcessor,
TFForceTokensLogitsProcessor,
TFGenerationMixin, TFGenerationMixin,
TFLogitsProcessor, TFLogitsProcessor,
TFLogitsProcessorList, TFLogitsProcessorList,
...@@ -6720,6 +6754,8 @@ if TYPE_CHECKING: ...@@ -6720,6 +6754,8 @@ if TYPE_CHECKING:
TFNoBadWordsLogitsProcessor, TFNoBadWordsLogitsProcessor,
TFNoRepeatNGramLogitsProcessor, TFNoRepeatNGramLogitsProcessor,
TFRepetitionPenaltyLogitsProcessor, TFRepetitionPenaltyLogitsProcessor,
TFSuppressTokensAtBeginLogitsProcessor,
TFSuppressTokensLogitsProcessor,
TFTemperatureLogitsWarper, TFTemperatureLogitsWarper,
TFTopKLogitsWarper, TFTopKLogitsWarper,
TFTopPLogitsWarper, TFTopPLogitsWarper,
...@@ -7285,14 +7321,18 @@ if TYPE_CHECKING: ...@@ -7285,14 +7321,18 @@ if TYPE_CHECKING:
from .generation import ( from .generation import (
FlaxForcedBOSTokenLogitsProcessor, FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxGenerationMixin, FlaxGenerationMixin,
FlaxLogitsProcessor, FlaxLogitsProcessor,
FlaxLogitsProcessorList, FlaxLogitsProcessorList,
FlaxLogitsWarper, FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor, FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper, FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper, FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper, FlaxTopPLogitsWarper,
FlaxWhisperTimeStampLogitsProcessor,
) )
from .modeling_flax_utils import FlaxPreTrainedModel from .modeling_flax_utils import FlaxPreTrainedModel
......
...@@ -41,12 +41,19 @@ else: ...@@ -41,12 +41,19 @@ else:
"ConstrainedBeamSearchScorer", "ConstrainedBeamSearchScorer",
] ]
_import_structure["logits_process"] = [ _import_structure["logits_process"] = [
"AlternatingCodebooksLogitsProcessor",
"ClassifierFreeGuidanceLogitsProcessor",
"EncoderNoRepeatNGramLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor",
"EpsilonLogitsWarper", "EpsilonLogitsWarper",
"EtaLogitsWarper", "EtaLogitsWarper",
"ExponentialDecayLengthPenalty",
"ForcedBOSTokenLogitsProcessor", "ForcedBOSTokenLogitsProcessor",
"ForcedEOSTokenLogitsProcessor", "ForcedEOSTokenLogitsProcessor",
"ForceTokensLogitsProcessor",
"HammingDiversityLogitsProcessor", "HammingDiversityLogitsProcessor",
"InfNanRemoveLogitsProcessor", "InfNanRemoveLogitsProcessor",
"LogitNormalization",
"LogitsProcessor", "LogitsProcessor",
"LogitsProcessorList", "LogitsProcessorList",
"LogitsWarper", "LogitsWarper",
...@@ -57,15 +64,14 @@ else: ...@@ -57,15 +64,14 @@ else:
"PrefixConstrainedLogitsProcessor", "PrefixConstrainedLogitsProcessor",
"RepetitionPenaltyLogitsProcessor", "RepetitionPenaltyLogitsProcessor",
"SequenceBiasLogitsProcessor", "SequenceBiasLogitsProcessor",
"EncoderRepetitionPenaltyLogitsProcessor", "SuppressTokensLogitsProcessor",
"SuppressTokensAtBeginLogitsProcessor",
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
"TopKLogitsWarper", "TopKLogitsWarper",
"TopPLogitsWarper", "TopPLogitsWarper",
"TypicalLogitsWarper", "TypicalLogitsWarper",
"EncoderNoRepeatNGramLogitsProcessor",
"ExponentialDecayLengthPenalty",
"LogitNormalization",
"UnbatchedClassifierFreeGuidanceLogitsProcessor", "UnbatchedClassifierFreeGuidanceLogitsProcessor",
"WhisperTimeStampLogitsProcessor",
] ]
_import_structure["stopping_criteria"] = [ _import_structure["stopping_criteria"] = [
"MaxNewTokensCriteria", "MaxNewTokensCriteria",
...@@ -99,6 +105,7 @@ else: ...@@ -99,6 +105,7 @@ else:
_import_structure["tf_logits_process"] = [ _import_structure["tf_logits_process"] = [
"TFForcedBOSTokenLogitsProcessor", "TFForcedBOSTokenLogitsProcessor",
"TFForcedEOSTokenLogitsProcessor", "TFForcedEOSTokenLogitsProcessor",
"TFForceTokensLogitsProcessor",
"TFLogitsProcessor", "TFLogitsProcessor",
"TFLogitsProcessorList", "TFLogitsProcessorList",
"TFLogitsWarper", "TFLogitsWarper",
...@@ -106,12 +113,11 @@ else: ...@@ -106,12 +113,11 @@ else:
"TFNoBadWordsLogitsProcessor", "TFNoBadWordsLogitsProcessor",
"TFNoRepeatNGramLogitsProcessor", "TFNoRepeatNGramLogitsProcessor",
"TFRepetitionPenaltyLogitsProcessor", "TFRepetitionPenaltyLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
"TFTemperatureLogitsWarper", "TFTemperatureLogitsWarper",
"TFTopKLogitsWarper", "TFTopKLogitsWarper",
"TFTopPLogitsWarper", "TFTopPLogitsWarper",
"TFForceTokensLogitsProcessor",
"TFSuppressTokensAtBeginLogitsProcessor",
"TFSuppressTokensLogitsProcessor",
] ]
_import_structure["tf_utils"] = [ _import_structure["tf_utils"] = [
"TFGenerationMixin", "TFGenerationMixin",
...@@ -137,13 +143,17 @@ else: ...@@ -137,13 +143,17 @@ else:
_import_structure["flax_logits_process"] = [ _import_structure["flax_logits_process"] = [
"FlaxForcedBOSTokenLogitsProcessor", "FlaxForcedBOSTokenLogitsProcessor",
"FlaxForcedEOSTokenLogitsProcessor", "FlaxForcedEOSTokenLogitsProcessor",
"FlaxForceTokensLogitsProcessor",
"FlaxLogitsProcessor", "FlaxLogitsProcessor",
"FlaxLogitsProcessorList", "FlaxLogitsProcessorList",
"FlaxLogitsWarper", "FlaxLogitsWarper",
"FlaxMinLengthLogitsProcessor", "FlaxMinLengthLogitsProcessor",
"FlaxSuppressTokensAtBeginLogitsProcessor",
"FlaxSuppressTokensLogitsProcessor",
"FlaxTemperatureLogitsWarper", "FlaxTemperatureLogitsWarper",
"FlaxTopKLogitsWarper", "FlaxTopKLogitsWarper",
"FlaxTopPLogitsWarper", "FlaxTopPLogitsWarper",
"FlaxWhisperTimeStampLogitsProcessor",
] ]
_import_structure["flax_utils"] = [ _import_structure["flax_utils"] = [
"FlaxGenerationMixin", "FlaxGenerationMixin",
...@@ -165,6 +175,8 @@ if TYPE_CHECKING: ...@@ -165,6 +175,8 @@ if TYPE_CHECKING:
from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint from .beam_constraints import Constraint, ConstraintListState, DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .beam_search import BeamHypotheses, BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .logits_process import ( from .logits_process import (
AlternatingCodebooksLogitsProcessor,
ClassifierFreeGuidanceLogitsProcessor,
EncoderNoRepeatNGramLogitsProcessor, EncoderNoRepeatNGramLogitsProcessor,
EncoderRepetitionPenaltyLogitsProcessor, EncoderRepetitionPenaltyLogitsProcessor,
EpsilonLogitsWarper, EpsilonLogitsWarper,
...@@ -172,6 +184,7 @@ if TYPE_CHECKING: ...@@ -172,6 +184,7 @@ if TYPE_CHECKING:
ExponentialDecayLengthPenalty, ExponentialDecayLengthPenalty,
ForcedBOSTokenLogitsProcessor, ForcedBOSTokenLogitsProcessor,
ForcedEOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor,
ForceTokensLogitsProcessor,
HammingDiversityLogitsProcessor, HammingDiversityLogitsProcessor,
InfNanRemoveLogitsProcessor, InfNanRemoveLogitsProcessor,
LogitNormalization, LogitNormalization,
...@@ -185,11 +198,14 @@ if TYPE_CHECKING: ...@@ -185,11 +198,14 @@ if TYPE_CHECKING:
PrefixConstrainedLogitsProcessor, PrefixConstrainedLogitsProcessor,
RepetitionPenaltyLogitsProcessor, RepetitionPenaltyLogitsProcessor,
SequenceBiasLogitsProcessor, SequenceBiasLogitsProcessor,
SuppressTokensAtBeginLogitsProcessor,
SuppressTokensLogitsProcessor,
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper, TypicalLogitsWarper,
UnbatchedClassifierFreeGuidanceLogitsProcessor, UnbatchedClassifierFreeGuidanceLogitsProcessor,
WhisperTimeStampLogitsProcessor,
) )
from .stopping_criteria import ( from .stopping_criteria import (
MaxLengthCriteria, MaxLengthCriteria,
...@@ -261,13 +277,17 @@ if TYPE_CHECKING: ...@@ -261,13 +277,17 @@ if TYPE_CHECKING:
from .flax_logits_process import ( from .flax_logits_process import (
FlaxForcedBOSTokenLogitsProcessor, FlaxForcedBOSTokenLogitsProcessor,
FlaxForcedEOSTokenLogitsProcessor, FlaxForcedEOSTokenLogitsProcessor,
FlaxForceTokensLogitsProcessor,
FlaxLogitsProcessor, FlaxLogitsProcessor,
FlaxLogitsProcessorList, FlaxLogitsProcessorList,
FlaxLogitsWarper, FlaxLogitsWarper,
FlaxMinLengthLogitsProcessor, FlaxMinLengthLogitsProcessor,
FlaxSuppressTokensAtBeginLogitsProcessor,
FlaxSuppressTokensLogitsProcessor,
FlaxTemperatureLogitsWarper, FlaxTemperatureLogitsWarper,
FlaxTopKLogitsWarper, FlaxTopKLogitsWarper,
FlaxTopPLogitsWarper, FlaxTopPLogitsWarper,
FlaxWhisperTimeStampLogitsProcessor,
) )
from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput from .flax_utils import FlaxBeamSearchOutput, FlaxGenerationMixin, FlaxGreedySearchOutput, FlaxSampleOutput
else: else:
......
...@@ -16,6 +16,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject): ...@@ -16,6 +16,13 @@ class FlaxForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxForceTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxGenerationMixin(metaclass=DummyObject): class FlaxGenerationMixin(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -51,6 +58,20 @@ class FlaxMinLengthLogitsProcessor(metaclass=DummyObject): ...@@ -51,6 +58,20 @@ class FlaxMinLengthLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxSuppressTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxTemperatureLogitsWarper(metaclass=DummyObject): class FlaxTemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
...@@ -72,6 +93,13 @@ class FlaxTopPLogitsWarper(metaclass=DummyObject): ...@@ -72,6 +93,13 @@ class FlaxTopPLogitsWarper(metaclass=DummyObject):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
class FlaxWhisperTimeStampLogitsProcessor(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxPreTrainedModel(metaclass=DummyObject): class FlaxPreTrainedModel(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
...@@ -79,6 +79,13 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject): ...@@ -79,6 +79,13 @@ class TextDatasetForNextSentencePrediction(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class AlternatingCodebooksLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class BeamScorer(metaclass=DummyObject): class BeamScorer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -93,6 +100,13 @@ class BeamSearchScorer(metaclass=DummyObject): ...@@ -93,6 +100,13 @@ class BeamSearchScorer(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConstrainedBeamSearchScorer(metaclass=DummyObject): class ConstrainedBeamSearchScorer(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -121,6 +135,41 @@ class DisjunctiveConstraint(metaclass=DummyObject): ...@@ -121,6 +135,41 @@ class DisjunctiveConstraint(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class EncoderNoRepeatNGramLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EncoderRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EpsilonLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class EtaLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ExponentialDecayLengthPenalty(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject): class ForcedBOSTokenLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -135,6 +184,13 @@ class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject): ...@@ -135,6 +184,13 @@ class ForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class ForceTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GenerationMixin(metaclass=DummyObject): class GenerationMixin(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -156,6 +212,13 @@ class InfNanRemoveLogitsProcessor(metaclass=DummyObject): ...@@ -156,6 +212,13 @@ class InfNanRemoveLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class LogitNormalization(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class LogitsProcessor(metaclass=DummyObject): class LogitsProcessor(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -261,6 +324,20 @@ class StoppingCriteriaList(metaclass=DummyObject): ...@@ -261,6 +324,20 @@ class StoppingCriteriaList(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class SuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SuppressTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TemperatureLogitsWarper(metaclass=DummyObject): class TemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -289,6 +366,20 @@ class TypicalLogitsWarper(metaclass=DummyObject): ...@@ -289,6 +366,20 @@ class TypicalLogitsWarper(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class UnbatchedClassifierFreeGuidanceLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class WhisperTimeStampLogitsProcessor(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
def top_k_top_p_filtering(*args, **kwargs): def top_k_top_p_filtering(*args, **kwargs):
requires_backends(top_k_top_p_filtering, ["torch"]) requires_backends(top_k_top_p_filtering, ["torch"])
......
...@@ -30,6 +30,13 @@ class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject): ...@@ -30,6 +30,13 @@ class TFForcedEOSTokenLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFForceTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFGenerationMixin(metaclass=DummyObject): class TFGenerationMixin(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
...@@ -86,6 +93,20 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject): ...@@ -86,6 +93,20 @@ class TFRepetitionPenaltyLogitsProcessor(metaclass=DummyObject):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
class TFSuppressTokensAtBeginLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFSuppressTokensLogitsProcessor(metaclass=DummyObject):
_backends = ["tf"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
class TFTemperatureLogitsWarper(metaclass=DummyObject): class TFTemperatureLogitsWarper(metaclass=DummyObject):
_backends = ["tf"] _backends = ["tf"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment