Commit a12ab0a8 authored by VictorSanh's avatar VictorSanh Committed by Victor SANH
Browse files

update binarized_data

parent 4d6dfbd3
...@@ -13,14 +13,14 @@ ...@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Preprocessing script before training DistilBERT. Preprocessing script before distillation.
""" """
import argparse import argparse
import pickle import pickle
import random import random
import time import time
import numpy as np import numpy as np
from transformers import BertTokenizer, RobertaTokenizer from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
...@@ -32,7 +32,7 @@ def main(): ...@@ -32,7 +32,7 @@ def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).")
parser.add_argument('--file_path', type=str, default='data/dump.txt', parser.add_argument('--file_path', type=str, default='data/dump.txt',
help='The path to the data.') help='The path to the data.')
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta']) parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2'])
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased',
help="The tokenizer to use.") help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump', parser.add_argument('--dump_file', type=str, default='data/dump',
...@@ -43,10 +43,16 @@ def main(): ...@@ -43,10 +43,16 @@ def main():
logger.info(f'Loading Tokenizer ({args.tokenizer_name})') logger.info(f'Loading Tokenizer ({args.tokenizer_name})')
if args.tokenizer_type == 'bert': if args.tokenizer_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]`
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]`
elif args.tokenizer_type == 'roberta': elif args.tokenizer_type == 'roberta':
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `[CLS]` for bert, `<s>` for roberta bos = tokenizer.special_tokens_map['cls_token'] # `<s>`
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` for bert, `</s>` for roberta sep = tokenizer.special_tokens_map['sep_token'] # `</s>`
elif args.tokenizer_type == 'gpt2':
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>`
sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>`
logger.info(f'Loading text from {args.file_path}') logger.info(f'Loading text from {args.file_path}')
with open(args.file_path, 'r', encoding='utf8') as fp: with open(args.file_path, 'r', encoding='utf8') as fp:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment