Unverified Commit 02362e6a authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'big-refactor' into refactor-more-tasks

parents b1b5239d 0ba4ae15
...@@ -94,10 +94,10 @@ class MultiChoice: ...@@ -94,10 +94,10 @@ class MultiChoice:
def __contains__(self, values): def __contains__(self, values):
for value in values.split(","): for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0: if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value))
eval_logger.info(f"Available tasks to choose:") eval_logger.info(f"Available tasks to choose:")
for choice in self.choices: for choice in self.choices:
eval_logger.info(f" - {choice}") eval_logger.info(f" - {choice}")
raise ValueError("'{}' is not in task list".format(value))
return True return True
def __iter__(self): def __iter__(self):
...@@ -468,6 +468,7 @@ def pad_and_concat( ...@@ -468,6 +468,7 @@ def pad_and_concat(
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'" ), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors): for i, tensor in enumerate(tensors):
if len(tensor.shape) == 2:
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0] tensor_len = tensor.shape[0]
if tensor_len < max_length: if tensor_len < max_length:
......
...@@ -43,7 +43,6 @@ def parse_args(): ...@@ -43,7 +43,6 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None) parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true") parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False) parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
return parser.parse_args() return parser.parse_args()
...@@ -90,7 +89,6 @@ def main(): ...@@ -90,7 +89,6 @@ def main():
decontamination_ngrams_path=args.decontamination_ngrams_path, decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity, check_integrity=args.check_integrity,
write_out=args.write_out, write_out=args.write_out,
output_base_path=args.output_base_path,
) )
if results is not None: if results is not None:
......
...@@ -43,7 +43,7 @@ setuptools.setup( ...@@ -43,7 +43,7 @@ setuptools.setup(
"sacrebleu==1.5.0", "sacrebleu==1.5.0",
"scikit-learn>=0.24.1", "scikit-learn>=0.24.1",
"sqlitedict", "sqlitedict",
"torch>=1.7", "torch>=1.8",
"tqdm-multiprocess", "tqdm-multiprocess",
"transformers>=4.1", "transformers>=4.1",
"zstandard", "zstandard",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment