"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "9d9b872b66f9ab9b7b7c73f2c00985dd92c4121b"
Unverified Commit ce9724e1 authored by James Thomin's avatar James Thomin Committed by GitHub
Browse files

Fix bug in input check for LengthGroupSampler (#10783)

This commit fixes a bug in the LengthGroupSampler where if
model_input_name is not set, the default value is None instead of
"input_ids"
parent 5f19c07a
......@@ -497,7 +497,7 @@ class LengthGroupedSampler(Sampler):
self.batch_size = batch_size
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
if lengths is None:
if not isinstance(dataset[0], dict) or model_input_name not in dataset[0]:
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{self.model_input_name}' key."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment