Commit bc3dd34d authored by Jannik Streidl's avatar Jannik Streidl
Browse files

collection query fix

parent 1846c1e8
...@@ -29,11 +29,13 @@ from langchain_community.document_loaders import ( ...@@ -29,11 +29,13 @@ from langchain_community.document_loaders import (
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import uuid import uuid
from utils.misc import calculate_sha256, calculate_sha256_string from utils.misc import calculate_sha256, calculate_sha256_string
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP from config import UPLOAD_DIR, SENTENCE_TRANSFORMER_EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
...@@ -113,12 +115,12 @@ def query_doc( ...@@ -113,12 +115,12 @@ def query_doc(
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=form_data.collection_name,
embedding_function=sentence_transformer_ef embedding_function=sentence_transformer_ef,
) )
else: else:
# for local development use the default model # for local development use the default model
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=form_data.collection_name,
) )
result = collection.query(query_texts=[form_data.query], n_results=form_data.k) result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result return result
...@@ -191,16 +193,16 @@ def query_collection( ...@@ -191,16 +193,16 @@ def query_collection(
for collection_name in form_data.collection_names: for collection_name in form_data.collection_names:
try: try:
if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ: if 'DOCKER_SENTENCE_TRANSFORMER_EMBED_MODEL' in os.environ:
# if you use docker use the model from the environment variable # if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=collection_name,
embedding_function=sentence_transformer_ef embedding_function=sentence_transformer_ef,
) )
else: else:
# for local development use the default model # for local development use the default model
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name, name=collection_name,
) )
result = collection.query( result = collection.query(
query_texts=[form_data.query], n_results=form_data.k query_texts=[form_data.query], n_results=form_data.k
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment