Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
241e3254
Commit
241e3254
authored
Sep 14, 2018
by
Sinan Tan
Committed by
xuehui
Sep 14, 2018
Browse files
Fix some pylint warnings for SQuAD QA model.
parent
5e01504d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
63 additions
and
53 deletions
+63
-53
examples/trials/ga_squad/data.py
examples/trials/ga_squad/data.py
+34
-30
examples/trials/ga_squad/evaluate.py
examples/trials/ga_squad/evaluate.py
+14
-13
examples/trials/ga_squad/graph.py
examples/trials/ga_squad/graph.py
+10
-10
examples/trials/ga_squad/train_model.py
examples/trials/ga_squad/train_model.py
+5
-0
No files found.
examples/trials/ga_squad/data.py
View file @
241e3254
...
@@ -19,6 +19,10 @@
...
@@ -19,6 +19,10 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
Data processing script for the QA model.
'''
import
csv
import
csv
import
json
import
json
from
random
import
shuffle
from
random
import
shuffle
...
@@ -73,19 +77,19 @@ def load_from_file(path, fmt=None, is_training=True):
...
@@ -73,19 +77,19 @@ def load_from_file(path, fmt=None, is_training=True):
for
doc
in
data
:
for
doc
in
data
:
for
paragraph
in
doc
[
'paragraphs'
]:
for
paragraph
in
doc
[
'paragraphs'
]:
passage
=
paragraph
[
'context'
]
passage
=
paragraph
[
'context'
]
for
qa
in
paragraph
[
'qas'
]:
for
qa
_pair
in
paragraph
[
'qas'
]:
question
=
qa
[
'question'
]
question
=
qa
_pair
[
'question'
]
id
=
qa
[
'id'
]
qa_
id
=
qa
_pair
[
'id'
]
if
not
is_training
:
if
not
is_training
:
qp_pairs
.
append
(
qp_pairs
.
append
(
{
'passage'
:
passage
,
'question'
:
question
,
'id'
:
id
})
{
'passage'
:
passage
,
'question'
:
question
,
'id'
:
qa_
id
})
else
:
else
:
for
answer
in
qa
[
'answers'
]:
for
answer
in
qa
_pair
[
'answers'
]:
answer_begin
=
int
(
answer
[
'answer_start'
])
answer_begin
=
int
(
answer
[
'answer_start'
])
answer_end
=
answer_begin
+
len
(
answer
[
'text'
])
answer_end
=
answer_begin
+
len
(
answer
[
'text'
])
qp_pairs
.
append
({
'passage'
:
passage
,
qp_pairs
.
append
({
'passage'
:
passage
,
'question'
:
question
,
'question'
:
question
,
'id'
:
id
,
'id'
:
qa_
id
,
'answer_begin'
:
answer_begin
,
'answer_begin'
:
answer_begin
,
'answer_end'
:
answer_end
})
'answer_end'
:
answer_end
})
else
:
else
:
...
@@ -121,21 +125,21 @@ def collect_vocab(qp_pairs):
...
@@ -121,21 +125,21 @@ def collect_vocab(qp_pairs):
Build the vocab from corpus.
Build the vocab from corpus.
'''
'''
vocab
=
set
()
vocab
=
set
()
for
qp
in
qp_pairs
:
for
qp
_pair
in
qp_pairs
:
for
word
in
qp
[
'question_tokens'
]:
for
word
in
qp
_pair
[
'question_tokens'
]:
vocab
.
add
(
word
[
'word'
])
vocab
.
add
(
word
[
'word'
])
for
word
in
qp
[
'passage_tokens'
]:
for
word
in
qp
_pair
[
'passage_tokens'
]:
vocab
.
add
(
word
[
'word'
])
vocab
.
add
(
word
[
'word'
])
return
vocab
return
vocab
def
shuffle_step
(
l
,
step
):
def
shuffle_step
(
entries
,
step
):
'''
'''
Shuffle the step
Shuffle the step
'''
'''
answer
=
[]
answer
=
[]
for
i
in
range
(
0
,
len
(
l
),
step
):
for
i
in
range
(
0
,
len
(
entries
),
step
):
sub
=
l
[
i
:
i
+
step
]
sub
=
entries
[
i
:
i
+
step
]
shuffle
(
sub
)
shuffle
(
sub
)
answer
+=
sub
answer
+=
sub
return
answer
return
answer
...
@@ -163,13 +167,13 @@ def get_char_input(data, char_dict, max_char_length):
...
@@ -163,13 +167,13 @@ def get_char_input(data, char_dict, max_char_length):
char_id
=
np
.
zeros
((
max_char_length
,
sequence_length
,
char_id
=
np
.
zeros
((
max_char_length
,
sequence_length
,
batch_size
),
dtype
=
np
.
int32
)
batch_size
),
dtype
=
np
.
int32
)
char_lengths
=
np
.
zeros
((
sequence_length
,
batch_size
),
dtype
=
np
.
float32
)
char_lengths
=
np
.
zeros
((
sequence_length
,
batch_size
),
dtype
=
np
.
float32
)
for
b
in
range
(
0
,
min
(
len
(
data
),
batch_size
)):
for
b
atch_idx
in
range
(
0
,
min
(
len
(
data
),
batch_size
)):
d
=
data
[
b
]
batch_data
=
data
[
b
atch_idx
]
for
s
in
range
(
0
,
min
(
len
(
d
),
sequence_length
)):
for
s
ample_idx
in
range
(
0
,
min
(
len
(
batch_data
),
sequence_length
)):
word
=
d
[
s
][
'word'
]
word
=
batch_data
[
sample_idx
][
'word'
]
char_lengths
[
s
,
b
]
=
min
(
len
(
word
),
max_char_length
)
char_lengths
[
s
ample_idx
,
batch_idx
]
=
min
(
len
(
word
),
max_char_length
)
for
i
in
range
(
0
,
min
(
len
(
word
),
max_char_length
)):
for
i
in
range
(
0
,
min
(
len
(
word
),
max_char_length
)):
char_id
[
i
,
s
,
b
]
=
get_id
(
char_dict
,
word
[
i
])
char_id
[
i
,
s
ample_idx
,
batch_idx
]
=
get_id
(
char_dict
,
word
[
i
])
return
char_id
,
char_lengths
return
char_id
,
char_lengths
...
@@ -180,26 +184,26 @@ def get_word_input(data, word_dict, embed, embed_dim):
...
@@ -180,26 +184,26 @@ def get_word_input(data, word_dict, embed, embed_dim):
batch_size
=
len
(
data
)
batch_size
=
len
(
data
)
max_sequence_length
=
max
(
len
(
d
)
for
d
in
data
)
max_sequence_length
=
max
(
len
(
d
)
for
d
in
data
)
sequence_length
=
max_sequence_length
sequence_length
=
max_sequence_length
t
=
np
.
zeros
((
max_sequence_length
,
batch_size
,
word_inpu
t
=
np
.
zeros
((
max_sequence_length
,
batch_size
,
embed_dim
),
dtype
=
np
.
float32
)
embed_dim
),
dtype
=
np
.
float32
)
ids
=
np
.
zeros
((
sequence_length
,
batch_size
),
dtype
=
np
.
int32
)
ids
=
np
.
zeros
((
sequence_length
,
batch_size
),
dtype
=
np
.
int32
)
masks
=
np
.
zeros
((
sequence_length
,
batch_size
),
dtype
=
np
.
float32
)
masks
=
np
.
zeros
((
sequence_length
,
batch_size
),
dtype
=
np
.
float32
)
lengths
=
np
.
zeros
([
batch_size
],
dtype
=
np
.
int32
)
lengths
=
np
.
zeros
([
batch_size
],
dtype
=
np
.
int32
)
for
b
in
range
(
0
,
min
(
len
(
data
),
batch_size
)):
for
b
atch_idx
in
range
(
0
,
min
(
len
(
data
),
batch_size
)):
d
=
data
[
b
]
batch_data
=
data
[
b
atch_idx
]
lengths
[
b
]
=
len
(
d
)
lengths
[
b
atch_idx
]
=
len
(
batch_data
)
for
s
in
range
(
0
,
min
(
len
(
d
),
sequence_length
)):
for
s
ample_idx
in
range
(
0
,
min
(
len
(
batch_data
),
sequence_length
)):
word
=
d
[
s
][
'word'
].
lower
()
word
=
batch_data
[
sample_idx
][
'word'
].
lower
()
if
word
in
word_dict
.
keys
():
if
word
in
word_dict
.
keys
():
t
[
s
,
b
]
=
embed
[
word_dict
[
word
]]
word_input
[
sample_idx
,
batch_idx
]
=
embed
[
word_dict
[
word
]]
ids
[
s
,
b
]
=
word_dict
[
word
]
ids
[
s
ample_idx
,
batch_idx
]
=
word_dict
[
word
]
masks
[
s
,
b
]
=
1
masks
[
s
ample_idx
,
batch_idx
]
=
1
t
=
np
.
reshape
(
t
,
(
-
1
,
embed_dim
))
word_inpu
t
=
np
.
reshape
(
word_inpu
t
,
(
-
1
,
embed_dim
))
return
t
,
ids
,
masks
,
lengths
return
word_inpu
t
,
ids
,
masks
,
lengths
def
get_word_index
(
tokens
,
char_index
):
def
get_word_index
(
tokens
,
char_index
):
...
...
examples/trials/ga_squad/evaluate.py
View file @
241e3254
...
@@ -19,6 +19,10 @@
...
@@ -19,6 +19,10 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
Evaluation scripts for QA model.
'''
from
__future__
import
print_function
from
__future__
import
print_function
from
collections
import
Counter
from
collections
import
Counter
import
string
import
string
...
@@ -68,8 +72,8 @@ def f1_score(prediction, ground_truth):
...
@@ -68,8 +72,8 @@ def f1_score(prediction, ground_truth):
return
0
return
0
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
precision
=
1.0
*
num_same
/
len
(
prediction_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
recall
=
1.0
*
num_same
/
len
(
ground_truth_tokens
)
f1
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
f1
_result
=
(
2
*
precision
*
recall
)
/
(
precision
+
recall
)
return
f1
return
f1
_result
def
exact_match_score
(
prediction
,
ground_truth
):
def
exact_match_score
(
prediction
,
ground_truth
):
'''
'''
...
@@ -91,28 +95,25 @@ def _evaluate(dataset, predictions):
...
@@ -91,28 +95,25 @@ def _evaluate(dataset, predictions):
'''
'''
Evaluate function.
Evaluate function.
'''
'''
f1
=
exact_match
=
total
=
0
f1
_result
=
exact_match
=
total
=
0
count
=
0
count
=
0
for
article
in
dataset
:
for
article
in
dataset
:
for
paragraph
in
article
[
'paragraphs'
]:
for
paragraph
in
article
[
'paragraphs'
]:
for
qa
in
paragraph
[
'qas'
]:
for
qa
_pair
in
paragraph
[
'qas'
]:
total
+=
1
total
+=
1
if
qa
[
'id'
]
not
in
predictions
:
if
qa_pair
[
'id'
]
not
in
predictions
:
message
=
'Unanswered question '
+
qa
[
'id'
]
+
\
' will receive score 0.'
#print(message, file=sys.stderr)
count
+=
1
count
+=
1
continue
continue
ground_truths
=
list
(
map
(
lambda
x
:
x
[
'text'
],
qa
[
'answers'
]))
ground_truths
=
list
(
map
(
lambda
x
:
x
[
'text'
],
qa
_pair
[
'answers'
]))
prediction
=
predictions
[
qa
[
'id'
]]
prediction
=
predictions
[
qa
_pair
[
'id'
]]
exact_match
+=
metric_max_over_ground_truths
(
exact_match
+=
metric_max_over_ground_truths
(
exact_match_score
,
prediction
,
ground_truths
)
exact_match_score
,
prediction
,
ground_truths
)
f1
+=
metric_max_over_ground_truths
(
f1
_result
+=
metric_max_over_ground_truths
(
f1_score
,
prediction
,
ground_truths
)
f1_score
,
prediction
,
ground_truths
)
print
(
'total'
,
total
,
'exact_match'
,
exact_match
,
'unanswer_question '
,
count
)
print
(
'total'
,
total
,
'exact_match'
,
exact_match
,
'unanswer_question '
,
count
)
exact_match
=
100.0
*
exact_match
/
total
exact_match
=
100.0
*
exact_match
/
total
f1
=
100.0
*
f1
/
total
f1
_result
=
100.0
*
f1
_result
/
total
return
{
'exact_match'
:
exact_match
,
'f1'
:
f1
}
return
{
'exact_match'
:
exact_match
,
'f1'
:
f1
_result
}
def
evaluate
(
data_file
,
pred_file
):
def
evaluate
(
data_file
,
pred_file
):
'''
'''
...
...
examples/trials/ga_squad/graph.py
View file @
241e3254
...
@@ -43,8 +43,8 @@ class Layer(object):
...
@@ -43,8 +43,8 @@ class Layer(object):
'''
'''
Layer class, which contains the information of graph.
Layer class, which contains the information of graph.
'''
'''
def
__init__
(
self
,
graph_type
,
input
=
None
,
output
=
None
,
size
=
None
):
def
__init__
(
self
,
graph_type
,
input
s
=
None
,
output
=
None
,
size
=
None
):
self
.
input
=
input
if
input
is
not
None
else
[]
self
.
input
=
input
s
if
input
s
is
not
None
else
[]
self
.
output
=
output
if
output
is
not
None
else
[]
self
.
output
=
output
if
output
is
not
None
else
[]
self
.
graph_type
=
graph_type
self
.
graph_type
=
graph_type
self
.
is_delete
=
False
self
.
is_delete
=
False
...
@@ -117,11 +117,11 @@ class Graph(object):
...
@@ -117,11 +117,11 @@ class Graph(object):
'''
'''
Customed Graph class.
Customed Graph class.
'''
'''
def
__init__
(
self
,
max_layer_num
,
input
,
output
,
hide
):
def
__init__
(
self
,
max_layer_num
,
input
s
,
output
,
hide
):
self
.
layers
=
[]
self
.
layers
=
[]
self
.
max_layer_num
=
max_layer_num
self
.
max_layer_num
=
max_layer_num
for
layer
in
input
:
for
layer
in
input
s
:
self
.
layers
.
append
(
layer
)
self
.
layers
.
append
(
layer
)
for
layer
in
output
:
for
layer
in
output
:
self
.
layers
.
append
(
layer
)
self
.
layers
.
append
(
layer
)
...
@@ -240,7 +240,7 @@ class Graph(object):
...
@@ -240,7 +240,7 @@ class Graph(object):
if
graph_type
<=
1
:
if
graph_type
<=
1
:
new_id
=
len
(
layers
)
new_id
=
len
(
layers
)
out
=
random
.
choice
(
layers_out
)
out
=
random
.
choice
(
layers_out
)
input
=
[]
input
s
=
[]
output
=
[
out
]
output
=
[
out
]
pos
=
random
.
randint
(
0
,
len
(
layers
[
out
].
input
)
-
1
)
pos
=
random
.
randint
(
0
,
len
(
layers
[
out
].
input
)
-
1
)
last_in
=
layers
[
out
].
input
[
pos
]
last_in
=
layers
[
out
].
input
[
pos
]
...
@@ -250,13 +250,13 @@ class Graph(object):
...
@@ -250,13 +250,13 @@ class Graph(object):
if
graph_type
==
1
:
if
graph_type
==
1
:
layers
[
last_in
].
output
.
remove
(
out
)
layers
[
last_in
].
output
.
remove
(
out
)
layers
[
last_in
].
output
.
append
(
new_id
)
layers
[
last_in
].
output
.
append
(
new_id
)
input
=
[
last_in
]
input
s
=
[
last_in
]
lay
=
Layer
(
graph_type
=
layer_type
,
input
=
input
,
output
=
output
)
lay
=
Layer
(
graph_type
=
layer_type
,
input
s
=
input
s
,
output
=
output
)
while
len
(
input
)
<
lay
.
input_size
:
while
len
(
input
s
)
<
lay
.
input_size
:
layer1
=
random
.
choice
(
layers_in
)
layer1
=
random
.
choice
(
layers_in
)
input
.
append
(
layer1
)
input
s
.
append
(
layer1
)
layers
[
layer1
].
output
.
append
(
new_id
)
layers
[
layer1
].
output
.
append
(
new_id
)
lay
.
input
=
input
lay
.
input
=
input
s
layers
.
append
(
lay
)
layers
.
append
(
lay
)
else
:
else
:
layer1
=
random
.
choice
(
layers_del
)
layer1
=
random
.
choice
(
layers_del
)
...
...
examples/trials/ga_squad/train_model.py
View file @
241e3254
...
@@ -32,6 +32,7 @@ from graph_to_tf import graph_to_network
...
@@ -32,6 +32,7 @@ from graph_to_tf import graph_to_network
class
GAGConfig
:
class
GAGConfig
:
"""The class for model hyper-parameter configuration."""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
batch_size
=
128
self
.
batch_size
=
128
...
@@ -56,6 +57,7 @@ class GAGConfig:
...
@@ -56,6 +57,7 @@ class GAGConfig:
class
GAG
:
class
GAG
:
"""The class for the computation graph based QA model."""
def
__init__
(
self
,
cfg
,
embed
,
graph
):
def
__init__
(
self
,
cfg
,
embed
,
graph
):
self
.
cfg
=
cfg
self
.
cfg
=
cfg
self
.
embed
=
embed
self
.
embed
=
embed
...
@@ -83,6 +85,7 @@ class GAG:
...
@@ -83,6 +85,7 @@ class GAG:
def
build_net
(
self
,
is_training
):
def
build_net
(
self
,
is_training
):
"""Build the whole neural network for the QA model."""
cfg
=
self
.
cfg
cfg
=
self
.
cfg
with
tf
.
device
(
'/cpu:0'
):
with
tf
.
device
(
'/cpu:0'
):
word_embed
=
tf
.
get_variable
(
word_embed
=
tf
.
get_variable
(
...
@@ -202,6 +205,7 @@ class GAG:
...
@@ -202,6 +205,7 @@ class GAG:
if
is_training
:
if
is_training
:
def
label_smoothing
(
inputs
,
masks
,
epsilon
=
0.1
):
def
label_smoothing
(
inputs
,
masks
,
epsilon
=
0.1
):
"""Modify target for label smoothing."""
epsilon
=
cfg
.
labelsmoothing
epsilon
=
cfg
.
labelsmoothing
num_of_channel
=
tf
.
shape
(
inputs
)[
-
1
]
# number of channels
num_of_channel
=
tf
.
shape
(
inputs
)[
-
1
]
# number of channels
inputs
=
tf
.
cast
(
inputs
,
tf
.
float32
)
inputs
=
tf
.
cast
(
inputs
,
tf
.
float32
)
...
@@ -229,6 +233,7 @@ class GAG:
...
@@ -229,6 +233,7 @@ class GAG:
return
tf
.
stack
([
self
.
begin_prob
,
self
.
end_prob
])
return
tf
.
stack
([
self
.
begin_prob
,
self
.
end_prob
])
def
build_char_states
(
self
,
char_embed
,
is_training
,
reuse
,
char_ids
,
char_lengths
):
def
build_char_states
(
self
,
char_embed
,
is_training
,
reuse
,
char_ids
,
char_lengths
):
"""Build char embedding network for the QA model."""
max_char_length
=
self
.
cfg
.
max_char_length
max_char_length
=
self
.
cfg
.
max_char_length
inputs
=
dropout
(
tf
.
nn
.
embedding_lookup
(
char_embed
,
char_ids
),
inputs
=
dropout
(
tf
.
nn
.
embedding_lookup
(
char_embed
,
char_ids
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment