Unverified Commit 567e24c9 authored by Charles Lovering's avatar Charles Lovering Committed by GitHub
Browse files

Merge pull request #10 from tttyuntian/master

Fixing stopping criteria & Modifying T0, T5 to BaseLM
parents e51880d8 3c390461
......@@ -121,6 +121,11 @@ class LM(abc.ABC):
class BaseLM(LM):
@property
@abstractmethod
def eot_token(self):
pass
@property
@abstractmethod
def eot_token_id(self):
......@@ -354,8 +359,15 @@ class BaseLM(LM):
isinstance(max_generation_length, int) or max_generation_length is None
)
until = [stopping_criteria]
if stopping_criteria is None:
until = [self.eot_token]
else:
until = [stopping_criteria]
primary_until = self.tok_encode(until[0])
if len(primary_until) == 0:
primary_until = torch.tensor([self.eot_token_id])
context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device)
......
......@@ -72,6 +72,10 @@ class HFLM(BaseLM):
# if gpus > 1:
# self.gpt2 = nn.DataParallel(self.gpt2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
......
......@@ -2,37 +2,36 @@ import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval.base import BaseLM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
import math
class T0LM(LM):
MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512
VOCAB_SIZE = 32100
EOT_TOKEN_ID = 1
class T0LM(BaseLM):
# MAX_GEN_TOKS = 256
# MAX_INP_LENGTH = 512
# VOCAB_SIZE = 32100
# EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t0', batch_size=1):
super().__init__()
if device:
self.device = torch.device(device)
self._device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.t0 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t0.eval()
if parallelize == "True":
print(parallelize)
self.t0.parallelize()
self.device = torch.device('cuda:0')
self._device = torch.device('cuda:0')
else:
self.t0.to(self.device)
self.t0.to(self._device)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH
# self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size)
......@@ -42,6 +41,53 @@ class T0LM(LM):
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self.tokenizer.model_max_length
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
@property
def batch_size(self):
# TODO: fix multi-gpu
return self._batch_size # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inputs_tok, targets_tok):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.t0(
**inputs_tok,
labels=targets_tok["input_ids"]
)
def loglikelihood(self, requests):
res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
......@@ -62,7 +108,7 @@ class T0LM(LM):
targets_tok = self.tokenizer(
list(targets),
max_length=self.MAX_GEN_TOKS,
max_length=self.max_gen_toks,
padding=True,
# truncation=True,
add_special_tokens=False,
......@@ -72,11 +118,7 @@ class T0LM(LM):
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad():
outputs = self.t0(
**inputs_tok,
labels=targets_tok["input_ids"]
)
outputs = self._model_call(inputs_tok, targets_tok)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
......@@ -103,9 +145,6 @@ class T0LM(LM):
res.append(answer)
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
......@@ -133,29 +172,11 @@ class T0LM(LM):
EOSCriteria(self.tokenizer.eos_token)
])
def greedy_until(self, requests):
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t0.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t0.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
......@@ -2,39 +2,44 @@ import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from lm_eval.base import LM
from lm_eval.base import BaseLM
from lm_eval import utils
from tqdm import tqdm
import numpy as np
import math
class T5LM(LM):
MAX_GEN_TOKS = 256
MAX_INP_LENGTH = 512
VOCAB_SIZE = 32128
EOT_TOKEN_ID = 1
def __init__(self, device='cuda', parallelize=False, pretrained='t5', batch_size=1):
class T5LM(BaseLM):
# MAX_GEN_TOKS = 256
# MAX_INP_LENGTH = 512
# VOCAB_SIZE = 32128
# EOT_TOKEN_ID = 1
def __init__(
self,
device='cuda',
parallelize=False,
pretrained='t5',
batch_size=1
):
super().__init__()
if device:
self.device = torch.device(device)
self._device = torch.device(device)
else:
self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(pretrained)
self._device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
self.t5 = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained)
self.t5.eval()
if parallelize == "True":
print(parallelize)
self.t5.parallelize()
self.device = torch.device('cuda:0')
self._device = torch.device('cuda:0')
else:
self.t5.to(self.device)
self.t5.to(self._device)
self.tokenizer = transformers.T5TokenizerFast.from_pretrained(pretrained)
self.max_length = self.MAX_INP_LENGTH
self.tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained)
# self.max_length = self.MAX_INP_LENGTH
self.batch_size = int(batch_size)
self._batch_size = int(batch_size)
@classmethod
def create_from_arg_string(cls, arg_string, additional_config={}):
......@@ -42,6 +47,53 @@ class T5LM(LM):
args2 = {k: v for k, v in additional_config.items() if v is not None}
return cls(**args, **args2)
@property
def eot_token(self):
return self.tokenizer.eos_token
@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
return self.tokenizer.eos_token_id
@property
def max_length(self):
return self.tokenizer.model_max_length
@property
def max_gen_toks(self):
return self.tokenizer.model_max_length
@property
def batch_size(self):
# TODO: fix multi-gpu
return self._batch_size # * gpus
@property
def device(self):
# TODO: fix multi-gpu
return self._device
def tok_encode(self, string: str):
return self.tokenizer.encode(string, add_special_tokens=False)
def tok_decode(self, tokens):
return self.tokenizer.decode(tokens)
def _model_call(self, inputs_tok, targets_tok):
"""
inps: a torch tensor of shape [batch, sequence]
the size of sequence may vary from call to call
returns: a torch tensor of shape [batch, sequence, vocab] with the
logits returned from the model
"""
with torch.no_grad():
return self.t5(
**inputs_tok,
labels=targets_tok["input_ids"]
)
def loglikelihood(self, requests):
res = []
for chunk in tqdm(utils.chunks(requests, self.batch_size), total=math.ceil(len(requests)/self.batch_size)):
......@@ -62,7 +114,7 @@ class T5LM(LM):
targets_tok = self.tokenizer(
list(targets),
max_length=self.MAX_GEN_TOKS,
max_length=self.max_gen_toks,
padding=True,
# truncation=True,
add_special_tokens=False,
......@@ -72,11 +124,7 @@ class T5LM(LM):
for key in targets_tok:
targets_tok[key] = targets_tok[key][:, -(self.max_length - 1) :]
with torch.no_grad():
outputs = self.t5(
**inputs_tok,
labels=targets_tok["input_ids"]
)
outputs = self._model_call(inputs_tok, targets_tok)
log_softmaxes = F.log_softmax(outputs.logits, dim=-1)
......@@ -103,9 +151,6 @@ class T5LM(LM):
res.append(answer)
return res
def loglikelihood_rolling(self, requests):
raise NotImplementedError
def _get_stopping_criteria(self, stopping_criteria_ids):
class MultitokenEOSCriteria(transformers.StoppingCriteria):
......@@ -133,29 +178,11 @@ class T5LM(LM):
EOSCriteria(self.tokenizer.eos_token)
])
def greedy_until(self, requests):
res = []
for context, until in tqdm(requests):
if isinstance(until, str): until = [until]
context_enc = self.tokenizer(context, return_tensors="pt").to(self.device).input_ids
stopping_criteria_ids = self.tokenizer.encode(until[0])
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
cont = self.t5.generate(
context_enc,
max_length=self.MAX_GEN_TOKS,
stopping_criteria=stopping_criteria,
do_sample=False
)
s = self.tokenizer.decode(cont[0].tolist())
self.cache_hook.add_partial("greedy_until", (context, until), s)
res.append(s)
return res
\ No newline at end of file
def _model_generate(self, context, max_length, stopping_criteria_ids):
stopping_criteria = self._get_stopping_criteria(stopping_criteria_ids)
return self.t5.generate(
context,
max_length=max_length,
stopping_criteria=stopping_criteria,
do_sample=False,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment