Unverified Commit 94ca7d2f authored by BOSEOP KIM's avatar BOSEOP KIM Committed by GitHub
Browse files

Fix type issue in using bucketing with Trainer (#18051)



* Fix type issue in using bucketing with Trainer

- Fix type issues in LengthGrouperSampler,
  DistributedLengthGroupedSampler

refs: #18003

* Change logging type in LengthGroupedSampler

- Change `logger.warning` to `logger.info`
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Change logging type in DistributedLengthGroupedSampler

- Change `logger.warning` to `logger.info`
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove adundant clause in LengthGroupedSampler

- Use `elif`
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove adundant clause in DistributedLengthGroupedSampler

- Use `elif`
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply black, isort to modified codes in the script
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 9bd39685
......@@ -558,6 +558,12 @@ class LengthGroupedSampler(Sampler):
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
elif isinstance(lengths, torch.Tensor):
logger.info(
"If lengths is a torch.Tensor, LengthGroupedSampler will be slow. Converting lengths to List[int]..."
)
lengths = lengths.tolist()
self.lengths = lengths
self.generator = generator
......@@ -614,6 +620,13 @@ class DistributedLengthGroupedSampler(DistributedSampler):
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in dataset]
elif isinstance(lengths, torch.Tensor):
logger.info(
"If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to"
" List[int]..."
)
lengths = lengths.tolist()
self.lengths = lengths
# If the dataset length is evenly divisible by # of replicas, then there
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment