Commit 1a6b31a8 authored by lintangsutawika's avatar lintangsutawika
Browse files

resolved conflict

parent 1d7d3de5
......@@ -11,12 +11,14 @@ from lm_eval.logger import eval_logger
from lm_eval.api.model import LM
from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator
from typing import Optional, Union
@register_model("hf-causal")
class HFLM(LM):
class HFCausalLM(LM):
def __init__(
self,
device="cuda",
......@@ -35,6 +37,7 @@ class HFLM(LM):
assert isinstance(batch_size, int)
gpus = torch.cuda.device_count()
if gpus <= 1:
if device:
if device not in ["cuda", "cpu"]:
......@@ -66,7 +69,7 @@ class HFLM(LM):
).to(self.device)
self.model.eval()
print(self.model.dtype)
eval_logger.info(self.model.dtype)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer,
......@@ -90,6 +93,14 @@ class HFLM(LM):
)
self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else:
self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
......@@ -157,27 +168,33 @@ class HFLM(LM):
logits returned from the model
"""
with torch.no_grad():
return self.model(inps)[0]
return self.model(inps).logits
def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs):
def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
else:
return self.model.generate(
context,
max_length=max_length,
pad_token_id=eos_token_id,
eos_token_id=eos_token_id,
stopping_criteria=stopping_criteria,
pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs,
)
......@@ -197,9 +214,6 @@ class HFLM(LM):
return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list(
......@@ -368,6 +382,7 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys():
......@@ -389,12 +404,13 @@ class HFLM(LM):
else:
max_gen_toks = self.max_gen_toks
try:
(primary_until,) = self.tok_encode(until[0])
except Exception:
# if our primary until would be multiple tokens long, we'll have errors.
# TODO: handling this better will let us stop generating earlier + often.
primary_until = self.eot_token_id
primary_until = until[0]
# try:
# (primary_until,) = self.tok_encode(until[0])
# except Exception:
# # if our primary until would be multiple tokens long, we'll have errors.
# # TODO: handling this better will let us stop generating earlier + often.
# primary_until = self.eot_token_id
context_enc = torch.tensor(
[self.tok_encode(context)[max_gen_toks - self.max_length :]]
......@@ -403,7 +419,7 @@ class HFLM(LM):
cont = self._model_generate(
context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until,
stop=primary_until,
**gen_kwargs,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment