browser.py 1.67 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
"""
Simple browser tool.

# Usage

Please start the backend browser server according to the instructions in the README.
"""

import re
Rayyyyy's avatar
Rayyyyy committed
10
11
12
from dataclasses import dataclass
from pprint import pprint

Rayyyyy's avatar
Rayyyyy committed
13
14
15
16
17
18
import requests
import streamlit as st

from .config import BROWSER_SERVER_URL
from .interface import ToolObservation

Rayyyyy's avatar
Rayyyyy committed
19

Rayyyyy's avatar
Rayyyyy committed
20
21
QUOTE_REGEX = re.compile(r"\[(\d+)†(.+?)\]")

Rayyyyy's avatar
Rayyyyy committed
22

Rayyyyy's avatar
Rayyyyy committed
23
24
25
26
27
@dataclass
class Quote:
    title: str
    url: str

Rayyyyy's avatar
Rayyyyy committed
28

Rayyyyy's avatar
Rayyyyy committed
29
30
31
32
33
34
35
36
37
# Quotes for displaying reference
if "quotes" not in st.session_state:
    st.session_state.quotes = {}

quotes: dict[str, Quote] = st.session_state.quotes


def map_response(response: dict) -> ToolObservation:
    # Save quotes for reference
Rayyyyy's avatar
Rayyyyy committed
38
    print("===BROWSER_RESPONSE===")
Rayyyyy's avatar
Rayyyyy committed
39
40
41
    pprint(response)
    role_metadata = response.get("roleMetadata")
    metadata = response.get("metadata")
Rayyyyy's avatar
Rayyyyy committed
42
43

    if role_metadata.split()[0] == "quote_result" and metadata:
Rayyyyy's avatar
Rayyyyy committed
44
        quote_id = QUOTE_REGEX.search(role_metadata.split()[1]).group(1)
Rayyyyy's avatar
Rayyyyy committed
45
46
47
48
49
        quote: dict[str, str] = metadata["metadata_list"][0]
        quotes[quote_id] = Quote(quote["title"], quote["url"])
    elif role_metadata == "browser_result" and metadata:
        for i, quote in enumerate(metadata["metadata_list"]):
            quotes[str(i)] = Quote(quote["title"], quote["url"])
Rayyyyy's avatar
Rayyyyy committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

    return ToolObservation(
        content_type=response.get("contentType"),
        text=response.get("result"),
        role_metadata=role_metadata,
        metadata=metadata,
    )


def tool_call(code: str, session_id: str) -> list[ToolObservation]:
    request = {
        "session_id": session_id,
        "action": code,
    }
    response = requests.post(BROWSER_SERVER_URL, json=request).json()
    return list(map(map_response, response))