Unverified Commit 5c17918f authored by Robert Dargavel Smith's avatar Robert Dargavel Smith Committed by GitHub
Browse files

Allow from transformers import TypicalLogitsWarper (#17477)

* Allow from transformers import TypicalLogitsWarper

* Added TypicalLogitsWarper

* Allow from transformers import TypicalLogitsWarper

* Allow from transformers import TypicalLogitsWarper

* Allow from transformers import TypicalLogitsWarper

* Allow from transformers import TypicalLogitsWarper

Added TypicalLogitsWarper

Allow from transformers import TypicalLogitsWarper

Allow from transformers import TypicalLogitsWarper

Allow from transformers import TypicalLogitsWarper
parent 607acd4f
...@@ -127,6 +127,9 @@ generation. ...@@ -127,6 +127,9 @@ generation.
[[autodoc]] TopKLogitsWarper [[autodoc]] TopKLogitsWarper
- __call__ - __call__
[[autodoc]] TypicalLogitsWarper
- __call__
[[autodoc]] NoRepeatNGramLogitsProcessor [[autodoc]] NoRepeatNGramLogitsProcessor
- __call__ - __call__
......
...@@ -703,6 +703,7 @@ else: ...@@ -703,6 +703,7 @@ else:
"TemperatureLogitsWarper", "TemperatureLogitsWarper",
"TopKLogitsWarper", "TopKLogitsWarper",
"TopPLogitsWarper", "TopPLogitsWarper",
"TypicalLogitsWarper",
] ]
_import_structure["generation_stopping_criteria"] = [ _import_structure["generation_stopping_criteria"] = [
"MaxLengthCriteria", "MaxLengthCriteria",
...@@ -3218,6 +3219,7 @@ if TYPE_CHECKING: ...@@ -3218,6 +3219,7 @@ if TYPE_CHECKING:
TemperatureLogitsWarper, TemperatureLogitsWarper,
TopKLogitsWarper, TopKLogitsWarper,
TopPLogitsWarper, TopPLogitsWarper,
TypicalLogitsWarper,
) )
from .generation_stopping_criteria import ( from .generation_stopping_criteria import (
MaxLengthCriteria, MaxLengthCriteria,
......
...@@ -234,6 +234,13 @@ class TopPLogitsWarper(metaclass=DummyObject): ...@@ -234,6 +234,13 @@ class TopPLogitsWarper(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class TypicalLogitsWarper(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class MaxLengthCriteria(metaclass=DummyObject): class MaxLengthCriteria(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment