Commit a468870f authored by thomwolf's avatar thomwolf
Browse files

refactoring generation

parent 07bc8efb
...@@ -57,8 +57,19 @@ class PretrainedConfig(object): ...@@ -57,8 +57,19 @@ class PretrainedConfig(object):
self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models self.torchscript = kwargs.pop('torchscript', False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop('pruned_heads', {})
# Is decoder is used in encoder-decoder models to differentiate encoder from decoder
self.is_decoder = kwargs.pop('is_decoder', False) self.is_decoder = kwargs.pop('is_decoder', False)
# Parameters for sequence generation
self.generate_length = kwargs.pop('generate_length', 10)
self.generate_do_sample = kwargs.pop('generate_do_sample', False)
self.generate_num_beams = kwargs.pop('generate_num_beams', 1)
self.generate_temperature = kwargs.pop('generate_temperature', 1.0)
self.generate_top_k = kwargs.pop('generate_top_k', 50)
self.generate_top_p = kwargs.pop('generate_top_p', 0.0)
self.generate_repetition_penalty = kwargs.pop('generate_repetition_penalty', 1.0)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it """ Save a configuration object to the directory `save_directory`, so that it
can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method. can be re-loaded using the :func:`~transformers.PretrainedConfig.from_pretrained` class method.
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment