Unverified Commit 2d56f106 authored by freitng's avatar freitng Committed by GitHub
Browse files

Modify default for max_new_tokens in python client (#1336)

# What does this PR do?
Since
([#1097](https://github.com/huggingface/text-generation-inference/pull/1097))
the clients do not need to specify a max_length anymore. However, the
python client in this repo had not yet been adapted to these changes.
This PR makes it possible to use the python client and not provide
max_new_tokens.

<!-- Remove if not applicable -->


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [x] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [x] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.
parent a9ea6068
......@@ -21,6 +21,22 @@ def test_generate(flan_t5_xxl_url, hf_headers):
assert not response.details.tokens[0].special
def test_generate_max_new_tokens_not_set(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate("test", decoder_input_details=True)
assert response.generated_text != ""
assert response.details.finish_reason == FinishReason.EndOfSequenceToken
assert response.details.generated_tokens > 1
assert response.details.seed is None
assert len(response.details.prefill) == 1
assert response.details.prefill[0] == InputToken(id=0, text="<pad>", logprob=None)
assert len(response.details.tokens) > 1
assert response.details.tokens[0].id == 3
assert response.details.tokens[0].text == " "
assert not response.details.tokens[0].special
def test_generate_best_of(flan_t5_xxl_url, hf_headers):
client = Client(flan_t5_xxl_url, hf_headers)
response = client.generate(
......
......@@ -62,7 +62,7 @@ class Client:
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
......@@ -157,7 +157,7 @@ class Client:
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
......@@ -312,7 +312,7 @@ class AsyncClient:
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
......@@ -405,7 +405,7 @@ class AsyncClient:
self,
prompt: str,
do_sample: bool = False,
max_new_tokens: int = 20,
max_new_tokens: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
......
......@@ -9,7 +9,7 @@ class Parameters(BaseModel):
# Activate logits sampling
do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: int = 20
max_new_tokens: Optional[int] = None
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment