"vscode:/vscode.git/clone" did not exist on "29daf498cd05e2e94ba2e4189e977afdec675a2b"
Unverified Commit 1fa02395 authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge pull request #565 from fattorib/seq2seq-refactor

[Refactor] Seq2Seq Models with Multi-Device Support
parents 9a8fee14 d3cfdcf6
...@@ -98,13 +98,16 @@ class TaskConfig(dict): ...@@ -98,13 +98,16 @@ class TaskConfig(dict):
if type(self.gold_alias) == str: if type(self.gold_alias) == str:
self.gold_alias = self.template_aliases + self.gold_alias self.gold_alias = self.template_aliases + self.gold_alias
if self.generation_kwargs or self.output_type == "greedy_until": if self.generation_kwargs:
assert ( assert (
self.output_type == "greedy_until" self.output_type == "greedy_until"
), "passed `generation_kwargs`, but not using a generation request type!" ), "passed `generation_kwargs`, but not using a generation request type!"
elif self.output_type == "greedy_until":
# ensure that we greedily generate in absence of explicit arguments otherwise # ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {"do_sample": False, "temperature": 0.0} self.generation_kwargs = {"do_sample": False, "temperature": 0.0}
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -123,6 +126,9 @@ class TaskConfig(dict): ...@@ -123,6 +126,9 @@ class TaskConfig(dict):
for k, v in list(cfg_dict.items()): for k, v in list(cfg_dict.items()):
if v is None: if v is None:
cfg_dict.pop(k) cfg_dict.pop(k)
elif isinstance(v, Callable):
# TODO: this should handle Promptsource template objects as a separate case?
cfg_dict[k] = str(v)
return cfg_dict return cfg_dict
...@@ -877,7 +883,9 @@ class ConfigurableTask(Task): ...@@ -877,7 +883,9 @@ class ConfigurableTask(Task):
for key, result in zip(self._metric_fn_list.keys(), results): for key, result in zip(self._metric_fn_list.keys(), results):
_dict = self._metric_fn_list[key].compute( _dict = self._metric_fn_list[key].compute(
references=[gold], predictions=[result], **self._metric_kwargs[key] references=[gold],
predictions=[result],
**self._metric_fn_kwargs[key],
) )
result_dict = {**result_dict, **_dict} result_dict = {**result_dict, **_dict}
......
...@@ -183,9 +183,7 @@ def evaluate( ...@@ -183,9 +183,7 @@ def evaluate(
# get lists of each type of request # get lists of each type of request
for task_name, task in task_dict.items(): for task_name, task in task_dict.items():
versions[task_name] = task.VERSION versions[task_name] = task.VERSION
configs[task_name] = dict( configs[task_name] = dict(task.dump_config())
task.dump_config()
) # TODO: don't access a private attribute here ; for non-YAML tasks handle this case
# deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order # deterministically shuffle docs and chop off the first `limit` because sometimes docs are in some kind of order
# task_docs = list(task_doc_func()) # task_docs = list(task_doc_func())
......
...@@ -2,5 +2,6 @@ from . import hf_causal ...@@ -2,5 +2,6 @@ from . import hf_causal
from . import openai_completions from . import openai_completions
from . import textsynth from . import textsynth
from . import dummy from . import dummy
from . import huggingface
# TODO: implement __all__ # TODO: implement __all__
...@@ -26,7 +26,6 @@ def anthropic_completion( ...@@ -26,7 +26,6 @@ def anthropic_completion(
max_tokens_to_sample=max_tokens_to_sample, max_tokens_to_sample=max_tokens_to_sample,
temperature=temperature, temperature=temperature,
) )
print(response)
return response["completion"] return response["completion"]
except RuntimeError: except RuntimeError:
# TODO: I don't actually know what error Anthropic raises when it times out # TODO: I don't actually know what error Anthropic raises when it times out
...@@ -99,7 +98,7 @@ class AnthropicLM(LM): ...@@ -99,7 +98,7 @@ class AnthropicLM(LM):
model=self.model, model=self.model,
prompt=inp, prompt=inp,
max_tokens_to_sample=self.max_gen_toks, max_tokens_to_sample=self.max_gen_toks,
temperature=0.0, temperature=0.0, # TODO: implement non-greedy sampling for Anthropic
stop=until, stop=until,
) )
res.append(response) res.append(response)
......
...@@ -11,12 +11,14 @@ from lm_eval.logger import eval_logger ...@@ -11,12 +11,14 @@ from lm_eval.logger import eval_logger
from lm_eval.api.model import LM from lm_eval.api.model import LM
from lm_eval.api.registry import register_model from lm_eval.api.registry import register_model
from lm_eval.utils import MultiTokenEOSCriteria, stop_sequences_criteria
from accelerate import Accelerator from accelerate import Accelerator
from typing import Optional, Union from typing import Optional, Union
@register_model("hf-causal") @register_model("hf-causal")
class HFLM(LM): class HFCausalLM(LM):
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -35,6 +37,7 @@ class HFLM(LM): ...@@ -35,6 +37,7 @@ class HFLM(LM):
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
gpus = torch.cuda.device_count() gpus = torch.cuda.device_count()
if gpus <= 1: if gpus <= 1:
if device: if device:
if device not in ["cuda", "cpu"]: if device not in ["cuda", "cpu"]:
...@@ -66,7 +69,7 @@ class HFLM(LM): ...@@ -66,7 +69,7 @@ class HFLM(LM):
).to(self.device) ).to(self.device)
self.model.eval() self.model.eval()
print(self.model.dtype) eval_logger.info(self.model.dtype)
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
pretrained if tokenizer is None else tokenizer, pretrained if tokenizer is None else tokenizer,
...@@ -90,6 +93,14 @@ class HFLM(LM): ...@@ -90,6 +93,14 @@ class HFLM(LM):
) )
self._rank = accelerator.local_process_index self._rank = accelerator.local_process_index
self._world_size = accelerator.num_processes self._world_size = accelerator.num_processes
# manually set model to use gpu, for case where many GPUs available but
# only seek to use one
self._device = (
torch.device(f"cuda:{accelerator.local_process_index}")
if torch.cuda.is_available()
else torch.device("cpu")
)
self.model.to(self.device)
else: else:
self.model = accelerator.prepare(self.model) self.model = accelerator.prepare(self.model)
self._device = torch.device(f"cuda:{accelerator.local_process_index}") self._device = torch.device(f"cuda:{accelerator.local_process_index}")
...@@ -157,27 +168,33 @@ class HFLM(LM): ...@@ -157,27 +168,33 @@ class HFLM(LM):
logits returned from the model logits returned from the model
""" """
with torch.no_grad(): with torch.no_grad():
return self.model(inps)[0] return self.model(inps).logits
def _model_generate(self, context, max_length, eos_token_id, **generation_kwargs): def _model_generate(self, context, max_length, stop, **generation_kwargs):
# we require users to pass do_sample=True explicitly # we require users to pass do_sample=True explicitly
# for non-greedy gen. This should be reevaluated when considering beam search. # for non-greedy gen. This should be reevaluated when considering beam search.
if "do_sample" not in generation_kwargs.keys(): if "do_sample" not in generation_kwargs.keys():
generation_kwargs["do_sample"] = False generation_kwargs["do_sample"] = False
# build stopping criteria
stopping_criteria = stop_sequences_criteria(
self.tokenizer, stop, 1, context.shape[0]
)
if hasattr(self, "accelerator"): if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self.model).generate( return self.accelerator.unwrap_model(self.model).generate(
context, context,
max_length=max_length, max_length=max_length,
pad_token_id=eos_token_id, stopping_criteria=stopping_criteria,
eos_token_id=eos_token_id, pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
else: else:
return self.model.generate( return self.model.generate(
context, context,
max_length=max_length, max_length=max_length,
pad_token_id=eos_token_id, stopping_criteria=stopping_criteria,
eos_token_id=eos_token_id, pad_token_id=self.eot_token_id,
use_cache=True,
**generation_kwargs, **generation_kwargs,
) )
...@@ -197,9 +214,6 @@ class HFLM(LM): ...@@ -197,9 +214,6 @@ class HFLM(LM):
return self._loglikelihood_tokens(new_reqs) return self._loglikelihood_tokens(new_reqs)
def loglikelihood_rolling(self, requests): def loglikelihood_rolling(self, requests):
# TODO: Implement caching once we've confirmed the perplexity implementation
# TODO: automatic batch size detection for vectorization
loglikelihoods = [] loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)): for (string,) in tqdm([req.args for req in requests], disable=(self.rank != 0)):
rolling_token_windows = list( rolling_token_windows = list(
...@@ -368,6 +382,7 @@ class HFLM(LM): ...@@ -368,6 +382,7 @@ class HFLM(LM):
re_ord = utils.Reorderer([req.args for req in requests], _collate) re_ord = utils.Reorderer([req.args for req in requests], _collate)
for context, gen_kwargs in tqdm(re_ord.get_reordered()): for context, gen_kwargs in tqdm(re_ord.get_reordered()):
until = None
if isinstance(gen_kwargs, dict): if isinstance(gen_kwargs, dict):
gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1 gen_kwargs = copy.deepcopy(gen_kwargs) # edge case for repeats > 1
if "until" in gen_kwargs.keys(): if "until" in gen_kwargs.keys():
...@@ -389,12 +404,13 @@ class HFLM(LM): ...@@ -389,12 +404,13 @@ class HFLM(LM):
else: else:
max_gen_toks = self.max_gen_toks max_gen_toks = self.max_gen_toks
try: primary_until = until[0]
(primary_until,) = self.tok_encode(until[0]) # try:
except Exception: # (primary_until,) = self.tok_encode(until[0])
# if our primary until would be multiple tokens long, we'll have errors. # except Exception:
# TODO: handling this better will let us stop generating earlier + often. # # if our primary until would be multiple tokens long, we'll have errors.
primary_until = self.eot_token_id # # TODO: handling this better will let us stop generating earlier + often.
# primary_until = self.eot_token_id
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[max_gen_toks - self.max_length :]] [self.tok_encode(context)[max_gen_toks - self.max_length :]]
...@@ -403,7 +419,7 @@ class HFLM(LM): ...@@ -403,7 +419,7 @@ class HFLM(LM):
cont = self._model_generate( cont = self._model_generate(
context=context_enc, context=context_enc,
max_length=context_enc.shape[1] + max_gen_toks, max_length=context_enc.shape[1] + max_gen_toks,
eos_token_id=primary_until, stop=primary_until,
**gen_kwargs, **gen_kwargs,
) )
......
This diff is collapsed.
group: group:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: "default" task: "boolq"
dataset_path: super_glue dataset_path: super_glue
dataset_name: boolq dataset_name: boolq
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:" doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[labe]}}" doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int. gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}" template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
group:
- super-glue-lm-eval-v1
task: "boolq-seq2seq"
dataset_path: super_glue
dataset_name: boolq
output_type: greedy_until
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}\nAnswer:"
doc_to_target: "{{answer_choices[label]}}"
gold_alias: "{{label}}" # this will be cast to an int.
template_aliases: "{% set answer_choices = ['no', 'yes'] %}"
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true
...@@ -14,6 +14,7 @@ from typing import List, Union ...@@ -14,6 +14,7 @@ from typing import List, Union
import gc import gc
import torch import torch
import transformers
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
...@@ -422,6 +423,51 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -422,6 +423,51 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
"""
assert (
padding_side == "left" or padding_side == "right"
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
# right-pad
tensors[i] = torch.cat(
[
tensor, # [seq]
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
],
dim=0,
).unsqueeze(0)
else:
# left-pad
tensors[i] = torch.cat(
[
torch.zeros(
max_length - tensor_len,
dtype=torch.long,
device=tensor.device,
), # [padding_length - seq]
tensor, # [seq]
],
dim=0,
).unsqueeze(0)
else:
tensors[i] = tensor.unsqueeze(0)
return torch.cat(tensors, dim=0)
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -435,3 +481,53 @@ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype: ...@@ -435,3 +481,53 @@ def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
else: else:
_torch_dtype = dtype _torch_dtype = dtype
return _torch_dtype return _torch_dtype
# Multi-token stopping criteria
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
def __init__(
self,
sequence: str,
tokenizer: transformers.PreTrainedTokenizer,
initial_decoder_input_length: int,
batch_size: int,
):
self.initial_decoder_input_length = initial_decoder_input_length
self.done_tracker = [False] * batch_size
self.sequence = sequence
self.sequence_ids = tokenizer.encode(sequence, add_special_tokens=False)
self.sequence_id_len = len(self.sequence_ids)
self.tokenizer = tokenizer
def __call__(self, input_ids, scores, **kwargs) -> bool:
# For efficiency, we compare the last n tokens where n is the number of tokens in the stop_sequence
lookback_ids_batch = input_ids[:, self.initial_decoder_input_length :][
:, -self.sequence_id_len :
]
lookback_tokens_batch = self.tokenizer.batch_decode(lookback_ids_batch)
for i, done in enumerate(self.done_tracker):
if not done:
self.done_tracker[i] = self.sequence in lookback_tokens_batch[i]
return False not in self.done_tracker
def stop_sequences_criteria(
tokenizer: transformers.PreTrainedTokenizer,
stop_sequences: List[str],
initial_decoder_input_length: int,
batch_size: int,
) -> transformers.StoppingCriteriaList:
return transformers.StoppingCriteriaList(
[
*[
MultiTokenEOSCriteria(
sequence, tokenizer, initial_decoder_input_length, batch_size
)
for sequence in stop_sequences
],
]
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment