Commit 783ffd1e authored by lintangsutawika's avatar lintangsutawika
Browse files

merged conflict resolved

parents 82a21c36 53ff475b
...@@ -19,11 +19,6 @@ import transformers ...@@ -19,11 +19,6 @@ import transformers
from omegaconf import OmegaConf from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
<<<<<<< HEAD
=======
import transformers
>>>>>>> more pre-commit
from lm_eval.logger import eval_logger from lm_eval.logger import eval_logger
...@@ -422,34 +417,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None): ...@@ -422,34 +417,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size) return islice(raw_iterator, rank, limit, world_size)
<<<<<<< HEAD
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
=======
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"): def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
""" """
Method for padding a list of tensors given the maximum tensor Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in length in the batch. Used for batching inputs and continuations in
seq2seq models. seq2seq models.
>>>>>>> more pre-commit
""" """
assert ( assert (
padding_side == "left" or padding_side == "right" padding_side == "left" or padding_side == "right"
...@@ -490,16 +462,22 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="r ...@@ -490,16 +462,22 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="r
return torch.cat(tensors, dim=0) return torch.cat(tensors, dim=0)
<<<<<<< HEAD
# Multi-token stopping criteria
=======
def clear_torch_cache(): def clear_torch_cache():
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
# Multi-token stopping criteria # Multi-token stopping criteria
>>>>>>> more pre-commit
class MultiTokenEOSCriteria(transformers.StoppingCriteria): class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence.""" """Criteria to stop on the specified multi-token sequence."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment