Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
0221acd1
Commit
0221acd1
authored
Mar 08, 2024
by
Timothy J. Baek
Browse files
feat: dall-e integration
parent
dd3a4b38
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
296 additions
and
64 deletions
+296
-64
backend/apps/images/main.py
backend/apps/images/main.py
+187
-50
backend/main.py
backend/main.py
+1
-0
src/lib/apis/images/index.ts
src/lib/apis/images/index.ts
+67
-0
src/lib/components/chat/Messages/ResponseMessage.svelte
src/lib/components/chat/Messages/ResponseMessage.svelte
+5
-3
src/lib/components/chat/Settings/Images.svelte
src/lib/components/chat/Settings/Images.svelte
+29
-9
src/lib/components/common/Image.svelte
src/lib/components/common/Image.svelte
+7
-2
No files found.
backend/apps/images/main.py
View file @
0221acd1
...
...
@@ -21,7 +21,16 @@ from utils.utils import (
from
utils.misc
import
calculate_sha256
from
typing
import
Optional
from
pydantic
import
BaseModel
from
config
import
AUTOMATIC1111_BASE_URL
from
pathlib
import
Path
import
uuid
import
base64
import
json
from
config
import
CACHE_DIR
,
AUTOMATIC1111_BASE_URL
IMAGE_CACHE_DIR
=
Path
(
CACHE_DIR
).
joinpath
(
"./image/generations/"
)
IMAGE_CACHE_DIR
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
app
=
FastAPI
()
app
.
add_middleware
(
...
...
@@ -32,25 +41,34 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
ENGINE
=
""
app
.
state
.
ENABLED
=
False
app
.
state
.
OPENAI_API_KEY
=
""
app
.
state
.
MODEL
=
""
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
ENABLED
=
app
.
state
.
AUTOMATIC1111_BASE_URL
!=
""
app
.
state
.
IMAGE_SIZE
=
"512x512"
app
.
state
.
IMAGE_STEPS
=
50
@
app
.
get
(
"/
enabled"
,
response_model
=
bool
)
async
def
get_
enable_status
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
app
.
state
.
ENABLED
@
app
.
get
(
"/
config"
)
async
def
get_
config
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
@
app
.
get
(
"/enabled/toggle"
,
response_model
=
bool
)
async
def
toggle_enabled
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
try
:
r
=
requests
.
head
(
app
.
state
.
AUTOMATIC1111_BASE_URL
)
app
.
state
.
ENABLED
=
not
app
.
state
.
ENABLED
return
app
.
state
.
ENABLED
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
class
ConfigUpdateForm
(
BaseModel
):
engine
:
str
enabled
:
bool
@
app
.
post
(
"/config/update"
)
async
def
update_config
(
form_data
:
ConfigUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
ENGINE
=
form_data
.
engine
app
.
state
.
ENABLED
=
form_data
.
enabled
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
class
UrlUpdateForm
(
BaseModel
):
...
...
@@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel):
@
app
.
get
(
"/url"
)
async
def
get_
openai
_url
(
user
=
Depends
(
get_admin_user
)):
async
def
get_
automatic1111
_url
(
user
=
Depends
(
get_admin_user
)):
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
}
@
app
.
post
(
"/url/update"
)
async
def
update_openai_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
async
def
update_automatic1111_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
if
form_data
.
url
==
""
:
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
else
:
app
.
state
.
AUTOMATIC1111_BASE_URL
=
form_data
.
url
.
strip
(
"/"
)
url
=
form_data
.
url
.
strip
(
"/"
)
try
:
r
=
requests
.
head
(
url
)
app
.
state
.
AUTOMATIC1111_BASE_URL
=
url
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
...
...
@@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use
}
class
OpenAIKeyUpdateForm
(
BaseModel
):
key
:
str
@
app
.
get
(
"/key"
)
async
def
get_openai_key
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
}
@
app
.
post
(
"/key/update"
)
async
def
update_openai_key
(
form_data
:
OpenAIKeyUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
if
form_data
.
key
==
""
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
app
.
state
.
OPENAI_API_KEY
=
form_data
.
key
return
{
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"status"
:
True
,
}
class
ImageSizeUpdateForm
(
BaseModel
):
size
:
str
...
...
@@ -132,9 +181,22 @@ async def update_image_size(
@
app
.
get
(
"/models"
)
def
get_models
(
user
=
Depends
(
get_current_user
)):
try
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
)
models
=
r
.
json
()
return
models
if
app
.
state
.
ENGINE
==
"openai"
:
return
[
{
"id"
:
"dall-e-2"
,
"name"
:
"DALL·E 2"
},
{
"id"
:
"dall-e-3"
,
"name"
:
"DALL·E 3"
},
]
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
)
models
=
r
.
json
()
return
list
(
map
(
lambda
model
:
{
"id"
:
model
[
"title"
],
"name"
:
model
[
"model_name"
]},
models
,
)
)
except
Exception
as
e
:
app
.
state
.
ENABLED
=
False
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
...
...
@@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)):
@
app
.
get
(
"/models/default"
)
async
def
get_default_model
(
user
=
Depends
(
get_admin_user
)):
try
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
return
{
"model"
:
options
[
"sd_model_checkpoint"
]}
if
app
.
state
.
ENGINE
==
"openai"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
"dall-e-2"
}
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
return
{
"model"
:
options
[
"sd_model_checkpoint"
]}
except
Exception
as
e
:
app
.
state
.
ENABLED
=
False
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
...
...
@@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel):
def
set_model_handler
(
model
:
str
):
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
if
model
!=
options
[
"sd_model_checkpoint"
]:
options
[
"sd_model_checkpoint"
]
=
model
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
,
json
=
options
)
if
app
.
state
.
ENGINE
==
"openai"
:
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
if
model
!=
options
[
"sd_model_checkpoint"
]:
options
[
"sd_model_checkpoint"
]
=
model
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
,
json
=
options
)
return
options
return
options
@
app
.
post
(
"/models/default/update"
)
...
...
@@ -185,6 +254,24 @@ class GenerateImageForm(BaseModel):
negative_prompt
:
Optional
[
str
]
=
None
def
save_b64_image
(
b64_str
):
image_id
=
str
(
uuid
.
uuid4
())
file_path
=
IMAGE_CACHE_DIR
.
joinpath
(
f
"
{
image_id
}
.png"
)
try
:
# Split the base64 string to get the actual image data
img_data
=
base64
.
b64decode
(
b64_str
)
# Write the image data to a file
with
open
(
file_path
,
"wb"
)
as
f
:
f
.
write
(
img_data
)
return
image_id
except
Exception
as
e
:
print
(
f
"Error saving image:
{
e
}
"
)
return
None
@
app
.
post
(
"/generations"
)
def
generate_image
(
form_data
:
GenerateImageForm
,
...
...
@@ -194,32 +281,82 @@ def generate_image(
print
(
form_data
)
try
:
if
form_data
.
model
:
set_model_handler
(
form_data
.
model
)
if
app
.
state
.
ENGINE
==
"openai"
:
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
IMAGE_SIZE
.
split
(
"x"
)))
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEY
}
"
headers
[
"Content-Type"
]
=
"application/json"
data
=
{
"prompt"
:
form_data
.
prompt
,
"batch_size"
:
form_data
.
n
,
"width"
:
width
,
"height"
:
height
,
}
data
=
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
!=
""
else
"dall-e-2"
,
"prompt"
:
form_data
.
prompt
,
"n"
:
form_data
.
n
,
"size"
:
form_data
.
size
,
"response_format"
:
"b64_json"
,
}
if
app
.
state
.
IMAGE_STEPS
!=
None
:
data
[
"steps"
]
=
app
.
state
.
IMAGE_STEPS
r
=
requests
.
post
(
url
=
f
"https://api.openai.com/v1/images/generations"
,
json
=
data
,
headers
=
headers
,
)
if
form_data
.
negative_prompt
!=
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
r
.
raise_for_status
()
print
(
data
)
res
=
r
.
json
(
)
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/txt2img"
,
json
=
data
,
)
images
=
[]
for
image
in
res
[
"data"
]:
image_id
=
save_b64_image
(
image
[
"b64_json"
])
images
.
append
({
"url"
:
f
"/cache/image/generations/
{
image_id
}
.png"
})
file_body_path
=
IMAGE_CACHE_DIR
.
joinpath
(
f
"
{
image_id
}
.json"
)
with
open
(
file_body_path
,
"w"
)
as
f
:
json
.
dump
(
data
,
f
)
return
images
else
:
if
form_data
.
model
:
set_model_handler
(
form_data
.
model
)
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
IMAGE_SIZE
.
split
(
"x"
)))
data
=
{
"prompt"
:
form_data
.
prompt
,
"batch_size"
:
form_data
.
n
,
"width"
:
width
,
"height"
:
height
,
}
if
app
.
state
.
IMAGE_STEPS
!=
None
:
data
[
"steps"
]
=
app
.
state
.
IMAGE_STEPS
if
form_data
.
negative_prompt
!=
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/txt2img"
,
json
=
data
,
)
res
=
r
.
json
()
print
(
res
)
images
=
[]
for
image
in
res
[
"images"
]:
image_id
=
save_b64_image
(
image
)
images
.
append
({
"url"
:
f
"/cache/image/generations/
{
image_id
}
.png"
})
file_body_path
=
IMAGE_CACHE_DIR
.
joinpath
(
f
"
{
image_id
}
.json"
)
with
open
(
file_body_path
,
"w"
)
as
f
:
json
.
dump
({
**
data
,
"info"
:
res
[
"info"
]},
f
)
return
images
return
r
.
json
()
except
Exception
as
e
:
print
(
e
)
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
backend/main.py
View file @
0221acd1
...
...
@@ -121,6 +121,7 @@ async def get_app_latest_release_version():
app
.
mount
(
"/static"
,
StaticFiles
(
directory
=
"static"
),
name
=
"static"
)
app
.
mount
(
"/cache"
,
StaticFiles
(
directory
=
"data/cache"
),
name
=
"cache"
)
app
.
mount
(
...
...
src/lib/apis/images/index.ts
View file @
0221acd1
...
...
@@ -72,6 +72,73 @@ export const updateImageGenerationConfig = async (
return
res
;
};
export
const
getOpenAIKey
=
async
(
token
:
string
=
''
)
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
IMAGES_API_BASE_URL
}
/key`
,
{
method
:
'
GET
'
,
headers
:
{
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
...(
token
&&
{
authorization
:
`Bearer
${
token
}
`
})
}
})
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
return
res
.
json
();
})
.
catch
((
err
)
=>
{
console
.
log
(
err
);
if
(
'
detail
'
in
err
)
{
error
=
err
.
detail
;
}
else
{
error
=
'
Server connection failed
'
;
}
return
null
;
});
if
(
error
)
{
throw
error
;
}
return
res
.
OPENAI_API_KEY
;
};
export
const
updateOpenAIKey
=
async
(
token
:
string
=
''
,
key
:
string
)
=>
{
let
error
=
null
;
const
res
=
await
fetch
(
`
${
IMAGES_API_BASE_URL
}
/key/update`
,
{
method
:
'
POST
'
,
headers
:
{
Accept
:
'
application/json
'
,
'
Content-Type
'
:
'
application/json
'
,
...(
token
&&
{
authorization
:
`Bearer
${
token
}
`
})
},
body
:
JSON
.
stringify
({
key
:
key
})
})
.
then
(
async
(
res
)
=>
{
if
(
!
res
.
ok
)
throw
await
res
.
json
();
return
res
.
json
();
})
.
catch
((
err
)
=>
{
console
.
log
(
err
);
if
(
'
detail
'
in
err
)
{
error
=
err
.
detail
;
}
else
{
error
=
'
Server connection failed
'
;
}
return
null
;
});
if
(
error
)
{
throw
error
;
}
return
res
.
OPENAI_API_KEY
;
};
export
const
getAUTOMATIC1111Url
=
async
(
token
:
string
=
''
)
=>
{
let
error
=
null
;
...
...
src/lib/components/chat/Messages/ResponseMessage.svelte
View file @
0221acd1
...
...
@@ -277,13 +277,15 @@
const generateImage = async (message) => {
generatingImage = true;
const res = await imageGenerations(localStorage.token, message.content);
const res = await imageGenerations(localStorage.token, message.content).catch((error) => {
toast.error(error);
});
console.log(res);
if (res) {
message.files = res.
images.
map((image) => ({
message.files = res.map((image) => ({
type: 'image',
url: `
data:image/png;base64,
${image}`
url: `${image
.url
}`
}));
dispatch('save', message);
...
...
src/lib/components/chat/Settings/Images.svelte
View file @
0221acd1
...
...
@@ -14,7 +14,9 @@
updateAUTOMATIC1111Url,
updateImageSize,
getImageSteps,
updateImageSteps
updateImageSteps,
getOpenAIKey,
updateOpenAIKey
} from '$lib/apis/images';
import { getBackendConfig } from '$lib/apis';
const dispatch = createEventDispatcher();
...
...
@@ -27,6 +29,7 @@
let enableImageGeneration = false;
let AUTOMATIC1111_BASE_URL = '';
let OPENAI_API_KEY = '';
let selectedModel = '';
let models = null;
...
...
@@ -97,6 +100,7 @@
enableImageGeneration = res.enabled;
}
AUTOMATIC1111_BASE_URL = await getAUTOMATIC1111Url(localStorage.token);
OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
imageSize = await getImageSize(localStorage.token);
steps = await getImageSteps(localStorage.token);
...
...
@@ -112,6 +116,10 @@
class="flex flex-col h-full justify-between space-y-3 text-sm"
on:submit|preventDefault={async () => {
loading = true;
await updateOpenAIKey(localStorage.token, OPENAI_API_KEY);
await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
await updateDefaultImageGenerationModel(localStorage.token, selectedModel);
await updateImageSize(localStorage.token, imageSize).catch((error) => {
toast.error(error);
...
...
@@ -156,10 +164,12 @@
on:click={() => {
if (imageGenerationEngine === '' && AUTOMATIC1111_BASE_URL === '') {
toast.error('AUTOMATIC1111 Base URL is required.');
enableImageGeneration = false;
} else {
enableImageGeneration = !enableImageGeneration;
updateImageGeneration();
}
updateImageGeneration();
}}
type="button"
>
...
...
@@ -172,21 +182,20 @@
</div>
</div>
</div>
<hr class=" dark:border-gray-700" />
{#if imageGenerationEngine === ''}
<hr class=" dark:border-gray-700" />
<div class=" mb-2.5 text-sm font-medium">AUTOMATIC1111 Base URL</div>
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-8
0
0 outline-none"
class="w-full rounded
-lg
py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-8
5
0 outline-none"
placeholder="Enter URL (e.g. http://127.0.0.1:7860/)"
bind:value={AUTOMATIC1111_BASE_URL}
/>
</div>
<button
class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded transition"
class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded
-lg
transition"
type="button"
on:click={() => {
// updateOllamaAPIUrlHandler();
...
...
@@ -219,6 +228,17 @@
(e.g. `sh webui.sh --api`)
</a>
</div>
{:else if imageGenerationEngine === 'openai'}
<div class=" mb-2.5 text-sm font-medium">OpenAI API Key</div>
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
placeholder="Enter API Key"
bind:value={OPENAI_API_KEY}
/>
</div>
</div>
{/if}
{#if enableImageGeneration}
...
...
@@ -229,7 +249,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<select
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-8
0
0 outline-none"
class="w-full rounded
-lg
py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-8
5
0 outline-none"
bind:value={selectedModel}
placeholder="Select a model"
>
...
...
@@ -249,7 +269,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-8
0
0 outline-none"
class="w-full rounded
-lg
py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-8
5
0 outline-none"
placeholder="Enter Image Size (e.g. 512x512)"
bind:value={imageSize}
/>
...
...
@@ -262,7 +282,7 @@
<div class="flex w-full">
<div class="flex-1 mr-2">
<input
class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-8
0
0 outline-none"
class="w-full rounded
-lg
py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-8
5
0 outline-none"
placeholder="Enter Number of Steps (e.g. 50)"
bind:value={steps}
/>
...
...
src/lib/components/common/Image.svelte
View file @
0221acd1
<script lang="ts">
import { WEBUI_BASE_URL } from '$lib/constants';
import ImagePreview from './ImagePreview.svelte';
export let src = '';
export let alt = '';
let _src = '';
$: _src = src.startsWith('/') ? `${WEBUI_BASE_URL}${src}` : src;
let showImagePreview = false;
</script>
<ImagePreview bind:show={showImagePreview}
{
src} {alt} />
<ImagePreview bind:show={showImagePreview}
src={_
src} {alt} />
<button
on:click={() => {
console.log('image preview');
showImagePreview = true;
}}
>
<img
{
src} {alt} class=" max-h-96 rounded-lg" draggable="false" />
<img
src={_
src} {alt} class=" max-h-96 rounded-lg" draggable="false" />
</button>
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment