Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
32dbb2d9
"vscode:/vscode.git/clone" did not exist on "c0281feb506b3cd8e9cfe19aa931ad05e295cffa"
Unverified
Commit
32dbb2d9
authored
Apr 26, 2021
by
Patrick von Platen
Committed by
GitHub
Apr 26, 2021
Browse files
make style (#11442)
parent
04ab2ca6
Changes
105
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
36 additions
and
36 deletions
+36
-36
src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
...ansformers/models/xlm_roberta/tokenization_xlm_roberta.py
+1
-1
src/transformers/models/xlnet/tokenization_xlnet.py
src/transformers/models/xlnet/tokenization_xlnet.py
+2
-2
src/transformers/testing_utils.py
src/transformers/testing_utils.py
+3
-3
src/transformers/tokenization_utils_base.py
src/transformers/tokenization_utils_base.py
+1
-1
src/transformers/trainer.py
src/transformers/trainer.py
+1
-1
src/transformers/trainer_callback.py
src/transformers/trainer_callback.py
+6
-6
src/transformers/trainer_pt_utils.py
src/transformers/trainer_pt_utils.py
+2
-2
src/transformers/trainer_utils.py
src/transformers/trainer_utils.py
+6
-6
src/transformers/utils/versions.py
src/transformers/utils/versions.py
+2
-2
tests/deepspeed/test_deepspeed.py
tests/deepspeed/test_deepspeed.py
+2
-2
tests/test_modeling_common.py
tests/test_modeling_common.py
+1
-1
tests/test_modeling_funnel.py
tests/test_modeling_funnel.py
+1
-1
tests/test_modeling_layoutlm.py
tests/test_modeling_layoutlm.py
+1
-1
tests/test_modeling_lxmert.py
tests/test_modeling_lxmert.py
+1
-1
tests/test_modeling_tapas.py
tests/test_modeling_tapas.py
+1
-1
tests/test_modeling_tf_funnel.py
tests/test_modeling_tf_funnel.py
+1
-1
tests/test_tokenization_common.py
tests/test_tokenization_common.py
+1
-1
tests/test_tokenization_fsmt.py
tests/test_tokenization_fsmt.py
+1
-1
tests/test_tokenization_layoutlm.py
tests/test_tokenization_layoutlm.py
+1
-1
tests/test_tokenization_xlm.py
tests/test_tokenization_xlm.py
+1
-1
No files found.
src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py
View file @
32dbb2d9
...
@@ -270,7 +270,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
...
@@ -270,7 +270,7 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
return
self
.
sp_model
.
encode
(
text
,
out_type
=
str
)
return
self
.
sp_model
.
encode
(
text
,
out_type
=
str
)
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
"""
Converts a token (str) in an id using the vocab.
"""
"""Converts a token (str) in an id using the vocab."""
if
token
in
self
.
fairseq_tokens_to_ids
:
if
token
in
self
.
fairseq_tokens_to_ids
:
return
self
.
fairseq_tokens_to_ids
[
token
]
return
self
.
fairseq_tokens_to_ids
[
token
]
spm_id
=
self
.
sp_model
.
PieceToId
(
token
)
spm_id
=
self
.
sp_model
.
PieceToId
(
token
)
...
...
src/transformers/models/xlnet/tokenization_xlnet.py
View file @
32dbb2d9
...
@@ -189,7 +189,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
...
@@ -189,7 +189,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
return
outputs
return
outputs
def
_tokenize
(
self
,
text
,
sample
=
False
):
def
_tokenize
(
self
,
text
,
sample
=
False
):
"""
Tokenize a string.
"""
"""Tokenize a string."""
text
=
self
.
preprocess_text
(
text
)
text
=
self
.
preprocess_text
(
text
)
if
not
sample
:
if
not
sample
:
...
@@ -213,7 +213,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
...
@@ -213,7 +213,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
return
new_pieces
return
new_pieces
def
_convert_token_to_id
(
self
,
token
):
def
_convert_token_to_id
(
self
,
token
):
"""
Converts a token (str) in an id using the vocab.
"""
"""Converts a token (str) in an id using the vocab."""
return
self
.
sp_model
.
PieceToId
(
token
)
return
self
.
sp_model
.
PieceToId
(
token
)
def
_convert_id_to_token
(
self
,
index
):
def
_convert_id_to_token
(
self
,
index
):
...
...
src/transformers/testing_utils.py
View file @
32dbb2d9
...
@@ -389,7 +389,7 @@ if is_tf_available():
...
@@ -389,7 +389,7 @@ if is_tf_available():
def
require_torch_gpu
(
test_case
):
def
require_torch_gpu
(
test_case
):
"""Decorator marking a test that requires CUDA and PyTorch.
"""
"""Decorator marking a test that requires CUDA and PyTorch."""
if
torch_device
!=
"cuda"
:
if
torch_device
!=
"cuda"
:
return
unittest
.
skip
(
"test requires CUDA"
)(
test_case
)
return
unittest
.
skip
(
"test requires CUDA"
)(
test_case
)
else
:
else
:
...
@@ -593,14 +593,14 @@ class CaptureStd:
...
@@ -593,14 +593,14 @@ class CaptureStd:
class
CaptureStdout
(
CaptureStd
):
class
CaptureStdout
(
CaptureStd
):
"""
Same as CaptureStd but captures only stdout
"""
"""Same as CaptureStd but captures only stdout"""
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
(
err
=
False
)
super
().
__init__
(
err
=
False
)
class
CaptureStderr
(
CaptureStd
):
class
CaptureStderr
(
CaptureStd
):
"""
Same as CaptureStd but captures only stderr
"""
"""Same as CaptureStd but captures only stderr"""
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
(
out
=
False
)
super
().
__init__
(
out
=
False
)
...
...
src/transformers/tokenization_utils_base.py
View file @
32dbb2d9
...
@@ -88,7 +88,7 @@ else:
...
@@ -88,7 +88,7 @@ else:
@
dataclass
@
dataclass
class
EncodingFast
:
class
EncodingFast
:
"""
This is dummy class because without the `tokenizers` library we don't have these objects anyway
"""
"""This is dummy class because without the `tokenizers` library we don't have these objects anyway"""
pass
pass
...
...
src/transformers/trainer.py
View file @
32dbb2d9
...
@@ -805,7 +805,7 @@ class Trainer:
...
@@ -805,7 +805,7 @@ class Trainer:
return
len
(
dataloader
.
dataset
)
return
len
(
dataloader
.
dataset
)
def
_hp_search_setup
(
self
,
trial
:
Union
[
"optuna.Trial"
,
Dict
[
str
,
Any
]]):
def
_hp_search_setup
(
self
,
trial
:
Union
[
"optuna.Trial"
,
Dict
[
str
,
Any
]]):
"""
HP search setup code
"""
"""HP search setup code"""
self
.
_trial
=
trial
self
.
_trial
=
trial
if
self
.
hp_search_backend
is
None
or
trial
is
None
:
if
self
.
hp_search_backend
is
None
or
trial
is
None
:
...
...
src/transformers/trainer_callback.py
View file @
32dbb2d9
...
@@ -92,14 +92,14 @@ class TrainerState:
...
@@ -92,14 +92,14 @@ class TrainerState:
self
.
log_history
=
[]
self
.
log_history
=
[]
def
save_to_json
(
self
,
json_path
:
str
):
def
save_to_json
(
self
,
json_path
:
str
):
"""
Save the content of this instance in JSON format inside :obj:`json_path`."""
"""Save the content of this instance in JSON format inside :obj:`json_path`."""
json_string
=
json
.
dumps
(
dataclasses
.
asdict
(
self
),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
json_string
=
json
.
dumps
(
dataclasses
.
asdict
(
self
),
indent
=
2
,
sort_keys
=
True
)
+
"
\n
"
with
open
(
json_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
json_path
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
json_string
)
f
.
write
(
json_string
)
@
classmethod
@
classmethod
def
load_from_json
(
cls
,
json_path
:
str
):
def
load_from_json
(
cls
,
json_path
:
str
):
"""
Create an instance from the content of :obj:`json_path`."""
"""Create an instance from the content of :obj:`json_path`."""
with
open
(
json_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
json_path
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
text
=
f
.
read
()
text
=
f
.
read
()
return
cls
(
**
json
.
loads
(
text
))
return
cls
(
**
json
.
loads
(
text
))
...
@@ -141,15 +141,15 @@ class TrainerControl:
...
@@ -141,15 +141,15 @@ class TrainerControl:
should_log
:
bool
=
False
should_log
:
bool
=
False
def
_new_training
(
self
):
def
_new_training
(
self
):
"""
Internal method that resets the variable for a new training.
"""
"""Internal method that resets the variable for a new training."""
self
.
should_training_stop
=
False
self
.
should_training_stop
=
False
def
_new_epoch
(
self
):
def
_new_epoch
(
self
):
"""
Internal method that resets the variable for a new epoch.
"""
"""Internal method that resets the variable for a new epoch."""
self
.
should_epoch_stop
=
False
self
.
should_epoch_stop
=
False
def
_new_step
(
self
):
def
_new_step
(
self
):
"""
Internal method that resets the variable for a new step.
"""
"""Internal method that resets the variable for a new step."""
self
.
should_save
=
False
self
.
should_save
=
False
self
.
should_evaluate
=
False
self
.
should_evaluate
=
False
self
.
should_log
=
False
self
.
should_log
=
False
...
@@ -275,7 +275,7 @@ class TrainerCallback:
...
@@ -275,7 +275,7 @@ class TrainerCallback:
class
CallbackHandler
(
TrainerCallback
):
class
CallbackHandler
(
TrainerCallback
):
"""
Internal class that just calls the list of callbacks in order.
"""
"""Internal class that just calls the list of callbacks in order."""
def
__init__
(
self
,
callbacks
,
model
,
tokenizer
,
optimizer
,
lr_scheduler
):
def
__init__
(
self
,
callbacks
,
model
,
tokenizer
,
optimizer
,
lr_scheduler
):
self
.
callbacks
=
[]
self
.
callbacks
=
[]
...
...
src/transformers/trainer_pt_utils.py
View file @
32dbb2d9
...
@@ -294,14 +294,14 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
...
@@ -294,14 +294,14 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset, bach_size: int):
def
nested_new_like
(
arrays
,
num_samples
,
padding_index
=-
100
):
def
nested_new_like
(
arrays
,
num_samples
,
padding_index
=-
100
):
"""
Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
"""Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
if
isinstance
(
arrays
,
(
list
,
tuple
)):
if
isinstance
(
arrays
,
(
list
,
tuple
)):
return
type
(
arrays
)(
nested_new_like
(
x
,
num_samples
)
for
x
in
arrays
)
return
type
(
arrays
)(
nested_new_like
(
x
,
num_samples
)
for
x
in
arrays
)
return
np
.
full_like
(
arrays
,
padding_index
,
shape
=
(
num_samples
,
*
arrays
.
shape
[
1
:]))
return
np
.
full_like
(
arrays
,
padding_index
,
shape
=
(
num_samples
,
*
arrays
.
shape
[
1
:]))
def
expand_like
(
arrays
,
new_seq_length
,
padding_index
=-
100
):
def
expand_like
(
arrays
,
new_seq_length
,
padding_index
=-
100
):
"""
Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
"""Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
result
=
np
.
full_like
(
arrays
,
padding_index
,
shape
=
(
arrays
.
shape
[
0
],
new_seq_length
)
+
arrays
.
shape
[
2
:])
result
=
np
.
full_like
(
arrays
,
padding_index
,
shape
=
(
arrays
.
shape
[
0
],
new_seq_length
)
+
arrays
.
shape
[
2
:])
result
[:,
:
arrays
.
shape
[
1
]]
=
arrays
result
[:,
:
arrays
.
shape
[
1
]]
=
arrays
return
result
return
result
...
...
src/transformers/trainer_utils.py
View file @
32dbb2d9
...
@@ -320,7 +320,7 @@ class TrainerMemoryTracker:
...
@@ -320,7 +320,7 @@ class TrainerMemoryTracker:
self
.
init_reported
=
False
self
.
init_reported
=
False
def
derive_stage
(
self
):
def
derive_stage
(
self
):
"""
derives the stage/caller name automatically
"""
"""derives the stage/caller name automatically"""
caller
=
inspect
.
currentframe
().
f_back
.
f_back
.
f_code
.
co_name
caller
=
inspect
.
currentframe
().
f_back
.
f_back
.
f_code
.
co_name
if
caller
in
self
.
stages
:
if
caller
in
self
.
stages
:
return
self
.
stages
[
caller
]
return
self
.
stages
[
caller
]
...
@@ -330,7 +330,7 @@ class TrainerMemoryTracker:
...
@@ -330,7 +330,7 @@ class TrainerMemoryTracker:
)
)
def
cpu_mem_used
(
self
):
def
cpu_mem_used
(
self
):
"""
get resident set size memory for the current process
"""
"""get resident set size memory for the current process"""
return
self
.
process
.
memory_info
().
rss
return
self
.
process
.
memory_info
().
rss
def
peak_monitor_func
(
self
):
def
peak_monitor_func
(
self
):
...
@@ -346,7 +346,7 @@ class TrainerMemoryTracker:
...
@@ -346,7 +346,7 @@ class TrainerMemoryTracker:
break
break
def
start
(
self
):
def
start
(
self
):
"""
start tracking for the caller's stage
"""
"""start tracking for the caller's stage"""
if
self
.
skip_memory_metrics
:
if
self
.
skip_memory_metrics
:
return
return
...
@@ -376,7 +376,7 @@ class TrainerMemoryTracker:
...
@@ -376,7 +376,7 @@ class TrainerMemoryTracker:
peak_monitor_thread
.
start
()
peak_monitor_thread
.
start
()
def
stop
(
self
,
stage
):
def
stop
(
self
,
stage
):
"""
stop tracking for the passed stage
"""
"""stop tracking for the passed stage"""
# deal with nested calls of eval during train - simply ignore those
# deal with nested calls of eval during train - simply ignore those
if
self
.
cur_stage
is
not
None
and
self
.
cur_stage
!=
stage
:
if
self
.
cur_stage
is
not
None
and
self
.
cur_stage
!=
stage
:
...
@@ -416,7 +416,7 @@ class TrainerMemoryTracker:
...
@@ -416,7 +416,7 @@ class TrainerMemoryTracker:
self
.
cur_stage
=
None
self
.
cur_stage
=
None
def
update_metrics
(
self
,
stage
,
metrics
):
def
update_metrics
(
self
,
stage
,
metrics
):
"""
stop tracking for the passed stage
"""
"""stop tracking for the passed stage"""
if
self
.
skip_memory_metrics
:
if
self
.
skip_memory_metrics
:
return
return
...
@@ -438,7 +438,7 @@ class TrainerMemoryTracker:
...
@@ -438,7 +438,7 @@ class TrainerMemoryTracker:
metrics
[
f
"
{
stage
}
_mem_gpu_
{
t
}
_delta"
]
=
self
.
gpu
[
stage
][
t
]
metrics
[
f
"
{
stage
}
_mem_gpu_
{
t
}
_delta"
]
=
self
.
gpu
[
stage
][
t
]
def
stop_and_update_metrics
(
self
,
metrics
=
None
):
def
stop_and_update_metrics
(
self
,
metrics
=
None
):
"""
combine stop + update in one call for simpler code
"""
"""combine stop + update in one call for simpler code"""
if
self
.
skip_memory_metrics
:
if
self
.
skip_memory_metrics
:
return
return
...
...
src/transformers/utils/versions.py
View file @
32dbb2d9
...
@@ -115,12 +115,12 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
...
@@ -115,12 +115,12 @@ def require_version(requirement: str, hint: Optional[str] = None) -> None:
def
require_version_core
(
requirement
):
def
require_version_core
(
requirement
):
"""
require_version wrapper which emits a core-specific hint on failure
"""
"""require_version wrapper which emits a core-specific hint on failure"""
hint
=
"Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master"
hint
=
"Try: pip install transformers -U or pip install -e '.[dev]' if you're working with git master"
return
require_version
(
requirement
,
hint
)
return
require_version
(
requirement
,
hint
)
def
require_version_examples
(
requirement
):
def
require_version_examples
(
requirement
):
"""
require_version wrapper which emits examples-specific hint on failure
"""
"""require_version wrapper which emits examples-specific hint on failure"""
hint
=
"Try: pip install -r examples/requirements.txt"
hint
=
"Try: pip install -r examples/requirements.txt"
return
require_version
(
requirement
,
hint
)
return
require_version
(
requirement
,
hint
)
tests/deepspeed/test_deepspeed.py
View file @
32dbb2d9
...
@@ -122,7 +122,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
...
@@ -122,7 +122,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
transformers
.
integrations
.
_is_deepspeed_zero3_enabled
=
None
transformers
.
integrations
.
_is_deepspeed_zero3_enabled
=
None
def
get_config_dict
(
self
,
stage
):
def
get_config_dict
(
self
,
stage
):
"""
As the tests modify the dict, always make a copy
"""
"""As the tests modify the dict, always make a copy"""
config
=
deepcopy
(
self
.
ds_config_dict
[
stage
])
config
=
deepcopy
(
self
.
ds_config_dict
[
stage
])
if
stage
==
ZERO3
:
if
stage
==
ZERO3
:
# This setting slows things down, so don't enable it by default unless needed by a test.
# This setting slows things down, so don't enable it by default unless needed by a test.
...
@@ -430,7 +430,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
...
@@ -430,7 +430,7 @@ class TrainerIntegrationDeepSpeed(TestCasePlus, TrainerIntegrationCommon):
@
require_deepspeed
@
require_deepspeed
@
require_torch_gpu
@
require_torch_gpu
class
TestDeepSpeedWithLauncher
(
TestCasePlus
):
class
TestDeepSpeedWithLauncher
(
TestCasePlus
):
"""
This class is for testing via an external script - can do multiple gpus
"""
"""This class is for testing via an external script - can do multiple gpus"""
# Tests to devise #
# Tests to devise #
#
#
...
...
tests/test_modeling_common.py
View file @
32dbb2d9
...
@@ -1122,7 +1122,7 @@ class ModelTesterMixin:
...
@@ -1122,7 +1122,7 @@ class ModelTesterMixin:
# a candidate for testing_utils
# a candidate for testing_utils
def
get_current_gpu_memory_use
():
def
get_current_gpu_memory_use
():
"""
returns a list of cuda memory allocations per GPU in MBs"""
"""returns a list of cuda memory allocations per GPU in MBs"""
per_device_memory
=
[]
per_device_memory
=
[]
for
id
in
range
(
torch
.
cuda
.
device_count
()):
for
id
in
range
(
torch
.
cuda
.
device_count
()):
...
...
tests/test_modeling_funnel.py
View file @
32dbb2d9
...
@@ -42,7 +42,7 @@ if is_torch_available():
...
@@ -42,7 +42,7 @@ if is_torch_available():
class
FunnelModelTester
:
class
FunnelModelTester
:
"""You can also import this e.g, from .test_modeling_funnel import FunnelModelTester
"""
"""You can also import this e.g, from .test_modeling_funnel import FunnelModelTester"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
tests/test_modeling_layoutlm.py
View file @
32dbb2d9
...
@@ -36,7 +36,7 @@ if is_torch_available():
...
@@ -36,7 +36,7 @@ if is_torch_available():
class
LayoutLMModelTester
:
class
LayoutLMModelTester
:
"""You can also import this e.g from .test_modeling_layoutlm import LayoutLMModelTester
"""
"""You can also import this e.g from .test_modeling_layoutlm import LayoutLMModelTester"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
tests/test_modeling_lxmert.py
View file @
32dbb2d9
...
@@ -40,7 +40,7 @@ if is_torch_available():
...
@@ -40,7 +40,7 @@ if is_torch_available():
class
LxmertModelTester
:
class
LxmertModelTester
:
"""You can also import this e.g from .test_modeling_bart import BartModelTester
"""
"""You can also import this e.g from .test_modeling_bart import BartModelTester"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
tests/test_modeling_tapas.py
View file @
32dbb2d9
...
@@ -63,7 +63,7 @@ if is_torch_available():
...
@@ -63,7 +63,7 @@ if is_torch_available():
class
TapasModelTester
:
class
TapasModelTester
:
"""You can also import this e.g from .test_modeling_tapas import TapasModelTester
"""
"""You can also import this e.g from .test_modeling_tapas import TapasModelTester"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
tests/test_modeling_tf_funnel.py
View file @
32dbb2d9
...
@@ -39,7 +39,7 @@ if is_tf_available():
...
@@ -39,7 +39,7 @@ if is_tf_available():
class
TFFunnelModelTester
:
class
TFFunnelModelTester
:
"""You can also import this e.g, from .test_modeling_funnel import FunnelModelTester
"""
"""You can also import this e.g, from .test_modeling_funnel import FunnelModelTester"""
def
__init__
(
def
__init__
(
self
,
self
,
...
...
tests/test_tokenization_common.py
View file @
32dbb2d9
...
@@ -58,7 +58,7 @@ NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilin
...
@@ -58,7 +58,7 @@ NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilin
def
filter_non_english
(
_
,
pretrained_name
:
str
):
def
filter_non_english
(
_
,
pretrained_name
:
str
):
"""
Filter all the model for non-english language
"""
"""Filter all the model for non-english language"""
return
not
any
([
lang
in
pretrained_name
for
lang
in
NON_ENGLISH_TAGS
])
return
not
any
([
lang
in
pretrained_name
for
lang
in
NON_ENGLISH_TAGS
])
...
...
tests/test_tokenization_fsmt.py
View file @
32dbb2d9
...
@@ -100,7 +100,7 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -100,7 +100,7 @@ class FSMTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertEqual
(
tokenizer
.
tgt_vocab_size
,
21
)
self
.
assertEqual
(
tokenizer
.
tgt_vocab_size
,
21
)
def
test_full_tokenizer
(
self
):
def
test_full_tokenizer
(
self
):
"""
Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
"""
"""Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt"""
tokenizer
=
FSMTTokenizer
(
self
.
langs
,
self
.
src_vocab_file
,
self
.
tgt_vocab_file
,
self
.
merges_file
)
tokenizer
=
FSMTTokenizer
(
self
.
langs
,
self
.
src_vocab_file
,
self
.
tgt_vocab_file
,
self
.
merges_file
)
text
=
"lower"
text
=
"lower"
...
...
tests/test_tokenization_layoutlm.py
View file @
32dbb2d9
...
@@ -70,5 +70,5 @@ class LayoutLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -70,5 +70,5 @@ class LayoutLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
self
.
assertListEqual
(
tokenizer
.
convert_tokens_to_ids
(
tokens
),
[
7
,
4
,
5
,
10
,
8
,
9
])
def
test_special_tokens_as_you_expect
(
self
):
def
test_special_tokens_as_you_expect
(
self
):
"""If you are training a seq2seq model that expects a decoder_prefix token make sure it is prepended to decoder_input_ids
"""
"""If you are training a seq2seq model that expects a decoder_prefix token make sure it is prepended to decoder_input_ids"""
pass
pass
tests/test_tokenization_xlm.py
View file @
32dbb2d9
...
@@ -72,7 +72,7 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -72,7 +72,7 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return
input_text
,
output_text
return
input_text
,
output_text
def
test_full_tokenizer
(
self
):
def
test_full_tokenizer
(
self
):
"""
Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt
"""
"""Adapted from Sennrich et al. 2015 and https://github.com/rsennrich/subword-nmt"""
tokenizer
=
XLMTokenizer
(
self
.
vocab_file
,
self
.
merges_file
)
tokenizer
=
XLMTokenizer
(
self
.
vocab_file
,
self
.
merges_file
)
text
=
"lower"
text
=
"lower"
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment