Unverified Commit 4c08d72a authored by Stella Biderman's avatar Stella Biderman Committed by GitHub
Browse files

Merge pull request #582 from EleutherAI/fix-max-len

Fix seqlen issues for bloom, remove extraneous OPT tokenizer check
parents f862a118 fa43ab2e
...@@ -17,6 +17,9 @@ def _get_dtype( ...@@ -17,6 +17,9 @@ def _get_dtype(
class HFLM(BaseLM): class HFLM(BaseLM):
_DEFAULT_MAX_LENGTH = 2048
def __init__( def __init__(
self, self,
device="cuda", device="cuda",
...@@ -26,6 +29,7 @@ class HFLM(BaseLM): ...@@ -26,6 +29,7 @@ class HFLM(BaseLM):
subfolder=None, subfolder=None,
tokenizer=None, tokenizer=None,
batch_size=1, batch_size=1,
max_length=None,
load_in_8bit: Optional[bool] = False, load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False, trust_remote_code: Optional[bool] = False,
dtype: Optional[Union[str, torch.dtype]]="auto", dtype: Optional[Union[str, torch.dtype]]="auto",
...@@ -72,22 +76,14 @@ class HFLM(BaseLM): ...@@ -72,22 +76,14 @@ class HFLM(BaseLM):
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
if isinstance(
self.tokenizer, (transformers.GPT2Tokenizer, transformers.GPT2TokenizerFast)
):
assert self.tokenizer.encode("hello\n\nhello") == [
31373,
198,
198,
31373,
], self.tokenizer.encode("hello\n\nhello")
# setup for automatic batch size detection # setup for automatic batch size detection
if batch_size == "auto": if batch_size == "auto":
self.batch_size_per_gpu = batch_size self.batch_size_per_gpu = batch_size
else: else:
self.batch_size_per_gpu = int(batch_size) self.batch_size_per_gpu = int(batch_size)
self._max_length = max_length
@property @property
def eot_token_id(self): def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
...@@ -95,11 +91,18 @@ class HFLM(BaseLM): ...@@ -95,11 +91,18 @@ class HFLM(BaseLM):
@property @property
def max_length(self): def max_length(self):
try: if self._max_length: # if max length manually set, return it
return self.gpt2.config.n_ctx return self._max_length
except AttributeError: seqlen_config_attrs = ("n_positions", "max_position_embeddings", "n_ctx")
# gptneoconfig doesn't have n_ctx apparently for attr in seqlen_config_attrs:
return self.gpt2.config.max_position_embeddings if hasattr(self.gpt2.config, attr):
return getattr(self.gpt2.config, attr)
if hasattr(self.tokenizer, "model_max_length"):
if self.tokenizer.model_max_length == 1000000000000000019884624838656:
return self._DEFAULT_MAX_LENGTH
return self.tokenizer.model_max_length
return self._DEFAULT_MAX_LENGTH
@property @property
def max_gen_toks(self): def max_gen_toks(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment