Commit 55a989e8 authored by Myle Ott's avatar Myle Ott
Browse files

Fix max_positions_valid in train.py

parent c6d6256b
...@@ -74,7 +74,10 @@ def main(): ...@@ -74,7 +74,10 @@ def main():
# The max number of positions can be different for train and valid # The max number of positions can be different for train and valid
# e.g., RNNs may support more positions at test time than seen in training # e.g., RNNs may support more positions at test time than seen in training
max_positions_train = (args.max_source_positions, args.max_target_positions) max_positions_train = (args.max_source_positions, args.max_target_positions)
max_positions_valid = (model.max_encoder_positions(), model.max_decoder_positions()) max_positions_valid = (
min(args.max_source_positions, model.max_encoder_positions()),
min(args.max_target_positions, model.max_decoder_positions())
)
# Start multiprocessing # Start multiprocessing
trainer = MultiprocessingTrainer(args, model, criterion) trainer = MultiprocessingTrainer(args, model, criterion)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment