Commit c55e8237 authored by Leo Gao's avatar Leo Gao
Browse files

Get rid of annoying logging

parent d5cd9655
...@@ -12,6 +12,7 @@ class GPT2LM(LM): ...@@ -12,6 +12,7 @@ class GPT2LM(LM):
self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device) self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(self.device)
self.gpt2.eval() self.gpt2.eval()
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
self.tokenizer.pad_token = "<|endoftext|>"
@classmethod @classmethod
def create_from_arg_string(cls, arg_string): def create_from_arg_string(cls, arg_string):
......
...@@ -38,6 +38,9 @@ class GPT3LM(LM): ...@@ -38,6 +38,9 @@ class GPT3LM(LM):
import openai import openai
self.engine = engine self.engine = engine
self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2') self.tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')
# to make the annoying "Using pad_token, but it is not set yet." error go away
self.tokenizer.pad_token = "<|endoftext|>"
self.truncate = truncate self.truncate = truncate
# Read from environment variable OPENAI_API_SECRET_KEY # Read from environment variable OPENAI_API_SECRET_KEY
...@@ -50,11 +53,12 @@ class GPT3LM(LM): ...@@ -50,11 +53,12 @@ class GPT3LM(LM):
def loglikelihood(self, requests): def loglikelihood(self, requests):
import openai import openai
for chunk in tqdm(utils.chunks(requests, self.REQ_CHUNK_SIZE)): res = []
for chunk in tqdm(list(utils.chunks(requests, self.REQ_CHUNK_SIZE))):
inps = [] inps = []
ctxlens = [] ctxlens = []
for context, continuation in chunk: for context, continuation in chunk:
print(context)
context_enc = self.tokenizer.encode(context) context_enc = self.tokenizer.encode(context)
continuation_enc = self.tokenizer.encode(continuation) continuation_enc = self.tokenizer.encode(continuation)
inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:] inp = (context_enc + continuation_enc)[-self.MAX_LENGTH:]
......
...@@ -4,9 +4,11 @@ import numpy as np ...@@ -4,9 +4,11 @@ import numpy as np
import random import random
import itertools import itertools
import collections import collections
import logging
from lm_eval import models, tasks, evaluator, base from lm_eval import models, tasks, evaluator, base
logging.getLogger("openai").setLevel(logging.WARNING)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment