You need to sign in or sign up before continuing.
Commit 30ef667d authored by Angela Fan's avatar Angela Fan
Browse files

add model override argument from load_ensemble_for_inference at generation...

add model override argument from load_ensemble_for_inference at generation time, updating readme for stories
parent ff3db3cd
FAIR Sequence-to-Sequence Toolkit for Story Generation FAIR Sequence-to-Sequence Toolkit for Story Generation
The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset. The following commands provide an example of pre-processing data, training a model, and generating text for story generation with the WritingPrompts dataset.
The dataset can be downloaded like this: The dataset can be downloaded like this:
...@@ -8,7 +8,7 @@ The dataset can be downloaded like this: ...@@ -8,7 +8,7 @@ The dataset can be downloaded like this:
curl https://s3.amazonaws.com/fairseq-py/data/writingPrompts.tar.gz | tar xvzf - curl https://s3.amazonaws.com/fairseq-py/data/writingPrompts.tar.gz | tar xvzf -
``` ```
and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token. and contains a train, test, and valid split. The dataset is described here: https://arxiv.org/abs/1805.04833. We model only the first 1000 words of each story, including one newLine token.
Example usage: Example usage:
...@@ -26,5 +26,7 @@ $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip- ...@@ -26,5 +26,7 @@ $ python train.py data-bin/writingPrompts -a fconv_self_att_wp --lr 0.25 --clip-
# add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint # add the arguments: --pretrained True --pretrained-checkpoint path/to/checkpoint
# Generate: # Generate:
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 # Note: to load the pretrained model at generation time, you need to pass in a model-override argument to communicate to the fusion model at generation time where you have placed the pretrained checkpoint. By default, it will load the exact path of the fusion model's pretrained model from training time. You should use model-override if you have moved the pretrained model (or are using our provided models). If you are generating from a non-fusion model, the model-override argument is not necessary.
$ python generate.py data-bin/writingPrompts --path /path/to/trained/model/checkpoint_best.pt --batch-size 32 --beam 1 --sampling --sampling-topk 10 --sampling-temperature 0.8 --nbest 1 --model-overrides "{'pretrained_checkpoint':'/path/to/pretrained/model/checkpoint'}"
``` ```
...@@ -81,7 +81,7 @@ class FConvModelSelfAtt(FairseqModel): ...@@ -81,7 +81,7 @@ class FConvModelSelfAtt(FairseqModel):
trained_encoder, trained_decoder = None, None trained_encoder, trained_decoder = None, None
pretrained = eval(args.pretrained) pretrained = eval(args.pretrained)
if pretrained: if pretrained:
print("| Loading pretrained model") print("| loading pretrained model")
trained_model = utils.load_ensemble_for_inference( trained_model = utils.load_ensemble_for_inference(
# not actually for inference, but loads pretrained model parameters # not actually for inference, but loads pretrained model parameters
filenames=[args.pretrained_checkpoint], filenames=[args.pretrained_checkpoint],
......
...@@ -290,6 +290,8 @@ def add_generation_args(parser): ...@@ -290,6 +290,8 @@ def add_generation_args(parser):
help='sample from top K likely next words instead of all words') help='sample from top K likely next words instead of all words')
group.add_argument('--sampling-temperature', default=1, type=float, metavar='N', group.add_argument('--sampling-temperature', default=1, type=float, metavar='N',
help='temperature for random sampling') help='temperature for random sampling')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
help='a dictionary used to override model args at generation that were used during model training')
return group return group
......
...@@ -38,7 +38,7 @@ def main(args): ...@@ -38,7 +38,7 @@ def main(args):
# Load ensemble # Load ensemble
print('| loading model(s) from {}'.format(args.path)) print('| loading model(s) from {}'.format(args.path))
models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task) models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task, model_arg_overrides=eval(args.model_overrides))
# Optimize ensemble for generation # Optimize ensemble for generation
for model in models: for model in models:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment