Unverified Commit 3dd030d2 authored by zspo's avatar zspo Committed by GitHub
Browse files

fix register (#25779)

parent dc0c1029
...@@ -1046,7 +1046,7 @@ class AutoConfig: ...@@ -1046,7 +1046,7 @@ class AutoConfig:
) )
@staticmethod @staticmethod
def register(model_type, config): def register(model_type, config, exist_ok=False):
""" """
Register a new configuration for this class. Register a new configuration for this class.
...@@ -1060,4 +1060,4 @@ class AutoConfig: ...@@ -1060,4 +1060,4 @@ class AutoConfig:
f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they " f"you passed (config has {config.model_type} and you passed {model_type}. Fix one of those so they "
"match!" "match!"
) )
CONFIG_MAPPING.register(model_type, config) CONFIG_MAPPING.register(model_type, config, exist_ok=exist_ok)
...@@ -379,7 +379,7 @@ class AutoFeatureExtractor: ...@@ -379,7 +379,7 @@ class AutoFeatureExtractor:
) )
@staticmethod @staticmethod
def register(config_class, feature_extractor_class): def register(config_class, feature_extractor_class, exist_ok=False):
""" """
Register a new feature extractor for this class. Register a new feature extractor for this class.
...@@ -388,4 +388,4 @@ class AutoFeatureExtractor: ...@@ -388,4 +388,4 @@ class AutoFeatureExtractor:
The configuration corresponding to the model to register. The configuration corresponding to the model to register.
feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register. feature_extractor_class ([`FeatureExtractorMixin`]): The feature extractor to register.
""" """
FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class) FEATURE_EXTRACTOR_MAPPING.register(config_class, feature_extractor_class, exist_ok=exist_ok)
...@@ -405,7 +405,7 @@ class AutoImageProcessor: ...@@ -405,7 +405,7 @@ class AutoImageProcessor:
) )
@staticmethod @staticmethod
def register(config_class, image_processor_class): def register(config_class, image_processor_class, exist_ok=False):
""" """
Register a new image processor for this class. Register a new image processor for this class.
...@@ -414,4 +414,4 @@ class AutoImageProcessor: ...@@ -414,4 +414,4 @@ class AutoImageProcessor:
The configuration corresponding to the model to register. The configuration corresponding to the model to register.
image_processor_class ([`ImageProcessingMixin`]): The image processor to register. image_processor_class ([`ImageProcessingMixin`]): The image processor to register.
""" """
IMAGE_PROCESSOR_MAPPING.register(config_class, image_processor_class) IMAGE_PROCESSOR_MAPPING.register(config_class, image_processor_class, exist_ok=exist_ok)
...@@ -319,7 +319,7 @@ class AutoProcessor: ...@@ -319,7 +319,7 @@ class AutoProcessor:
) )
@staticmethod @staticmethod
def register(config_class, processor_class): def register(config_class, processor_class, exist_ok=False):
""" """
Register a new processor for this class. Register a new processor for this class.
...@@ -328,4 +328,4 @@ class AutoProcessor: ...@@ -328,4 +328,4 @@ class AutoProcessor:
The configuration corresponding to the model to register. The configuration corresponding to the model to register.
processor_class ([`FeatureExtractorMixin`]): The processor to register. processor_class ([`FeatureExtractorMixin`]): The processor to register.
""" """
PROCESSOR_MAPPING.register(config_class, processor_class) PROCESSOR_MAPPING.register(config_class, processor_class, exist_ok=exist_ok)
...@@ -764,7 +764,7 @@ class AutoTokenizer: ...@@ -764,7 +764,7 @@ class AutoTokenizer:
f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}." f"Model type should be one of {', '.join(c.__name__ for c in TOKENIZER_MAPPING.keys())}."
) )
def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None): def register(config_class, slow_tokenizer_class=None, fast_tokenizer_class=None, exist_ok=False):
""" """
Register a new tokenizer in this mapping. Register a new tokenizer in this mapping.
...@@ -805,4 +805,4 @@ class AutoTokenizer: ...@@ -805,4 +805,4 @@ class AutoTokenizer:
if fast_tokenizer_class is None: if fast_tokenizer_class is None:
fast_tokenizer_class = existing_fast fast_tokenizer_class = existing_fast
TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class)) TOKENIZER_MAPPING.register(config_class, (slow_tokenizer_class, fast_tokenizer_class), exist_ok=exist_ok)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment