".github/vscode:/vscode.git/clone" did not exist on "a9fd80336dd93ed671e4eb42242938f4d990a1bf"
Commit b7f1b050 authored by Neel Kant's avatar Neel Kant
Browse files

Lint whole repo

parent c99fa80c
...@@ -39,16 +39,13 @@ class RaceDataset(Dataset): ...@@ -39,16 +39,13 @@ class RaceDataset(Dataset):
print_rank_0(' >> total number of samples: {}'.format( print_rank_0(' >> total number of samples: {}'.format(
len(self.samples))) len(self.samples)))
def __len__(self): def __len__(self):
return len(self.samples) return len(self.samples)
def __getitem__(self, idx): def __getitem__(self, idx):
return self.samples[idx] return self.samples[idx]
def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length): def process_single_datapath(datapath, tokenizer, max_qa_length, max_seq_length):
"""Read in RACE files, combine, clean-up, tokenize, and convert to """Read in RACE files, combine, clean-up, tokenize, and convert to
samples.""" samples."""
......
...@@ -64,12 +64,12 @@ class _LMDataset(torch.utils.data.Dataset): ...@@ -64,12 +64,12 @@ class _LMDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
start_idx = idx * self.overalapping_eval start_idx = idx * self.overalapping_eval
end_idx = start_idx + self.seq_len end_idx = start_idx + self.seq_len
tokens = self.tokens[start_idx:end_idx+1] tokens = self.tokens[start_idx:end_idx + 1]
num_tokens = len(tokens) num_tokens = len(tokens)
pad_mask = [1]*num_tokens pad_mask = [1] * num_tokens
if num_tokens < self.seq_len+1: if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len+1-num_tokens) num_pad = (self.seq_len + 1 - num_tokens)
pad_mask += [0]*(num_pad) pad_mask += [0] * (num_pad)
tokens += [self.pad_idx] * num_pad tokens += [self.pad_idx] * num_pad
pad_mask = np.array(pad_mask[1:]) pad_mask = np.array(pad_mask[1:])
if self.overalapping_eval != self.seq_len and idx != 0: if self.overalapping_eval != self.seq_len and idx != 0:
...@@ -103,7 +103,7 @@ class _LambadaDataset(torch.utils.data.Dataset): ...@@ -103,7 +103,7 @@ class _LambadaDataset(torch.utils.data.Dataset):
last_token = text.split()[-1] last_token = text.split()[-1]
start_idx = text.rfind(last_token) start_idx = text.rfind(last_token)
beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip()) beginning_tokens = self.tokenizer.tokenize(text[:start_idx].strip())
last_token = self.tokenizer.tokenize(' '+last_token) last_token = self.tokenizer.tokenize(' ' + last_token)
return beginning_tokens, last_token return beginning_tokens, last_token
def __len__(self): def __len__(self):
...@@ -112,14 +112,14 @@ class _LambadaDataset(torch.utils.data.Dataset): ...@@ -112,14 +112,14 @@ class _LambadaDataset(torch.utils.data.Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
tokens = self.tokens[idx] tokens = self.tokens[idx]
num_tokens = len(tokens) num_tokens = len(tokens)
pad_mask = [0]*num_tokens pad_mask = [0] * num_tokens
labels = self.labels[idx] labels = self.labels[idx]
pad_mask += [1]*len(labels) pad_mask += [1] * len(labels)
tokens = tokens+labels tokens = tokens + labels
num_tokens = len(tokens) num_tokens = len(tokens)
if num_tokens < self.seq_len+1: if num_tokens < self.seq_len + 1:
num_pad = (self.seq_len+1-num_tokens) num_pad = (self.seq_len + 1 - num_tokens)
pad_mask += [0]*(num_pad) pad_mask += [0] * (num_pad)
tokens += [self.pad_idx] * num_pad tokens += [self.pad_idx] * num_pad
pad_mask = np.array(pad_mask[1:]) pad_mask = np.array(pad_mask[1:])
......
...@@ -23,21 +23,21 @@ def ptb_detokenizer(string): ...@@ -23,21 +23,21 @@ def ptb_detokenizer(string):
string = string.replace(" \n", "\n") string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n") string = string.replace("\n ", "\n")
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" N ","1 ") string = string.replace(" N ", "1 ")
string = string.replace("$ 1", "$1") string = string.replace("$ 1", "$1")
string = string.replace("# 1", "#1") string = string.replace("# 1", "#1")
return string return string
def wikitext_detokenizer(string): def wikitext_detokenizer(string):
#contractions # contractions
string = string.replace("s '", "s'") string = string.replace("s '", "s'")
string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string)
# number separators # number separators
string = string.replace(" @-@ ", "-") string = string.replace(" @-@ ", "-")
string = string.replace(" @,@ ", ",") string = string.replace(" @,@ ", ",")
string = string.replace(" @.@ ", ".") string = string.replace(" @.@ ", ".")
#punctuation # punctuation
string = string.replace(" : ", ": ") string = string.replace(" : ", ": ")
string = string.replace(" ; ", "; ") string = string.replace(" ; ", "; ")
string = string.replace(" . ", ". ") string = string.replace(" . ", ". ")
...@@ -54,7 +54,7 @@ def wikitext_detokenizer(string): ...@@ -54,7 +54,7 @@ def wikitext_detokenizer(string):
string = string.replace("= = = =", "====") string = string.replace("= = = =", "====")
string = string.replace("= = =", "===") string = string.replace("= = =", "===")
string = string.replace("= =", "==") string = string.replace("= =", "==")
string = string.replace(" "+chr(176)+" ", chr(176)) string = string.replace(" " + chr(176) + " ", chr(176))
string = string.replace(" \n", "\n") string = string.replace(" \n", "\n")
string = string.replace("\n ", "\n") string = string.replace("\n ", "\n")
string = string.replace(" N ", " 1 ") string = string.replace(" N ", " 1 ")
...@@ -78,5 +78,3 @@ def get_detokenizer(path): ...@@ -78,5 +78,3 @@ def get_detokenizer(path):
for key in _DETOKENIZERS.keys(): for key in _DETOKENIZERS.keys():
if key in path: if key in path:
return _DETOKENIZERS[key] return _DETOKENIZERS[key]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment