Commit c995bdbb authored by Alisehen's avatar Alisehen
Browse files

add check-para

parent 48558801
...@@ -13,7 +13,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role ...@@ -13,7 +13,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
from fastapi.responses import JSONResponse
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
# Define own data structure instead of importing from OpenAI # Define own data structure instead of importing from OpenAI
...@@ -143,7 +143,67 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): ...@@ -143,7 +143,67 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
# Process messages with tool functionality if needed # Process messages with tool functionality if needed
enhanced_messages = list(create.messages) enhanced_messages = list(create.messages)
if create.model != Config().model_name:
return JSONResponse(
status_code=400,
content={
"error": {
"message": "Model not found",
"code": 404,
"type": "NotFound"
}
})
if create.max_tokens<0 or create.max_completion_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.temperature<0 or create.temperature>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"temperature must be in [0, 2], got {create.temperature}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.top_p<=0 or create.top_p>1:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"top_p must be in (0, 1], got {create.top_p}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.frequency_penalty<-2 or create.frequency_penalty>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"frequency_penalty must be in [-2, 2], got {create.frequency_penalty}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.presence_penalty<-2 or create.presence_penalty>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"presence_penalty must be in [-2, 2], got {create.presence_penalty}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
# Check if tools are present # Check if tools are present
has_tools = create.tools and len(create.tools) > 0 has_tools = create.tools and len(create.tools) > 0
......
...@@ -7,13 +7,63 @@ from ktransformers.server.utils.create_interface import get_interface ...@@ -7,13 +7,63 @@ from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.endpoints.chat import RawUsage
from fastapi.responses import JSONResponse
from ktransformers.server.config.config import Config
router = APIRouter() router = APIRouter()
@router.post("/completions",tags=['openai']) @router.post("/completions",tags=['openai'])
async def create_completion(request:Request, create:CompletionCreate): async def create_completion(request:Request, create:CompletionCreate):
id = str(uuid4()) id = str(uuid4())
if create.model != Config().model_name:
return JSONResponse(
status_code=400,
content={
"error": {
"message": "Model not found",
"code": 404,
"type": "NotFound"
}
})
if create.max_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.max_completion_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_completion_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.temperature<0 or create.temperature>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"temperature must be in [0, 2], got {create.temperature}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.top_p<=0 or create.top_p>1:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"top_p must be in (0, 1], got {create.top_p}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
interface = get_interface() interface = get_interface()
print(f'COMPLETION INPUT:----\n{create.prompt}\n----') print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment