"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4d6dfbd3762ac57e44d97b1c6d0243cfffd1880b"
Commit e83d9f1c authored by VictorSanh's avatar VictorSanh
Browse files

cleaning - change ' to " (black requirements)

parent ebba9e92
...@@ -114,17 +114,17 @@ class LmSeqsDataset(Dataset): ...@@ -114,17 +114,17 @@ class LmSeqsDataset(Dataset):
""" """
Remove sequences with a (too) high level of unknown tokens. Remove sequences with a (too) high level of unknown tokens.
""" """
if 'unk_token' not in self.params.special_tok_ids: if "unk_token" not in self.params.special_tok_ids:
return return
else: else:
unk_token_id = self.params.special_tok_ids['unk_token'] unk_token_id = self.params.special_tok_ids["unk_token"]
init_size = len(self) init_size = len(self)
unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids]) unk_occs = np.array([np.count_nonzero(a == unk_token_id) for a in self.token_ids])
indices = (unk_occs/self.lengths) < 0.5 indices = (unk_occs / self.lengths) < 0.5
self.token_ids = self.token_ids[indices] self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices] self.lengths = self.lengths[indices]
new_size = len(self) new_size = len(self)
logger.info(f'Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).') logger.info(f"Remove {init_size - new_size} sequences with a high level of unknown tokens (50%).")
def print_statistics(self): def print_statistics(self):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment