Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
994d8660
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e424d2e45d740a7d5cc4c9502bfa1c70f51d1535"
Commit
994d8660
authored
Mar 06, 2019
by
thomwolf
Browse files
fixing PYTORCH_PRETRAINED_BERT_CACHE use in examples
parent
2dd8f524
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
3 additions
and
3 deletions
+3
-3
examples/run_classifier.py
examples/run_classifier.py
+1
-1
examples/run_squad.py
examples/run_squad.py
+1
-1
examples/run_swag.py
examples/run_swag.py
+1
-1
No files found.
examples/run_classifier.py
View file @
994d8660
...
@@ -495,7 +495,7 @@ def main():
...
@@ -495,7 +495,7 @@ def main():
num_train_optimization_steps
=
num_train_optimization_steps
//
torch
.
distributed
.
get_world_size
()
num_train_optimization_steps
=
num_train_optimization_steps
//
torch
.
distributed
.
get_world_size
()
# Prepare model
# Prepare model
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
))
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
os
.
path
.
join
(
str
(
PYTORCH_PRETRAINED_BERT_CACHE
)
,
'distributed_{}'
.
format
(
args
.
local_rank
))
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
model
=
BertForSequenceClassification
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
cache_dir
,
cache_dir
=
cache_dir
,
num_labels
=
num_labels
)
num_labels
=
num_labels
)
...
...
examples/run_squad.py
View file @
994d8660
...
@@ -894,7 +894,7 @@ def main():
...
@@ -894,7 +894,7 @@ def main():
# Prepare model
# Prepare model
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
)))
cache_dir
=
os
.
path
.
join
(
str
(
PYTORCH_PRETRAINED_BERT_CACHE
)
,
'distributed_{}'
.
format
(
args
.
local_rank
)))
if
args
.
fp16
:
if
args
.
fp16
:
model
.
half
()
model
.
half
()
...
...
examples/run_swag.py
View file @
994d8660
...
@@ -367,7 +367,7 @@ def main():
...
@@ -367,7 +367,7 @@ def main():
# Prepare model
# Prepare model
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
cache_dir
=
os
.
path
.
join
(
PYTORCH_PRETRAINED_BERT_CACHE
,
'distributed_{}'
.
format
(
args
.
local_rank
)),
cache_dir
=
os
.
path
.
join
(
str
(
PYTORCH_PRETRAINED_BERT_CACHE
)
,
'distributed_{}'
.
format
(
args
.
local_rank
)),
num_choices
=
4
)
num_choices
=
4
)
if
args
.
fp16
:
if
args
.
fp16
:
model
.
half
()
model
.
half
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment