"...text-generation-inference.git" did not exist on "d22b0c1fbef747f3c38f6424a7a6d4c90ed408c3"
Commit b3cd43b2 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

fix max_positions comparison

parent 2e507d3c
...@@ -239,7 +239,7 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -239,7 +239,7 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
""" """
def check_size(idx): def check_size(idx):
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) < max_positions return size_fn(idx) <= max_positions
else: else:
return all(a <= b for a, b in zip(size_fn(idx), max_positions)) return all(a <= b for a, b in zip(size_fn(idx), max_positions))
...@@ -250,7 +250,7 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -250,7 +250,7 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
raise Exception(( raise Exception((
'Size of sample #{} is invalid (={}) since max_positions={}, ' 'Size of sample #{} is invalid (={}) since max_positions={}, '
'skip this example with --skip-invalid-size-inputs-valid-test' 'skip this example with --skip-invalid-size-inputs-valid-test'
).format(idx, self.size(idx), max_positions)) ).format(idx, size_fn(idx), max_positions))
yield idx yield idx
if len(ignored) > 0: if len(ignored) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment