Unverified Commit e248e9b0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

up (#14154)

parent 1f60df81
......@@ -187,6 +187,13 @@ class DataTrainingArguments:
"so that the cached datasets can consequently be loaded in distributed training"
},
)
use_auth_token: Optional[bool] = field(
default=False,
metadata={
"help": "If :obj:`True`, will use the token generated when running"
":obj:`transformers-cli logiin as HTTP bearer authorization for remote files."
},
)
@dataclass
......@@ -408,7 +415,9 @@ def main():
# one local process can concurrently download model & vocab.
# load config
config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
config = AutoConfig.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
)
# tokenizer is defined by `tokenizer_class` if present in config else by `model_type`
config_for_tokenizer = config if config.tokenizer_class is not None else None
......@@ -422,9 +431,10 @@ def main():
unk_token="[UNK]",
pad_token="[PAD]",
word_delimiter_token="|",
use_auth_token=data_args.use_auth_token,
)
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir
model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
......@@ -447,7 +457,10 @@ def main():
# create model
model = AutoModelForCTC.from_pretrained(
model_args.model_name_or_path, cache_dir=model_args.cache_dir, config=config
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
config=config,
use_auth_token=data_args.use_auth_token,
)
# freeze encoder
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment