reka.py 1.85 KB
Newer Older
luopl's avatar
luopl committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from vlmeval.smp import *
from vlmeval.api.base import BaseAPI
from time import sleep
import mimetypes


class Reka_Wrapper(BaseAPI):

    is_api: bool = True
    INTERLEAVE: bool = False

    def __init__(self,
                 model: str = 'reka-flash-20240226',
                 key: str = None,
                 retry: int = 10,
                 wait: int = 3,
                 system_prompt: str = None,
                 verbose: bool = True,
                 temperature: float = 0,
                 max_tokens: int = 1024,
                 **kwargs):

        try:
            import reka
        except ImportError:
            raise ImportError('Please install reka by running "pip install reka-api"')

        self.model = model
        default_kwargs = dict(temperature=temperature, request_output_len=max_tokens)
        default_kwargs.update(kwargs)
        self.kwargs = default_kwargs
        if key is not None:
            self.key = key
        else:
            self.key = os.environ.get('REKA_API_KEY', '')
        super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)

    def generate_inner(self, inputs, **kwargs) -> str:
        import reka
        reka.API_KEY = self.key
        dataset = kwargs.pop('dataset', None)
        prompt, image_path = self.message_to_promptimg(inputs, dataset=dataset)
        image_b64 = encode_image_file_to_base64(image_path)

        response = reka.chat(
            model_name=self.model,
            human=prompt,
            media_url=f'data:image/jpeg;base64,{image_b64}',
            **self.kwargs)

        try:
            return 0, response['text'], response
        except Exception as err:
            return -1, self.fail_msg + str(err), response


class Reka(Reka_Wrapper):

    def generate(self, message, dataset=None):
        return super(Reka_Wrapper, self).generate(message)