Unverified Commit e248e9b0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

up (#14154)

parent 1f60df81
...@@ -187,6 +187,13 @@ class DataTrainingArguments: ...@@ -187,6 +187,13 @@ class DataTrainingArguments:
"so that the cached datasets can consequently be loaded in distributed training" "so that the cached datasets can consequently be loaded in distributed training"
}, },
) )
use_auth_token: Optional[bool] = field(
default=False,
metadata={
"help": "If :obj:`True`, will use the token generated when running"
":obj:`transformers-cli logiin as HTTP bearer authorization for remote files."
},
)
@dataclass @dataclass
...@@ -408,7 +415,9 @@ def main(): ...@@ -408,7 +415,9 @@ def main():
# one local process can concurrently download model & vocab. # one local process can concurrently download model & vocab.
# load config # load config
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) config = AutoConfig.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
)
# tokenizer is defined by `tokenizer_class` if present in config else by `model_type` # tokenizer is defined by `tokenizer_class` if present in config else by `model_type`
config_for_tokenizer = config if config.tokenizer_class is not None else None config_for_tokenizer = config if config.tokenizer_class is not None else None
...@@ -422,9 +431,10 @@ def main(): ...@@ -422,9 +431,10 @@ def main():
unk_token="[UNK]", unk_token="[UNK]",
pad_token="[PAD]", pad_token="[PAD]",
word_delimiter_token="|", word_delimiter_token="|",
use_auth_token=data_args.use_auth_token,
) )
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
) )
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
...@@ -447,7 +457,10 @@ def main(): ...@@ -447,7 +457,10 @@ def main():
# create model # create model
model = AutoModelForCTC.from_pretrained( model = AutoModelForCTC.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, config=config model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
config=config,
use_auth_token=data_args.use_auth_token,
) )
# freeze encoder # freeze encoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment