Unverified Commit b9633c46 authored by Jae-Won Chung's avatar Jae-Won Chung Committed by GitHub
Browse files

Fix typing in `Model.generate_token` (#733)

## What does this PR do?

This PR fixes a minor type annotation issue in the signature of
`Model.generate_token`.

All existing overrides of `Model.generate_token` return
`Tuple[List[Generation], Optional[B]]`:

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/causal_lm.py#L535-L537

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/flash_causal_lm.py#L802-L804

https://github.com/huggingface/text-generation-inference/blob/3ef5ffbc6400370ff2e1546550a6bad3ac61b079/server/text_generation_server/models/seq2seq_lm.py#L589-L591

I suspect that back in 017a2a8c when `GeneratedText` and `Generation`
were separated, the function signature was not updated.

## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?

CC @OlivierDehaene
parent 92bb56b0
......@@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.models.types import Batch, Generation
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
......@@ -52,7 +52,7 @@ class Model(ABC):
raise NotImplementedError
@abstractmethod
def generate_token(self, batch: B) -> Tuple[List[GeneratedText], Optional[B]]:
def generate_token(self, batch: B) -> Tuple[List[Generation], Optional[B]]:
raise NotImplementedError
def warmup(self, batch: B) -> Optional[int]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment