client.py 19.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
import json
import requests

from aiohttp import ClientSession, ClientTimeout
from pydantic import ValidationError
from typing import Dict, Optional, List, AsyncIterator, Iterator

from text_generation.types import (
    StreamResponse,
    Response,
    Request,
    Parameters,
drbh's avatar
drbh committed
13
    Grammar,
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
)
from text_generation.errors import parse_error


class Client:
    """Client to make calls to a text-generation-inference instance

     Example:

     ```python
     >>> from text_generation import Client

     >>> client = Client("https://api-inference.huggingface.co/models/bigscience/bloomz")
     >>> client.generate("Why is the sky blue?").generated_text
     ' Rayleigh scattering'

     >>> result = ""
     >>> for response in client.generate_stream("Why is the sky blue?"):
     >>>     if not response.token.special:
     >>>         result += response.token.text
     >>> result
    ' Rayleigh scattering'
     ```
    """

    def __init__(
40
41
42
43
44
        self,
        base_url: str,
        headers: Optional[Dict[str, str]] = None,
        cookies: Optional[Dict[str, str]] = None,
        timeout: int = 10,
45
46
47
48
49
50
51
    ):
        """
        Args:
            base_url (`str`):
                text-generation-inference instance base url
            headers (`Optional[Dict[str, str]]`):
                Additional headers
52
53
            cookies (`Optional[Dict[str, str]]`):
                Cookies to include in the requests
54
55
56
57
58
            timeout (`int`):
                Timeout in seconds
        """
        self.base_url = base_url
        self.headers = headers
59
        self.cookies = cookies
60
61
62
63
64
65
        self.timeout = timeout

    def generate(
        self,
        prompt: str,
        do_sample: bool = False,
66
        max_new_tokens: int = 20,
67
        best_of: Optional[int] = None,
68
69
70
71
72
73
74
        repetition_penalty: Optional[float] = None,
        return_full_text: bool = False,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
75
76
        truncate: Optional[int] = None,
        typical_p: Optional[float] = None,
77
        watermark: bool = False,
78
        decoder_input_details: bool = False,
Nicolas Patry's avatar
Nicolas Patry committed
79
        top_n_tokens: Optional[int] = None,
drbh's avatar
drbh committed
80
        grammar: Optional[Grammar] = None,
81
82
83
84
85
86
87
88
89
90
91
    ) -> Response:
        """
        Given a prompt, generate the following text

        Args:
            prompt (`str`):
                Input text
            do_sample (`bool`):
                Activate logits sampling
            max_new_tokens (`int`):
                Maximum number of generated tokens
92
93
            best_of (`int`):
                Generate best_of sequences and return the one if the highest token logprobs
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
            repetition_penalty (`float`):
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
            return_full_text (`bool`):
                Whether to prepend the prompt to the generated text
            seed (`int`):
                Random sampling seed
            stop_sequences (`List[str]`):
                Stop generating tokens if a member of `stop_sequences` is generated
            temperature (`float`):
                The value used to module the logits distribution.
            top_k (`int`):
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
            top_p (`float`):
                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
                higher are kept for generation.
110
111
112
113
114
            truncate (`int`):
                Truncate inputs tokens to the given size
            typical_p (`float`):
                Typical Decoding mass
                See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
115
            watermark (`bool`):
116
                Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
117
118
            decoder_input_details (`bool`):
                Return the decoder input token logprobs and ids
Nicolas Patry's avatar
Nicolas Patry committed
119
120
            top_n_tokens (`int`):
                Return the `n` most likely tokens at each step
121
122
123
124
125
126

        Returns:
            Response: generated response
        """
        # Validate parameters
        parameters = Parameters(
127
            best_of=best_of,
128
129
130
131
132
133
134
135
136
137
            details=True,
            do_sample=do_sample,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            return_full_text=return_full_text,
            seed=seed,
            stop=stop_sequences if stop_sequences is not None else [],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
138
139
            truncate=truncate,
            typical_p=typical_p,
140
            watermark=watermark,
141
            decoder_input_details=decoder_input_details,
OlivierDehaene's avatar
OlivierDehaene committed
142
            top_n_tokens=top_n_tokens,
drbh's avatar
drbh committed
143
            grammar=grammar,
144
145
146
147
148
149
150
        )
        request = Request(inputs=prompt, stream=False, parameters=parameters)

        resp = requests.post(
            self.base_url,
            json=request.dict(),
            headers=self.headers,
151
            cookies=self.cookies,
152
153
154
155
156
157
158
159
160
161
162
            timeout=self.timeout,
        )
        payload = resp.json()
        if resp.status_code != 200:
            raise parse_error(resp.status_code, payload)
        return Response(**payload[0])

    def generate_stream(
        self,
        prompt: str,
        do_sample: bool = False,
163
        max_new_tokens: int = 20,
164
165
166
167
168
169
170
        repetition_penalty: Optional[float] = None,
        return_full_text: bool = False,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
171
172
        truncate: Optional[int] = None,
        typical_p: Optional[float] = None,
173
        watermark: bool = False,
Nicolas Patry's avatar
Nicolas Patry committed
174
        top_n_tokens: Optional[int] = None,
drbh's avatar
drbh committed
175
        grammar: Optional[Grammar] = None,
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    ) -> Iterator[StreamResponse]:
        """
        Given a prompt, generate the following stream of tokens

        Args:
            prompt (`str`):
                Input text
            do_sample (`bool`):
                Activate logits sampling
            max_new_tokens (`int`):
                Maximum number of generated tokens
            repetition_penalty (`float`):
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
            return_full_text (`bool`):
                Whether to prepend the prompt to the generated text
            seed (`int`):
                Random sampling seed
            stop_sequences (`List[str]`):
                Stop generating tokens if a member of `stop_sequences` is generated
            temperature (`float`):
                The value used to module the logits distribution.
            top_k (`int`):
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
            top_p (`float`):
                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
                higher are kept for generation.
203
204
205
206
207
            truncate (`int`):
                Truncate inputs tokens to the given size
            typical_p (`float`):
                Typical Decoding mass
                See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
208
            watermark (`bool`):
209
                Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Nicolas Patry's avatar
Nicolas Patry committed
210
211
            top_n_tokens (`int`):
                Return the `n` most likely tokens at each step
212
213
214
215
216
217

        Returns:
            Iterator[StreamResponse]: stream of generated tokens
        """
        # Validate parameters
        parameters = Parameters(
218
            best_of=None,
219
            details=True,
220
            decoder_input_details=False,
221
222
223
224
225
226
227
228
229
            do_sample=do_sample,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            return_full_text=return_full_text,
            seed=seed,
            stop=stop_sequences if stop_sequences is not None else [],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
230
231
            truncate=truncate,
            typical_p=typical_p,
232
            watermark=watermark,
Nicolas Patry's avatar
Nicolas Patry committed
233
            top_n_tokens=top_n_tokens,
drbh's avatar
drbh committed
234
            grammar=grammar,
235
236
237
238
239
240
241
        )
        request = Request(inputs=prompt, stream=True, parameters=parameters)

        resp = requests.post(
            self.base_url,
            json=request.dict(),
            headers=self.headers,
242
            cookies=self.cookies,
243
            timeout=self.timeout,
244
            stream=True,
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        )

        if resp.status_code != 200:
            raise parse_error(resp.status_code, resp.json())

        # Parse ServerSentEvents
        for byte_payload in resp.iter_lines():
            # Skip line
            if byte_payload == b"\n":
                continue

            payload = byte_payload.decode("utf-8")

            # Event data
            if payload.startswith("data:"):
                # Decode payload
                json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
                # Parse payload
                try:
                    response = StreamResponse(**json_payload)
                except ValidationError:
                    # If we failed to parse the payload, then it is an error payload
                    raise parse_error(resp.status_code, json_payload)
                yield response


class AsyncClient:
    """Asynchronous Client to make calls to a text-generation-inference instance

     Example:

     ```python
     >>> from text_generation import AsyncClient

     >>> client = AsyncClient("https://api-inference.huggingface.co/models/bigscience/bloomz")
     >>> response = await client.generate("Why is the sky blue?")
     >>> response.generated_text
     ' Rayleigh scattering'

     >>> result = ""
     >>> async for response in client.generate_stream("Why is the sky blue?"):
     >>>     if not response.token.special:
     >>>         result += response.token.text
     >>> result
    ' Rayleigh scattering'
     ```
    """

    def __init__(
294
295
296
297
298
        self,
        base_url: str,
        headers: Optional[Dict[str, str]] = None,
        cookies: Optional[Dict[str, str]] = None,
        timeout: int = 10,
299
300
301
302
303
304
305
    ):
        """
        Args:
            base_url (`str`):
                text-generation-inference instance base url
            headers (`Optional[Dict[str, str]]`):
                Additional headers
306
307
            cookies (`Optional[Dict[str, str]]`):
                Cookies to include in the requests
308
309
310
311
312
            timeout (`int`):
                Timeout in seconds
        """
        self.base_url = base_url
        self.headers = headers
313
        self.cookies = cookies
314
315
316
317
318
319
        self.timeout = ClientTimeout(timeout * 60)

    async def generate(
        self,
        prompt: str,
        do_sample: bool = False,
320
        max_new_tokens: int = 20,
321
        best_of: Optional[int] = None,
322
323
324
325
326
327
328
        repetition_penalty: Optional[float] = None,
        return_full_text: bool = False,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
329
330
        truncate: Optional[int] = None,
        typical_p: Optional[float] = None,
331
        watermark: bool = False,
332
        decoder_input_details: bool = False,
Nicolas Patry's avatar
Nicolas Patry committed
333
        top_n_tokens: Optional[int] = None,
drbh's avatar
drbh committed
334
        grammar: Optional[Grammar] = None,
335
336
337
338
339
340
341
342
343
344
345
    ) -> Response:
        """
        Given a prompt, generate the following text asynchronously

        Args:
            prompt (`str`):
                Input text
            do_sample (`bool`):
                Activate logits sampling
            max_new_tokens (`int`):
                Maximum number of generated tokens
346
347
            best_of (`int`):
                Generate best_of sequences and return the one if the highest token logprobs
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
            repetition_penalty (`float`):
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
            return_full_text (`bool`):
                Whether to prepend the prompt to the generated text
            seed (`int`):
                Random sampling seed
            stop_sequences (`List[str]`):
                Stop generating tokens if a member of `stop_sequences` is generated
            temperature (`float`):
                The value used to module the logits distribution.
            top_k (`int`):
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
            top_p (`float`):
                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
                higher are kept for generation.
364
365
366
367
368
            truncate (`int`):
                Truncate inputs tokens to the given size
            typical_p (`float`):
                Typical Decoding mass
                See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
369
            watermark (`bool`):
370
                Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
371
372
            decoder_input_details (`bool`):
                Return the decoder input token logprobs and ids
Nicolas Patry's avatar
Nicolas Patry committed
373
374
            top_n_tokens (`int`):
                Return the `n` most likely tokens at each step
375
376
377
378

        Returns:
            Response: generated response
        """
drbh's avatar
drbh committed
379

380
381
        # Validate parameters
        parameters = Parameters(
382
            best_of=best_of,
383
            details=True,
384
            decoder_input_details=decoder_input_details,
385
386
387
388
389
390
391
392
393
            do_sample=do_sample,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            return_full_text=return_full_text,
            seed=seed,
            stop=stop_sequences if stop_sequences is not None else [],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
394
395
            truncate=truncate,
            typical_p=typical_p,
396
            watermark=watermark,
Nicolas Patry's avatar
Nicolas Patry committed
397
            top_n_tokens=top_n_tokens,
drbh's avatar
drbh committed
398
            grammar=grammar,
399
400
401
        )
        request = Request(inputs=prompt, stream=False, parameters=parameters)

402
403
404
        async with ClientSession(
            headers=self.headers, cookies=self.cookies, timeout=self.timeout
        ) as session:
405
406
407
408
409
410
411
412
413
414
415
            async with session.post(self.base_url, json=request.dict()) as resp:
                payload = await resp.json()

                if resp.status != 200:
                    raise parse_error(resp.status, payload)
                return Response(**payload[0])

    async def generate_stream(
        self,
        prompt: str,
        do_sample: bool = False,
416
        max_new_tokens: int = 20,
417
418
419
420
421
422
423
        repetition_penalty: Optional[float] = None,
        return_full_text: bool = False,
        seed: Optional[int] = None,
        stop_sequences: Optional[List[str]] = None,
        temperature: Optional[float] = None,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
424
425
        truncate: Optional[int] = None,
        typical_p: Optional[float] = None,
426
        watermark: bool = False,
Nicolas Patry's avatar
Nicolas Patry committed
427
        top_n_tokens: Optional[int] = None,
drbh's avatar
drbh committed
428
        grammar: Optional[Grammar] = None,
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
    ) -> AsyncIterator[StreamResponse]:
        """
        Given a prompt, generate the following stream of tokens asynchronously

        Args:
            prompt (`str`):
                Input text
            do_sample (`bool`):
                Activate logits sampling
            max_new_tokens (`int`):
                Maximum number of generated tokens
            repetition_penalty (`float`):
                The parameter for repetition penalty. 1.0 means no penalty. See [this
                paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
            return_full_text (`bool`):
                Whether to prepend the prompt to the generated text
            seed (`int`):
                Random sampling seed
            stop_sequences (`List[str]`):
                Stop generating tokens if a member of `stop_sequences` is generated
            temperature (`float`):
                The value used to module the logits distribution.
            top_k (`int`):
                The number of highest probability vocabulary tokens to keep for top-k-filtering.
            top_p (`float`):
                If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
                higher are kept for generation.
456
457
458
459
460
            truncate (`int`):
                Truncate inputs tokens to the given size
            typical_p (`float`):
                Typical Decoding mass
                See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
461
            watermark (`bool`):
462
                Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Nicolas Patry's avatar
Nicolas Patry committed
463
464
            top_n_tokens (`int`):
                Return the `n` most likely tokens at each step
465
466
467
468
469
470

        Returns:
            AsyncIterator[StreamResponse]: stream of generated tokens
        """
        # Validate parameters
        parameters = Parameters(
471
            best_of=None,
472
            details=True,
473
            decoder_input_details=False,
474
475
476
477
478
479
480
481
482
            do_sample=do_sample,
            max_new_tokens=max_new_tokens,
            repetition_penalty=repetition_penalty,
            return_full_text=return_full_text,
            seed=seed,
            stop=stop_sequences if stop_sequences is not None else [],
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
483
484
            truncate=truncate,
            typical_p=typical_p,
485
            watermark=watermark,
Nicolas Patry's avatar
Nicolas Patry committed
486
            top_n_tokens=top_n_tokens,
drbh's avatar
drbh committed
487
            grammar=grammar,
488
489
490
        )
        request = Request(inputs=prompt, stream=True, parameters=parameters)

491
492
493
        async with ClientSession(
            headers=self.headers, cookies=self.cookies, timeout=self.timeout
        ) as session:
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            async with session.post(self.base_url, json=request.dict()) as resp:
                if resp.status != 200:
                    raise parse_error(resp.status, await resp.json())

                # Parse ServerSentEvents
                async for byte_payload in resp.content:
                    # Skip line
                    if byte_payload == b"\n":
                        continue

                    payload = byte_payload.decode("utf-8")

                    # Event data
                    if payload.startswith("data:"):
                        # Decode payload
                        json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
                        # Parse payload
                        try:
                            response = StreamResponse(**json_payload)
                        except ValidationError:
                            # If we failed to parse the payload, then it is an error payload
                            raise parse_error(resp.status, json_payload)
                        yield response