Unverified Commit 449a83df authored by wang jiahao's avatar wang jiahao Committed by GitHub
Browse files

Merge pull request #1183 from kvcache-ai/check-para

add check-para
parents 7e4813e8 f7d93931
...@@ -13,6 +13,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role ...@@ -13,6 +13,7 @@ from ktransformers.server.schemas.endpoints.chat import RawUsage, Role
from ktransformers.server.backend.base import BackendInterfaceBase from ktransformers.server.backend.base import BackendInterfaceBase
from ktransformers.server.config.config import Config from ktransformers.server.config.config import Config
from ktransformers.server.config.log import logger from ktransformers.server.config.log import logger
from fastapi.responses import JSONResponse
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk, CompletionUsage
# Define own data structure instead of importing from OpenAI # Define own data structure instead of importing from OpenAI
...@@ -137,7 +138,57 @@ async def chat_completion(request: Request, create: ChatCompletionCreate): ...@@ -137,7 +138,57 @@ async def chat_completion(request: Request, create: ChatCompletionCreate):
# Process messages with tool functionality if needed # Process messages with tool functionality if needed
enhanced_messages = list(create.messages) enhanced_messages = list(create.messages)
if create.max_tokens<0 or create.max_completion_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.temperature<0 or create.temperature>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"temperature must be in [0, 2], got {create.temperature}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.top_p<=0 or create.top_p>1:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"top_p must be in (0, 1], got {create.top_p}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.frequency_penalty<-2 or create.frequency_penalty>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"frequency_penalty must be in [-2, 2], got {create.frequency_penalty}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.presence_penalty<-2 or create.presence_penalty>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"presence_penalty must be in [-2, 2], got {create.presence_penalty}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
# Check if tools are present # Check if tools are present
has_tools = create.tools and len(create.tools) > 0 has_tools = create.tools and len(create.tools) > 0
......
...@@ -7,13 +7,53 @@ from ktransformers.server.utils.create_interface import get_interface ...@@ -7,13 +7,53 @@ from ktransformers.server.utils.create_interface import get_interface
from ktransformers.server.schemas.assistants.streaming import stream_response from ktransformers.server.schemas.assistants.streaming import stream_response
from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject from ktransformers.server.schemas.legacy.completions import CompletionCreate,CompletionObject
from ktransformers.server.schemas.endpoints.chat import RawUsage from ktransformers.server.schemas.endpoints.chat import RawUsage
from fastapi.responses import JSONResponse
from ktransformers.server.config.config import Config
router = APIRouter() router = APIRouter()
@router.post("/completions",tags=['openai']) @router.post("/completions",tags=['openai'])
async def create_completion(request:Request, create:CompletionCreate): async def create_completion(request:Request, create:CompletionCreate):
id = str(uuid4()) id = str(uuid4())
if create.max_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.max_completion_tokens<0:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"max_new_tokens must be at least 0, got {create.max_completion_tokens}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.temperature<0 or create.temperature>2:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"temperature must be in [0, 2], got {create.temperature}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
if create.top_p<=0 or create.top_p>1:
return JSONResponse(
status_code=400,
content={
"object": "error",
"message": f"top_p must be in (0, 1], got {create.top_p}.",
"type": "BadRequestError",
"param": None,
"code": 400
})
interface = get_interface() interface = get_interface()
print(f'COMPLETION INPUT:----\n{create.prompt}\n----') print(f'COMPLETION INPUT:----\n{create.prompt}\n----')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment