Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
073219b4
"...composable_kernel.git" did not exist on "2ee1c0a70a305b00770c51a6ad585895992b327c"
Commit
073219b4
authored
Jan 21, 2020
by
Lysandre
Browse files
Manage impossible examples SQuAD v2
parent
983c484f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
2 deletions
+8
-2
src/transformers/data/processors/squad.py
src/transformers/data/processors/squad.py
+8
-2
No files found.
src/transformers/data/processors/squad.py
View file @
073219b4
...
@@ -242,6 +242,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
...
@@ -242,6 +242,7 @@ def squad_convert_example_to_features(example, max_seq_length, doc_stride, max_q
token_to_orig_map
=
span
[
"token_to_orig_map"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
],
start_position
=
start_position
,
start_position
=
start_position
,
end_position
=
end_position
,
end_position
=
end_position
,
is_impossible
=
span_is_impossible
)
)
)
)
return
features
return
features
...
@@ -332,6 +333,7 @@ def squad_convert_examples_to_features(
...
@@ -332,6 +333,7 @@ def squad_convert_examples_to_features(
all_token_type_ids
=
torch
.
tensor
([
f
.
token_type_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_token_type_ids
=
torch
.
tensor
([
f
.
token_type_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_cls_index
=
torch
.
tensor
([
f
.
cls_index
for
f
in
features
],
dtype
=
torch
.
long
)
all_cls_index
=
torch
.
tensor
([
f
.
cls_index
for
f
in
features
],
dtype
=
torch
.
long
)
all_p_mask
=
torch
.
tensor
([
f
.
p_mask
for
f
in
features
],
dtype
=
torch
.
float
)
all_p_mask
=
torch
.
tensor
([
f
.
p_mask
for
f
in
features
],
dtype
=
torch
.
float
)
all_is_impossible
=
torch
.
tensor
([
f
.
is_impossible
for
f
in
features
],
dtype
=
torch
.
float
)
if
not
is_training
:
if
not
is_training
:
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
...
@@ -349,6 +351,7 @@ def squad_convert_examples_to_features(
...
@@ -349,6 +351,7 @@ def squad_convert_examples_to_features(
all_end_positions
,
all_end_positions
,
all_cls_index
,
all_cls_index
,
all_p_mask
,
all_p_mask
,
all_is_impossible
)
)
return
features
,
dataset
return
features
,
dataset
...
@@ -369,6 +372,7 @@ def squad_convert_examples_to_features(
...
@@ -369,6 +372,7 @@ def squad_convert_examples_to_features(
"end_position"
:
ex
.
end_position
,
"end_position"
:
ex
.
end_position
,
"cls_index"
:
ex
.
cls_index
,
"cls_index"
:
ex
.
cls_index
,
"p_mask"
:
ex
.
p_mask
,
"p_mask"
:
ex
.
p_mask
,
"is_impossible"
:
ex
.
is_impossible
},
},
)
)
...
@@ -376,7 +380,7 @@ def squad_convert_examples_to_features(
...
@@ -376,7 +380,7 @@ def squad_convert_examples_to_features(
gen
,
gen
,
(
(
{
"input_ids"
:
tf
.
int32
,
"attention_mask"
:
tf
.
int32
,
"token_type_ids"
:
tf
.
int32
},
{
"input_ids"
:
tf
.
int32
,
"attention_mask"
:
tf
.
int32
,
"token_type_ids"
:
tf
.
int32
},
{
"start_position"
:
tf
.
int64
,
"end_position"
:
tf
.
int64
,
"cls_index"
:
tf
.
int64
,
"p_mask"
:
tf
.
int32
},
{
"start_position"
:
tf
.
int64
,
"end_position"
:
tf
.
int64
,
"cls_index"
:
tf
.
int64
,
"p_mask"
:
tf
.
int32
,
"is_impossible"
:
tf
.
int32
},
),
),
(
(
{
{
...
@@ -389,6 +393,7 @@ def squad_convert_examples_to_features(
...
@@ -389,6 +393,7 @@ def squad_convert_examples_to_features(
"end_position"
:
tf
.
TensorShape
([]),
"end_position"
:
tf
.
TensorShape
([]),
"cls_index"
:
tf
.
TensorShape
([]),
"cls_index"
:
tf
.
TensorShape
([]),
"p_mask"
:
tf
.
TensorShape
([
None
]),
"p_mask"
:
tf
.
TensorShape
([
None
]),
"is_impossible"
:
tf
.
TensorShape
([])
},
},
),
),
)
)
...
@@ -658,6 +663,7 @@ class SquadFeatures(object):
...
@@ -658,6 +663,7 @@ class SquadFeatures(object):
token_to_orig_map
,
token_to_orig_map
,
start_position
,
start_position
,
end_position
,
end_position
,
is_impossible
):
):
self
.
input_ids
=
input_ids
self
.
input_ids
=
input_ids
self
.
attention_mask
=
attention_mask
self
.
attention_mask
=
attention_mask
...
@@ -674,7 +680,7 @@ class SquadFeatures(object):
...
@@ -674,7 +680,7 @@ class SquadFeatures(object):
self
.
start_position
=
start_position
self
.
start_position
=
start_position
self
.
end_position
=
end_position
self
.
end_position
=
end_position
self
.
is_impossible
=
is_impossible
class
SquadResult
(
object
):
class
SquadResult
(
object
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment