Commit d4e075e3 authored by lintangsutawika's avatar lintangsutawika
Browse files

added truncation option

parent f96f5fad
......@@ -69,6 +69,7 @@ class HFLM(LM):
revision: Optional[str] = "main",
subfolder: Optional[str] = None,
tokenizer: Optional[str] = None,
truncation: Optional[bool] = False,
max_length: Optional[int] = None,
device: Optional[str] = "cuda",
dtype: Optional[Union[str, torch.dtype]] = "auto",
......@@ -243,6 +244,8 @@ class HFLM(LM):
use_fast=use_fast_tokenizer,
)
self.truncation = truncation
self.vocab_size = self.tokenizer.vocab_size
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
......@@ -422,7 +425,7 @@ class HFLM(LM):
return encoding
def tok_batch_encode(
self, strings: List[str], padding_side="left", left_truncate_len=None
self, strings: List[str], padding_side="left", left_truncate_len=None, truncation=False
):
# encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
old_padding_side = self.tokenizer.padding_side
......@@ -435,6 +438,7 @@ class HFLM(LM):
encoding = self.tokenizer(
strings,
truncation=truncation,
padding="longest",
return_tensors="pt",
add_special_tokens=add_special_tokens,
......@@ -859,7 +863,7 @@ class HFLM(LM):
# encode, pad, and truncate contexts for this batch
context_enc, attn_masks = self.tok_batch_encode(
contexts, left_truncate_len=max_ctx_len
contexts, left_truncate_len=max_ctx_len, truncation=self.truncation,
)
context_enc = context_enc.to(self.device)
attn_masks = attn_masks.to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment