Unverified Commit 1f82a5d9 authored by Anthony MOI's avatar Anthony MOI
Browse files

Update for changes in tokenizers API

parent 734d29b0
...@@ -86,7 +86,7 @@ setup( ...@@ -86,7 +86,7 @@ setup(
packages=find_packages("src"), packages=find_packages("src"),
install_requires=[ install_requires=[
"numpy", "numpy",
"tokenizers", "tokenizers == 0.0.10",
# accessing files from S3 directly # accessing files from S3 directly
"boto3", "boto3",
# filesystem locks e.g. to prevent parallel downloads # filesystem locks e.g. to prevent parallel downloads
......
...@@ -583,12 +583,14 @@ class BertTokenizerFast(FastPreTrainedTokenizer): ...@@ -583,12 +583,14 @@ class BertTokenizerFast(FastPreTrainedTokenizer):
) )
) )
if max_length is not None: if max_length is not None:
self._tokenizer.with_truncation(max_length, stride, truncation_strategy) self._tokenizer.with_truncation(max_length,
stride=stride,
strategy=truncation_strategy)
self._tokenizer.with_padding( self._tokenizer.with_padding(
max_length if pad_to_max_length else None, max_length=max_length if pad_to_max_length else None,
self.padding_side, direction=self.padding_side,
self.pad_token_id, pad_id=self.pad_token_id,
self.pad_token_type_id, pad_type_id=self.pad_token_type_id,
self.pad_token, pad_token=self.pad_token,
) )
self._decoder = tk.decoders.WordPiece.new() self._decoder = tk.decoders.WordPiece.new()
...@@ -274,15 +274,17 @@ class GPT2TokenizerFast(FastPreTrainedTokenizer): ...@@ -274,15 +274,17 @@ class GPT2TokenizerFast(FastPreTrainedTokenizer):
self._tokenizer = tk.Tokenizer(tk.models.BPE.from_files(vocab_file, merges_file)) self._tokenizer = tk.Tokenizer(tk.models.BPE.from_files(vocab_file, merges_file))
self._update_special_tokens() self._update_special_tokens()
self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space)) self._tokenizer.with_pre_tokenizer(tk.pre_tokenizers.ByteLevel.new(add_prefix_space=add_prefix_space))
self._tokenizer.with_decoder(tk.decoders.ByteLevel.new()) self._tokenizer.with_decoder(tk.decoders.ByteLevel.new())
if max_length: if max_length:
self._tokenizer.with_truncation(max_length, stride, truncation_strategy) self._tokenizer.with_truncation(max_length,
stride=stride,
strategy=truncation_strategy)
self._tokenizer.with_padding( self._tokenizer.with_padding(
max_length if pad_to_max_length else None, max_length=max_length if pad_to_max_length else None,
self.padding_side, direction=self.padding_side,
self.pad_token_id if self.pad_token_id is not None else 0, pad_id=self.pad_token_id if self.pad_token_id is not None else 0,
self.pad_token_type_id, pad_type_id=self.pad_token_type_id,
self.pad_token if self.pad_token is not None else "", pad_token=self.pad_token if self.pad_token is not None else "",
) )
self._decoder = tk.decoders.ByteLevel.new() self._decoder = tk.decoders.ByteLevel.new()
...@@ -1430,10 +1430,10 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer): ...@@ -1430,10 +1430,10 @@ class FastPreTrainedTokenizer(PreTrainedTokenizer):
@property @property
def vocab_size(self): def vocab_size(self):
return self.tokenizer.get_vocab_size(False) return self.tokenizer.get_vocab_size(with_added_tokens=False)
def __len__(self): def __len__(self):
return self.tokenizer.get_vocab_size(True) return self.tokenizer.get_vocab_size(with_added_tokens=True)
def _update_special_tokens(self): def _update_special_tokens(self):
self.tokenizer.add_special_tokens(self.all_special_tokens) self.tokenizer.add_special_tokens(self.all_special_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment