Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
......@@ -13,42 +13,44 @@
# See the License for the specific language governing permissions and
# limitations under the License
""" Tokenization classes for XLM-RoBERTa model."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
from shutil import copyfile
import sentencepiece as spm
from transformers.tokenization_utils import PreTrainedTokenizer
from .tokenization_xlnet import SPIECE_UNDERLINE
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'sentencepiece.bpe.model'}
VOCAB_FILES_NAMES = {"vocab_file": "sentencepiece.bpe.model"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlm-roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-sentencepiece.bpe.model",
'xlm-roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll02-dutch': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll02-spanish': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll03-english': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model",
'xlm-roberta-large-finetuned-conll03-german': "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
"vocab_file": {
"xlm-roberta-base": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-base-sentencepiece.bpe.model",
"xlm-roberta-large": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-sentencepiece.bpe.model",
"xlm-roberta-large-finetuned-conll02-dutch": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-dutch-sentencepiece.bpe.model",
"xlm-roberta-large-finetuned-conll02-spanish": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll02-spanish-sentencepiece.bpe.model",
"xlm-roberta-large-finetuned-conll03-english": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-english-sentencepiece.bpe.model",
"xlm-roberta-large-finetuned-conll03-german": "https://s3.amazonaws.com/models.huggingface.co/bert/xlm-roberta-large-finetuned-conll03-german-sentencepiece.bpe.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlm-roberta-base': 512,
'xlm-roberta-large': 512,
'xlm-roberta-large-finetuned-conll02-dutch': 512,
'xlm-roberta-large-finetuned-conll02-spanish': 512,
'xlm-roberta-large-finetuned-conll03-english': 512,
'xlm-roberta-large-finetuned-conll03-german': 512,
"xlm-roberta-base": 512,
"xlm-roberta-large": 512,
"xlm-roberta-large-finetuned-conll02-dutch": 512,
"xlm-roberta-large-finetuned-conll02-spanish": 512,
"xlm-roberta-large-finetuned-conll03-english": 512,
"xlm-roberta-large-finetuned-conll03-german": 512,
}
class XLMRobertaTokenizer(PreTrainedTokenizer):
"""
Adapted from RobertaTokenizer and XLNetTokenizer
......@@ -56,17 +58,33 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file, bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", pad_token='<pad>', mask_token='<mask>',
**kwargs):
super(XLMRobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
def __init__(
self,
vocab_file,
bos_token="<s>",
eos_token="</s>",
sep_token="</s>",
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
**kwargs
):
super(XLMRobertaTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
cls_token=cls_token,
pad_token=pad_token,
mask_token=mask_token,
**kwargs)
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 4 # take into account special tokens
self.sp_model = spm.SentencePieceProcessor()
......@@ -85,7 +103,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
# The first "real" token "," has position 4 in the original fairseq vocab and position 3 in the spm vocab
self.fairseq_offset = 1
self.fairseq_tokens_to_ids['<mask>'] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_tokens_to_ids["<mask>"] = len(self.sp_model) + len(self.fairseq_tokens_to_ids)
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......@@ -119,8 +137,10 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is None:
......@@ -164,7 +184,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def save_vocabulary(self, save_directory):
......@@ -174,7 +194,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
......
......@@ -13,36 +13,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization classes for XLNet model."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import os
import unicodedata
from shutil import copyfile
import unicodedata
import six
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {'vocab_file': 'spiece.model'}
VOCAB_FILES_NAMES = {"vocab_file": "spiece.model"}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'xlnet-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
'xlnet-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
"vocab_file": {
"xlnet-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-base-cased-spiece.model",
"xlnet-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/xlnet-large-cased-spiece.model",
}
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'xlnet-base-cased': None,
'xlnet-large-cased': None,
"xlnet-base-cased": None,
"xlnet-large-cased": None,
}
SPIECE_UNDERLINE = u'▁'
SPIECE_UNDERLINE = "▁"
# Segments (not really needed)
SEG_ID_A = 0
......@@ -51,27 +50,46 @@ SEG_ID_CLS = 2
SEG_ID_SEP = 3
SEG_ID_PAD = 4
class XLNetTokenizer(PreTrainedTokenizer):
"""
SentencePiece based tokenizer. Peculiarities:
- requires `SentencePiece <https://github.com/google/sentencepiece>`_
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
padding_side = "left"
def __init__(self, vocab_file,
do_lower_case=False, remove_space=True, keep_accents=False,
bos_token="<s>", eos_token="</s>", unk_token="<unk>", sep_token="<sep>",
pad_token="<pad>", cls_token="<cls>", mask_token="<mask>",
additional_special_tokens=["<eop>", "<eod>"], **kwargs):
super(XLNetTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token,
unk_token=unk_token, sep_token=sep_token,
pad_token=pad_token, cls_token=cls_token,
mask_token=mask_token, additional_special_tokens=
additional_special_tokens, **kwargs)
def __init__(
self,
vocab_file,
do_lower_case=False,
remove_space=True,
keep_accents=False,
bos_token="<s>",
eos_token="</s>",
unk_token="<unk>",
sep_token="<sep>",
pad_token="<pad>",
cls_token="<cls>",
mask_token="<mask>",
additional_special_tokens=["<eop>", "<eod>"],
**kwargs
):
super(XLNetTokenizer, self).__init__(
bos_token=bos_token,
eos_token=eos_token,
unk_token=unk_token,
sep_token=sep_token,
pad_token=pad_token,
cls_token=cls_token,
mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs
)
self.max_len_single_sentence = self.max_len - 2 # take into account special tokens
self.max_len_sentences_pair = self.max_len - 3 # take into account special tokens
......@@ -80,8 +98,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
logger.warning(
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.do_lower_case = do_lower_case
self.remove_space = remove_space
......@@ -105,24 +125,26 @@ class XLNetTokenizer(PreTrainedTokenizer):
try:
import sentencepiece as spm
except ImportError:
logger.warning("You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece")
logger.warning(
"You need to install SentencePiece to use XLNetTokenizer: https://github.com/google/sentencepiece"
"pip install sentencepiece"
)
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def preprocess_text(self, inputs):
if self.remove_space:
outputs = ' '.join(inputs.strip().split())
outputs = " ".join(inputs.strip().split())
else:
outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode('utf-8')
outputs = outputs.decode("utf-8")
if not self.keep_accents:
outputs = unicodedata.normalize('NFKD', outputs)
outputs = ''.join([c for c in outputs if not unicodedata.combining(c)])
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
if self.do_lower_case:
outputs = outputs.lower()
......@@ -134,8 +156,8 @@ class XLNetTokenizer(PreTrainedTokenizer):
"""
text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode):
text = text.encode('utf-8')
if six.PY2 and isinstance(text, unicode): # noqa: F821
text = text.encode("utf-8")
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
......@@ -143,9 +165,8 @@ class XLNetTokenizer(PreTrainedTokenizer):
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == str(',') and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(
piece[:-1].replace(SPIECE_UNDERLINE, ''))
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = self.sp_model.EncodeAsPieces(piece[:-1].replace(SPIECE_UNDERLINE, ""))
if piece[0] != SPIECE_UNDERLINE and cur_pieces[0][0] == SPIECE_UNDERLINE:
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
......@@ -161,7 +182,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
ret_pieces = []
for piece in new_pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
piece = piece.decode("utf-8")
ret_pieces.append(piece)
new_pieces = ret_pieces
......@@ -175,12 +196,12 @@ class XLNetTokenizer(PreTrainedTokenizer):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
token = self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8')
token = token.decode("utf-8")
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string."""
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
out_string = "".join(tokens).replace(SPIECE_UNDERLINE, " ").strip()
return out_string
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
......@@ -215,8 +236,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError("You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model.")
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
......@@ -247,7 +270,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
out_vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES["vocab_file"])
if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
copyfile(self.vocab_file, out_vocab_file)
......
''' Script for downloading all GLUE data.
""" Script for downloading all GLUE data.
Original source: https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
Note: for legal reasons, we are unable to host MRPC.
......@@ -16,31 +16,33 @@ rm MSRParaphraseCorpus.msi
1/30/19: It looks like SentEval is no longer hosting their extracted and tokenized MRPC data, so you'll need to download the data from the original source for now.
2/11/19: It looks like SentEval actually *is* hosting the extracted data. Hooray!
'''
"""
import argparse
import os
import sys
import shutil
import argparse
import tempfile
import urllib.request
import zipfile
TASKS = ["CoLA", "SST", "MRPC", "QQP", "STS", "MNLI", "SNLI", "QNLI", "RTE", "WNLI", "diagnostic"]
TASK2PATH = {"CoLA":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4',
"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
"MRPC":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc',
"QQP":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5',
"STS":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5',
"MNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce',
"SNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df',
"QNLI": 'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601',
"RTE":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb',
"WNLI":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf',
"diagnostic":'https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D'}
MRPC_TRAIN = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt'
MRPC_TEST = 'https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt'
TASK2PATH = {
"CoLA": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FCoLA.zip?alt=media&token=46d5e637-3411-4188-bc44-5809b5bfb5f4",
"SST": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8",
"MRPC": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2Fmrpc_dev_ids.tsv?alt=media&token=ec5c0836-31d5-48f4-b431-7480817f1adc",
"QQP": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQQP.zip?alt=media&token=700c6acf-160d-4d89-81d1-de4191d02cb5",
"STS": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSTS-B.zip?alt=media&token=bddb94a7-8706-4e0d-a694-1109e12273b5",
"MNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FMNLI.zip?alt=media&token=50329ea1-e339-40e2-809c-10c40afff3ce",
"SNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSNLI.zip?alt=media&token=4afcfbb2-ff0c-4b2d-a09a-dbf07926f4df",
"QNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FQNLIv2.zip?alt=media&token=6fdcf570-0fc5-4631-8456-9505272d1601",
"RTE": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FRTE.zip?alt=media&token=5efa7e85-a0bb-4f19-8ea2-9e1840f077fb",
"WNLI": "https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FWNLI.zip?alt=media&token=068ad0a0-ded7-4bd7-99a5-5e00222e0faf",
"diagnostic": "https://storage.googleapis.com/mtl-sentence-representations.appspot.com/tsvsWithoutLabels%2FAX.tsv?GoogleAccessId=firebase-adminsdk-0khhl@mtl-sentence-representations.iam.gserviceaccount.com&Expires=2498860800&Signature=DuQ2CSPt2Yfre0C%2BiISrVYrIFaZH1Lc7hBVZDD4ZyR7fZYOMNOUGpi8QxBmTNOrNPjR3z1cggo7WXFfrgECP6FBJSsURv8Ybrue8Ypt%2FTPxbuJ0Xc2FhDi%2BarnecCBFO77RSbfuz%2Bs95hRrYhTnByqu3U%2FYZPaj3tZt5QdfpH2IUROY8LiBXoXS46LE%2FgOQc%2FKN%2BA9SoscRDYsnxHfG0IjXGwHN%2Bf88q6hOmAxeNPx6moDulUF6XMUAaXCSFU%2BnRO2RDL9CapWxj%2BDl7syNyHhB7987hZ80B%2FwFkQ3MEs8auvt5XW1%2Bd4aCU7ytgM69r8JDCwibfhZxpaa4gd50QXQ%3D%3D",
}
MRPC_TRAIN = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt"
MRPC_TEST = "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt"
def download_and_extract(task, data_dir):
print("Downloading and extracting %s..." % task)
......@@ -51,6 +53,7 @@ def download_and_extract(task, data_dir):
os.remove(data_file)
print("\tCompleted!")
def format_mrpc(data_dir, path_to_data):
print("Processing MRPC...")
mrpc_dir = os.path.join(data_dir, "MRPC")
......@@ -72,30 +75,32 @@ def format_mrpc(data_dir, path_to_data):
dev_ids = []
with open(os.path.join(mrpc_dir, "dev_ids.tsv"), encoding="utf8") as ids_fh:
for row in ids_fh:
dev_ids.append(row.strip().split('\t'))
dev_ids.append(row.strip().split("\t"))
with open(mrpc_train_file, encoding="utf8") as data_fh, \
open(os.path.join(mrpc_dir, "train.tsv"), 'w', encoding="utf8") as train_fh, \
open(os.path.join(mrpc_dir, "dev.tsv"), 'w', encoding="utf8") as dev_fh:
with open(mrpc_train_file, encoding="utf8") as data_fh, open(
os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8"
) as train_fh, open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh:
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
for row in data_fh:
label, id1, id2, s1, s2 = row.strip().split('\t')
label, id1, id2, s1, s2 = row.strip().split("\t")
if [id1, id2] in dev_ids:
dev_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))
with open(mrpc_test_file, encoding="utf8") as data_fh, \
open(os.path.join(mrpc_dir, "test.tsv"), 'w', encoding="utf8") as test_fh:
with open(mrpc_test_file, encoding="utf8") as data_fh, open(
os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8"
) as test_fh:
header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
label, id1, id2, s1, s2 = row.strip().split('\t')
label, id1, id2, s1, s2 = row.strip().split("\t")
test_fh.write("%d\t%s\t%s\t%s\t%s\n" % (idx, id1, id2, s1, s2))
print("\tCompleted!")
def download_diagnostic(data_dir):
print("Downloading and extracting diagnostic...")
if not os.path.isdir(os.path.join(data_dir, "diagnostic")):
......@@ -105,8 +110,9 @@ def download_diagnostic(data_dir):
print("\tCompleted!")
return
def get_tasks(task_names):
task_names = task_names.split(',')
task_names = task_names.split(",")
if "all" in task_names:
tasks = TASKS
else:
......@@ -116,13 +122,19 @@ def get_tasks(task_names):
tasks.append(task_name)
return tasks
def main(arguments):
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', help='directory to save data to', type=str, default='glue_data')
parser.add_argument('--tasks', help='tasks to download data for as a comma separated string',
type=str, default='all')
parser.add_argument('--path_to_mrpc', help='path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt',
type=str, default='')
parser.add_argument("--data_dir", help="directory to save data to", type=str, default="glue_data")
parser.add_argument(
"--tasks", help="tasks to download data for as a comma separated string", type=str, default="all"
)
parser.add_argument(
"--path_to_mrpc",
help="path to directory containing extracted MRPC data, msr_paraphrase_train.txt and msr_paraphrase_text.txt",
type=str,
default="",
)
args = parser.parse_args(arguments)
if not os.path.isdir(args.data_dir):
......@@ -130,13 +142,13 @@ def main(arguments):
tasks = get_tasks(args.tasks)
for task in tasks:
if task == 'MRPC':
if task == "MRPC":
format_mrpc(args.data_dir, args.path_to_mrpc)
elif task == 'diagnostic':
elif task == "diagnostic":
download_diagnostic(args.data_dir)
else:
download_and_extract(task, args.data_dir)
if __name__ == '__main__':
if __name__ == "__main__":
sys.exit(main(sys.argv[1:]))
......@@ -43,7 +43,7 @@ def scan_code_for_links(source):
""" Scans the file to find links using a regular expression.
Returns a list of links.
"""
with open(source, 'r') as content:
with open(source, "r") as content:
content = content.read()
raw_links = re.findall(REGEXP_FIND_S3_LINKS, content)
links = [prefix + suffix for _, prefix, suffix in raw_links]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment