"test/config_test/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "1500458a073792621a7460abb4e4f6be918dca13"
Commit 582d11f1 authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

refac: RAG_EMBEDDING_MODEL_PATH removed

parent cb2158a7
...@@ -80,16 +80,15 @@ app.state.RAG_TEMPLATE = RAG_TEMPLATE ...@@ -80,16 +80,15 @@ app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_EMBEDDING_MODEL_PATH = get_embedding_model_path(
app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
)
app.state.TOP_K = 4 app.state.TOP_K = 4
app.state.sentence_transformer_ef = ( app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction( embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL_PATH, model_name=get_embedding_model_path(
app.state.RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL_AUTO_UPDATE
),
device=DEVICE_TYPE, device=DEVICE_TYPE,
) )
) )
...@@ -130,7 +129,6 @@ async def get_embedding_model(user=Depends(get_admin_user)): ...@@ -130,7 +129,6 @@ async def get_embedding_model(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH,
} }
...@@ -143,43 +141,32 @@ async def update_embedding_model( ...@@ -143,43 +141,32 @@ async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.debug(f"form_data.embedding_model: {form_data.embedding_model}")
log.info( log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
embedding_model_path = None
sentence_transformer_ef = None
try: try:
embedding_model_path = get_embedding_model_path(form_data.embedding_model, True) sentence_transformer_ef = (
if app.state.RAG_EMBEDDING_MODEL_PATH != embedding_model_path: embedding_functions.SentenceTransformerEmbeddingFunction(
sentence_transformer_ef = ( model_name=get_embedding_model_path(form_data.embedding_model, True),
embedding_functions.SentenceTransformerEmbeddingFunction( device=DEVICE_TYPE,
model_name=embedding_model_path,
device=DEVICE_TYPE,
)
) )
except Exception as e:
log.exception(f"Problem updating embedding model: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(e),
) )
if sentence_transformer_ef:
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.RAG_EMBEDDING_MODEL_PATH = embedding_model_path
app.state.sentence_transformer_ef = sentence_transformer_ef app.state.sentence_transformer_ef = sentence_transformer_ef
log.debug( return {
f"app.state.RAG_EMBEDDING_MODEL_PATH: {app.state.RAG_EMBEDDING_MODEL_PATH}" "status": True,
) "embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
return { except Exception as e:
"status": sentence_transformer_ef != None, log.exception(f"Problem updating embedding model: {e}")
"embedding_model": app.state.RAG_EMBEDDING_MODEL, raise HTTPException(
"embedding_model_path": app.state.RAG_EMBEDDING_MODEL_PATH, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
} detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.get("/config") @app.get("/config")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment