Unverified Commit 3de07473 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: add generation config class (#20218)


Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8503cc75
......@@ -18,6 +18,14 @@ Each framework has a generate method for auto-regressive text generation impleme
- TensorFlow [`~generation.TFGenerationMixin.generate`] is implemented in [`~generation.TFGenerationMixin`].
- Flax/JAX [`~generation.FlaxGenerationMixin.generate`] is implemented in [`~generation.FlaxGenerationMixin`].
<!--- TODO: add a brief description of GenerationConfig (with examples) when it becomes usable with generate --->
## GenerationConfig
[[autodoc]] generation.GenerationConfig
- from_pretrained
- save_pretrained
## GenerationMixin
[[autodoc]] generation.GenerationMixin
......
......@@ -96,7 +96,7 @@ _import_structure = {
"feature_extraction_sequence_utils": ["SequenceFeatureExtractor"],
"feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"],
"file_utils": [],
"generation": [],
"generation": ["GenerationConfig"],
"hf_argparser": ["HfArgumentParser"],
"integrations": [
"is_clearml_available",
......@@ -3258,6 +3258,9 @@ if TYPE_CHECKING:
# Feature Extractor
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
# Generation
from .generation import GenerationConfig
from .hf_argparser import HfArgumentParser
# Integrations
......
......@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_flax_available, is_tf_available, is_torch_available
_import_structure = {}
_import_structure = {"configuration_utils": ["GenerationConfig"]}
try:
......@@ -149,6 +149,8 @@ else:
]
if TYPE_CHECKING:
from .configuration_utils import GenerationConfig
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
......
This diff is collapsed.
......@@ -178,6 +178,7 @@ SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
CONFIG_NAME = "config.json"
FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME
GENERATION_CONFIG_NAME = "generation_config.json"
MODEL_CARD_NAME = "modelcard.json"
SENTENCEPIECE_UNDERLINE = "▁"
......
# coding=utf-8
# Copyright 2022 The HuggingFace Team Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a clone of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from parameterized import parameterized
from transformers.generation import GenerationConfig
class LogitsProcessorTest(unittest.TestCase):
@parameterized.expand([(None,), ("foo.json",)])
def test_save_load_config(self, config_name):
config = GenerationConfig(
do_sample=True,
temperature=0.7,
length_penalty=1.0,
bad_words_ids=[[1, 2, 3], [4, 5]],
)
with tempfile.TemporaryDirectory() as tmp_dir:
config.save_pretrained(tmp_dir, config_name=config_name)
loaded_config = GenerationConfig.from_pretrained(tmp_dir, config_name=config_name)
# Checks parameters that were specified
self.assertEqual(loaded_config.do_sample, True)
self.assertEqual(loaded_config.temperature, 0.7)
self.assertEqual(loaded_config.length_penalty, 1.0)
self.assertEqual(loaded_config.bad_words_ids, [[1, 2, 3], [4, 5]])
# Checks parameters that were not specified (defaults)
self.assertEqual(loaded_config.top_k, 50)
self.assertEqual(loaded_config.max_length, 20)
self.assertEqual(loaded_config.max_time, None)
......@@ -12,8 +12,9 @@ docs/source/en/model_doc/byt5.mdx
docs/source/en/model_doc/tapex.mdx
docs/source/en/model_doc/donut.mdx
docs/source/en/model_doc/encoder-decoder.mdx
src/transformers/generation/utils.py
src/transformers/generation/configuration_utils.py
src/transformers/generation/tf_utils.py
src/transformers/generation/utils.py
src/transformers/models/albert/configuration_albert.py
src/transformers/models/albert/modeling_albert.py
src/transformers/models/albert/modeling_tf_albert.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment