Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
8d5a47c7
Unverified
Commit
8d5a47c7
authored
Dec 20, 2019
by
Thomas Wolf
Committed by
GitHub
Dec 20, 2019
Browse files
Merge pull request #2243 from huggingface/fix-xlm-roberta
fixing xlm-roberta tokenizer max_length and automodels
parents
65c75fc5
79e4a6a2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
98 additions
and
57 deletions
+98
-57
transformers/commands/run.py
transformers/commands/run.py
+22
-10
transformers/commands/serving.py
transformers/commands/serving.py
+7
-3
transformers/modeling_auto.py
transformers/modeling_auto.py
+14
-3
transformers/modeling_utils.py
transformers/modeling_utils.py
+1
-1
transformers/pipelines.py
transformers/pipelines.py
+39
-33
transformers/tokenization_utils.py
transformers/tokenization_utils.py
+5
-1
transformers/tokenization_xlm_roberta.py
transformers/tokenization_xlm_roberta.py
+10
-6
No files found.
transformers/commands/run.py
View file @
8d5a47c7
...
...
@@ -9,6 +9,9 @@ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
def
try_infer_format_from_ext
(
path
:
str
):
if
not
path
:
return
'pipe'
for
ext
in
PipelineDataFormat
.
SUPPORTED_FORMATS
:
if
path
.
endswith
(
ext
):
return
ext
...
...
@@ -20,9 +23,16 @@ def try_infer_format_from_ext(path: str):
def
run_command_factory
(
args
):
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
if
args
.
model
else
None
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
format
=
try_infer_format_from_ext
(
args
.
input
)
if
args
.
format
==
'infer'
else
args
.
format
reader
=
PipelineDataFormat
.
from_str
(
format
,
args
.
output
,
args
.
input
,
args
.
column
)
reader
=
PipelineDataFormat
.
from_str
(
format
=
format
,
output_path
=
args
.
output
,
input_path
=
args
.
input
,
column
=
args
.
column
if
args
.
column
else
nlp
.
default_input_names
)
return
RunCommand
(
nlp
,
reader
)
...
...
@@ -35,24 +45,26 @@ class RunCommand(BaseTransformersCLICommand):
@
staticmethod
def
register_subcommand
(
parser
:
ArgumentParser
):
run_parser
=
parser
.
add_parser
(
'run'
,
help
=
"Run a pipeline through the CLI"
)
run_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
run_parser
.
add_argument
(
'--task'
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'Task to run'
)
run_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Name or path to the model to instantiate.'
)
run_parser
.
add_argument
(
'--input'
,
type
=
str
,
help
=
'Path to the file to use for inference'
)
run_parser
.
add_argument
(
'--output'
,
type
=
str
,
help
=
'Path to the file that will be used post to write results.'
)
run_parser
.
add_argument
(
'--model'
,
type
=
str
,
help
=
'Name or path to the model to instantiate.'
)
run_parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Name or path to the model
\'
s config to instantiate.'
)
run_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Name of the tokenizer to use. (default: same as the model name)'
)
run_parser
.
add_argument
(
'--column'
,
type
=
str
,
help
=
'Name of the column to use as input. (For multi columns input as QA use column1,columns2)'
)
run_parser
.
add_argument
(
'--format'
,
type
=
str
,
default
=
'infer'
,
choices
=
PipelineDataFormat
.
SUPPORTED_FORMATS
,
help
=
'Input format to read from'
)
run_parser
.
add_argument
(
'--input'
,
type
=
str
,
help
=
'Path to the file to use for inference'
)
run_parser
.
add_argument
(
'--output'
,
type
=
str
,
help
=
'Path to the file that will be used post to write results.'
)
run_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
run_parser
.
set_defaults
(
func
=
run_command_factory
)
def
run
(
self
):
nlp
,
output
=
self
.
_nlp
,
[]
nlp
,
outputs
=
self
.
_nlp
,
[]
for
entry
in
self
.
_reader
:
if
self
.
_reader
.
is_multi_columns
:
output
+=
nlp
(
**
entry
)
output
=
nlp
(
**
entry
)
if
self
.
_reader
.
is_multi_columns
else
nlp
(
entry
)
if
isinstance
(
output
,
dict
):
outputs
.
append
(
output
)
else
:
output
+=
nlp
(
entry
)
output
s
+=
output
# Saving data
if
self
.
_nlp
.
binary_output
:
...
...
transformers/commands/serving.py
View file @
8d5a47c7
...
...
@@ -24,7 +24,11 @@ def serve_command_factory(args: Namespace):
Factory function used to instantiate serving server from provided command line arguments.
:return: ServeCommand
"""
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
nlp
=
pipeline
(
task
=
args
.
task
,
model
=
args
.
model
if
args
.
model
else
None
,
config
=
args
.
config
,
tokenizer
=
args
.
tokenizer
,
device
=
args
.
device
)
return
ServeCommand
(
nlp
,
args
.
host
,
args
.
port
)
...
...
@@ -68,12 +72,12 @@ class ServeCommand(BaseTransformersCLICommand):
"""
serve_parser
=
parser
.
add_parser
(
'serve'
,
help
=
'CLI tool to run inference requests through REST and GraphQL endpoints.'
)
serve_parser
.
add_argument
(
'--task'
,
type
=
str
,
choices
=
SUPPORTED_TASKS
.
keys
(),
help
=
'The task to run the pipeline on'
)
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
'localhost'
,
help
=
'Interface the server will listen on.'
)
serve_parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
8888
,
help
=
'Port the serving will listen to.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
required
=
True
,
help
=
'Model
\'
s name or path to stored model.'
)
serve_parser
.
add_argument
(
'--model'
,
type
=
str
,
help
=
'Model
\'
s name or path to stored model.'
)
serve_parser
.
add_argument
(
'--config'
,
type
=
str
,
help
=
'Model
\'
s config name or path to stored model.'
)
serve_parser
.
add_argument
(
'--tokenizer'
,
type
=
str
,
help
=
'Tokenizer name to use.'
)
serve_parser
.
add_argument
(
'--device'
,
type
=
int
,
default
=-
1
,
help
=
'Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)'
)
serve_parser
.
set_defaults
(
func
=
serve_command_factory
)
def
__init__
(
self
,
pipeline
:
Pipeline
,
host
:
str
,
port
:
int
):
...
...
transformers/modeling_auto.py
View file @
8d5a47c7
...
...
@@ -20,7 +20,7 @@ import logging
from
.configuration_auto
import
(
AlbertConfig
,
BertConfig
,
CamembertConfig
,
CTRLConfig
,
DistilBertConfig
,
GPT2Config
,
OpenAIGPTConfig
,
RobertaConfig
,
TransfoXLConfig
,
XLMConfig
,
XLNetConfig
)
TransfoXLConfig
,
XLMConfig
,
XLNetConfig
,
XLMRobertaConfig
)
from
.modeling_bert
import
BertModel
,
BertForMaskedLM
,
BertForSequenceClassification
,
BertForQuestionAnswering
,
\
BertForTokenClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
...
...
@@ -41,7 +41,8 @@ from .modeling_camembert import CamembertModel, CamembertForMaskedLM, CamembertF
from
.modeling_albert
import
AlbertModel
,
AlbertForMaskedLM
,
AlbertForSequenceClassification
,
\
AlbertForQuestionAnswering
,
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_t5
import
T5Model
,
T5WithLMHeadModel
,
T5_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_xlm_roberta
import
XLMRobertaModel
,
XLMRobertaForMaskedLM
,
XLMRobertaForSequenceClassification
,
XLMRobertaForMultipleChoice
,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_xlm_roberta
import
XLMRobertaModel
,
XLMRobertaForMaskedLM
,
XLMRobertaForSequenceClassification
,
\
XLMRobertaForMultipleChoice
,
XLMRobertaForTokenClassification
,
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_utils
import
PreTrainedModel
,
SequenceSummary
...
...
@@ -146,6 +147,8 @@ class AutoModel(object):
return
AlbertModel
(
config
)
elif
isinstance
(
config
,
CamembertConfig
):
return
CamembertModel
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaModel
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
@
classmethod
...
...
@@ -333,6 +336,8 @@ class AutoModelWithLMHead(object):
return
XLMWithLMHeadModel
(
config
)
elif
isinstance
(
config
,
CTRLConfig
):
return
CTRLLMHeadModel
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForMaskedLM
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
@
classmethod
...
...
@@ -509,6 +514,8 @@ class AutoModelForSequenceClassification(object):
return
XLNetForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLMConfig
):
return
XLMForSequenceClassification
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForSequenceClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
@
classmethod
...
...
@@ -787,6 +794,8 @@ class AutoModelForTokenClassification:
return
XLNetForTokenClassification
(
config
)
elif
isinstance
(
config
,
RobertaConfig
):
return
RobertaForTokenClassification
(
config
)
elif
isinstance
(
config
,
XLMRobertaConfig
):
return
XLMRobertaForTokenClassification
(
config
)
raise
ValueError
(
"Unrecognized configuration class {}"
.
format
(
config
))
@
classmethod
...
...
@@ -865,6 +874,8 @@ class AutoModelForTokenClassification:
return
CamembertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'xlm-roberta'
in
pretrained_model_name_or_path
:
return
XLMRobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
elif
'bert'
in
pretrained_model_name_or_path
:
...
...
@@ -873,4 +884,4 @@ class AutoModelForTokenClassification:
return
XLNetForTokenClassification
.
from_pretrained
(
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
)
raise
ValueError
(
"Unrecognized model identifier in {}. Should contains one of "
"'bert', 'xlnet', 'camembert', 'distilbert', 'roberta'"
.
format
(
pretrained_model_name_or_path
))
"'bert', 'xlnet', 'camembert', 'distilbert',
'xlm-roberta',
'roberta'"
.
format
(
pretrained_model_name_or_path
))
transformers/modeling_utils.py
View file @
8d5a47c7
transformers/pipelines.py
View file @
8d5a47c7
...
...
@@ -14,12 +14,14 @@
# limitations under the License.
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
sys
import
csv
import
json
import
os
import
pickle
import
logging
import
six
from
abc
import
ABC
,
abstractmethod
from
contextlib
import
contextmanager
from
itertools
import
groupby
...
...
@@ -49,7 +51,7 @@ logger = logging.getLogger(__name__)
def
get_framework
(
model
=
None
):
""" Select framework (TensorFlow/PyTorch) to use.
If both frameworks are installed and no specific model is provided, defaults to using
TensorFlow
.
If both frameworks are installed and no specific model is provided, defaults to using
PyTorch
.
"""
if
is_tf_available
()
and
is_torch_available
()
and
model
is
not
None
and
not
isinstance
(
model
,
str
):
# Both framework are available but the use supplied a model class instance.
...
...
@@ -60,7 +62,8 @@ def get_framework(model=None):
"To install TensorFlow 2.0, read the instructions at https://www.tensorflow.org/install/ "
"To install PyTorch, read the instructions at https://pytorch.org/."
)
else
:
framework
=
'tf'
if
is_tf_available
()
else
'pt'
# framework = 'tf' if is_tf_available() else 'pt'
framework
=
'pt'
if
is_torch_available
()
else
'tf'
return
framework
class
ArgumentHandler
(
ABC
):
...
...
@@ -97,28 +100,29 @@ class PipelineDataFormat:
Supported data formats currently includes:
- JSON
- CSV
- stdin/stdout (pipe)
PipelineDataFormat also includes some utilities to work with multi-columns like mapping from datasets columns
to pipelines keyword arguments through the `dataset_kwarg_1=dataset_column_1` format.
"""
SUPPORTED_FORMATS
=
[
'json'
,
'csv'
,
'pipe'
]
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
self
.
output
=
output
self
.
path
=
input
self
.
column
=
column
.
split
(
','
)
if
column
else
[
''
]
def
__init__
(
self
,
output
_path
:
Optional
[
str
],
input
_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
self
.
output
_path
=
output
_path
self
.
input_
path
=
input
_path
self
.
column
=
column
.
split
(
','
)
if
column
is
not
None
else
[
''
]
self
.
is_multi_columns
=
len
(
self
.
column
)
>
1
if
self
.
is_multi_columns
:
self
.
column
=
[
tuple
(
c
.
split
(
'='
))
if
'='
in
c
else
(
c
,
c
)
for
c
in
self
.
column
]
if
output
is
not
None
:
if
exists
(
abspath
(
self
.
output
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
))
if
output
_path
is
not
None
:
if
exists
(
abspath
(
self
.
output
_path
)):
raise
OSError
(
'{} already exists on disk'
.
format
(
self
.
output
_path
))
if
input
is
not
None
:
if
not
exists
(
abspath
(
self
.
path
)):
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
path
))
if
input
_path
is
not
None
:
if
not
exists
(
abspath
(
self
.
input_
path
)):
raise
OSError
(
'{} doesnt exist on disk'
.
format
(
self
.
input_
path
))
@
abstractmethod
def
__iter__
(
self
):
...
...
@@ -139,7 +143,7 @@ class PipelineDataFormat:
:param data: data to store
:return: (str) Path where the data has been saved
"""
path
,
_
=
os
.
path
.
splitext
(
self
.
output
)
path
,
_
=
os
.
path
.
splitext
(
self
.
output
_path
)
binary_path
=
os
.
path
.
extsep
.
join
((
path
,
'pickle'
))
with
open
(
binary_path
,
'wb+'
)
as
f_output
:
...
...
@@ -148,23 +152,23 @@ class PipelineDataFormat:
return
binary_path
@
staticmethod
def
from_str
(
name
:
str
,
output
:
Optional
[
str
],
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
if
name
==
'json'
:
return
JsonPipelineDataFormat
(
output
,
path
,
column
)
elif
name
==
'csv'
:
return
CsvPipelineDataFormat
(
output
,
path
,
column
)
elif
name
==
'pipe'
:
return
PipedPipelineDataFormat
(
output
,
path
,
column
)
def
from_str
(
format
:
str
,
output
_path
:
Optional
[
str
],
input_
path
:
Optional
[
str
],
column
:
Optional
[
str
]):
if
format
==
'json'
:
return
JsonPipelineDataFormat
(
output
_path
,
input_
path
,
column
)
elif
format
==
'csv'
:
return
CsvPipelineDataFormat
(
output
_path
,
input_
path
,
column
)
elif
format
==
'pipe'
:
return
PipedPipelineDataFormat
(
output
_path
,
input_
path
,
column
)
else
:
raise
KeyError
(
'Unknown reader {} (Available reader are json/csv/pipe)'
.
format
(
name
))
raise
KeyError
(
'Unknown reader {} (Available reader are json/csv/pipe)'
.
format
(
format
))
class
CsvPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
,
input
,
column
)
def
__init__
(
self
,
output
_path
:
Optional
[
str
],
input
_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
_path
,
input
_path
,
column
)
def
__iter__
(
self
):
with
open
(
self
.
path
,
'r'
)
as
f
:
with
open
(
self
.
input_
path
,
'r'
)
as
f
:
reader
=
csv
.
DictReader
(
f
)
for
row
in
reader
:
if
self
.
is_multi_columns
:
...
...
@@ -173,7 +177,7 @@ class CsvPipelineDataFormat(PipelineDataFormat):
yield
row
[
self
.
column
[
0
]]
def
save
(
self
,
data
:
List
[
dict
]):
with
open
(
self
.
output
,
'w'
)
as
f
:
with
open
(
self
.
output
_path
,
'w'
)
as
f
:
if
len
(
data
)
>
0
:
writer
=
csv
.
DictWriter
(
f
,
list
(
data
[
0
].
keys
()))
writer
.
writeheader
()
...
...
@@ -181,10 +185,10 @@ class CsvPipelineDataFormat(PipelineDataFormat):
class
JsonPipelineDataFormat
(
PipelineDataFormat
):
def
__init__
(
self
,
output
:
Optional
[
str
],
input
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
,
input
,
column
)
def
__init__
(
self
,
output
_path
:
Optional
[
str
],
input
_path
:
Optional
[
str
],
column
:
Optional
[
str
]):
super
().
__init__
(
output
_path
,
input
_path
,
column
)
with
open
(
input
,
'r'
)
as
f
:
with
open
(
input
_path
,
'r'
)
as
f
:
self
.
_entries
=
json
.
load
(
f
)
def
__iter__
(
self
):
...
...
@@ -195,7 +199,7 @@ class JsonPipelineDataFormat(PipelineDataFormat):
yield
entry
[
self
.
column
[
0
]]
def
save
(
self
,
data
:
dict
):
with
open
(
self
.
output
,
'w'
)
as
f
:
with
open
(
self
.
output
_path
,
'w'
)
as
f
:
json
.
dump
(
data
,
f
)
...
...
@@ -207,9 +211,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
If columns are provided, then the output will be a dictionary with {column_x: value_x}
"""
def
__iter__
(
self
):
import
sys
for
line
in
sys
.
stdin
:
# Split for multi-columns
if
'
\t
'
in
line
:
...
...
@@ -228,7 +230,7 @@ class PipedPipelineDataFormat(PipelineDataFormat):
print
(
data
)
def
save_binary
(
self
,
data
:
Union
[
dict
,
List
[
dict
]])
->
str
:
if
self
.
output
is
None
:
if
self
.
output
_path
is
None
:
raise
KeyError
(
'When using piped input on pipeline outputting large object requires an output file path. '
'Please provide such output path through --output argument.'
...
...
@@ -293,6 +295,9 @@ class Pipeline(_ScikitCompat):
nlp = NerPipeline(model='...', config='...', tokenizer='...')
nlp = QuestionAnsweringPipeline(model=AutoModel.from_pretrained('...'), tokenizer='...')
"""
default_input_names
=
None
def
__init__
(
self
,
model
,
tokenizer
:
PreTrainedTokenizer
=
None
,
modelcard
:
ModelCard
=
None
,
framework
:
Optional
[
str
]
=
None
,
args_parser
:
ArgumentHandler
=
None
,
device
:
int
=
-
1
,
...
...
@@ -581,6 +586,8 @@ class QuestionAnsweringPipeline(Pipeline):
Question Answering pipeline using ModelForQuestionAnswering head.
"""
default_input_names
=
'question,context'
def
__init__
(
self
,
model
,
tokenizer
:
Optional
[
PreTrainedTokenizer
],
modelcard
:
Optional
[
ModelCard
],
...
...
@@ -683,7 +690,6 @@ class QuestionAnsweringPipeline(Pipeline):
}
for
s
,
e
,
score
in
zip
(
starts
,
ends
,
scores
)
]
if
len
(
answers
)
==
1
:
return
answers
[
0
]
return
answers
...
...
transformers/tokenization_utils.py
View file @
8d5a47c7
...
...
@@ -434,7 +434,11 @@ class PreTrainedTokenizer(object):
init_kwargs
[
key
]
=
value
# Instantiate tokenizer.
try
:
tokenizer
=
cls
(
*
init_inputs
,
**
init_kwargs
)
except
OSError
:
OSError
(
"Unable to load vocabulary from file. "
"Please check that the provided vocabulary is accessible and not corrupted."
)
# Save inputs and kwargs for saving and re-loading with ``save_pretrained``
tokenizer
.
init_inputs
=
init_inputs
...
...
transformers/tokenization_xlm_roberta.py
View file @
8d5a47c7
...
...
@@ -40,8 +40,12 @@ PRETRAINED_VOCAB_FILES_MAP = {
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
'xlm-roberta-base'
:
None
,
'xlm-roberta-large'
:
None
,
'xlm-roberta-base'
:
512
,
'xlm-roberta-large'
:
512
,
'xlm-roberta-large-finetuned-conll02-dutch'
:
512
,
'xlm-roberta-large-finetuned-conll02-spanish'
:
512
,
'xlm-roberta-large-finetuned-conll03-english'
:
512
,
'xlm-roberta-large-finetuned-conll03-german'
:
512
,
}
class
XLMRobertaTokenizer
(
PreTrainedTokenizer
):
...
...
@@ -58,7 +62,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
def
__init__
(
self
,
vocab_file
,
bos_token
=
"<s>"
,
eos_token
=
"</s>"
,
sep_token
=
"</s>"
,
cls_token
=
"<s>"
,
unk_token
=
"<unk>"
,
pad_token
=
'<pad>'
,
mask_token
=
'<mask>'
,
**
kwargs
):
super
(
XLMRobertaTokenizer
,
self
).
__init__
(
max_len
=
512
,
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
super
(
XLMRobertaTokenizer
,
self
).
__init__
(
bos_token
=
bos_token
,
eos_token
=
eos_token
,
unk_token
=
unk_token
,
sep_token
=
sep_token
,
cls_token
=
cls_token
,
pad_token
=
pad_token
,
mask_token
=
mask_token
,
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment