Unverified Commit 78501190 authored by dconathan's avatar dconathan Committed by GitHub
Browse files

feat(python-client): add cookies to Client constructors and requests (#132)

I have a use case where we need to pass cookies (for auth reasons) to an
internally hosted server.

Note: I couldn't get the client tests to pass - do you need to have an
HF token?

```python
FAILED tests/test_client.py::test_generate - text_generation.errors.BadRequestError: Authorization header is correct, but the token seems invalid
```
parent a3b7db93
...@@ -36,7 +36,11 @@ class Client: ...@@ -36,7 +36,11 @@ class Client:
""" """
def __init__( def __init__(
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10 self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
): ):
""" """
Args: Args:
...@@ -44,11 +48,14 @@ class Client: ...@@ -44,11 +48,14 @@ class Client:
text-generation-inference instance base url text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`): headers (`Optional[Dict[str, str]]`):
Additional headers Additional headers
cookies (`Optional[Dict[str, str]]`):
Cookies to include in the requests
timeout (`int`): timeout (`int`):
Timeout in seconds Timeout in seconds
""" """
self.base_url = base_url self.base_url = base_url
self.headers = headers self.headers = headers
self.cookies = cookies
self.timeout = timeout self.timeout = timeout
def generate( def generate(
...@@ -130,6 +137,7 @@ class Client: ...@@ -130,6 +137,7 @@ class Client:
self.base_url, self.base_url,
json=request.dict(), json=request.dict(),
headers=self.headers, headers=self.headers,
cookies=self.cookies,
timeout=self.timeout, timeout=self.timeout,
) )
payload = resp.json() payload = resp.json()
...@@ -216,6 +224,7 @@ class Client: ...@@ -216,6 +224,7 @@ class Client:
self.base_url, self.base_url,
json=request.dict(), json=request.dict(),
headers=self.headers, headers=self.headers,
cookies=self.cookies,
timeout=self.timeout, timeout=self.timeout,
stream=True, stream=True,
) )
...@@ -267,7 +276,11 @@ class AsyncClient: ...@@ -267,7 +276,11 @@ class AsyncClient:
""" """
def __init__( def __init__(
self, base_url: str, headers: Optional[Dict[str, str]] = None, timeout: int = 10 self,
base_url: str,
headers: Optional[Dict[str, str]] = None,
cookies: Optional[Dict[str, str]] = None,
timeout: int = 10,
): ):
""" """
Args: Args:
...@@ -275,11 +288,14 @@ class AsyncClient: ...@@ -275,11 +288,14 @@ class AsyncClient:
text-generation-inference instance base url text-generation-inference instance base url
headers (`Optional[Dict[str, str]]`): headers (`Optional[Dict[str, str]]`):
Additional headers Additional headers
cookies (`Optional[Dict[str, str]]`):
Cookies to include in the requests
timeout (`int`): timeout (`int`):
Timeout in seconds Timeout in seconds
""" """
self.base_url = base_url self.base_url = base_url
self.headers = headers self.headers = headers
self.cookies = cookies
self.timeout = ClientTimeout(timeout * 60) self.timeout = ClientTimeout(timeout * 60)
async def generate( async def generate(
...@@ -357,7 +373,9 @@ class AsyncClient: ...@@ -357,7 +373,9 @@ class AsyncClient:
) )
request = Request(inputs=prompt, stream=False, parameters=parameters) request = Request(inputs=prompt, stream=False, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session: async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp: async with session.post(self.base_url, json=request.dict()) as resp:
payload = await resp.json() payload = await resp.json()
...@@ -440,7 +458,9 @@ class AsyncClient: ...@@ -440,7 +458,9 @@ class AsyncClient:
) )
request = Request(inputs=prompt, stream=True, parameters=parameters) request = Request(inputs=prompt, stream=True, parameters=parameters)
async with ClientSession(headers=self.headers, timeout=self.timeout) as session: async with ClientSession(
headers=self.headers, cookies=self.cookies, timeout=self.timeout
) as session:
async with session.post(self.base_url, json=request.dict()) as resp: async with session.post(self.base_url, json=request.dict()) as resp:
if resp.status != 200: if resp.status != 200:
......
...@@ -92,7 +92,9 @@ class InferenceAPIClient(Client): ...@@ -92,7 +92,9 @@ class InferenceAPIClient(Client):
) )
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIClient, self).__init__(base_url, headers, timeout) super(InferenceAPIClient, self).__init__(
base_url, headers=headers, timeout=timeout
)
class InferenceAPIAsyncClient(AsyncClient): class InferenceAPIAsyncClient(AsyncClient):
...@@ -147,4 +149,6 @@ class InferenceAPIAsyncClient(AsyncClient): ...@@ -147,4 +149,6 @@ class InferenceAPIAsyncClient(AsyncClient):
) )
base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}" base_url = f"{INFERENCE_ENDPOINT}/models/{repo_id}"
super(InferenceAPIAsyncClient, self).__init__(base_url, headers, timeout) super(InferenceAPIAsyncClient, self).__init__(
base_url, headers=headers, timeout=timeout
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment