Unverified Commit 02362e6a authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

Merge branch 'big-refactor' into refactor-more-tasks

parents b1b5239d 0ba4ae15
......@@ -94,10 +94,10 @@ class MultiChoice:
def __contains__(self, values):
for value in values.split(","):
if len(fnmatch.filter(self.choices, value)) == 0:
eval_logger.warning("{} is not in task list.".format(value))
eval_logger.info(f"Available tasks to choose:")
for choice in self.choices:
eval_logger.info(f" - {choice}")
raise ValueError("'{}' is not in task list".format(value))
return True
def __iter__(self):
......@@ -468,7 +468,8 @@ def pad_and_concat(
), f"Unrecognized padding type: '{padding_side}' not 'left' or 'right'"
for i, tensor in enumerate(tensors):
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
if len(tensor.shape) == 2:
tensor = tensor.squeeze(0) # squeeze, in case passed [1, seq] size
tensor_len = tensor.shape[0]
if tensor_len < max_length:
if padding_side == "right":
......
......@@ -43,7 +43,6 @@ def parse_args():
parser.add_argument("--decontamination_ngrams_path", default=None)
parser.add_argument("--check_integrity", action="store_true")
parser.add_argument("--write_out", action="store_true", default=False)
parser.add_argument("--output_base_path", type=str, default=None)
return parser.parse_args()
......@@ -90,7 +89,6 @@ def main():
decontamination_ngrams_path=args.decontamination_ngrams_path,
check_integrity=args.check_integrity,
write_out=args.write_out,
output_base_path=args.output_base_path,
)
if results is not None:
......
......@@ -43,7 +43,7 @@ setuptools.setup(
"sacrebleu==1.5.0",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.7",
"torch>=1.8",
"tqdm-multiprocess",
"transformers>=4.1",
"zstandard",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment