Unverified Commit f8a77e1b authored by Jakob Görgen's avatar Jakob Görgen
Browse files

client/simbricks/client/client: functions for handling the creation, retrieval...

client/simbricks/client/client: functions for handling the creation, retrieval and deletion + updates of runs
parent e91118f4
...@@ -36,17 +36,25 @@ class BaseClient: ...@@ -36,17 +36,25 @@ class BaseClient:
self._base_url = base_url self._base_url = base_url
self._token_provider = TokenProvider() self._token_provider = TokenProvider()
async def _get_headers(self) -> dict: async def _get_headers(self, overwrite_headers: dict[str, typing.Any] | None = None) -> dict:
headers = {} headers = {}
token = await self._token_provider.access_token() token = await self._token_provider.access_token()
headers["Authorization"] = f"Bearer {token}" headers["Authorization"] = f"Bearer {token}"
headers["accept"] = "application/json"
headers["Content-Type"] = "application/json" if overwrite_headers:
headers.update(overwrite_headers)
headers = {k: v for k, v in headers.items() if v is not None}
return headers return headers
def build_url(self, url: str) -> str:
return f"{self._base_url}{url}"
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def session(self) -> typing.AsyncIterator[aiohttp.ClientSession]: async def session(
headers = await self._get_headers() self, overwrite_headers: dict[str, typing.Any] | None = None
) -> typing.AsyncIterator[aiohttp.ClientSession]:
headers = await self._get_headers(overwrite_headers=overwrite_headers)
session = aiohttp.ClientSession(headers=headers) session = aiohttp.ClientSession(headers=headers)
try: try:
yield session yield session
...@@ -55,26 +63,43 @@ class BaseClient: ...@@ -55,26 +63,43 @@ class BaseClient:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def post( async def post(
self, url: str, data: typing.Any = None, **kwargs: typing.Any self,
url: str,
data: typing.Any = None,
**kwargs: typing.Any,
) -> typing.AsyncIterator[aiohttp.ClientResponse]: ) -> typing.AsyncIterator[aiohttp.ClientResponse]:
url = f"{self._base_url}{url}"
async with self.session() as session: async with self.session() as session:
async with session.post(url=url, data=data, **kwargs) as resp: # TODO: handel connection error async with session.post(
url=self.build_url(url), data=data, **kwargs
) as resp: # TODO: handel connection error
print(await resp.text()) print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully resp.raise_for_status() # TODO: handel gracefully
yield resp yield resp
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def put( async def put(
self, url: str, data: typing.Any = None, **kwargs: typing.Any self,
url: str,
overwrite_headers: dict[str, typing.Any] | None = None,
data: typing.Any = None,
**kwargs: typing.Any,
) -> typing.AsyncIterator[aiohttp.ClientResponse]: ) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self.session(overwrite_headers=overwrite_headers) as session:
async with session.put(
url=self.build_url(url), data=data, **kwargs
) as resp: # TODO: handel connection error
print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully
yield resp
url = f"{self._base_url}{url}" @contextlib.asynccontextmanager
async def patch(
self, url: str, data: typing.Any = None, **kwargs: typing.Any
) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self.session() as session: async with self.session() as session:
async with session.put(url=url, data=data, **kwargs) as resp: # TODO: handel connection error async with session.patch(
url=self.build_url(url), data=data, **kwargs
) as resp: # TODO: handel connection error
print(await resp.text()) print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully resp.raise_for_status() # TODO: handel gracefully
yield resp yield resp
...@@ -83,21 +108,18 @@ class BaseClient: ...@@ -83,21 +108,18 @@ class BaseClient:
async def get( async def get(
self, url: str, data: typing.Any = None, **kwargs: typing.Any self, url: str, data: typing.Any = None, **kwargs: typing.Any
) -> typing.AsyncIterator[aiohttp.ClientResponse]: ) -> typing.AsyncIterator[aiohttp.ClientResponse]:
url = f"{self._base_url}{url}"
async with self.session() as session: async with self.session() as session:
async with session.get(url=url, data=data, **kwargs) as resp: # TODO: handel connection error async with session.get(
url=self.build_url(url), data=data, **kwargs
) as resp: # TODO: handel connection error
print(await resp.text()) print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully resp.raise_for_status() # TODO: handel gracefully
yield resp yield resp
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def delete(self, url: str, **kwargs: typing.Any) -> typing.AsyncIterator[aiohttp.ClientResponse]: async def delete(self, url: str, **kwargs: typing.Any) -> typing.AsyncIterator[aiohttp.ClientResponse]:
url = f"{self._base_url}{url}"
async with self.session() as session: async with self.session() as session:
async with session.delete(url=url, **kwargs) as resp: # TODO: handel connection error async with session.delete(url=self.build_url(url), **kwargs) as resp: # TODO: handel connection error
print(await resp.text()) print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully resp.raise_for_status() # TODO: handel gracefully
yield resp yield resp
...@@ -152,9 +174,22 @@ class NSClient: ...@@ -152,9 +174,22 @@ class NSClient:
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
async def put( async def put(
self,
url: str,
overwrite_headers: dict[str, typing.Any] | None = None,
data: typing.Any = None,
**kwargs: typing.Any,
) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self._base_client.put(
url=self._build_ns_prefix(url=url), overwrite_headers=overwrite_headers, data=data, **kwargs
) as resp:
yield resp
@contextlib.asynccontextmanager
async def patch(
self, url: str, data: typing.Any = None, **kwargs: typing.Any self, url: str, data: typing.Any = None, **kwargs: typing.Any
) -> typing.AsyncIterator[aiohttp.ClientResponse]: ) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self._base_client.put(url=self._build_ns_prefix(url=url), data=data, **kwargs) as resp: async with self._base_client.patch(url=self._build_ns_prefix(url=url), data=data, **kwargs) as resp:
yield resp yield resp
@contextlib.asynccontextmanager @contextlib.asynccontextmanager
...@@ -165,6 +200,11 @@ class NSClient: ...@@ -165,6 +200,11 @@ class NSClient:
async with self._base_client.get(url=self._build_ns_prefix(url=url), data=data, **kwargs) as resp: async with self._base_client.get(url=self._build_ns_prefix(url=url), data=data, **kwargs) as resp:
yield resp yield resp
@contextlib.asynccontextmanager
async def delete(self, url: str, **kwargs: typing.Any) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self._base_client.delete(url=self._build_ns_prefix(url=url), **kwargs) as resp:
yield resp
async def info(self): async def info(self):
async with self.get(url="/info") as resp: async with self.get(url="/info") as resp:
return await resp.json() return await resp.json()
...@@ -174,8 +214,8 @@ class NSClient: ...@@ -174,8 +214,8 @@ class NSClient:
async with self.post(url="/", json=namespace_json) as resp: async with self.post(url="/", json=namespace_json) as resp:
return await resp.json() return await resp.json()
async def delete(self, ns_id: int): async def delete_ns(self, ns_id: int):
async with self._base_client.delete(url=self._build_ns_prefix(f"/{ns_id}")) as _: async with self.delete(url=self._build_ns_prefix(f"/{ns_id}")) as _:
return return
# retrieve namespace ns_id, useful for retrieving a child the current namespace # retrieve namespace ns_id, useful for retrieving a child the current namespace
...@@ -249,6 +289,14 @@ class SimBricksClient: ...@@ -249,6 +289,14 @@ class SimBricksClient:
async with self._ns_client.post(url="/runs", json=json_obj) as resp: async with self._ns_client.post(url="/runs", json=json_obj) as resp:
return await resp.json() return await resp.json()
async def delete_run(self, rid: int):
async with self._ns_client.delete(url=f"/runs/{rid}") as resp:
return await resp.json()
async def update_run(self, rid: int, updates: dict[str, typing.Any] = {"state": "pending"}) -> dict:
async with self._ns_client.patch(url=f"/runs/{rid}", json=updates) as resp:
return await resp.json()
async def get_run(self, run_id: int) -> dict: async def get_run(self, run_id: int) -> dict:
async with self._ns_client.get(url=f"/runs/{run_id}") as resp: async with self._ns_client.get(url=f"/runs/{run_id}") as resp:
return await resp.json() return await resp.json()
...@@ -257,6 +305,30 @@ class SimBricksClient: ...@@ -257,6 +305,30 @@ class SimBricksClient:
async with self._ns_client.get(url=f"/runs") as resp: async with self._ns_client.get(url=f"/runs") as resp:
return await resp.json() return await resp.json()
async def set_run_input(self, rid: int, uploaded_input_file: str):
with open(uploaded_input_file, "rb") as f:
file_data = {"file": f}
async with self._ns_client.put(url=f"/runs/input/{rid}", data=file_data) as resp:
return await resp.json()
async def get_run_input(self, rid: int, store_path: str):
async with self._ns_client.post(url=f"/runs/input/{rid}") as resp:
content = await resp.read()
with open(store_path, "wb") as f:
f.write(content)
async def set_run_artifact(self, rid: int, uploaded_output_file: str):
with open(uploaded_output_file, "rb") as f:
file_data = {"file": f}
async with self._ns_client.put(url=f"/runs/output/{rid}", data=file_data) as resp:
return await resp.json()
async def get_run_artifact(self, rid: int, store_path: str):
async with self._ns_client.post(url=f"/runs/output/{rid}") as resp:
content = await resp.read()
with open(store_path, "wb") as f:
f.write(content)
class RunnerClient: class RunnerClient:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment