Unverified Commit ce9724e1 authored by James Thomin's avatar James Thomin Committed by GitHub
Browse files

Fix bug in input check for LengthGroupSampler (#10783)

This commit fixes a bug in the LengthGroupSampler where if
model_input_name is not set, the default value is None instead of
"input_ids"
parent 5f19c07a
...@@ -497,7 +497,7 @@ class LengthGroupedSampler(Sampler): ...@@ -497,7 +497,7 @@ class LengthGroupedSampler(Sampler):
self.batch_size = batch_size self.batch_size = batch_size
self.model_input_name = model_input_name if model_input_name is not None else "input_ids" self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None: if lengths is None:
if not isinstance(dataset[0], dict) or model_input_name not in dataset[0]: if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
raise ValueError( raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an " "Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key." f"'{self.model_input_name}' key."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment