Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f5c2a122
Unverified
Commit
f5c2a122
authored
Jun 22, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 22, 2020
Browse files
Upgrade examples to pl=0.8.1(#5146)
parent
06b60c8b
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
53 additions
and
150 deletions
+53
-150
examples/lightning_base.py
examples/lightning_base.py
+22
-41
examples/requirements.txt
examples/requirements.txt
+1
-1
examples/summarization/callbacks.py
examples/summarization/callbacks.py
+1
-2
examples/summarization/distillation.py
examples/summarization/distillation.py
+1
-0
examples/summarization/finetune.py
examples/summarization/finetune.py
+6
-23
examples/summarization/run_distiller.sh
examples/summarization/run_distiller.sh
+0
-1
examples/summarization/run_eval.py
examples/summarization/run_eval.py
+1
-0
examples/summarization/test_summarization_examples.py
examples/summarization/test_summarization_examples.py
+16
-73
examples/summarization/utils.py
examples/summarization/utils.py
+2
-0
examples/text-classification/run_pl_glue.py
examples/text-classification/run_pl_glue.py
+2
-8
src/transformers/tokenization_auto.py
src/transformers/tokenization_auto.py
+1
-1
No files found.
examples/lightning_base.py
View file @
f5c2a122
...
...
@@ -8,6 +8,7 @@ from typing import Any, Dict
import
numpy
as
np
import
pytorch_lightning
as
pl
import
torch
from
pytorch_lightning.utilities
import
rank_zero_info
,
rank_zero_only
from
transformers
import
(
AdamW
,
...
...
@@ -60,10 +61,9 @@ class BaseTransformer(pl.LightningModule):
model
=
None
,
**
config_kwargs
):
"Initialize a model."
"""Initialize a model, tokenizer and config."""
super
().
__init__
()
self
.
hparams
=
hparams
self
.
hparams
=
hparams
# TODO: move to self.save_hyperparameters()
self
.
step_count
=
0
self
.
tfmr_ckpts
=
{}
self
.
output_dir
=
Path
(
self
.
hparams
.
output_dir
)
...
...
@@ -84,8 +84,8 @@ class BaseTransformer(pl.LightningModule):
)
else
:
self
.
tokenizer
:
PreTrainedTokenizer
=
tokenizer
self
.
model_type
=
MODEL_MODES
[
mode
]
if
model
is
None
:
self
.
model_type
=
MODEL_MODES
[
mode
]
self
.
model
=
self
.
model_type
.
from_pretrained
(
self
.
hparams
.
model_name_or_path
,
from_tf
=
bool
(
".ckpt"
in
self
.
hparams
.
model_name_or_path
),
...
...
@@ -93,18 +93,13 @@ class BaseTransformer(pl.LightningModule):
cache_dir
=
cache_dir
,
)
else
:
self
.
model_type
=
None
self
.
model
=
model
def
load_hf_checkpoint
(
self
,
*
args
,
**
kwargs
):
self
.
model
=
self
.
model_type
.
from_pretrained
(
*
args
,
**
kwargs
)
def
is_logger
(
self
):
return
self
.
trainer
.
proc_rank
<=
0
def
configure_optimizers
(
self
):
"Prepare optimizer and schedule (linear warmup and decay)"
model
=
self
.
model
no_decay
=
[
"bias"
,
"LayerNorm.weight"
]
optimizer_grouped_parameters
=
[
...
...
@@ -121,23 +116,10 @@ class BaseTransformer(pl.LightningModule):
self
.
opt
=
optimizer
return
[
optimizer
]
def
optimizer_step
(
self
,
epoch
,
batch_idx
,
optimizer
,
optimizer_idx
,
second_order_closure
=
None
):
if
self
.
trainer
.
use_tpu
:
xm
.
optimizer_step
(
optimizer
)
else
:
optimizer
.
step
()
optimizer
.
zero_grad
()
self
.
lr_scheduler
.
step
()
def
get_tqdm_dict
(
self
):
avg_loss
=
getattr
(
self
.
trainer
,
"avg_loss"
,
0.0
)
tqdm_dict
=
{
"loss"
:
"{:.3f}"
.
format
(
avg_loss
),
"lr"
:
self
.
lr_scheduler
.
get_last_lr
()[
-
1
]}
return
tqdm_dict
def
test_step
(
self
,
batch
,
batch_nb
):
return
self
.
validation_step
(
batch
,
batch_nb
)
def
test_end
(
self
,
outputs
):
def
test_
epoch_
end
(
self
,
outputs
):
return
self
.
validation_end
(
outputs
)
def
train_dataloader
(
self
):
...
...
@@ -208,6 +190,7 @@ class BaseTransformer(pl.LightningModule):
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight decay if we apply some."
)
parser
.
add_argument
(
"--adam_epsilon"
,
default
=
1e-8
,
type
=
float
,
help
=
"Epsilon for Adam optimizer."
)
parser
.
add_argument
(
"--warmup_steps"
,
default
=
500
,
type
=
int
,
help
=
"Linear warmup over warmup_steps."
)
parser
.
add_argument
(
"--num_workers"
,
default
=
4
,
type
=
int
,
help
=
"kwarg passed to DataLoader"
)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3
,
type
=
int
,
help
=
"Total number of training epochs to perform."
)
...
...
@@ -217,28 +200,26 @@ class BaseTransformer(pl.LightningModule):
class
LoggingCallback
(
pl
.
Callback
):
@
rank_zero_only
def
on_validation_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
logger
.
info
(
"***** Validation results *****"
)
if
pl_module
.
is_logger
():
metrics
=
trainer
.
callback_metrics
# Log results
rank_zero_info
(
"***** Validation results *****"
)
metrics
=
trainer
.
callback_metrics
# Log results
for
key
in
sorted
(
metrics
):
if
key
not
in
[
"log"
,
"progress_bar"
]:
rank_zero_info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
@
rank_zero_only
def
on_test_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
logger
.
info
(
"***** Test results *****"
)
metrics
=
trainer
.
callback_metrics
# Log and save results to file
output_test_results_file
=
os
.
path
.
join
(
pl_module
.
hparams
.
output_dir
,
"test_results.txt"
)
with
open
(
output_test_results_file
,
"w"
)
as
writer
:
for
key
in
sorted
(
metrics
):
if
key
not
in
[
"log"
,
"progress_bar"
]:
logger
.
info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
def
on_test_end
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
):
logger
.
info
(
"***** Test results *****"
)
if
pl_module
.
is_logger
():
metrics
=
trainer
.
callback_metrics
# Log and save results to file
output_test_results_file
=
os
.
path
.
join
(
pl_module
.
hparams
.
output_dir
,
"test_results.txt"
)
with
open
(
output_test_results_file
,
"w"
)
as
writer
:
for
key
in
sorted
(
metrics
):
if
key
not
in
[
"log"
,
"progress_bar"
]:
logger
.
info
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
metrics
[
key
])))
def
add_generic_args
(
parser
,
root_dir
)
->
None
:
...
...
examples/requirements.txt
View file @
f5c2a122
...
...
@@ -5,7 +5,7 @@ psutil
sacrebleu
rouge-score
tensorflow_datasets
pytorch-lightning==0.
7.6
pytorch-lightning==0.
8.1
matplotlib
git-python==1.0.3
faiss
...
...
examples/summarization/callbacks.py
View file @
f5c2a122
...
...
@@ -19,12 +19,11 @@ logger = logging.getLogger(__name__)
class
Seq2SeqLoggingCallback
(
pl
.
Callback
):
@
rank_zero_only
def
_write_logs
(
self
,
trainer
:
pl
.
Trainer
,
pl_module
:
pl
.
LightningModule
,
type_path
:
str
,
save_generations
=
True
)
->
None
:
logger
.
info
(
f
"*****
{
type_path
}
results at step
{
trainer
.
global_step
:
05
d
}
*****"
)
if
not
pl_module
.
is_logger
():
return
metrics
=
trainer
.
callback_metrics
trainer
.
logger
.
log_metrics
({
k
:
v
for
k
,
v
in
metrics
.
items
()
if
k
not
in
[
"log"
,
"progress_bar"
,
"preds"
]})
# Log results
...
...
examples/summarization/distillation.py
View file @
f5c2a122
...
...
@@ -271,6 +271,7 @@ class SummarizationDistiller(SummarizationModule):
class
T5SummarizationDistiller
(
SummarizationDistiller
):
def
pre_init
(
self
,
hparams
):
raise
NotImplementedError
(
"T5 Distillation does not work yet"
)
teacher
=
T5ForConditionalGeneration
.
from_pretrained
(
hparams
.
teacher
)
n_layer
=
hparams
.
student_decoder_layers
assert
n_layer
==
hparams
.
student_encoder_layers
# TODO(SS): relax this
...
...
examples/summarization/finetune.py
View file @
f5c2a122
...
...
@@ -85,7 +85,7 @@ class SummarizationModule(BaseTransformer):
if
self
.
hparams
.
freeze_encoder
:
freeze_params
(
self
.
model
.
model
.
encoder
)
# TODO: this will break for t5
self
.
hparams
.
git_sha
=
get_git_info
()[
"repo_sha"
]
self
.
num_workers
=
4
if
self
.
hparams
.
gpus
<=
1
else
None
# passing num_workers breaks lightning for multigpu
self
.
num_workers
=
hparams
.
num_workers
def
freeze_embeds
(
self
):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
...
...
@@ -126,7 +126,7 @@ class SummarizationModule(BaseTransformer):
def
validation_step
(
self
,
batch
,
batch_idx
)
->
Dict
:
return
self
.
_generative_step
(
batch
)
def
validation_end
(
self
,
outputs
,
prefix
=
"val"
)
->
Dict
:
def
validation_
epoch_
end
(
self
,
outputs
,
prefix
=
"val"
)
->
Dict
:
self
.
step_count
+=
1
losses
=
{
k
:
torch
.
stack
([
x
[
k
]
for
x
in
outputs
]).
mean
()
for
k
in
self
.
loss_names
}
loss
=
losses
[
"loss"
]
...
...
@@ -144,14 +144,12 @@ class SummarizationModule(BaseTransformer):
self
.
metrics
[
prefix
].
append
(
metrics
)
pickle_save
(
self
.
metrics
,
self
.
metrics_save_path
)
def
_generative_step
(
self
,
batch
)
:
def
_generative_step
(
self
,
batch
:
dict
)
->
dict
:
pad_token_id
=
self
.
tokenizer
.
pad_token_id
source_ids
,
source_mask
,
y
=
SummarizationDataset
.
trim_seq2seq_batch
(
batch
,
pad_token_id
)
# TODO(SS): task specific params
t0
=
time
.
time
()
generated_ids
=
self
.
model
.
generate
(
input_ids
=
source_ids
,
attention_mask
=
source_mask
,
use_cache
=
True
,)
gen_time
=
time
.
time
()
-
t0
gen_time
=
time
.
time
()
-
t0
/
source_ids
.
shape
[
0
]
preds
=
self
.
ids_to_clean_text
(
generated_ids
)
target
=
self
.
ids_to_clean_text
(
y
)
loss_tensors
=
self
.
_step
(
batch
)
...
...
@@ -164,24 +162,8 @@ class SummarizationModule(BaseTransformer):
def
test_step
(
self
,
batch
,
batch_idx
):
return
self
.
_generative_step
(
batch
)
def
test_end
(
self
,
outputs
):
return
self
.
validation_end
(
outputs
,
prefix
=
"test"
)
def
test_epoch_end
(
self
,
outputs
):
output_test_predictions_file
=
os
.
path
.
join
(
self
.
hparams
.
output_dir
,
"test_predictions.txt"
)
output_test_targets_file
=
os
.
path
.
join
(
self
.
hparams
.
output_dir
,
"test_targets.txt"
)
# write predictions and targets for later rouge evaluation.
with
open
(
output_test_predictions_file
,
"w+"
)
as
p_writer
,
open
(
output_test_targets_file
,
"w+"
)
as
t_writer
:
for
output_batch
in
outputs
:
p_writer
.
writelines
(
s
+
"
\n
"
for
s
in
output_batch
[
"preds"
])
t_writer
.
writelines
(
s
+
"
\n
"
for
s
in
output_batch
[
"target"
])
p_writer
.
close
()
t_writer
.
close
()
return
self
.
test_end
(
outputs
)
def
validation_epoch_end
(
self
,
outputs
):
self
.
validation_end
(
outputs
,
"val"
)
return
self
.
validation_epoch_end
(
outputs
,
prefix
=
"test"
)
def
get_dataset
(
self
,
type_path
)
->
SummarizationDataset
:
n_obs
=
self
.
n_obs
[
type_path
]
...
...
@@ -310,6 +292,7 @@ def main(args, model=None) -> SummarizationModule:
logger
=
logger
,
# TODO: early stopping callback seems messed up
)
pickle_save
(
model
.
hparams
,
model
.
output_dir
/
"hparams.pkl"
)
if
not
args
.
do_predict
:
return
model
...
...
examples/summarization/run_distiller.sh
View file @
f5c2a122
...
...
@@ -7,6 +7,5 @@ python distillation.py \
--learning_rate
=
3e-4
\
--do_train
\
--do_predict
\
--fp16
\
--val_check_interval
0.1
\
$@
examples/summarization/run_eval.py
View file @
f5c2a122
...
...
@@ -26,6 +26,7 @@ def generate_summaries(
examples
:
list
,
out_file
:
str
,
model_name
:
str
,
batch_size
:
int
=
8
,
device
:
str
=
DEFAULT_DEVICE
,
fp16
=
False
,
)
->
None
:
fout
=
Path
(
out_file
).
open
(
"w"
,
encoding
=
"utf-8"
)
model_name
=
str
(
model_name
)
model
=
AutoModelForSeq2SeqLM
.
from_pretrained
(
model_name
).
to
(
device
)
if
fp16
:
model
=
model
.
half
()
...
...
examples/summarization/test_summarization_examples.py
View file @
f5c2a122
...
...
@@ -24,6 +24,7 @@ logger = logging.getLogger()
FP16_EVER
=
False
CHEAP_ARGS
=
{
"logger"
:
"default"
,
"num_workers"
:
2
,
"alpha_hid"
:
0
,
"freeze_embeds"
:
True
,
"enc_only"
:
False
,
...
...
@@ -79,7 +80,8 @@ def _dump_articles(path: Path, articles: list):
f
.
write
(
"
\n
"
.
join
(
articles
))
BDIR
=
Path
(
"~/transformers_fork/examples/summarization/bart/"
).
absolute
()
MSG
=
"T5 is broken at the moment"
T5_TINY
=
"patrickvonplaten/t5-tiny-random"
def
make_test_data_dir
():
...
...
@@ -92,7 +94,6 @@ def make_test_data_dir():
return
tmp_dir
@
unittest
.
skip
(
"These wont' pass until hidden_states kwarg is merged."
)
class
TestSummarizationDistiller
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -108,47 +109,22 @@ class TestSummarizationDistiller(unittest.TestCase):
freeze_encoder
=
True
,
gpus
=
2
,
sortish_sampler
=
False
,
)
self
.
_bart_distiller_cli
(
updates
)
@
unittest
.
skipUnless
(
torch
.
cuda
.
is_available
(),
"skipping fp16 test"
)
def
test_bdc_fp16
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
alpha_hid
=
3.0
,
freeze_encoder
=
True
,
gpus
=
1
,
fp16
=
FP16_EVER
,
fp16_opt_level
=
"O1"
,
)
self
.
_bart_distiller_cli
(
updates
)
@
unittest
.
skipUnless
(
torch
.
cuda
.
is_available
(),
"skipping fp16 test"
)
def
test_bdc_t5_eval_fp16
(
self
):
updates
=
dict
(
fp16
=
FP16_EVER
,
gpus
=
1
,
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
do_train
=
False
,
do_predict
=
True
,
tokenizer_name
=
None
,
no_teacher
=
True
,
)
self
.
_bart_distiller_cli
(
updates
,
check_contents
=
False
)
self
.
_bart_distiller_cli
(
updates
)
@
unittest
.
skipUnless
(
torch
.
cuda
.
is_available
(),
"skipping fp16 test"
)
def
test_bdc_t5_train_fp16
(
self
):
def
test_bdc_t5_train
(
self
):
updates
=
dict
(
fp16
=
FP16_EVER
,
gpus
=
1
,
gpus
=
1
if
torch
.
cuda
.
is_available
()
else
0
,
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
model_name_or_path
=
T5_TINY
,
do_train
=
True
,
do_predict
=
True
,
tokenizer_name
=
"patrickvonplaten/t5-tiny-random"
,
tokenizer_name
=
T5_TINY
,
no_teacher
=
True
,
alpha_hid
=
2.0
,
)
self
.
_bart_distiller_cli
(
updates
)
...
...
@@ -161,7 +137,6 @@ class TestSummarizationDistiller(unittest.TestCase):
self
.
_bart_distiller_cli
(
updates
)
def
test_bdc_checkpointing
(
self
):
updates
=
dict
(
student_encoder_layers
=
2
,
student_decoder_layers
=
1
,
...
...
@@ -184,32 +159,8 @@ class TestSummarizationDistiller(unittest.TestCase):
evaluate_checkpoint
(
ckpts
[
0
],
dest_dir
=
Path
(
tempfile
.
mkdtemp
()))
def
test_bdc_t5
(
self
):
updates
=
dict
(
student_encoder_layers
=
1
,
student_decoder_layers
=
1
,
alpha_hid
=
2.0
,
teacher
=
"patrickvonplaten/t5-tiny-random"
,
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
tokenizer_name
=
"patrickvonplaten/t5-tiny-random"
,
)
self
.
_bart_distiller_cli
(
updates
)
def
test_bdc_t5_eval
(
self
):
updates
=
dict
(
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
do_train
=
False
,
do_predict
=
True
,
tokenizer_name
=
"patrickvonplaten/t5-tiny-random"
,
no_teacher
=
True
,
)
self
.
_bart_distiller_cli
(
updates
,
check_contents
=
False
)
def
_bart_distiller_cli
(
self
,
updates
,
check_contents
=
True
):
default_updates
=
dict
(
model_type
=
"bart"
,
train_batch_size
=
1
,
eval_batch_size
=
2
,
num_train_epochs
=
2
,
...
...
@@ -237,21 +188,14 @@ class TestSummarizationDistiller(unittest.TestCase):
self
.
assertIn
(
ckpt_name
,
contents
)
self
.
assertIn
(
"metrics.pkl"
,
contents
)
self
.
assertIn
(
"test_generations.txt"
,
contents
)
self
.
assertIn
(
"val_generations_1.txt"
,
contents
)
self
.
assertIn
(
"val_
1_
results.txt"
,
contents
)
self
.
assertIn
(
"val_generations_
0000
1.txt"
,
contents
)
self
.
assertIn
(
"val_results
_00001
.txt"
,
contents
)
self
.
assertIn
(
"test_results.txt"
,
contents
)
# self.assertEqual(len(contents), 15)
metrics
=
pickle_load
(
Path
(
output_dir
)
/
"metrics.pkl"
)
import
pandas
as
pd
val_df
=
pd
.
DataFrame
(
metrics
[
"val"
])
train_df
=
pd
.
DataFrame
(
metrics
[
"train"
])
test_df
=
pd
.
DataFrame
(
metrics
[
"test"
])
desired_n_evals
=
args_d
[
"num_train_epochs"
]
*
2
+
1
self
.
assertEqual
(
val_df
.
shape
[
0
],
desired_n_evals
)
#
self
.
assertEqual
(
test_df
.
shape
[
1
],
val_df
.
shape
[
1
])
self
.
assertEqual
(
train_df
.
shape
[
0
],
0
)
desired_n_evals
=
int
(
args_d
[
"num_train_epochs"
]
*
(
1
/
args_d
[
"val_check_interval"
])
+
1
)
self
.
assertEqual
(
len
(
metrics
[
"val"
]),
desired_n_evals
)
self
.
assertEqual
(
len
(
metrics
[
"train"
]),
0
)
# doesn't get logged here
return
model
...
...
@@ -281,9 +225,8 @@ class TestBartExamples(unittest.TestCase):
output_dir
=
tempfile
.
mkdtemp
(
prefix
=
"output_"
)
args_d
.
update
(
data_dir
=
tmp_dir
,
model_type
=
"t5"
,
model_name_or_path
=
"patrickvonplaten/t5-tiny-random"
,
tokenizer_name
=
None
,
# "patrickvonplaten/t5-tiny-random",
model_name_or_path
=
T5_TINY
,
tokenizer_name
=
None
,
# T5_TINY,
train_batch_size
=
2
,
eval_batch_size
=
2
,
gpus
=
0
,
...
...
examples/summarization/utils.py
View file @
f5c2a122
...
...
@@ -45,8 +45,10 @@ def encode_file(
max_length
=
max_length
,
pad_to_max_length
=
pad_to_max_length
,
add_prefix_space
=
True
,
truncation
=
True
,
return_tensors
=
return_tensors
,
)
assert
tokenized
.
input_ids
.
shape
[
1
]
==
max_length
examples
.
append
(
tokenized
)
torch
.
save
(
lmap
(
dict
,
examples
),
cache_path
.
open
(
"wb"
))
return
examples
...
...
examples/text-classification/run_pl_glue.py
View file @
f5c2a122
...
...
@@ -108,7 +108,7 @@ class GLUETransformer(BaseTransformer):
return
{
"val_loss"
:
tmp_eval_loss
.
detach
().
cpu
(),
"pred"
:
preds
,
"target"
:
out_label_ids
}
def
_eval_end
(
self
,
outputs
):
def
_eval_end
(
self
,
outputs
)
->
tuple
:
val_loss_mean
=
torch
.
stack
([
x
[
"val_loss"
]
for
x
in
outputs
]).
mean
().
detach
().
cpu
().
item
()
preds
=
np
.
concatenate
([
x
[
"pred"
]
for
x
in
outputs
],
axis
=
0
)
...
...
@@ -132,20 +132,14 @@ class GLUETransformer(BaseTransformer):
logs
=
ret
[
"log"
]
return
{
"val_loss"
:
logs
[
"val_loss"
],
"log"
:
logs
,
"progress_bar"
:
logs
}
def
test_epoch_end
(
self
,
outputs
):
# updating to test_epoch_end instead of deprecated test_end
def
test_epoch_end
(
self
,
outputs
)
->
dict
:
ret
,
predictions
,
targets
=
self
.
_eval_end
(
outputs
)
# Converting to the dic required by pl
# https://github.com/PyTorchLightning/pytorch-lightning/blob/master/\
# pytorch_lightning/trainer/logging.py#L139
logs
=
ret
[
"log"
]
# `val_loss` is the key returned by `self._eval_end()` but actually refers to `test_loss`
return
{
"avg_test_loss"
:
logs
[
"val_loss"
],
"log"
:
logs
,
"progress_bar"
:
logs
}
@
staticmethod
def
add_model_specific_args
(
parser
,
root_dir
):
# Add NER specific options
BaseTransformer
.
add_model_specific_args
(
parser
,
root_dir
)
parser
.
add_argument
(
"--max_seq_length"
,
...
...
src/transformers/tokenization_auto.py
View file @
f5c2a122
...
...
@@ -205,7 +205,7 @@ class AutoTokenizer:
if
not
isinstance
(
config
,
PretrainedConfig
):
config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
if
"bert-base-japanese"
in
pretrained_model_name_or_path
:
if
"bert-base-japanese"
in
str
(
pretrained_model_name_or_path
)
:
return
BertJapaneseTokenizer
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
use_fast
=
kwargs
.
pop
(
"use_fast"
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment