gpt3.py 1.57 KB
Newer Older
Jason Phang's avatar
gpt3  
Jason Phang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import os
import openai
import transformers
from ..base import LM
from .. import utils
from . import MODEL_REGISTRY


@MODEL_REGISTRY.register("gpt3")
class GPT3LM(LM):
    def __init__(self, engine):
        self.engine = engine
        self.tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
        # Read from environment variable OPENAI_API_SECRET_KEY
        openai.api_key = os.environ["OPENAI_API_SECRET_KEY"]

    @classmethod
    def create_from_args(cls, arg_string):
        args = utils.simple_parse_args_string(arg_string)
        return cls(engine=args.get("engine", "davinci"))

    def generate(self, context, max_gen_length):
        response = openai.Completion.create(
            engine=self.engine,
            prompt=context,
            max_tokens=max_gen_length,
            temperature=0.0,
        )
        return response.choices[0]["text"]

    def logprob_of(self, context, continuation):
        full_text = context + continuation
        full_text_length = len(self.tokenizer.tokenize(full_text))
        context_length = len(self.tokenizer.tokenize(context))
        continuation_length = len(self.tokenizer.tokenize(continuation))
        assert full_text_length == context_length + continuation_length
        response = openai.Completion.create(
            engine=self.engine,
            prompt=full_text,
            max_tokens=0, temperature=0.0,
            logprobs=0,
        )
        logprobs = response.choices[0]["logprobs"]["token_logprobs"]
        continuation_logprobs = logprobs[-continuation_length:]
        return sum(continuation_logprobs)