Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
7b5f434a
Unverified
Commit
7b5f434a
authored
Jun 13, 2024
by
Que Nguyen
Committed by
GitHub
Jun 13, 2024
Browse files
Implement domain whitelisting for web search results
parent
a382e82d
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
63 additions
and
29 deletions
+63
-29
backend/apps/rag/main.py
backend/apps/rag/main.py
+9
-1
backend/apps/rag/search/brave.py
backend/apps/rag/search/brave.py
+5
-4
backend/apps/rag/search/duckduckgo.py
backend/apps/rag/search/duckduckgo.py
+6
-5
backend/apps/rag/search/google_pse.py
backend/apps/rag/search/google_pse.py
+5
-4
backend/apps/rag/search/main.py
backend/apps/rag/search/main.py
+11
-1
backend/apps/rag/search/searxng.py
backend/apps/rag/search/searxng.py
+3
-2
backend/apps/rag/search/serper.py
backend/apps/rag/search/serper.py
+5
-4
backend/apps/rag/search/serply.py
backend/apps/rag/search/serply.py
+5
-4
backend/apps/rag/search/serpstack.py
backend/apps/rag/search/serpstack.py
+5
-4
backend/config.py
backend/config.py
+9
-0
No files found.
backend/apps/rag/main.py
View file @
7b5f434a
...
...
@@ -111,6 +111,7 @@ from config import (
YOUTUBE_LOADER_LANGUAGE
,
ENABLE_RAG_WEB_SEARCH
,
RAG_WEB_SEARCH_ENGINE
,
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
,
SEARXNG_QUERY_URL
,
GOOGLE_PSE_API_KEY
,
GOOGLE_PSE_ENGINE_ID
,
...
...
@@ -163,6 +164,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
app
.
state
.
config
.
ENABLE_RAG_WEB_SEARCH
=
ENABLE_RAG_WEB_SEARCH
app
.
state
.
config
.
RAG_WEB_SEARCH_ENGINE
=
RAG_WEB_SEARCH_ENGINE
app
.
state
.
config
.
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
=
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
app
.
state
.
config
.
SEARXNG_QUERY_URL
=
SEARXNG_QUERY_URL
app
.
state
.
config
.
GOOGLE_PSE_API_KEY
=
GOOGLE_PSE_API_KEY
...
...
@@ -768,6 +770,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app
.
state
.
config
.
SEARXNG_QUERY_URL
,
query
,
app
.
state
.
config
.
RAG_WEB_SEARCH_RESULT_COUNT
,
app
.
state
.
config
.
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
)
else
:
raise
Exception
(
"No SEARXNG_QUERY_URL found in environment variables"
)
...
...
@@ -781,6 +784,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app
.
state
.
config
.
GOOGLE_PSE_ENGINE_ID
,
query
,
app
.
state
.
config
.
RAG_WEB_SEARCH_RESULT_COUNT
,
app
.
state
.
config
.
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
)
else
:
raise
Exception
(
...
...
@@ -792,6 +796,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app
.
state
.
config
.
BRAVE_SEARCH_API_KEY
,
query
,
app
.
state
.
config
.
RAG_WEB_SEARCH_RESULT_COUNT
,
app
.
state
.
config
.
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
)
else
:
raise
Exception
(
"No BRAVE_SEARCH_API_KEY found in environment variables"
)
...
...
@@ -801,6 +806,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app
.
state
.
config
.
SERPSTACK_API_KEY
,
query
,
app
.
state
.
config
.
RAG_WEB_SEARCH_RESULT_COUNT
,
app
.
state
.
config
.
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
,
https_enabled
=
app
.
state
.
config
.
SERPSTACK_HTTPS
,
)
else
:
...
...
@@ -811,6 +817,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app
.
state
.
config
.
SERPER_API_KEY
,
query
,
app
.
state
.
config
.
RAG_WEB_SEARCH_RESULT_COUNT
,
app
.
state
.
config
.
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
)
else
:
raise
Exception
(
"No SERPER_API_KEY found in environment variables"
)
...
...
@@ -820,11 +827,12 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
app
.
state
.
config
.
SERPLY_API_KEY
,
query
,
app
.
state
.
config
.
RAG_WEB_SEARCH_RESULT_COUNT
,
app
.
state
.
config
.
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
)
else
:
raise
Exception
(
"No SERPLY_API_KEY found in environment variables"
)
elif
engine
==
"duckduckgo"
:
return
search_duckduckgo
(
query
,
app
.
state
.
config
.
RAG_WEB_SEARCH_RESULT_COUNT
)
return
search_duckduckgo
(
query
,
app
.
state
.
config
.
RAG_WEB_SEARCH_RESULT_COUNT
,
app
.
state
.
config
.
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
)
else
:
raise
Exception
(
"No search engine API key found in environment variables"
)
...
...
backend/apps/rag/search/brave.py
View file @
7b5f434a
import
logging
from
typing
import
List
import
requests
from
apps.rag.search.main
import
SearchResult
from
apps.rag.search.main
import
SearchResult
,
filter_by_whitelist
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
def
search_brave
(
api_key
:
str
,
query
:
str
,
count
:
int
)
->
list
[
SearchResult
]:
def
search_brave
(
api_key
:
str
,
query
:
str
,
whitelist
:
List
[
str
],
count
:
int
)
->
list
[
SearchResult
]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
...
...
@@ -29,9 +29,10 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
json_response
=
response
.
json
()
results
=
json_response
.
get
(
"web"
,
{}).
get
(
"results"
,
[])
filtered_results
=
filter_by_whitelist
(
results
,
whitelist
)
return
[
SearchResult
(
link
=
result
[
"url"
],
title
=
result
.
get
(
"title"
),
snippet
=
result
.
get
(
"snippet"
)
)
for
result
in
results
[:
count
]
for
result
in
filtered_
results
[:
count
]
]
backend/apps/rag/search/duckduckgo.py
View file @
7b5f434a
import
logging
from
apps.rag.search.main
import
SearchResult
from
typing
import
List
from
apps.rag.search.main
import
SearchResult
,
filter_by_whitelist
from
duckduckgo_search
import
DDGS
from
config
import
SRC_LOG_LEVELS
...
...
@@ -8,7 +8,7 @@ log = logging.getLogger(__name__)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
def
search_duckduckgo
(
query
:
str
,
count
:
int
)
->
list
[
SearchResult
]:
def
search_duckduckgo
(
query
:
str
,
count
:
int
,
whitelist
:
List
[
str
]
)
->
list
[
SearchResult
]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
...
...
@@ -41,6 +41,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
snippet
=
result
.
get
(
"body"
),
)
)
print
(
results
)
# print(results)
filtered_results
=
filter_by_whitelist
(
results
,
whitelist
)
# Return the list of search results
return
results
return
filtered_
results
backend/apps/rag/search/google_pse.py
View file @
7b5f434a
import
json
import
logging
from
typing
import
List
import
requests
from
apps.rag.search.main
import
SearchResult
from
apps.rag.search.main
import
SearchResult
,
filter_by_whitelist
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
...
...
@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def
search_google_pse
(
api_key
:
str
,
search_engine_id
:
str
,
query
:
str
,
count
:
int
api_key
:
str
,
search_engine_id
:
str
,
query
:
str
,
count
:
int
,
whitelist
:
List
[
str
]
)
->
list
[
SearchResult
]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
...
...
@@ -35,11 +35,12 @@ def search_google_pse(
json_response
=
response
.
json
()
results
=
json_response
.
get
(
"items"
,
[])
filtered_results
=
filter_by_whitelist
(
results
,
whitelist
)
return
[
SearchResult
(
link
=
result
[
"link"
],
title
=
result
.
get
(
"title"
),
snippet
=
result
.
get
(
"snippet"
),
)
for
result
in
results
for
result
in
filtered_
results
]
backend/apps/rag/search/main.py
View file @
7b5f434a
from
typing
import
Optional
from
urllib.parse
import
urlparse
from
pydantic
import
BaseModel
def
filter_by_whitelist
(
results
,
whitelist
):
if
not
whitelist
:
return
results
filtered_results
=
[]
for
result
in
results
:
domain
=
urlparse
(
result
[
"url"
]).
netloc
if
any
(
domain
.
endswith
(
whitelisted_domain
)
for
whitelisted_domain
in
whitelist
):
filtered_results
.
append
(
result
)
return
filtered_results
class
SearchResult
(
BaseModel
):
link
:
str
title
:
Optional
[
str
]
...
...
backend/apps/rag/search/searxng.py
View file @
7b5f434a
...
...
@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def
search_searxng
(
query_url
:
str
,
query
:
str
,
count
:
int
,
**
kwargs
query_url
:
str
,
query
:
str
,
count
:
int
,
whitelist
:
List
[
str
],
**
kwargs
)
->
List
[
SearchResult
]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
...
...
@@ -78,9 +78,10 @@ def search_searxng(
json_response
=
response
.
json
()
results
=
json_response
.
get
(
"results"
,
[])
sorted_results
=
sorted
(
results
,
key
=
lambda
x
:
x
.
get
(
"score"
,
0
),
reverse
=
True
)
filtered_results
=
filter_by_whitelist
(
sorted_results
,
whitelist
)
return
[
SearchResult
(
link
=
result
[
"url"
],
title
=
result
.
get
(
"title"
),
snippet
=
result
.
get
(
"content"
)
)
for
result
in
sort
ed_results
[:
count
]
for
result
in
filter
ed_results
[:
count
]
]
backend/apps/rag/search/serper.py
View file @
7b5f434a
import
json
import
logging
from
typing
import
List
import
requests
from
apps.rag.search.main
import
SearchResult
from
apps.rag.search.main
import
SearchResult
,
filter_by_whitelist
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
log
.
setLevel
(
SRC_LOG_LEVELS
[
"RAG"
])
def
search_serper
(
api_key
:
str
,
query
:
str
,
count
:
int
)
->
list
[
SearchResult
]:
def
search_serper
(
api_key
:
str
,
query
:
str
,
count
:
int
,
whitelist
:
List
[
str
]
)
->
list
[
SearchResult
]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
...
...
@@ -29,11 +29,12 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
results
=
sorted
(
json_response
.
get
(
"organic"
,
[]),
key
=
lambda
x
:
x
.
get
(
"position"
,
0
)
)
filtered_results
=
filter_by_whitelist
(
results
,
whitelist
)
return
[
SearchResult
(
link
=
result
[
"link"
],
title
=
result
.
get
(
"title"
),
snippet
=
result
.
get
(
"description"
),
)
for
result
in
results
[:
count
]
for
result
in
filtered_
results
[:
count
]
]
backend/apps/rag/search/serply.py
View file @
7b5f434a
import
json
import
logging
from
typing
import
List
import
requests
from
urllib.parse
import
urlencode
from
apps.rag.search.main
import
SearchResult
from
apps.rag.search.main
import
SearchResult
,
filter_by_whitelist
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
...
...
@@ -15,6 +15,7 @@ def search_serply(
api_key
:
str
,
query
:
str
,
count
:
int
,
whitelist
:
List
[
str
],
hl
:
str
=
"us"
,
limit
:
int
=
10
,
device_type
:
str
=
"desktop"
,
...
...
@@ -57,12 +58,12 @@ def search_serply(
results
=
sorted
(
json_response
.
get
(
"results"
,
[]),
key
=
lambda
x
:
x
.
get
(
"realPosition"
,
0
)
)
filtered_results
=
filter_by_whitelist
(
results
,
whitelist
)
return
[
SearchResult
(
link
=
result
[
"link"
],
title
=
result
.
get
(
"title"
),
snippet
=
result
.
get
(
"description"
),
)
for
result
in
results
[:
count
]
for
result
in
filtered_
results
[:
count
]
]
backend/apps/rag/search/serpstack.py
View file @
7b5f434a
import
json
import
logging
from
typing
import
List
import
requests
from
apps.rag.search.main
import
SearchResult
from
apps.rag.search.main
import
SearchResult
,
filter_by_whitelist
from
config
import
SRC_LOG_LEVELS
log
=
logging
.
getLogger
(
__name__
)
...
...
@@ -11,7 +11,7 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
def
search_serpstack
(
api_key
:
str
,
query
:
str
,
count
:
int
,
https_enabled
:
bool
=
True
api_key
:
str
,
query
:
str
,
count
:
int
,
whitelist
:
List
[
str
],
https_enabled
:
bool
=
True
)
->
list
[
SearchResult
]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
...
...
@@ -35,9 +35,10 @@ def search_serpstack(
results
=
sorted
(
json_response
.
get
(
"organic_results"
,
[]),
key
=
lambda
x
:
x
.
get
(
"position"
,
0
)
)
filtered_results
=
filter_by_whitelist
(
results
,
whitelist
)
return
[
SearchResult
(
link
=
result
[
"url"
],
title
=
result
.
get
(
"title"
),
snippet
=
result
.
get
(
"snippet"
)
)
for
result
in
results
[:
count
]
for
result
in
filtered_
results
[:
count
]
]
backend/config.py
View file @
7b5f434a
...
...
@@ -894,6 +894,15 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
os
.
getenv
(
"RAG_WEB_SEARCH_ENGINE"
,
""
),
)
RAG_WEB_SEARCH_WHITE_LIST_DOMAINS
=
PersistentConfig
(
"RAG_WEB_SEARCH_WHITE_LIST_DOMAINS"
,
"rag.rag_web_search_white_list_domains"
,
[
# "example.com",
# "anotherdomain.com",
],
)
SEARXNG_QUERY_URL
=
PersistentConfig
(
"SEARXNG_QUERY_URL"
,
"rag.web.search.searxng_query_url"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment