Commit 783ffd1e authored by lintangsutawika's avatar lintangsutawika
Browse files

merged conflict resolved

parents 82a21c36 53ff475b
......@@ -19,11 +19,6 @@ import transformers
from omegaconf import OmegaConf
from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice
<<<<<<< HEAD
=======
import transformers
>>>>>>> more pre-commit
from lm_eval.logger import eval_logger
......@@ -422,34 +417,11 @@ def create_iterator(raw_iterator, rank, world_size, limit=None):
return islice(raw_iterator, rank, limit, world_size)
<<<<<<< HEAD
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
def pad_and_concat(max_length:int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
=======
def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="right"):
"""
Method for padding a list of tensors given the maximum tensor
length in the batch. Used for batching inputs and continuations in
seq2seq models.
>>>>>>> more pre-commit
"""
assert (
padding_side == "left" or padding_side == "right"
......@@ -490,16 +462,22 @@ def pad_and_concat(max_length: int, tensors: List[torch.Tensor], padding_side="r
return torch.cat(tensors, dim=0)
<<<<<<< HEAD
# Multi-token stopping criteria
=======
def clear_torch_cache():
gc.collect()
torch.cuda.empty_cache()
def get_dtype(dtype: Union[str, torch.dtype]) -> torch.dtype:
"""Converts `dtype` from `str` to torch.dtype when possible. Does not use an instantiated HF AutoConfig"""
if isinstance(dtype, str) and dtype != "auto":
# Convert `str` args torch dtype: `float16` -> `torch.float16`
_torch_dtype = getattr(torch, dtype)
else:
_torch_dtype = dtype
return _torch_dtype
# Multi-token stopping criteria
>>>>>>> more pre-commit
class MultiTokenEOSCriteria(transformers.StoppingCriteria):
"""Criteria to stop on the specified multi-token sequence."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment