Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9af479b3
Commit
9af479b3
authored
Nov 02, 2018
by
thomwolf
Browse files
conversion run_squad ok
parent
8e81e5e6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
225 deletions
+23
-225
run_squad_pytorch.py
run_squad_pytorch.py
+23
-225
No files found.
run_squad_pytorch.py
View file @
9af479b3
...
@@ -27,7 +27,6 @@ import modeling
...
@@ -27,7 +27,6 @@ import modeling
import
optimization
import
optimization
import
tokenization
import
tokenization
import
six
import
six
import
tensorflow
as
tf
import
argparse
import
argparse
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
from
torch.utils.data
import
TensorDataset
,
DataLoader
,
RandomSampler
,
SequentialSampler
...
@@ -177,7 +176,7 @@ class InputFeatures(object):
...
@@ -177,7 +176,7 @@ class InputFeatures(object):
def
read_squad_examples
(
input_file
,
is_training
):
def
read_squad_examples
(
input_file
,
is_training
):
"""Read a SQuAD json file into a list of SquadExample."""
"""Read a SQuAD json file into a list of SquadExample."""
with
tf
.
gfile
.
O
pen
(
input_file
,
"r"
)
as
reader
:
with
o
pen
(
input_file
,
"r"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
input_data
=
json
.
load
(
reader
)[
"data"
]
def
is_whitespace
(
c
):
def
is_whitespace
(
c
):
...
@@ -229,7 +228,7 @@ def read_squad_examples(input_file, is_training):
...
@@ -229,7 +228,7 @@ def read_squad_examples(input_file, is_training):
cleaned_answer_text
=
" "
.
join
(
cleaned_answer_text
=
" "
.
join
(
tokenization
.
whitespace_tokenize
(
orig_answer_text
))
tokenization
.
whitespace_tokenize
(
orig_answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
tf
.
logg
ing
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
logg
er
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
actual_text
,
cleaned_answer_text
)
continue
continue
...
@@ -356,27 +355,27 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -356,27 +355,27 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position
=
tok_end_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
if
example_index
<
20
:
if
example_index
<
20
:
tf
.
logg
ing
.
info
(
"*** Example ***"
)
logg
er
.
info
(
"*** Example ***"
)
tf
.
logg
ing
.
info
(
"unique_id: %s"
%
(
unique_id
))
logg
er
.
info
(
"unique_id: %s"
%
(
unique_id
))
tf
.
logg
ing
.
info
(
"example_index: %s"
%
(
example_index
))
logg
er
.
info
(
"example_index: %s"
%
(
example_index
))
tf
.
logg
ing
.
info
(
"doc_span_index: %s"
%
(
doc_span_index
))
logg
er
.
info
(
"doc_span_index: %s"
%
(
doc_span_index
))
tf
.
logg
ing
.
info
(
"tokens: %s"
%
" "
.
join
(
logg
er
.
info
(
"tokens: %s"
%
" "
.
join
(
[
tokenization
.
printable_text
(
x
)
for
x
in
tokens
]))
[
tokenization
.
printable_text
(
x
)
for
x
in
tokens
]))
tf
.
logg
ing
.
info
(
"token_to_orig_map: %s"
%
" "
.
join
(
logg
er
.
info
(
"token_to_orig_map: %s"
%
" "
.
join
(
[
"%d:%d"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_to_orig_map
)]))
[
"%d:%d"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_to_orig_map
)]))
tf
.
logg
ing
.
info
(
"token_is_max_context: %s"
%
" "
.
join
([
logg
er
.
info
(
"token_is_max_context: %s"
%
" "
.
join
([
"%d:%s"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_is_max_context
)
"%d:%s"
%
(
x
,
y
)
for
(
x
,
y
)
in
six
.
iteritems
(
token_is_max_context
)
]))
]))
tf
.
logg
ing
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
logg
er
.
info
(
"input_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_ids
]))
tf
.
logg
ing
.
info
(
logg
er
.
info
(
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
"input_mask: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
input_mask
]))
tf
.
logg
ing
.
info
(
logg
er
.
info
(
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
"segment_ids: %s"
%
" "
.
join
([
str
(
x
)
for
x
in
segment_ids
]))
if
is_training
:
if
is_training
:
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
answer_text
=
" "
.
join
(
tokens
[
start_position
:(
end_position
+
1
)])
tf
.
logg
ing
.
info
(
"start_position: %d"
%
(
start_position
))
logg
er
.
info
(
"start_position: %d"
%
(
start_position
))
tf
.
logg
ing
.
info
(
"end_position: %d"
%
(
end_position
))
logg
er
.
info
(
"end_position: %d"
%
(
end_position
))
tf
.
logg
ing
.
info
(
logg
er
.
info
(
"answer: %s"
%
(
tokenization
.
printable_text
(
answer_text
)))
"answer: %s"
%
(
tokenization
.
printable_text
(
answer_text
)))
features
.
append
(
features
.
append
(
...
@@ -471,207 +470,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
...
@@ -471,207 +470,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return
cur_span_index
==
best_span_index
return
cur_span_index
==
best_span_index
def
create_model
(
bert_config
,
is_training
,
input_ids
,
input_mask
,
segment_ids
,
use_one_hot_embeddings
):
"""Creates a classification model."""
model
=
modeling
.
BertModel
(
config
=
bert_config
,
is_training
=
is_training
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
token_type_ids
=
segment_ids
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
final_hidden
=
model
.
get_sequence_output
()
final_hidden_shape
=
modeling
.
get_shape_list
(
final_hidden
,
expected_rank
=
3
)
batch_size
=
final_hidden_shape
[
0
]
seq_length
=
final_hidden_shape
[
1
]
hidden_size
=
final_hidden_shape
[
2
]
output_weights
=
tf
.
get_variable
(
"cls/squad/output_weights"
,
[
2
,
hidden_size
],
initializer
=
tf
.
truncated_normal_initializer
(
stddev
=
0.02
))
output_bias
=
tf
.
get_variable
(
"cls/squad/output_bias"
,
[
2
],
initializer
=
tf
.
zeros_initializer
())
final_hidden_matrix
=
tf
.
reshape
(
final_hidden
,
[
batch_size
*
seq_length
,
hidden_size
])
logits
=
tf
.
matmul
(
final_hidden_matrix
,
output_weights
,
transpose_b
=
True
)
logits
=
tf
.
nn
.
bias_add
(
logits
,
output_bias
)
logits
=
tf
.
reshape
(
logits
,
[
batch_size
,
seq_length
,
2
])
logits
=
tf
.
transpose
(
logits
,
[
2
,
0
,
1
])
unstacked_logits
=
tf
.
unstack
(
logits
,
axis
=
0
)
(
start_logits
,
end_logits
)
=
(
unstacked_logits
[
0
],
unstacked_logits
[
1
])
return
(
start_logits
,
end_logits
)
def
model_fn_builder
(
bert_config
,
init_checkpoint
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
,
use_one_hot_embeddings
):
"""Returns `model_fn` closure for TPUEstimator."""
def
model_fn
(
features
,
labels
,
mode
,
params
):
# pylint: disable=unused-argument
"""The `model_fn` for TPUEstimator."""
tf
.
logging
.
info
(
"*** Features ***"
)
for
name
in
sorted
(
features
.
keys
()):
tf
.
logging
.
info
(
" name = %s, shape = %s"
%
(
name
,
features
[
name
].
shape
))
unique_ids
=
features
[
"unique_ids"
]
input_ids
=
features
[
"input_ids"
]
input_mask
=
features
[
"input_mask"
]
segment_ids
=
features
[
"segment_ids"
]
is_training
=
(
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
)
(
start_logits
,
end_logits
)
=
create_model
(
bert_config
=
bert_config
,
is_training
=
is_training
,
input_ids
=
input_ids
,
input_mask
=
input_mask
,
segment_ids
=
segment_ids
,
use_one_hot_embeddings
=
use_one_hot_embeddings
)
tvars
=
tf
.
trainable_variables
()
initialized_variable_names
=
{}
scaffold_fn
=
None
if
init_checkpoint
:
(
assignment_map
,
initialized_variable_names
)
=
modeling
.
get_assigment_map_from_checkpoint
(
tvars
,
init_checkpoint
)
if
use_tpu
:
def
tpu_scaffold
():
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
return
tf
.
train
.
Scaffold
()
scaffold_fn
=
tpu_scaffold
else
:
tf
.
train
.
init_from_checkpoint
(
init_checkpoint
,
assignment_map
)
tf
.
logging
.
info
(
"**** Trainable Variables ****"
)
for
var
in
tvars
:
init_string
=
""
if
var
.
name
in
initialized_variable_names
:
init_string
=
", *INIT_FROM_CKPT*"
tf
.
logging
.
info
(
" name = %s, shape = %s%s"
,
var
.
name
,
var
.
shape
,
init_string
)
output_spec
=
None
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
seq_length
=
modeling
.
get_shape_list
(
input_ids
)[
1
]
def
compute_loss
(
logits
,
positions
):
one_hot_positions
=
tf
.
one_hot
(
positions
,
depth
=
seq_length
,
dtype
=
tf
.
float32
)
log_probs
=
tf
.
nn
.
log_softmax
(
logits
,
axis
=-
1
)
loss
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
one_hot_positions
*
log_probs
,
axis
=-
1
))
return
loss
start_positions
=
features
[
"start_positions"
]
end_positions
=
features
[
"end_positions"
]
start_loss
=
compute_loss
(
start_logits
,
start_positions
)
end_loss
=
compute_loss
(
end_logits
,
end_positions
)
total_loss
=
(
start_loss
+
end_loss
)
/
2.0
train_op
=
optimization
.
create_optimizer
(
total_loss
,
learning_rate
,
num_train_steps
,
num_warmup_steps
,
use_tpu
)
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
loss
=
total_loss
,
train_op
=
train_op
,
scaffold_fn
=
scaffold_fn
)
elif
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
predictions
=
{
"unique_ids"
:
unique_ids
,
"start_logits"
:
start_logits
,
"end_logits"
:
end_logits
,
}
output_spec
=
tf
.
contrib
.
tpu
.
TPUEstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
,
scaffold_fn
=
scaffold_fn
)
else
:
raise
ValueError
(
"Only TRAIN and PREDICT modes are supported: %s"
%
(
mode
))
return
output_spec
return
model_fn
def
input_fn_builder
(
features
,
seq_length
,
is_training
,
drop_remainder
):
"""Creates an `input_fn` closure to be passed to TPUEstimator."""
all_unique_ids
=
[]
all_input_ids
=
[]
all_input_mask
=
[]
all_segment_ids
=
[]
all_start_positions
=
[]
all_end_positions
=
[]
for
feature
in
features
:
all_unique_ids
.
append
(
feature
.
unique_id
)
all_input_ids
.
append
(
feature
.
input_ids
)
all_input_mask
.
append
(
feature
.
input_mask
)
all_segment_ids
.
append
(
feature
.
segment_ids
)
if
is_training
:
all_start_positions
.
append
(
feature
.
start_position
)
all_end_positions
.
append
(
feature
.
end_position
)
def
input_fn
(
params
):
"""The actual input function."""
batch_size
=
params
[
"batch_size"
]
num_examples
=
len
(
features
)
# This is for demo purposes and does NOT scale to large data sets. We do
# not use Dataset.from_generator() because that uses tf.py_func which is
# not TPU compatible. The right way to load data is with TFRecordReader.
feature_map
=
{
"unique_ids"
:
tf
.
constant
(
all_unique_ids
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
),
"input_ids"
:
tf
.
constant
(
all_input_ids
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
"input_mask"
:
tf
.
constant
(
all_input_mask
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
"segment_ids"
:
tf
.
constant
(
all_segment_ids
,
shape
=
[
num_examples
,
seq_length
],
dtype
=
tf
.
int32
),
}
if
is_training
:
feature_map
[
"start_positions"
]
=
tf
.
constant
(
all_start_positions
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
)
feature_map
[
"end_positions"
]
=
tf
.
constant
(
all_end_positions
,
shape
=
[
num_examples
],
dtype
=
tf
.
int32
)
d
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
feature_map
)
if
is_training
:
d
=
d
.
repeat
()
d
=
d
.
shuffle
(
buffer_size
=
100
)
d
=
d
.
batch
(
batch_size
=
batch_size
,
drop_remainder
=
drop_remainder
)
return
d
return
input_fn
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
[
"unique_id"
,
"start_logits"
,
"end_logits"
])
[
"unique_id"
,
"start_logits"
,
"end_logits"
])
...
@@ -681,8 +479,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -681,8 +479,8 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
max_answer_length
,
do_lower_case
,
output_prediction_file
,
output_nbest_file
):
output_nbest_file
):
"""Write final predictions to the json file."""
"""Write final predictions to the json file."""
tf
.
logg
ing
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
logg
er
.
info
(
"Writing predictions to: %s"
%
(
output_prediction_file
))
tf
.
logg
ing
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
logg
er
.
info
(
"Writing nbest to: %s"
%
(
output_nbest_file
))
example_index_to_features
=
collections
.
defaultdict
(
list
)
example_index_to_features
=
collections
.
defaultdict
(
list
)
for
feature
in
all_features
:
for
feature
in
all_features
:
...
@@ -804,10 +602,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
...
@@ -804,10 +602,10 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
all_predictions
[
example
.
qas_id
]
=
nbest_json
[
0
][
"text"
]
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
all_nbest_json
[
example
.
qas_id
]
=
nbest_json
with
tf
.
gfile
.
GFile
(
output_prediction_file
,
"w"
)
as
writer
:
with
open
(
output_prediction_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
all_predictions
,
indent
=
4
)
+
"
\n
"
)
with
tf
.
gfile
.
GFile
(
output_nbest_file
,
"w"
)
as
writer
:
with
open
(
output_nbest_file
,
"w"
)
as
writer
:
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
writer
.
write
(
json
.
dumps
(
all_nbest_json
,
indent
=
4
)
+
"
\n
"
)
...
@@ -861,7 +659,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -861,7 +659,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
start_position
=
tok_text
.
find
(
pred_text
)
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
start_position
==
-
1
:
if
args
.
verbose_logging
:
if
args
.
verbose_logging
:
tf
.
logg
ing
.
info
(
logg
er
.
info
(
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
"Unable to find text: '%s' in '%s'"
%
(
pred_text
,
orig_text
))
return
orig_text
return
orig_text
end_position
=
start_position
+
len
(
pred_text
)
-
1
end_position
=
start_position
+
len
(
pred_text
)
-
1
...
@@ -871,7 +669,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -871,7 +669,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
len
(
orig_ns_text
)
!=
len
(
tok_ns_text
):
if
args
.
verbose_logging
:
if
args
.
verbose_logging
:
tf
.
logg
ing
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
logg
er
.
info
(
"Length not equal after stripping spaces: '%s' vs '%s'"
,
orig_ns_text
,
tok_ns_text
)
orig_ns_text
,
tok_ns_text
)
return
orig_text
return
orig_text
...
@@ -889,7 +687,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -889,7 +687,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
if
orig_start_position
is
None
:
if
orig_start_position
is
None
:
if
args
.
verbose_logging
:
if
args
.
verbose_logging
:
tf
.
logg
ing
.
info
(
"Couldn't map start position"
)
logg
er
.
info
(
"Couldn't map start position"
)
return
orig_text
return
orig_text
orig_end_position
=
None
orig_end_position
=
None
...
@@ -900,7 +698,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -900,7 +698,7 @@ def get_final_text(pred_text, orig_text, do_lower_case):
if
orig_end_position
is
None
:
if
orig_end_position
is
None
:
if
args
.
verbose_logging
:
if
args
.
verbose_logging
:
tf
.
logg
ing
.
info
(
"Couldn't map end position"
)
logg
er
.
info
(
"Couldn't map end position"
)
return
orig_text
return
orig_text
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
output_text
=
orig_text
[
orig_start_position
:(
orig_end_position
+
1
)]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment