Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e248e9b0
Unverified
Commit
e248e9b0
authored
Oct 26, 2021
by
Patrick von Platen
Committed by
GitHub
Oct 26, 2021
Browse files
up (#14154)
parent
1f60df81
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
3 deletions
+16
-3
examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
.../pytorch/speech-recognition/run_speech_recognition_ctc.py
+16
-3
No files found.
examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
View file @
e248e9b0
...
...
@@ -187,6 +187,13 @@ class DataTrainingArguments:
"so that the cached datasets can consequently be loaded in distributed training"
},
)
use_auth_token
:
Optional
[
bool
]
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"If :obj:`True`, will use the token generated when running"
":obj:`transformers-cli logiin as HTTP bearer authorization for remote files."
},
)
@
dataclass
...
...
@@ -408,7 +415,9 @@ def main():
# one local process can concurrently download model & vocab.
# load config
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
)
config
=
AutoConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
use_auth_token
=
data_args
.
use_auth_token
)
# tokenizer is defined by `tokenizer_class` if present in config else by `model_type`
config_for_tokenizer
=
config
if
config
.
tokenizer_class
is
not
None
else
None
...
...
@@ -422,9 +431,10 @@ def main():
unk_token
=
"[UNK]"
,
pad_token
=
"[PAD]"
,
word_delimiter_token
=
"|"
,
use_auth_token
=
data_args
.
use_auth_token
,
)
feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
use_auth_token
=
data_args
.
use_auth_token
)
processor
=
Wav2Vec2Processor
(
feature_extractor
=
feature_extractor
,
tokenizer
=
tokenizer
)
...
...
@@ -447,7 +457,10 @@ def main():
# create model
model
=
AutoModelForCTC
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
config
=
config
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
config
=
config
,
use_auth_token
=
data_args
.
use_auth_token
,
)
# freeze encoder
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment