gpt3.py 1.53 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
import os
import transformers
Jason Phang's avatar
lib  
Jason Phang committed
3
4
from lm_eval.base import LM
from lm_eval import utils
Jason Phang's avatar
gpt3  
Jason Phang committed
5
6
7


class GPT3LM(LM):
Jason Phang's avatar
Jason Phang committed
8
9
10
11
12
13
14
15
16
17
18
19

    MAX_LENGTH = 2048

    def __init__(self, engine, truncate=False):
        """

        :param engine: str
            OpenAI API engine (e.g. davinci)
        :param truncate: bool
            Truncate input if too long (if False and input is too long, throw error)
        """
        import openai
Jason Phang's avatar
gpt3  
Jason Phang committed
20
        self.engine = engine
21
        self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
Jason Phang's avatar
Jason Phang committed
22
23
        self.truncate = truncate

Jason Phang's avatar
gpt3  
Jason Phang committed
24
25
26
27
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]

    @classmethod
Jason Phang's avatar
lib  
Jason Phang committed
28
    def create_from_arg_string(cls, arg_string):
Jason Phang's avatar
gpt3  
Jason Phang committed
29
30
31
        args = utils.simple_parse_args_string(arg_string)
        return cls(engine=args.get("engine", "davinci"))

Jason Phang's avatar
checkin  
Jason Phang committed
32
    def loglikelihood(self, context, continuation):
Jason Phang's avatar
Jason Phang committed
33
        import openai
34
35
36
37
38
39
        
        context_enc = self.tokenizer.encode(context)
        continuation_enc = self.tokenizer.encode(continuation)
        inp = (context_enc + continuation_enc)[-1024:]
        ctxlen = len(context_enc) - max(0, len(context_enc) + len(continuation_enc) - 1024)

Jason Phang's avatar
gpt3  
Jason Phang committed
40
41
        response = openai.Completion.create(
            engine=self.engine,
42
            prompt=inp,
Jason Phang's avatar
lib  
Jason Phang committed
43
            echo=True,
Jason Phang's avatar
gpt3  
Jason Phang committed
44
45
46
47
            max_tokens=0, temperature=0.0,
            logprobs=0,
        )
        logprobs = response.choices[0]["logprobs"]["token_logprobs"]
48
        continuation_logprobs = logprobs[ctxlen:]
Jason Phang's avatar
gpt3  
Jason Phang committed
49
        return sum(continuation_logprobs)
Jason Phang's avatar
Jason Phang committed
50