Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4fef5919
Commit
4fef5919
authored
Jul 11, 2019
by
thomwolf
Browse files
updating examples
parent
50b7e52a
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
116 additions
and
150 deletions
+116
-150
examples/run_glue.py
examples/run_glue.py
+59
-35
examples/utils_glue.py
examples/utils_glue.py
+1
-0
pytorch_transformers/modeling_bert.py
pytorch_transformers/modeling_bert.py
+6
-6
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+4
-4
pytorch_transformers/modeling_openai.py
pytorch_transformers/modeling_openai.py
+1
-1
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+5
-5
pytorch_transformers/modeling_transfo_xl_utilities.py
pytorch_transformers/modeling_transfo_xl_utilities.py
+0
-70
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+16
-5
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+8
-8
pytorch_transformers/tokenization_transfo_xl.py
pytorch_transformers/tokenization_transfo_xl.py
+16
-16
No files found.
examples/run_glue.py
View file @
4fef5919
...
...
@@ -18,46 +18,37 @@
from
__future__
import
absolute_import
,
division
,
print_function
import
argparse
import
glob
import
logging
import
os
import
random
from
tqdm
import
tqdm
,
trange
import
numpy
as
np
import
torch
from
tensorboardX
import
SummaryWriter
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
tqdm
import
tqdm
,
trange
from
tensorboardX
import
SummaryWriter
from
pytorch_transformers
import
(
BertForSequenceClassification
,
XLNetForSequenceClassification
,
XLMForSequenceClassification
,
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
pytorch_transformers
import
(
BertTokenizer
,
XLNetTokenizer
,
XLMTokenizer
)
from
pytorch_transformers
import
WEIGHTS_NAME
from
pytorch_transformers
import
(
BertConfig
,
BertForSequenceClassification
,
BertTokenizer
,
XLMConfig
,
XLMForSequenceClassification
,
XLMTokenizer
,
XLNetConfig
,
XLNetForSequenceClassification
,
XLNetTokenizer
)
from
pytorch_transformers.optimization
import
BertAdam
from
utils_glue
import
processors
,
output_modes
,
convert_examples_to_features
,
compute_metrics
from
utils_glue
import
(
compute_metrics
,
convert_examples_to_features
,
output_modes
,
processors
)
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
((
tuple
(
m
.
keys
())
for
m
in
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)),
())
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
XLNetConfig
,
XLMConfig
)),
())
MODEL_CLASSES
=
{
'bert'
:
BertForSequenceClassification
,
'xlnet'
:
XLNetForSequenceClassification
,
'xlm'
:
XLMForSequenceClassification
,
}
TOKENIZER_CLASSES
=
{
'bert'
:
BertTokenizer
,
'xlnet'
:
XLNetTokenizer
,
'xlm'
:
XLMTokenizer
,
'bert'
:
(
BertConfig
,
BertForSequenceClassification
,
BertTokenizer
),
'xlnet'
:
(
XLNetConfig
,
XLNetForSequenceClassification
,
XLNetTokenizer
),
'xlm'
:
(
XLMConfig
,
XLMForSequenceClassification
,
XLMTokenizer
),
}
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
...
...
@@ -130,14 +121,26 @@ def train(args, train_dataset, model, tokenizer):
optimizer
.
step
()
optimizer
.
zero_grad
()
global_step
+=
1
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
# Log metrics
if
args
.
local_rank
==
-
1
:
# Only evaluate on single GPU otherwise metrics may not average well
results
=
evaluate
(
args
,
model
,
tokenizer
)
results
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
for
key
,
value
in
results
.
items
():
tb_writer
.
add_scalar
(
'eval_{}'
.
format
(
key
),
value
,
global_step
)
tb_writer
.
add_scalar
(
'lr'
,
optimizer
.
get_lr
()[
0
],
global_step
)
tb_writer
.
add_scalar
(
'loss'
,
(
tr_loss
-
logging_loss
)
/
args
.
logging_steps
,
global_step
)
logging_loss
=
tr_loss
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
save_steps
>
0
and
global_step
%
args
.
save_steps
==
0
:
# Save model checkpoint
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
'checkpoint-{}'
.
format
(
global_step
))
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
'training_args.bin'
))
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
break
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
...
...
@@ -146,7 +149,7 @@ def train(args, train_dataset, model, tokenizer):
return
global_step
,
tr_loss
/
global_step
def
evaluate
(
args
,
model
,
tokenizer
):
def
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
""
):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names
=
(
"mnli"
,
"mnli-mm"
)
if
args
.
task_name
==
"mnli"
else
(
args
.
task_name
,)
eval_outputs_dirs
=
(
args
.
output_dir
,
args
.
output_dir
+
'-MM'
)
if
args
.
task_name
==
"mnli"
else
(
args
.
output_dir
,)
...
...
@@ -202,7 +205,7 @@ def evaluate(args, model, tokenizer):
output_eval_file
=
os
.
path
.
join
(
eval_output_dir
,
"eval_results.txt"
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results *****"
)
logger
.
info
(
"***** Eval results
{}
*****"
.
format
(
prefix
)
)
for
key
in
sorted
(
result
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
result
[
key
]))
writer
.
write
(
"%s = %s
\n
"
%
(
key
,
str
(
result
[
key
])))
...
...
@@ -264,6 +267,10 @@ def main():
help
=
"The output directory where the model predictions and checkpoints will be written."
)
## Other parameters
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained tokenizer name or path if not the same as model_name"
)
parser
.
add_argument
(
"--cache_dir"
,
default
=
""
,
type
=
str
,
help
=
"Where do you want to store the pre-trained models downloaded from s3"
)
parser
.
add_argument
(
"--max_seq_length"
,
default
=
128
,
type
=
int
,
...
...
@@ -293,8 +300,12 @@ def main():
parser
.
add_argument
(
"--warmup_proportion"
,
default
=
0.1
,
type
=
float
,
help
=
"Proportion of training with linear learning rate warmup (0.1 = 10%% of training)."
)
parser
.
add_argument
(
'--logging_steps'
,
type
=
int
,
default
=
10
0
,
parser
.
add_argument
(
'--logging_steps'
,
type
=
int
,
default
=
5
0
,
help
=
"Log every X updates steps."
)
parser
.
add_argument
(
'--save_steps'
,
type
=
int
,
default
=
50
,
help
=
"Save checkpoint every X updates steps."
)
parser
.
add_argument
(
"--eval_all_checkpoints"
,
action
=
'store_true'
,
help
=
"Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
)
parser
.
add_argument
(
"--no_cuda"
,
action
=
'store_true'
,
help
=
"Avoid using CUDA when available"
)
parser
.
add_argument
(
'--overwrite_output_dir'
,
action
=
'store_true'
,
...
...
@@ -363,11 +374,15 @@ def main():
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
args
.
model_type
=
args
.
model_name
.
lower
().
split
(
'-'
)[
0
]
tokenizer_class
=
TOKENIZER_CLASSES
[
args
.
model_type
]
model_class
=
MODEL_CLASSES
[
args
.
model_type
]
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
model_name
,
do_lower_case
=
args
.
do_lower_case
)
model
=
model_class
.
from_pretrained
(
args
.
model_name
,
num_labels
=
num_labels
)
args
.
model_type
=
""
for
key
in
MODEL_CLASSES
:
if
key
in
args
.
model_name
.
lower
():
args
.
model_type
=
key
# take the first match in model types
break
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
config
=
config_class
.
from_pretrained
(
args
.
config_name
if
args
.
config_name
else
args
.
model_name
,
num_labels
=
num_labels
,
finetuning_task
=
args
.
task_name
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
tokenizer_name
if
args
.
tokenizer_name
else
args
.
model_name
,
do_lower_case
=
args
.
do_lower_case
)
model
=
model_class
.
from_pretrained
(
args
.
model_name
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name
),
config
=
config
)
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
...
...
@@ -410,8 +425,17 @@ def main():
# Evaluation
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
results
=
evaluate
(
args
,
model
,
tokenizer
)
checkpoints
=
[
args
.
output_dir
+
'./'
+
WEIGHTS_NAME
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
os
.
path
.
dirname
(
c
)
for
c
in
glob
.
glob
(
args
.
output_dir
+
'/**/'
+
WEIGHTS_NAME
,
recursive
=
True
))
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
results
=
{}
for
checkpoint
in
checkpoints
:
global_step
=
int
(
checkpoints
.
split
(
'-'
)[
-
1
])
model
=
model_class
.
from_pretrained
(
checkpoints
)
model
.
to
(
args
.
device
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
(
n
+
'_{}'
.
format
())
return
results
...
...
examples/utils_glue.py
View file @
4fef5919
...
...
@@ -21,6 +21,7 @@ import csv
import
logging
import
os
import
sys
from
io
import
open
from
scipy.stats
import
pearsonr
,
spearmanr
from
sklearn.metrics
import
matthews_corrcoef
,
f1_score
...
...
pytorch_transformers/modeling_bert.py
View file @
4fef5919
...
...
@@ -73,17 +73,17 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger
.
error
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
logger
.
info
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
logger
.
info
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
...
...
@@ -93,7 +93,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
any
(
n
in
[
"adam_v"
,
"adam_m"
,
"global_step"
]
for
n
in
name
):
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
logger
.
info
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
pointer
=
model
for
m_name
in
name
:
...
...
@@ -113,7 +113,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
try
:
pointer
=
getattr
(
pointer
,
l
[
0
])
except
AttributeError
:
print
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
logger
.
info
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
...
...
@@ -127,7 +127,7 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
logger
.
info
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
...
...
pytorch_transformers/modeling_gpt2.py
View file @
4fef5919
...
...
@@ -49,17 +49,17 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger
.
error
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
gpt2_checkpoint_path
)
print
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
logger
.
info
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
logger
.
info
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
.
squeeze
())
...
...
@@ -90,7 +90,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
logger
.
info
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
...
...
pytorch_transformers/modeling_openai.py
View file @
4fef5919
...
...
@@ -110,7 +110,7 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
logger
.
info
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
4fef5919
...
...
@@ -126,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger
.
error
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
# Build TF to PyTorch weights loading map
...
...
@@ -136,7 +136,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
tf_weights
=
{}
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
logger
.
info
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
tf_weights
[
name
]
=
array
...
...
@@ -157,7 +157,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
except
AssertionError
as
e
:
e
.
args
+=
(
p_i
.
shape
,
arr_i
.
shape
)
raise
print
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
logger
.
info
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
p_i
.
data
=
torch
.
from_numpy
(
arr_i
)
else
:
try
:
...
...
@@ -165,13 +165,13 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
logger
.
info
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
tf_weights
.
pop
(
name
,
None
)
tf_weights
.
pop
(
name
+
'/Adam'
,
None
)
tf_weights
.
pop
(
name
+
'/Adam_1'
,
None
)
print
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
logger
.
info
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
return
model
...
...
pytorch_transformers/modeling_transfo_xl_utilities.py
View file @
4fef5919
...
...
@@ -272,7 +272,6 @@ class LogUniformSampler(object):
self
.
range_max
=
range_max
log_indices
=
torch
.
arange
(
1.
,
range_max
+
2.
,
1.
).
log_
()
self
.
dist
=
(
log_indices
[
1
:]
-
log_indices
[:
-
1
])
/
log_indices
[
-
1
]
# print('P', self.dist.numpy().tolist()[-30:])
self
.
log_q
=
(
-
(
-
self
.
dist
.
double
().
log1p_
()
*
2
*
n_sample
).
expm1_
()).
log_
().
float
()
...
...
@@ -331,72 +330,3 @@ def sample_logits(embedding, bias, labels, inputs, sampler):
logits
=
torch
.
cat
([
true_logits
[:,
:,
None
],
sample_logits
],
-
1
)
return
logits
# class LogUniformSampler(object):
# def __init__(self, range_max, unique=False):
# """
# Reference : https://github.com/tensorflow/tensorflow/blob/r1.10/tensorflow/python/ops/candidate_sampling_ops.py
# `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
# """
# self.range_max = range_max
# log_indices = torch.arange(1., range_max+2., 1.).log_()
# self.dist = (log_indices[1:] - log_indices[:-1]) / log_indices[-1]
# self.unique = unique
# if self.unique:
# self.exclude_mask = torch.ByteTensor(range_max).fill_(0)
# def sample(self, n_sample, labels):
# pos_sample, new_labels = labels.unique(return_inverse=True)
# n_pos_sample = pos_sample.size(0)
# n_neg_sample = n_sample - n_pos_sample
# if self.unique:
# self.exclude_mask.index_fill_(0, pos_sample, 1)
# sample_dist = self.dist.clone().masked_fill_(self.exclude_mask, 0)
# self.exclude_mask.index_fill_(0, pos_sample, 0)
# else:
# sample_dist = self.dist
# neg_sample = torch.multinomial(sample_dist, n_neg_sample)
# sample = torch.cat([pos_sample, neg_sample])
# sample_prob = self.dist[sample]
# return new_labels, sample, sample_prob
if
__name__
==
'__main__'
:
S
,
B
=
3
,
4
n_vocab
=
10000
n_sample
=
5
H
=
32
labels
=
torch
.
LongTensor
(
S
,
B
).
random_
(
0
,
n_vocab
)
# sampler = LogUniformSampler(n_vocab, unique=False)
# new_labels, sample, sample_prob = sampler.sample(n_sample, labels)
sampler
=
LogUniformSampler
(
n_vocab
,
n_sample
)
#, unique=True)
# true_probs, samp_probs, neg_samples = sampler.sample(n_sample, labels)
# print('true_probs', true_probs.numpy().tolist())
# print('samp_probs', samp_probs.numpy().tolist())
# print('neg_samples', neg_samples.numpy().tolist())
# print('sum', torch.sum(sampler.dist).item())
# assert torch.all(torch.sort(sample.unique())[0].eq(torch.sort(sample)[0])).item()
embedding
=
nn
.
Embedding
(
n_vocab
,
H
)
bias
=
torch
.
zeros
(
n_vocab
)
inputs
=
torch
.
Tensor
(
S
,
B
,
H
).
normal_
()
logits
,
out_labels
=
sample_logits
(
embedding
,
bias
,
labels
,
inputs
,
sampler
,
n_sample
)
print
(
'logits'
,
logits
.
detach
().
numpy
().
tolist
())
print
(
'logits shape'
,
logits
.
size
())
print
(
'out_labels'
,
out_labels
.
detach
().
numpy
().
tolist
())
print
(
'out_labels shape'
,
out_labels
.
size
())
pytorch_transformers/modeling_utils.py
View file @
4fef5919
...
...
@@ -57,16 +57,18 @@ class PretrainedConfig(object):
pretrained_model_name_or_path: either:
- a str with the name of a pre-trained model to load selected in the list of:
. `xlnet-large-cased`
- a path or url to a
pretrained model archive containing:
. `config.json`
a configuration file for the model
- a path or url to a
directory containing a configuration file `config.json` for the model,
- a path or url to
a configuration file for the model
.
cache_dir: an optional path to a folder in which the pre-trained model configuration will be cached.
"""
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
if
pretrained_model_name_or_path
in
cls
.
pretrained_config_archive_map
:
config_file
=
cls
.
pretrained_config_archive_map
[
pretrained_model_name_or_path
]
el
se
:
el
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
)
:
config_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
CONFIG_NAME
)
else
:
config_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
try
:
resolved_config_file
=
cached_path
(
config_file
,
cache_dir
=
cache_dir
)
...
...
@@ -200,6 +202,7 @@ class PreTrainedModel(nn.Module):
- a path or url to a tensorflow pretrained model checkpoint containing:
. `config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
config: an optional configuration for the model
from_tf: should we load the weights from a locally saved TensorFlow checkpoint
cache_dir: an optional path to a folder in which the pre-trained models will be cached.
state_dict: an optional state dictionnary (collections.OrderedDict object) to use
...
...
@@ -207,23 +210,31 @@ class PreTrainedModel(nn.Module):
*inputs, **kwargs: additional input for the specific XLNet class
(ex: num_labels for XLNetForSequenceClassification)
"""
config
=
kwargs
.
pop
(
'config'
,
None
)
state_dict
=
kwargs
.
pop
(
'state_dict'
,
None
)
cache_dir
=
kwargs
.
pop
(
'cache_dir'
,
None
)
from_tf
=
kwargs
.
pop
(
'from_tf'
,
False
)
output_loading_info
=
kwargs
.
pop
(
'output_loading_info'
,
False
)
# Load config
if
config
is
None
:
config
=
cls
.
config_class
.
from_pretrained
(
pretrained_model_name_or_path
,
*
inputs
,
**
kwargs
)
# Load model
if
pretrained_model_name_or_path
in
cls
.
pretrained_model_archive_map
:
archive_file
=
cls
.
pretrained_model_archive_map
[
pretrained_model_name_or_path
]
el
se
:
el
if
os
.
path
.
isdir
(
pretrained_model_name_or_path
)
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
TF_WEIGHTS_NAME
+
".index"
)
else
:
archive_file
=
os
.
path
.
join
(
pretrained_model_name_or_path
,
WEIGHTS_NAME
)
else
:
if
from_tf
:
# Directly load from a TensorFlow checkpoint
archive_file
=
pretrained_model_name_or_path
+
".index"
else
:
archive_file
=
pretrained_model_name_or_path
# redirect to the cache, if necessary
try
:
resolved_archive_file
=
cached_path
(
archive_file
,
cache_dir
=
cache_dir
)
...
...
pytorch_transformers/modeling_xlnet.py
View file @
4fef5919
...
...
@@ -122,14 +122,14 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
print
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
logger
.
error
(
"Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
tf_weights
=
{}
for
name
,
shape
in
init_vars
:
print
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
logger
.
info
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
tf_weights
[
name
]
=
array
...
...
@@ -137,15 +137,15 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
tf_to_pt_map
=
build_tf_xlnet_to_pytorch_map
(
model
,
config
,
tf_weights
)
for
name
,
pointer
in
tf_to_pt_map
.
items
():
print
(
"Importing {}"
.
format
(
name
))
logger
.
info
(
"Importing {}"
.
format
(
name
))
if
name
not
in
tf_weights
:
print
(
"{} not in tf pre-trained weights, skipping"
.
format
(
name
))
logger
.
info
(
"{} not in tf pre-trained weights, skipping"
.
format
(
name
))
continue
array
=
tf_weights
[
name
]
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
'kernel'
in
name
and
(
'ff'
in
name
or
'summary'
in
name
or
'logit'
in
name
):
print
(
"Transposing"
)
logger
.
info
(
"Transposing"
)
array
=
np
.
transpose
(
array
)
if
isinstance
(
pointer
,
list
):
# Here we will split the TF weigths
...
...
@@ -157,7 +157,7 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
except
AssertionError
as
e
:
e
.
args
+=
(
p_i
.
shape
,
arr_i
.
shape
)
raise
print
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
logger
.
info
(
"Initialize PyTorch weight {} for layer {}"
.
format
(
name
,
i
))
p_i
.
data
=
torch
.
from_numpy
(
arr_i
)
else
:
try
:
...
...
@@ -165,13 +165,13 @@ def load_tf_weights_in_xlnet(model, config, tf_path):
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
print
(
"Initialize PyTorch weight {}"
.
format
(
name
))
logger
.
info
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
tf_weights
.
pop
(
name
,
None
)
tf_weights
.
pop
(
name
+
'/Adam'
,
None
)
tf_weights
.
pop
(
name
+
'/Adam_1'
,
None
)
print
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
logger
.
info
(
"Weights not copied to PyTorch model: {}"
.
format
(
', '
.
join
(
tf_weights
.
keys
())))
return
model
...
...
pytorch_transformers/tokenization_transfo_xl.py
View file @
4fef5919
...
...
@@ -98,14 +98,14 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self
.
build_vocab
()
def
count_file
(
self
,
path
,
verbose
=
False
,
add_eos
=
False
):
if
verbose
:
print
(
'counting file {} ...'
.
format
(
path
))
if
verbose
:
logger
.
info
(
'counting file {} ...'
.
format
(
path
))
assert
os
.
path
.
exists
(
path
)
sents
=
[]
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
logger
.
info
(
' line {}'
.
format
(
idx
))
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
)
self
.
counter
.
update
(
symbols
)
sents
.
append
(
symbols
)
...
...
@@ -116,10 +116,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
"""
sents : a list of sentences, each a list of tokenized symbols
"""
if
verbose
:
print
(
'counting {} sents ...'
.
format
(
len
(
sents
)))
if
verbose
:
logger
.
info
(
'counting {} sents ...'
.
format
(
len
(
sents
)))
for
idx
,
symbols
in
enumerate
(
sents
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
logger
.
info
(
' line {}'
.
format
(
idx
))
self
.
counter
.
update
(
symbols
)
def
_build_from_file
(
self
,
vocab_file
):
...
...
@@ -147,11 +147,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
def
build_vocab
(
self
):
if
self
.
vocab_file
:
print
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
logger
.
info
(
'building vocab from {}'
.
format
(
self
.
vocab_file
))
self
.
_build_from_file
(
self
.
vocab_file
)
print
(
'final vocab size {}'
.
format
(
len
(
self
)))
logger
.
info
(
'final vocab size {}'
.
format
(
len
(
self
)))
else
:
print
(
'building vocab with min_freq={}, max_size={}'
.
format
(
logger
.
info
(
'building vocab with min_freq={}, max_size={}'
.
format
(
self
.
min_freq
,
self
.
max_size
))
self
.
idx2sym
=
[]
self
.
sym2idx
=
OrderedDict
()
...
...
@@ -163,18 +163,18 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if
cnt
<
self
.
min_freq
:
break
self
.
add_symbol
(
sym
)
print
(
'final vocab size {} from {} unique tokens'
.
format
(
logger
.
info
(
'final vocab size {} from {} unique tokens'
.
format
(
len
(
self
),
len
(
self
.
counter
)))
def
encode_file
(
self
,
path
,
ordered
=
False
,
verbose
=
False
,
add_eos
=
True
,
add_double_eos
=
False
):
if
verbose
:
print
(
'encoding file {} ...'
.
format
(
path
))
if
verbose
:
logger
.
info
(
'encoding file {} ...'
.
format
(
path
))
assert
os
.
path
.
exists
(
path
)
encoded
=
[]
with
open
(
path
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
for
idx
,
line
in
enumerate
(
f
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
logger
.
info
(
' line {}'
.
format
(
idx
))
symbols
=
self
.
tokenize
(
line
,
add_eos
=
add_eos
,
add_double_eos
=
add_double_eos
)
encoded
.
append
(
self
.
convert_to_tensor
(
symbols
))
...
...
@@ -185,11 +185,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return
encoded
def
encode_sents
(
self
,
sents
,
ordered
=
False
,
verbose
=
False
):
if
verbose
:
print
(
'encoding {} sents ...'
.
format
(
len
(
sents
)))
if
verbose
:
logger
.
info
(
'encoding {} sents ...'
.
format
(
len
(
sents
)))
encoded
=
[]
for
idx
,
symbols
in
enumerate
(
sents
):
if
verbose
and
idx
>
0
and
idx
%
500000
==
0
:
print
(
' line {}'
.
format
(
idx
))
logger
.
info
(
' line {}'
.
format
(
idx
))
encoded
.
append
(
self
.
convert_to_tensor
(
symbols
))
if
ordered
:
...
...
@@ -218,7 +218,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if
sym
in
self
.
sym2idx
:
return
self
.
sym2idx
[
sym
]
else
:
#
print
('encounter unk {}'.format(sym))
#
logger.info
('encounter unk {}'.format(sym))
# assert '<eos>' not in sym
if
hasattr
(
self
,
'unk_idx'
):
return
self
.
sym2idx
.
get
(
sym
,
self
.
unk_idx
)
...
...
@@ -544,14 +544,14 @@ def get_lm_corpus(datadir, dataset):
fn
=
os
.
path
.
join
(
datadir
,
'cache.pt'
)
fn_pickle
=
os
.
path
.
join
(
datadir
,
'cache.pkl'
)
if
os
.
path
.
exists
(
fn
):
print
(
'Loading cached dataset...'
)
logger
.
info
(
'Loading cached dataset...'
)
corpus
=
torch
.
load
(
fn_pickle
)
elif
os
.
path
.
exists
(
fn
):
print
(
'Loading cached dataset from pickle...'
)
logger
.
info
(
'Loading cached dataset from pickle...'
)
with
open
(
fn
,
"rb"
)
as
fp
:
corpus
=
pickle
.
load
(
fp
)
else
:
print
(
'Producing dataset {}...'
.
format
(
dataset
))
logger
.
info
(
'Producing dataset {}...'
.
format
(
dataset
))
kwargs
=
{}
if
dataset
in
[
'wt103'
,
'wt2'
]:
kwargs
[
'special'
]
=
[
'<eos>'
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment