Unverified Commit b9633c46 authored by Jae-Won Chung's avatar Jae-Won Chung Committed by GitHub
Browse files

Fix typing in `Model.generate_token` (#733)

## What does this PR do?

This PR fixes a minor type annotation issue in the signature of
`Model.generate_token`.

All existing overrides of `Model.generate_token` return
`Tuple[List[Generation], Optional[B]]`:

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/causal_lm.py#L535-L537

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/flash_causal_lm.py#L802-L804

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/seq2seq_lm.py#L589-L591

I suspect that back in 017a2a8c when `GeneratedText` and `Generation`
were separated, the function signature was not updated.

## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

CC @OlivierDehaene
parent 92bb56b0
...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod ...@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText from text_generation_server.models.types import Batch, Generation
from text_generation_server.pb.generate_pb2 import InfoResponse from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch) B = TypeVar("B", bound=Batch)
...@@ -52,7 +52,7 @@ class Model(ABC): ...@@ -52,7 +52,7 @@ class Model(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]: def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
raise NotImplementedError raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]: def warmup(self, batch: B) -> Optional[int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment