gpt3.py 2.55 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import os
import transformers
Jason Phang's avatar
lib  
Jason Phang committed
3
4
from lm_eval.base import LM
from lm_eval import utils
Jason Phang's avatar
gpt3  
Jason Phang committed
5
6
7


class GPT3LM(LM):
Jason Phang's avatar
Jason Phang committed
8
9
10
11
12
13
14
15
16
17
18
19

    MAX_LENGTH = 2048

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
20
21
        self.engine = engine
        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
Jason Phang's avatar
Jason Phang committed
22
23
        self.truncate = truncate

Jason Phang's avatar
gpt3  
Jason Phang committed
24
25
26
27
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]

    @classmethod
Jason Phang's avatar
lib  
Jason Phang committed
28
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
gpt3  
Jason Phang committed
29
30
31
32
        args = utils.simple_parse_args_string(arg_string)
        return cls(engine=args.get("engine", "davinci"))

    def generate(self, context, max_gen_length):
Jason Phang's avatar
Jason Phang committed
33
34
35
36
37
38
        import openai
        if self.truncate:
            prompt = self.smart_truncate(context, buffer=max_gen_length)
        else:
            prompt = context

Jason Phang's avatar
gpt3  
Jason Phang committed
39
40
        response = openai.Completion.create(
            engine=self.engine,
Jason Phang's avatar
Jason Phang committed
41
            prompt=prompt,
Jason Phang's avatar
gpt3  
Jason Phang committed
42
43
44
45
46
            max_tokens=max_gen_length,
            temperature=0.0,
        )
        return response.choices[0]["text"]

Jason Phang's avatar
checkin  
Jason Phang committed
47
    def loglikelihood(self, context, continuation):
Jason Phang's avatar
Jason Phang committed
48
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
49
50
51
52
53
        full_text = context + continuation
        full_text_length = len(self.tokenizer.tokenize(full_text))
        context_length = len(self.tokenizer.tokenize(context))
        continuation_length = len(self.tokenizer.tokenize(continuation))
        assert full_text_length == context_length + continuation_length
Jason Phang's avatar
Jason Phang committed
54
55
56
57
        if self.truncate:
            prompt = self.smart_truncate(full_text, buffer=0)
        else:
            prompt = full_text
Jason Phang's avatar
gpt3  
Jason Phang committed
58
59
        response = openai.Completion.create(
            engine=self.engine,
Jason Phang's avatar
Jason Phang committed
60
            prompt=prompt,
Jason Phang's avatar
lib  
Jason Phang committed
61
            echo=True,
Jason Phang's avatar
gpt3  
Jason Phang committed
62
63
64
65
66
67
            max_tokens=0, temperature=0.0,
            logprobs=0,
        )
        logprobs = response.choices[0]["logprobs"]["token_logprobs"]
        continuation_logprobs = logprobs[-continuation_length:]
        return sum(continuation_logprobs)
Jason Phang's avatar
Jason Phang committed
68
69
70
71
72
73
74
75
76
77

    def smart_truncate(self, string, buffer=1):
        tokens = self.tokenizer.tokenize(string)
        available_length = self.MAX_LENGTH - 1 - buffer  # OpenAI adds 1 token
        kept_tokens = tokens[-available_length:]
        new_string = self.tokenizer.convert_tokens_to_string(kept_tokens)
        return new_string

    def num_tokens(self, string):
        return len(self.tokenizer.tokenize(string))