Commit b9da7256 authored by Yash-1511's avatar Yash-1511
Browse files

feat: add tavily web search in web search provider

parent 162643a4
...@@ -73,6 +73,7 @@ from apps.rag.search.serper import search_serper ...@@ -73,6 +73,7 @@ from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo from apps.rag.search.duckduckgo import search_duckduckgo
from apps.rag.search.tavily import search_tavily
from utils.misc import ( from utils.misc import (
calculate_sha256, calculate_sha256,
...@@ -119,6 +120,7 @@ from config import ( ...@@ -119,6 +120,7 @@ from config import (
SERPSTACK_HTTPS, SERPSTACK_HTTPS,
SERPER_API_KEY, SERPER_API_KEY,
SERPLY_API_KEY, SERPLY_API_KEY,
TAVILY_API_KEY,
RAG_WEB_SEARCH_RESULT_COUNT, RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS, RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
RAG_EMBEDDING_OPENAI_BATCH_SIZE, RAG_EMBEDDING_OPENAI_BATCH_SIZE,
...@@ -172,6 +174,7 @@ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY ...@@ -172,6 +174,7 @@ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
...@@ -400,6 +403,7 @@ async def get_rag_config(user=Depends(get_admin_user)): ...@@ -400,6 +403,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
"serpstack_https": app.state.config.SERPSTACK_HTTPS, "serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY, "serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY, "serply_api_key": app.state.config.SERPLY_API_KEY,
"tavily_api_key": app.state.config.TAVILY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
}, },
...@@ -428,6 +432,7 @@ class WebSearchConfig(BaseModel): ...@@ -428,6 +432,7 @@ class WebSearchConfig(BaseModel):
serpstack_https: Optional[bool] = None serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None serper_api_key: Optional[str] = None
serply_api_key: Optional[str] = None serply_api_key: Optional[str] = None
tavily_api_key: Optional[str] = None
result_count: Optional[int] = None result_count: Optional[int] = None
concurrent_requests: Optional[int] = None concurrent_requests: Optional[int] = None
...@@ -479,6 +484,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -479,6 +484,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = ( app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests form_data.web.search.concurrent_requests
...@@ -508,6 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -508,6 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"serpstack_https": app.state.config.SERPSTACK_HTTPS, "serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY, "serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY, "serply_api_key": app.state.config.SERPLY_API_KEY,
"tavily_api_key": app.state.config.TAVILY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT, "result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS, "concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
}, },
...@@ -756,7 +763,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -756,7 +763,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
- SERPSTACK_API_KEY - SERPSTACK_API_KEY
- SERPER_API_KEY - SERPER_API_KEY
- SERPLY_API_KEY - SERPLY_API_KEY
- TAVILY_API_KEY
Args: Args:
query (str): The query to search for query (str): The query to search for
""" """
...@@ -825,6 +832,15 @@ def search_web(engine: str, query: str) -> list[SearchResult]: ...@@ -825,6 +832,15 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
raise Exception("No SERPLY_API_KEY found in environment variables") raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo": elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT) return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
elif engine == "tavily":
if app.state.config.TAVILY_API_KEY:
return search_tavily(
app.state.config.TAVILY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No TAVILY_API_KEY found in environment variables")
else: else:
raise Exception("No search engine API key found in environment variables") raise Exception("No search engine API key found in environment variables")
......
import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Tavily Search API key
query (str): The query to search for
Returns:
List[SearchResult]: A list of search results
"""
url = "https://api.tavily.com/search"
data = {"query": query, "api_key": api_key}
response = requests.post(url, json=data)
response.raise_for_status()
json_response = response.json()
raw_search_results = json_response.get("results", [])
return [
SearchResult(
link=result["url"],
title=result.get("title", ""),
snippet=result.get("content"),
)
for result in raw_search_results[:count]
]
...@@ -942,6 +942,11 @@ SERPLY_API_KEY = PersistentConfig( ...@@ -942,6 +942,11 @@ SERPLY_API_KEY = PersistentConfig(
os.getenv("SERPLY_API_KEY", ""), os.getenv("SERPLY_API_KEY", ""),
) )
TAVILY_API_KEY = PersistentConfig(
"TAVILY_API_KEY",
"rag.web.search.tavily_api_key",
os.getenv("TAVILY_API_KEY", ""),
)
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig( RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT", "RAG_WEB_SEARCH_RESULT_COUNT",
......
...@@ -18,7 +18,8 @@ ...@@ -18,7 +18,8 @@
'serpstack', 'serpstack',
'serper', 'serper',
'serply', 'serply',
'duckduckgo' 'duckduckgo',
'tavily'
]; ];
let youtubeLanguage = 'en'; let youtubeLanguage = 'en';
...@@ -214,6 +215,24 @@ ...@@ -214,6 +215,24 @@
</div> </div>
</div> </div>
</div> </div>
{:else if webConfig.search.engine === 'tavily'}
<div>
<div class=" self-center text-xs font-medium mb-1">
{$i18n.t('Tavily API Key')}
</div>
<div class="flex w-full">
<div class="flex-1">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
type="text"
placeholder={$i18n.t('Enter Tavily API Key')}
bind:value={webConfig.search.tavily_api_key}
autocomplete="off"
/>
</div>
</div>
</div>
{/if} {/if}
</div> </div>
{/if} {/if}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment