Unverified Commit c9f09a4f authored by Fred Reiss's avatar Fred Reiss Committed by GitHub
Browse files

[mypy] Fix mypy warnings in api_server.py (#11941)


Signed-off-by: default avatarFred Reiss <frreiss@us.ibm.com>
parent d45cbe70
......@@ -14,7 +14,7 @@ from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set, Tuple
from typing import AsyncIterator, Dict, Optional, Set, Tuple, Union
import uvloop
from fastapi import APIRouter, FastAPI, HTTPException, Request
......@@ -420,6 +420,8 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
"use the Pooling API (`/pooling`) instead.")
res = await fallback_handler.create_pooling(request, raw_request)
generator: Union[ErrorResponse, EmbeddingResponse]
if isinstance(res, PoolingResponse):
generator = EmbeddingResponse(
id=res.id,
......@@ -494,7 +496,7 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
return await create_score(request, raw_request)
TASK_HANDLERS = {
TASK_HANDLERS: Dict[str, Dict[str, tuple]] = {
"generate": {
"messages": (ChatCompletionRequest, create_chat_completion),
"default": (CompletionRequest, create_completion),
......@@ -652,7 +654,7 @@ def build_app(args: Namespace) -> FastAPI:
module_path, object_name = middleware.rsplit(".", 1)
imported = getattr(importlib.import_module(module_path), object_name)
if inspect.isclass(imported):
app.add_middleware(imported)
app.add_middleware(imported) # type: ignore[arg-type]
elif inspect.iscoroutinefunction(imported):
app.middleware("http")(imported)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment