Unverified Commit 28a08116 authored by regisss's avatar regisss Committed by GitHub
Browse files

Improve mismatched sizes management when loading a pretrained model (#17257)

- Add --ignore_mismatched_sizes argument to classification examples

- Expand the error message when loading a model whose head dimensions are different from expected dimensions
parent 1f13ba81
...@@ -18,13 +18,13 @@ limitations under the License. ...@@ -18,13 +18,13 @@ limitations under the License.
The following examples showcase how to fine-tune `Wav2Vec2` for audio classification using PyTorch. The following examples showcase how to fine-tune `Wav2Vec2` for audio classification using PyTorch.
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone, Speech recognition models that have been pretrained in unsupervised fashion on audio data alone,
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html), *e.g.* [Wav2Vec2](https://huggingface.co/transformers/main/model_doc/wav2vec2.html),
[HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html), [HuBERT](https://huggingface.co/transformers/main/model_doc/hubert.html),
[XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only [XLSR-Wav2Vec2](https://huggingface.co/transformers/main/model_doc/xlsr_wav2vec2.html), have shown to require only
very little annotated data to yield good performance on speech classification datasets. very little annotated data to yield good performance on speech classification datasets.
## Single-GPU ## Single-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the 🗣️ [Keyword Spotting subset](https://huggingface.co/datasets/superb#ks) of the SUPERB dataset. The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) on the 🗣️ [Keyword Spotting subset](https://huggingface.co/datasets/superb#ks) of the SUPERB dataset.
...@@ -63,7 +63,9 @@ On a single V100 GPU (16GB), this script should run in ~14 minutes and yield acc ...@@ -63,7 +63,9 @@ On a single V100 GPU (16GB), this script should run in ~14 minutes and yield acc
👀 See the results here: [anton-l/wav2vec2-base-ft-keyword-spotting](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting) 👀 See the results here: [anton-l/wav2vec2-base-ft-keyword-spotting](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting)
## Multi-GPU > If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
## Multi-GPU
The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) for 🌎 **Language Identification** on the [CommonLanguage dataset](https://huggingface.co/datasets/anton-l/common_language). The following command shows how to fine-tune [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) for 🌎 **Language Identification** on the [CommonLanguage dataset](https://huggingface.co/datasets/anton-l/common_language).
...@@ -139,7 +141,7 @@ It has been verified that the script works for the following datasets: ...@@ -139,7 +141,7 @@ It has been verified that the script works for the following datasets:
| Dataset | Pretrained Model | # transformer layers | Accuracy on eval | GPU setup | Training time | Fine-tuned Model & Logs | | Dataset | Pretrained Model | # transformer layers | Accuracy on eval | GPU setup | Training time | Fine-tuned Model & Logs |
|---------|------------------|----------------------|------------------|-----------|---------------|--------------------------| |---------|------------------|----------------------|------------------|-----------|---------------|--------------------------|
| Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) | | Keyword Spotting | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 2 | 0.9706 | 1 V100 GPU | 11min | [here](https://huggingface.co/anton-l/distilhubert-ft-keyword-spotting) |
| Keyword Spotting | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) | 12 | 0.9826 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting) | | Keyword Spotting | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) | 12 | 0.9826 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/wav2vec2-base-ft-keyword-spotting) |
| Keyword Spotting | [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) | 12 | 0.9819 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/hubert-base-ft-keyword-spotting) | | Keyword Spotting | [facebook/hubert-base-ls960](https://huggingface.co/facebook/hubert-base-ls960) | 12 | 0.9819 | 1 V100 GPU | 14min | [here](https://huggingface.co/anton-l/hubert-base-ft-keyword-spotting) |
| Keyword Spotting | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 24 | 0.9757 | 1 V100 GPU | 15min | [here](https://huggingface.co/anton-l/sew-mid-100k-ft-keyword-spotting) | | Keyword Spotting | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 24 | 0.9757 | 1 V100 GPU | 15min | [here](https://huggingface.co/anton-l/sew-mid-100k-ft-keyword-spotting) |
......
...@@ -163,6 +163,10 @@ class ModelArguments: ...@@ -163,6 +163,10 @@ class ModelArguments:
freeze_feature_extractor: Optional[bool] = field( freeze_feature_extractor: Optional[bool] = field(
default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."} default=None, metadata={"help": "Whether to freeze the feature extractor layers of the model."}
) )
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def __post_init__(self): def __post_init__(self):
if not self.freeze_feature_extractor and self.freeze_feature_encoder: if not self.freeze_feature_extractor and self.freeze_feature_encoder:
...@@ -333,6 +337,7 @@ def main(): ...@@ -333,6 +337,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
) )
# freeze the convolutional waveform encoder # freeze the convolutional waveform encoder
......
...@@ -62,9 +62,11 @@ python run_image_classification.py \ ...@@ -62,9 +62,11 @@ python run_image_classification.py \
Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with any model or dataset from the [hub](https://huggingface.co/). For an overview of all possible arguments, we refer to the [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) of the `TrainingArguments`, which can be passed as flags. Note that you can replace the model and dataset by simply setting the `model_name_or_path` and `dataset_name` arguments respectively, with any model or dataset from the [hub](https://huggingface.co/). For an overview of all possible arguments, we refer to the [docs](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) of the `TrainingArguments`, which can be passed as flags.
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
### Using your own data ### Using your own data
To use your own dataset, there are 2 ways: To use your own dataset, there are 2 ways:
- you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments - you can either provide your own folders as `--train_dir` and/or `--validation_dir` arguments
- you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument. - you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.
......
...@@ -150,6 +150,10 @@ class ModelArguments: ...@@ -150,6 +150,10 @@ class ModelArguments:
) )
}, },
) )
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def collate_fn(examples): def collate_fn(examples):
...@@ -269,6 +273,7 @@ def main(): ...@@ -269,6 +273,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
) )
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path, model_args.feature_extractor_name or model_args.model_name_or_path,
......
...@@ -165,6 +165,11 @@ def parse_args(): ...@@ -165,6 +165,11 @@ def parse_args():
action="store_true", action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.", help="Whether to load in all available experiment trackers from the environment and use them for logging.",
) )
parser.add_argument(
"--ignore_mismatched_sizes",
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args() args = parser.parse_args()
# Sanity checks # Sanity checks
...@@ -278,6 +283,7 @@ def main(): ...@@ -278,6 +283,7 @@ def main():
args.model_name_or_path, args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
) )
# Preprocessing the datasets # Preprocessing the datasets
......
...@@ -22,7 +22,7 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/ ...@@ -22,7 +22,7 @@ Based on the script [`run_glue.py`](https://github.com/huggingface/transformers/
Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding Fine-tuning the library models for sequence classification on the GLUE benchmark: [General Language Understanding
Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models) Evaluation](https://gluebenchmark.com/). This script can fine-tune any of the models on the [hub](https://huggingface.co/models)
and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file and can also be used for a dataset hosted on our [hub](https://huggingface.co/datasets) or your own data in a csv or a JSON file
(the script might need some tweaks in that case, refer to the comments inside for help). (the script might need some tweaks in that case, refer to the comments inside for help).
GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them: GLUE is made up of a total of 9 different tasks. Here is how to run the script on one of them:
...@@ -79,6 +79,8 @@ python run_glue.py \ ...@@ -79,6 +79,8 @@ python run_glue.py \
--output_dir /tmp/imdb/ --output_dir /tmp/imdb/
``` ```
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
### Mixed precision training ### Mixed precision training
......
...@@ -196,6 +196,10 @@ class ModelArguments: ...@@ -196,6 +196,10 @@ class ModelArguments:
) )
}, },
) )
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def main(): def main():
...@@ -364,6 +368,7 @@ def main(): ...@@ -364,6 +368,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
) )
# Preprocessing the raw_datasets # Preprocessing the raw_datasets
......
...@@ -170,6 +170,11 @@ def parse_args(): ...@@ -170,6 +170,11 @@ def parse_args():
action="store_true", action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.", help="Whether to load in all available experiment trackers from the environment and use them for logging.",
) )
parser.add_argument(
"--ignore_mismatched_sizes",
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args() args = parser.parse_args()
# Sanity checks # Sanity checks
...@@ -288,6 +293,7 @@ def main(): ...@@ -288,6 +293,7 @@ def main():
args.model_name_or_path, args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
) )
# Preprocessing the datasets # Preprocessing the datasets
......
...@@ -162,6 +162,10 @@ class ModelArguments: ...@@ -162,6 +162,10 @@ class ModelArguments:
) )
}, },
) )
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
def main(): def main():
...@@ -291,6 +295,7 @@ def main(): ...@@ -291,6 +295,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
) )
# Preprocessing the datasets # Preprocessing the datasets
......
...@@ -55,6 +55,8 @@ uses special features of those tokenizers. You can check if your favorite model ...@@ -55,6 +55,8 @@ uses special features of those tokenizers. You can check if your favorite model
[this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version [this table](https://huggingface.co/transformers/index.html#supported-frameworks), if it doesn't you can still use the old version
of the script. of the script.
> If your model classification head dimensions do not fit the number of labels in the dataset, you can specify `--ignore_mismatched_sizes` to adapt it.
## Old version of the script ## Old version of the script
You can find the old version of the PyTorch script [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/token-classification/run_ner.py). You can find the old version of the PyTorch script [here](https://github.com/huggingface/transformers/blob/main/examples/legacy/token-classification/run_ner.py).
......
...@@ -87,6 +87,10 @@ class ModelArguments: ...@@ -87,6 +87,10 @@ class ModelArguments:
) )
}, },
) )
ignore_mismatched_sizes: bool = field(
default=False,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
)
@dataclass @dataclass
...@@ -364,6 +368,7 @@ def main(): ...@@ -364,6 +368,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
ignore_mismatched_sizes=model_args.ignore_mismatched_sizes,
) )
# Tokenizer check: this script requires a fast tokenizer. # Tokenizer check: this script requires a fast tokenizer.
......
...@@ -223,6 +223,11 @@ def parse_args(): ...@@ -223,6 +223,11 @@ def parse_args():
action="store_true", action="store_true",
help="Whether to load in all available experiment trackers from the environment and use them for logging.", help="Whether to load in all available experiment trackers from the environment and use them for logging.",
) )
parser.add_argument(
"--ignore_mismatched_sizes",
action="store_true",
help="Whether or not to enable to load a pretrained model whose head dimensions are different.",
)
args = parser.parse_args() args = parser.parse_args()
# Sanity checks # Sanity checks
...@@ -383,6 +388,7 @@ def main(): ...@@ -383,6 +388,7 @@ def main():
args.model_name_or_path, args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
ignore_mismatched_sizes=args.ignore_mismatched_sizes,
) )
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
......
...@@ -2253,6 +2253,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2253,6 +2253,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if len(error_msgs) > 0: if len(error_msgs) > 0:
error_msg = "\n\t".join(error_msgs) error_msg = "\n\t".join(error_msgs)
if "size mismatch" in error_msg:
error_msg += (
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment