"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "3933db2c91e635da52a28a9e7e2927f551b2fee6"
Commit febab588 authored by Timothy J. Baek's avatar Timothy J. Baek
Browse files

feat: memory integration

parent 2638ae6a
...@@ -71,7 +71,7 @@ class QueryMemoryForm(BaseModel): ...@@ -71,7 +71,7 @@ class QueryMemoryForm(BaseModel):
content: str content: str
@router.post("/query", response_model=Optional[MemoryModel]) @router.post("/query")
async def query_memory( async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user) request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
): ):
......
...@@ -26,8 +26,8 @@ ...@@ -26,8 +26,8 @@
if (res) { if (res) {
console.log(res); console.log(res);
toast.success('Memory added successfully'); toast.success('Memory added successfully');
content = '';
show = false; show = false;
dispatch('save'); dispatch('save');
} }
......
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
import { LITELLM_API_BASE_URL, OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants'; import { LITELLM_API_BASE_URL, OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants';
import { WEBUI_BASE_URL } from '$lib/constants'; import { WEBUI_BASE_URL } from '$lib/constants';
import { createOpenAITextStream } from '$lib/apis/streaming'; import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
...@@ -254,6 +255,26 @@ ...@@ -254,6 +255,26 @@
const sendPrompt = async (prompt, parentId, modelId = null) => { const sendPrompt = async (prompt, parentId, modelId = null) => {
const _chatId = JSON.parse(JSON.stringify($chatId)); const _chatId = JSON.parse(JSON.stringify($chatId));
let userContext = null;
if ($settings?.memory ?? false) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
console.log(userContext);
}
}
await Promise.all( await Promise.all(
(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map( (modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
async (modelId) => { async (modelId) => {
...@@ -270,6 +291,7 @@ ...@@ -270,6 +291,7 @@
role: 'assistant', role: 'assistant',
content: '', content: '',
model: model.id, model: model.id,
userContext: userContext,
timestamp: Math.floor(Date.now() / 1000) // Unix epoch timestamp: Math.floor(Date.now() / 1000) // Unix epoch
}; };
...@@ -311,10 +333,13 @@ ...@@ -311,10 +333,13 @@
scrollToBottom(); scrollToBottom();
const messagesBody = [ const messagesBody = [
$settings.system $settings.system || responseMessage?.userContext
? { ? {
role: 'system', role: 'system',
content: $settings.system content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
: ''
} }
: undefined, : undefined,
...messages ...messages
...@@ -567,10 +592,13 @@ ...@@ -567,10 +592,13 @@
model: model.id, model: model.id,
stream: true, stream: true,
messages: [ messages: [
$settings.system $settings.system || responseMessage?.userContext
? { ? {
role: 'system', role: 'system',
content: $settings.system content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
: ''
} }
: undefined, : undefined,
...messages ...messages
......
...@@ -43,6 +43,7 @@ ...@@ -43,6 +43,7 @@
WEBUI_BASE_URL WEBUI_BASE_URL
} from '$lib/constants'; } from '$lib/constants';
import { createOpenAITextStream } from '$lib/apis/streaming'; import { createOpenAITextStream } from '$lib/apis/streaming';
import { queryMemory } from '$lib/apis/memories';
const i18n = getContext('i18n'); const i18n = getContext('i18n');
...@@ -260,6 +261,26 @@ ...@@ -260,6 +261,26 @@
const sendPrompt = async (prompt, parentId, modelId = null) => { const sendPrompt = async (prompt, parentId, modelId = null) => {
const _chatId = JSON.parse(JSON.stringify($chatId)); const _chatId = JSON.parse(JSON.stringify($chatId));
let userContext = null;
if ($settings?.memory ?? false) {
const res = await queryMemory(localStorage.token, prompt).catch((error) => {
toast.error(error);
return null;
});
if (res) {
userContext = res.documents.reduce((acc, doc, index) => {
const createdAtTimestamp = res.metadatas[index][0].created_at;
const createdAtDate = new Date(createdAtTimestamp * 1000).toISOString().split('T')[0];
acc.push(`${index + 1}. [${createdAtDate}]. ${doc[0]}`);
return acc;
}, []);
console.log(userContext);
}
}
await Promise.all( await Promise.all(
(modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map( (modelId ? [modelId] : atSelectedModel !== '' ? [atSelectedModel.id] : selectedModels).map(
async (modelId) => { async (modelId) => {
...@@ -317,10 +338,13 @@ ...@@ -317,10 +338,13 @@
scrollToBottom(); scrollToBottom();
const messagesBody = [ const messagesBody = [
$settings.system $settings.system || responseMessage?.userContext
? { ? {
role: 'system', role: 'system',
content: $settings.system content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
: ''
} }
: undefined, : undefined,
...messages ...messages
...@@ -573,10 +597,13 @@ ...@@ -573,10 +597,13 @@
model: model.id, model: model.id,
stream: true, stream: true,
messages: [ messages: [
$settings.system $settings.system || responseMessage?.userContext
? { ? {
role: 'system', role: 'system',
content: $settings.system content:
$settings.system + (responseMessage?.userContext ?? null)
? `\n\nUser Context:\n${responseMessage.userContext.join('\n')}`
: ''
} }
: undefined, : undefined,
...messages ...messages
...@@ -705,6 +732,7 @@ ...@@ -705,6 +732,7 @@
} catch (error) { } catch (error) {
await handleOpenAIError(error, null, model, responseMessage); await handleOpenAIError(error, null, model, responseMessage);
} }
messages = messages;
stopResponseFlag = false; stopResponseFlag = false;
await tick(); await tick();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment