Commit 161d1e06 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix dummy batch when --max-tokens is small (fixes #347)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/366

Differential Revision: D13058513

Pulled By: myleott

fbshipit-source-id: a146d2cfb345d404775ed8d6b8e4a4ad4e7a33b4
parent 7e60d45b
......@@ -192,7 +192,7 @@ class LanguagePairDataset(FairseqDataset):
max_positions,
(self.max_source_positions, self.max_target_positions),
)
bsz = num_tokens // max(src_len, tgt_len)
bsz = max(num_tokens // max(src_len, tgt_len), 1)
return self.collater([
{
'id': i,
......
......@@ -153,7 +153,7 @@ class MonolingualDataset(FairseqDataset):
"""Return a dummy batch with a given number of tokens."""
if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len
bsz = max(num_tokens // tgt_len, 1)
target = self.vocab.dummy_sentence(tgt_len + 2)
source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment