"vscode:/vscode.git/clone" did not exist on "a94da42014b5714cbf6aad5e585f4fb06296377b"
Unverified Commit f8a77e1b authored by Jakob Görgen's avatar Jakob Görgen
Browse files

client/simbricks/client/client: functions for handling the creation, retrieval...

client/simbricks/client/client: functions for handling the creation, retrieval and deletion + updates of runs
parent e91118f4
......@@ -36,17 +36,25 @@ class BaseClient:
self._base_url = base_url
self._token_provider = TokenProvider()
async def _get_headers(self) -> dict:
async def _get_headers(self, overwrite_headers: dict[str, typing.Any] | None = None) -> dict:
headers = {}
token = await self._token_provider.access_token()
headers["Authorization"] = f"Bearer {token}"
headers["accept"] = "application/json"
headers["Content-Type"] = "application/json"
if overwrite_headers:
headers.update(overwrite_headers)
headers = {k: v for k, v in headers.items() if v is not None}
return headers
def build_url(self, url: str) -> str:
return f"{self._base_url}{url}"
@contextlib.asynccontextmanager
async def session(self) -> typing.AsyncIterator[aiohttp.ClientSession]:
headers = await self._get_headers()
async def session(
self, overwrite_headers: dict[str, typing.Any] | None = None
) -> typing.AsyncIterator[aiohttp.ClientSession]:
headers = await self._get_headers(overwrite_headers=overwrite_headers)
session = aiohttp.ClientSession(headers=headers)
try:
yield session
......@@ -55,26 +63,43 @@ class BaseClient:
@contextlib.asynccontextmanager
async def post(
self, url: str, data: typing.Any = None, **kwargs: typing.Any
self,
url: str,
data: typing.Any = None,
**kwargs: typing.Any,
) -> typing.AsyncIterator[aiohttp.ClientResponse]:
url = f"{self._base_url}{url}"
async with self.session() as session:
async with session.post(url=url, data=data, **kwargs) as resp: # TODO: handel connection error
async with session.post(
url=self.build_url(url), data=data, **kwargs
) as resp: # TODO: handel connection error
print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully
yield resp
@contextlib.asynccontextmanager
async def put(
self, url: str, data: typing.Any = None, **kwargs: typing.Any
self,
url: str,
overwrite_headers: dict[str, typing.Any] | None = None,
data: typing.Any = None,
**kwargs: typing.Any,
) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self.session(overwrite_headers=overwrite_headers) as session:
async with session.put(
url=self.build_url(url), data=data, **kwargs
) as resp: # TODO: handel connection error
print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully
yield resp
url = f"{self._base_url}{url}"
@contextlib.asynccontextmanager
async def patch(
self, url: str, data: typing.Any = None, **kwargs: typing.Any
) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self.session() as session:
async with session.put(url=url, data=data, **kwargs) as resp: # TODO: handel connection error
async with session.patch(
url=self.build_url(url), data=data, **kwargs
) as resp: # TODO: handel connection error
print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully
yield resp
......@@ -83,21 +108,18 @@ class BaseClient:
async def get(
self, url: str, data: typing.Any = None, **kwargs: typing.Any
) -> typing.AsyncIterator[aiohttp.ClientResponse]:
url = f"{self._base_url}{url}"
async with self.session() as session:
async with session.get(url=url, data=data, **kwargs) as resp: # TODO: handel connection error
async with session.get(
url=self.build_url(url), data=data, **kwargs
) as resp: # TODO: handel connection error
print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully
yield resp
@contextlib.asynccontextmanager
async def delete(self, url: str, **kwargs: typing.Any) -> typing.AsyncIterator[aiohttp.ClientResponse]:
url = f"{self._base_url}{url}"
async with self.session() as session:
async with session.delete(url=url, **kwargs) as resp: # TODO: handel connection error
async with session.delete(url=self.build_url(url), **kwargs) as resp: # TODO: handel connection error
print(await resp.text())
resp.raise_for_status() # TODO: handel gracefully
yield resp
......@@ -152,9 +174,22 @@ class NSClient:
@contextlib.asynccontextmanager
async def put(
self,
url: str,
overwrite_headers: dict[str, typing.Any] | None = None,
data: typing.Any = None,
**kwargs: typing.Any,
) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self._base_client.put(
url=self._build_ns_prefix(url=url), overwrite_headers=overwrite_headers, data=data, **kwargs
) as resp:
yield resp
@contextlib.asynccontextmanager
async def patch(
self, url: str, data: typing.Any = None, **kwargs: typing.Any
) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self._base_client.put(url=self._build_ns_prefix(url=url), data=data, **kwargs) as resp:
async with self._base_client.patch(url=self._build_ns_prefix(url=url), data=data, **kwargs) as resp:
yield resp
@contextlib.asynccontextmanager
......@@ -165,6 +200,11 @@ class NSClient:
async with self._base_client.get(url=self._build_ns_prefix(url=url), data=data, **kwargs) as resp:
yield resp
@contextlib.asynccontextmanager
async def delete(self, url: str, **kwargs: typing.Any) -> typing.AsyncIterator[aiohttp.ClientResponse]:
async with self._base_client.delete(url=self._build_ns_prefix(url=url), **kwargs) as resp:
yield resp
async def info(self):
async with self.get(url="/info") as resp:
return await resp.json()
......@@ -174,8 +214,8 @@ class NSClient:
async with self.post(url="/", json=namespace_json) as resp:
return await resp.json()
async def delete(self, ns_id: int):
async with self._base_client.delete(url=self._build_ns_prefix(f"/{ns_id}")) as _:
async def delete_ns(self, ns_id: int):
async with self.delete(url=self._build_ns_prefix(f"/{ns_id}")) as _:
return
# retrieve namespace ns_id, useful for retrieving a child the current namespace
......@@ -249,6 +289,14 @@ class SimBricksClient:
async with self._ns_client.post(url="/runs", json=json_obj) as resp:
return await resp.json()
async def delete_run(self, rid: int):
async with self._ns_client.delete(url=f"/runs/{rid}") as resp:
return await resp.json()
async def update_run(self, rid: int, updates: dict[str, typing.Any] = {"state": "pending"}) -> dict:
async with self._ns_client.patch(url=f"/runs/{rid}", json=updates) as resp:
return await resp.json()
async def get_run(self, run_id: int) -> dict:
async with self._ns_client.get(url=f"/runs/{run_id}") as resp:
return await resp.json()
......@@ -257,6 +305,30 @@ class SimBricksClient:
async with self._ns_client.get(url=f"/runs") as resp:
return await resp.json()
async def set_run_input(self, rid: int, uploaded_input_file: str):
with open(uploaded_input_file, "rb") as f:
file_data = {"file": f}
async with self._ns_client.put(url=f"/runs/input/{rid}", data=file_data) as resp:
return await resp.json()
async def get_run_input(self, rid: int, store_path: str):
async with self._ns_client.post(url=f"/runs/input/{rid}") as resp:
content = await resp.read()
with open(store_path, "wb") as f:
f.write(content)
async def set_run_artifact(self, rid: int, uploaded_output_file: str):
with open(uploaded_output_file, "rb") as f:
file_data = {"file": f}
async with self._ns_client.put(url=f"/runs/output/{rid}", data=file_data) as resp:
return await resp.json()
async def get_run_artifact(self, rid: int, store_path: str):
async with self._ns_client.post(url=f"/runs/output/{rid}") as resp:
content = await resp.read()
with open(store_path, "wb") as f:
f.write(content)
class RunnerClient:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment