"...components/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "5965b788ea3eae754b21f11c9734169544defcf8"
Commit a0c62e94 authored by sanchit-gandhi's avatar sanchit-gandhi
Browse files

regularisation

parent 8d6d3540
...@@ -248,6 +248,52 @@ class ModelArguments: ...@@ -248,6 +248,52 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."}, metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
) )
attention_dropout: float = field(
default=0.0, metadata={"help": "The dropout ratio for the attention probabilities."}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "The dropout ratio for activations inside the fully connected layer."}
)
feat_proj_dropout: float = field(default=0.0, metadata={"help": "The dropout ratio for the projected features."})
hidden_dropout: float = field(
default=0.0,
metadata={
"help": "The dropout probability for all fully connected layers in the embeddings, encoder, and pooler."
},
)
final_dropout: float = field(
default=0.0,
metadata={"help": "The dropout probability for the final projection layer."},
)
mask_time_prob: float = field(
default=0.05,
metadata={
"help": (
"Probability of each feature vector along the time axis to be chosen as the start of the vector "
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature "
"vectors will be masked along the time axis."
)
},
)
mask_time_length: int = field(
default=10,
metadata={"help": "Length of vector span to mask along the time axis."},
)
mask_feature_prob: float = field(
default=0.0,
metadata={
"help": (
"Probability of each feature vector along the feature axis to be chosen as the start of the vectorspan"
" to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature"
" bins will be masked along the time axis."
)
},
)
mask_feature_length: int = field(
default=10,
metadata={"help": "Length of vector span to mask along the feature axis."},
)
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
def convert_dataset_str_to_list( def convert_dataset_str_to_list(
...@@ -606,7 +652,9 @@ def main(): ...@@ -606,7 +652,9 @@ def main():
) )
# We'll include these in the model's config to get human readable labels in the Inference API. # We'll include these in the model's config to get human readable labels in the Inference API.
set_labels = set(raw_datasets["train"]["labels"]).union(set(raw_datasets["eval"]["labels"])) set_labels = set(raw_datasets["train"]["labels"])
if training_args.do_eval:
set_labels = set_labels.union(set(raw_datasets["eval"]["labels"]))
label2id, id2label = {}, {} label2id, id2label = {}, {}
for i, label in enumerate(set(set_labels)): for i, label in enumerate(set(set_labels)):
label2id[label] = str(i) label2id[label] = str(i)
...@@ -654,6 +702,22 @@ def main(): ...@@ -654,6 +702,22 @@ def main():
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
) )
# adapt config with regularization
config.update(
{
"feat_proj_dropout": model_args.feat_proj_dropout,
"attention_dropout": model_args.attention_dropout,
"hidden_dropout": model_args.hidden_dropout,
"final_dropout": model_args.final_dropout,
"mask_time_prob": model_args.mask_time_prob,
"mask_time_length": model_args.mask_time_length,
"mask_feature_prob": model_args.mask_feature_prob,
"mask_feature_length": model_args.mask_feature_length,
"layerdrop": model_args.layerdrop,
"activation_dropout": model_args.activation_dropout,
}
)
model = AutoModelForAudioClassification.from_pretrained( model = AutoModelForAudioClassification.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path), from_tf=bool(".ckpt" in model_args.model_name_or_path),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment