Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
15550ce0
Commit
15550ce0
authored
May 15, 2020
by
Julien Chaumond
Browse files
[skip ci] remove local rank
parent
62427d08
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
12 deletions
+4
-12
examples/language-modeling/run_language_modeling.py
examples/language-modeling/run_language_modeling.py
+3
-11
src/transformers/configuration_roberta.py
src/transformers/configuration_roberta.py
+1
-1
No files found.
examples/language-modeling/run_language_modeling.py
View file @
15550ce0
...
...
@@ -115,7 +115,7 @@ class DataTrainingArguments:
)
def
get_dataset
(
args
:
DataTrainingArguments
,
tokenizer
:
PreTrainedTokenizer
,
evaluate
=
False
,
local_rank
=-
1
):
def
get_dataset
(
args
:
DataTrainingArguments
,
tokenizer
:
PreTrainedTokenizer
,
evaluate
=
False
):
file_path
=
args
.
eval_data_file
if
evaluate
else
args
.
train_data_file
if
args
.
line_by_line
:
return
LineByLineTextDataset
(
tokenizer
=
tokenizer
,
file_path
=
file_path
,
block_size
=
args
.
block_size
)
...
...
@@ -216,16 +216,8 @@ def main():
data_args
.
block_size
=
min
(
data_args
.
block_size
,
tokenizer
.
max_len
)
# Get datasets
train_dataset
=
(
get_dataset
(
data_args
,
tokenizer
=
tokenizer
,
local_rank
=
training_args
.
local_rank
)
if
training_args
.
do_train
else
None
)
eval_dataset
=
(
get_dataset
(
data_args
,
tokenizer
=
tokenizer
,
local_rank
=
training_args
.
local_rank
,
evaluate
=
True
)
if
training_args
.
do_eval
else
None
)
train_dataset
=
get_dataset
(
data_args
,
tokenizer
=
tokenizer
)
if
training_args
.
do_train
else
None
eval_dataset
=
get_dataset
(
data_args
,
tokenizer
=
tokenizer
,
evaluate
=
True
)
if
training_args
.
do_eval
else
None
data_collator
=
DataCollatorForLanguageModeling
(
tokenizer
=
tokenizer
,
mlm
=
data_args
.
mlm
,
mlm_probability
=
data_args
.
mlm_probability
)
...
...
src/transformers/configuration_roberta.py
View file @
15550ce0
...
...
@@ -68,6 +68,6 @@ class RobertaConfig(BertConfig):
model_type
=
"roberta"
def
__init__
(
self
,
pad_token_id
=
1
,
bos_token_id
=
0
,
eos_token_id
=
2
,
**
kwargs
):
"""Constructs
Flau
bertConfig.
"""Constructs
Ro
bert
a
Config.
"""
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment