"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "8dba1808852e7f5c08f91296006ec254cecdd1b1"
Commit 161d1e06 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix dummy batch when --max-tokens is small (fixes #347)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/366

Differential Revision: D13058513

Pulled By: myleott

fbshipit-source-id: a146d2cfb345d404775ed8d6b8e4a4ad4e7a33b4
parent 7e60d45b
...@@ -192,7 +192,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -192,7 +192,7 @@ class LanguagePairDataset(FairseqDataset):
max_positions, max_positions,
(self.max_source_positions, self.max_target_positions), (self.max_source_positions, self.max_target_positions),
) )
bsz = num_tokens // max(src_len, tgt_len) bsz = max(num_tokens // max(src_len, tgt_len), 1)
return self.collater([ return self.collater([
{ {
'id': i, 'id': i,
......
...@@ -153,7 +153,7 @@ class MonolingualDataset(FairseqDataset): ...@@ -153,7 +153,7 @@ class MonolingualDataset(FairseqDataset):
"""Return a dummy batch with a given number of tokens.""" """Return a dummy batch with a given number of tokens."""
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
tgt_len = min(tgt_len, max_positions) tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len bsz = max(num_tokens // tgt_len, 1)
target = self.vocab.dummy_sentence(tgt_len + 2) target = self.vocab.dummy_sentence(tgt_len + 2)
source, past_target, future_target = target[1:-1], target[2:], target[:-2] source, past_target, future_target = target[1:-1], target[2:], target[:-2]
source, target = self._make_source_target(source, past_target, future_target) source, target = self._make_source_target(source, past_target, future_target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment