cogview.py 751 Bytes
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
import streamlit as st
from zhipuai import ZhipuAI
from zhipuai.types.image import GeneratedImage

from .config import COGVIEW_MODEL, ZHIPU_AI_KEY
from .interface import ToolObservation

Rayyyyy's avatar
Rayyyyy committed
8

Rayyyyy's avatar
Rayyyyy committed
9
10
11
12
@st.cache_resource
def get_zhipu_client():
    return ZhipuAI(api_key=ZHIPU_AI_KEY)

Rayyyyy's avatar
Rayyyyy committed
13

Rayyyyy's avatar
Rayyyyy committed
14
15
def map_response(img: GeneratedImage):
    return ToolObservation(
Rayyyyy's avatar
Rayyyyy committed
16
17
        content_type="image",
        text="CogView 已经生成并向用户展示了生成的图片。",
Rayyyyy's avatar
Rayyyyy committed
18
        image_url=img.url,
Rayyyyy's avatar
Rayyyyy committed
19
        role_metadata="cogview_result",
Rayyyyy's avatar
Rayyyyy committed
20
21
    )

Rayyyyy's avatar
Rayyyyy committed
22

Rayyyyy's avatar
Rayyyyy committed
23
24
25
26
def tool_call(prompt: str, session_id: str) -> list[ToolObservation]:
    client = get_zhipu_client()
    response = client.images.generations(model=COGVIEW_MODEL, prompt=prompt).data
    return list(map(map_response, response))