Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d4614729
Commit
d4614729
authored
Dec 13, 2019
by
Lysandre
Browse files
return for SQuAD [BLACKED]
parent
f24a228a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
172 additions
and
110 deletions
+172
-110
transformers/data/processors/glue.py
transformers/data/processors/glue.py
+1
-1
transformers/data/processors/squad.py
transformers/data/processors/squad.py
+171
-109
No files found.
transformers/data/processors/glue.py
View file @
d4614729
...
...
@@ -133,7 +133,7 @@ def glue_convert_examples_to_features(examples, tokenizer,
if
is_tf_available
()
and
is_tf_dataset
:
def
gen
():
for
ex
in
features
:
yield
({
'input_ids'
:
ex
.
input_ids
,
yield
({
'input_ids'
:
ex
.
input_ids
,
'attention_mask'
:
ex
.
attention_mask
,
'token_type_ids'
:
ex
.
token_type_ids
},
ex
.
label
)
...
...
transformers/data/processors/squad.py
View file @
d4614729
...
...
@@ -18,19 +18,20 @@ if is_tf_available():
logger
=
logging
.
getLogger
(
__name__
)
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
def
_improve_answer_span
(
doc_tokens
,
input_start
,
input_end
,
tokenizer
,
orig_answer_text
):
"""Returns tokenized answer spans that better match the annotated answer."""
tok_answer_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_answer_text
))
for
new_start
in
range
(
input_start
,
input_end
+
1
):
for
new_end
in
range
(
input_end
,
new_start
-
1
,
-
1
):
text_span
=
" "
.
join
(
doc_tokens
[
new_start
:
(
new_end
+
1
)])
text_span
=
" "
.
join
(
doc_tokens
[
new_start
:
(
new_end
+
1
)])
if
text_span
==
tok_answer_text
:
return
(
new_start
,
new_end
)
return
(
input_start
,
input_end
)
def
_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
best_score
=
None
...
...
@@ -50,10 +51,11 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
return
cur_span_index
==
best_span_index
def
_new_check_is_max_context
(
doc_spans
,
cur_span_index
,
position
):
"""Check if this is the 'max context' doc span for the token."""
# if len(doc_spans) == 1:
# return True
# return True
best_score
=
None
best_span_index
=
None
for
(
span_index
,
doc_span
)
in
enumerate
(
doc_spans
):
...
...
@@ -71,14 +73,16 @@ def _new_check_is_max_context(doc_spans, cur_span_index, position):
return
cur_span_index
==
best_span_index
def
_is_whitespace
(
c
):
if
c
==
" "
or
c
==
"
\t
"
or
c
==
"
\r
"
or
c
==
"
\n
"
or
ord
(
c
)
==
0x202F
:
return
True
return
False
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
return_dataset
=
False
):
def
squad_convert_examples_to_features
(
examples
,
tokenizer
,
max_seq_length
,
doc_stride
,
max_query_length
,
is_training
,
return_dataset
=
False
):
"""
Converts a list of examples into a list of features that can be directly given as input to a model.
It is model-dependant and takes advantage of many of the tokenizer's features to create the model's inputs.
...
...
@@ -112,7 +116,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
)
"""
# Defining helper methods
# Defining helper methods
unique_id
=
1000000000
features
=
[]
...
...
@@ -123,13 +127,12 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
end_position
=
example
.
end_position
# If the answer cannot be found in the text, then skip this example.
actual_text
=
" "
.
join
(
example
.
doc_tokens
[
start_position
:
(
end_position
+
1
)])
actual_text
=
" "
.
join
(
example
.
doc_tokens
[
start_position
:
(
end_position
+
1
)])
cleaned_answer_text
=
" "
.
join
(
whitespace_tokenize
(
example
.
answer_text
))
if
actual_text
.
find
(
cleaned_answer_text
)
==
-
1
:
logger
.
warning
(
"Could not find answer: '%s' vs. '%s'"
,
actual_text
,
cleaned_answer_text
)
continue
tok_to_orig_index
=
[]
orig_to_tok_index
=
[]
all_doc_tokens
=
[]
...
...
@@ -140,7 +143,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
if
is_training
and
not
example
.
is_impossible
:
tok_start_position
=
orig_to_tok_index
[
example
.
start_position
]
if
example
.
end_position
<
len
(
example
.
doc_tokens
)
-
1
:
...
...
@@ -153,36 +155,41 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
)
spans
=
[]
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
truncated_query
=
tokenizer
.
encode
(
example
.
question_text
,
add_special_tokens
=
False
,
max_length
=
max_query_length
)
sequence_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_single_sentence
sequence_pair_added_tokens
=
tokenizer
.
max_len
-
tokenizer
.
max_len_sentences_pair
span_doc_tokens
=
all_doc_tokens
while
len
(
spans
)
*
doc_stride
<
len
(
all_doc_tokens
):
encoded_dict
=
tokenizer
.
encode_plus
(
truncated_query
if
tokenizer
.
padding_side
==
"right"
else
span_doc_tokens
,
span_doc_tokens
if
tokenizer
.
padding_side
==
"right"
else
truncated_query
,
max_length
=
max_seq_length
,
return_overflowing_tokens
=
True
,
truncated_query
if
tokenizer
.
padding_side
==
"right"
else
span_doc_tokens
,
span_doc_tokens
if
tokenizer
.
padding_side
==
"right"
else
truncated_query
,
max_length
=
max_seq_length
,
return_overflowing_tokens
=
True
,
pad_to_max_length
=
True
,
stride
=
max_seq_length
-
doc_stride
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
truncation_strategy
=
'
only_second
'
if
tokenizer
.
padding_side
==
"right"
else
'
only_first
'
truncation_strategy
=
"
only_second
"
if
tokenizer
.
padding_side
==
"right"
else
"
only_first
"
,
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
)
paragraph_len
=
min
(
len
(
all_doc_tokens
)
-
len
(
spans
)
*
doc_stride
,
max_seq_length
-
len
(
truncated_query
)
-
sequence_pair_added_tokens
,
)
if
tokenizer
.
pad_token_id
in
encoded_dict
[
'
input_ids
'
]:
non_padded_ids
=
encoded_dict
[
'
input_ids
'
][:
encoded_dict
[
'
input_ids
'
].
index
(
tokenizer
.
pad_token_id
)]
if
tokenizer
.
pad_token_id
in
encoded_dict
[
"
input_ids
"
]:
non_padded_ids
=
encoded_dict
[
"
input_ids
"
][:
encoded_dict
[
"
input_ids
"
].
index
(
tokenizer
.
pad_token_id
)]
else
:
non_padded_ids
=
encoded_dict
[
'
input_ids
'
]
non_padded_ids
=
encoded_dict
[
"
input_ids
"
]
tokens
=
tokenizer
.
convert_ids_to_tokens
(
non_padded_ids
)
token_to_orig_map
=
{}
for
i
in
range
(
paragraph_len
):
index
=
len
(
truncated_query
)
+
sequence_added_tokens
+
i
if
tokenizer
.
padding_side
==
"right"
else
i
index
=
len
(
truncated_query
)
+
sequence_added_tokens
+
i
if
tokenizer
.
padding_side
==
"right"
else
i
token_to_orig_map
[
index
]
=
tok_to_orig_index
[
len
(
spans
)
*
doc_stride
+
i
]
encoded_dict
[
"paragraph_len"
]
=
paragraph_len
...
...
@@ -202,16 +209,20 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
for
doc_span_index
in
range
(
len
(
spans
)):
for
j
in
range
(
spans
[
doc_span_index
][
"paragraph_len"
]):
is_max_context
=
_new_check_is_max_context
(
spans
,
doc_span_index
,
doc_span_index
*
doc_stride
+
j
)
index
=
j
if
tokenizer
.
padding_side
==
"left"
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
index
=
(
j
if
tokenizer
.
padding_side
==
"left"
else
spans
[
doc_span_index
][
"truncated_query_with_special_tokens_length"
]
+
j
)
spans
[
doc_span_index
][
"token_is_max_context"
][
index
]
=
is_max_context
for
span
in
spans
:
# Identify the position of the CLS token
cls_index
=
span
[
'
input_ids
'
].
index
(
tokenizer
.
cls_token_id
)
cls_index
=
span
[
"
input_ids
"
].
index
(
tokenizer
.
cls_token_id
)
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
# Original TF implem also keep the classification token (set to 0) (not sure why...)
p_mask
=
np
.
array
(
span
[
'
token_type_ids
'
])
p_mask
=
np
.
array
(
span
[
"
token_type_ids
"
])
p_mask
=
np
.
minimum
(
p_mask
,
1
)
...
...
@@ -224,7 +235,6 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
# Set the CLS index to '0'
p_mask
[
cls_index
]
=
0
span_is_impossible
=
example
.
is_impossible
start_position
=
0
end_position
=
0
...
...
@@ -247,55 +257,99 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
doc_offset
=
0
else
:
doc_offset
=
len
(
truncated_query
)
+
sequence_added_tokens
start_position
=
tok_start_position
-
doc_start
+
doc_offset
end_position
=
tok_end_position
-
doc_start
+
doc_offset
features
.
append
(
SquadFeatures
(
span
[
'input_ids'
],
span
[
'attention_mask'
],
span
[
'token_type_ids'
],
cls_index
,
p_mask
.
tolist
(),
example_index
=
example_index
,
unique_id
=
unique_id
,
paragraph_len
=
span
[
'paragraph_len'
],
token_is_max_context
=
span
[
"token_is_max_context"
],
tokens
=
span
[
"tokens"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
],
start_position
=
start_position
,
end_position
=
end_position
))
features
.
append
(
SquadFeatures
(
span
[
"input_ids"
],
span
[
"attention_mask"
],
span
[
"token_type_ids"
],
cls_index
,
p_mask
.
tolist
(),
example_index
=
example_index
,
unique_id
=
unique_id
,
paragraph_len
=
span
[
"paragraph_len"
],
token_is_max_context
=
span
[
"token_is_max_context"
],
tokens
=
span
[
"tokens"
],
token_to_orig_map
=
span
[
"token_to_orig_map"
],
start_position
=
start_position
,
end_position
=
end_position
,
)
)
unique_id
+=
1
if
return_dataset
==
'
pt
'
:
if
return_dataset
==
"
pt
"
:
if
not
is_torch_available
():
raise
ImportError
(
"Pytorch must be installed to return a pytorch dataset."
)
# Convert to Tensors and build dataset
all_input_ids
=
torch
.
tensor
([
f
.
input_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_
input
_mask
=
torch
.
tensor
([
f
.
attention_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_
segment
_ids
=
torch
.
tensor
([
f
.
token_type_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_
attention
_mask
s
=
torch
.
tensor
([
f
.
attention_mask
for
f
in
features
],
dtype
=
torch
.
long
)
all_
token_type
_ids
=
torch
.
tensor
([
f
.
token_type_ids
for
f
in
features
],
dtype
=
torch
.
long
)
all_cls_index
=
torch
.
tensor
([
f
.
cls_index
for
f
in
features
],
dtype
=
torch
.
long
)
all_p_mask
=
torch
.
tensor
([
f
.
p_mask
for
f
in
features
],
dtype
=
torch
.
float
)
if
not
is_training
:
all_example_index
=
torch
.
arange
(
all_input_ids
.
size
(
0
),
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_example_index
,
all_cls_index
,
all_p_mask
)
dataset
=
TensorDataset
(
all_input_ids
,
all_attention_masks
,
all_token_type_ids
,
all_example_index
,
all_cls_index
,
all_p_mask
)
else
:
all_start_positions
=
torch
.
tensor
([
f
.
start_position
for
f
in
features
],
dtype
=
torch
.
long
)
all_end_positions
=
torch
.
tensor
([
f
.
end_position
for
f
in
features
],
dtype
=
torch
.
long
)
dataset
=
TensorDataset
(
all_input_ids
,
all_input_mask
,
all_segment_ids
,
all_start_positions
,
all_end_positions
,
all_cls_index
,
all_p_mask
)
dataset
=
TensorDataset
(
all_input_ids
,
all_attention_masks
,
all_token_type_ids
,
all_start_positions
,
all_end_positions
,
all_cls_index
,
all_p_mask
,
)
return
features
,
dataset
elif
return_dataset
==
"tf"
:
if
not
is_tf_available
():
raise
ImportError
(
"TensorFlow must be installed to return a TensorFlow dataset."
)
def
gen
():
for
ex
in
features
:
yield
(
{
"input_ids"
:
ex
.
input_ids
,
"attention_mask"
:
ex
.
attention_mask
,
"token_type_ids"
:
ex
.
token_type_ids
,
},
{
"start_position"
:
ex
.
start_position
,
"end_position"
:
ex
.
end_position
,
"cls_index"
:
ex
.
cls_index
,
"p_mask"
:
ex
.
p_mask
,
}
)
return
tf
.
data
.
Dataset
.
from_generator
(
gen
,
(
{
"input_ids"
:
tf
.
int32
,
"attention_mask"
:
tf
.
int32
,
"token_type_ids"
:
tf
.
int32
},
{
"start_position"
:
tf
.
int64
,
"end_position"
:
tf
.
int64
,
"cls_index"
:
tf
.
int64
,
"p_mask"
:
tf
.
int32
},
),
(
{
"input_ids"
:
tf
.
TensorShape
([
None
]),
"attention_mask"
:
tf
.
TensorShape
([
None
]),
"token_type_ids"
:
tf
.
TensorShape
([
None
]),
},
{
"start_position"
:
tf
.
TensorShape
([]),
"end_position"
:
tf
.
TensorShape
([]),
"cls_index"
:
tf
.
TensorShape
([]),
"p_mask"
:
tf
.
TensorShape
([
None
]),
},
),
)
return
features
...
...
@@ -305,31 +359,32 @@ class SquadProcessor(DataProcessor):
Processor for the SQuAD data set.
Overriden by SquadV1Processor and SquadV2Processor, used by the version 1.1 and version 2.0 of SQuAD, respectively.
"""
train_file
=
None
dev_file
=
None
def
_get_example_from_tensor_dict
(
self
,
tensor_dict
,
evaluate
=
False
):
if
not
evaluate
:
answer
=
tensor_dict
[
'
answers
'
][
'
text
'
][
0
].
numpy
().
decode
(
'
utf-8
'
)
answer_start
=
tensor_dict
[
'
answers
'
][
'
answer_start
'
][
0
].
numpy
()
answer
=
tensor_dict
[
"
answers
"
][
"
text
"
][
0
].
numpy
().
decode
(
"
utf-8
"
)
answer_start
=
tensor_dict
[
"
answers
"
][
"
answer_start
"
][
0
].
numpy
()
answers
=
[]
else
:
answers
=
[
{
"answer_start"
:
start
.
numpy
(),
"text"
:
text
.
numpy
().
decode
(
'utf-8'
)
}
for
start
,
text
in
zip
(
tensor_dict
[
'answers'
][
"answer_start"
],
tensor_dict
[
'answers'
][
"text"
])
]
answers
=
[
{
"answer_start"
:
start
.
numpy
(),
"text"
:
text
.
numpy
().
decode
(
"utf-8"
)}
for
start
,
text
in
zip
(
tensor_dict
[
"answers"
][
"answer_start"
],
tensor_dict
[
"answers"
][
"text"
]
)
]
answer
=
None
answer_start
=
None
return
SquadExample
(
qas_id
=
tensor_dict
[
'
id
'
].
numpy
().
decode
(
"utf-8"
),
question_text
=
tensor_dict
[
'
question
'
].
numpy
().
decode
(
'
utf-8
'
),
context_text
=
tensor_dict
[
'
context
'
].
numpy
().
decode
(
'
utf-8
'
),
qas_id
=
tensor_dict
[
"
id
"
].
numpy
().
decode
(
"utf-8"
),
question_text
=
tensor_dict
[
"
question
"
].
numpy
().
decode
(
"
utf-8
"
),
context_text
=
tensor_dict
[
"
context
"
].
numpy
().
decode
(
"
utf-8
"
),
answer_text
=
answer
,
start_position_character
=
answer_start
,
title
=
tensor_dict
[
'
title
'
].
numpy
().
decode
(
'
utf-8
'
),
answers
=
answers
title
=
tensor_dict
[
"
title
"
].
numpy
().
decode
(
"
utf-8
"
),
answers
=
answers
,
)
def
get_examples_from_dataset
(
self
,
dataset
,
evaluate
=
False
):
...
...
@@ -359,7 +414,7 @@ class SquadProcessor(DataProcessor):
examples
=
[]
for
tensor_dict
in
tqdm
(
dataset
):
examples
.
append
(
self
.
_get_example_from_tensor_dict
(
tensor_dict
,
evaluate
=
evaluate
))
examples
.
append
(
self
.
_get_example_from_tensor_dict
(
tensor_dict
,
evaluate
=
evaluate
))
return
examples
...
...
@@ -379,7 +434,9 @@ class SquadProcessor(DataProcessor):
if
self
.
train_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
train_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
train_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
"utf-8"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"train"
)
...
...
@@ -397,8 +454,10 @@ class SquadProcessor(DataProcessor):
if
self
.
dev_file
is
None
:
raise
ValueError
(
"SquadProcessor should be instantiated via SquadV1Processor or SquadV2Processor"
)
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
'utf-8'
)
as
reader
:
with
open
(
os
.
path
.
join
(
data_dir
,
self
.
dev_file
if
filename
is
None
else
filename
),
"r"
,
encoding
=
"utf-8"
)
as
reader
:
input_data
=
json
.
load
(
reader
)[
"data"
]
return
self
.
_create_examples
(
input_data
,
"dev"
)
...
...
@@ -406,7 +465,7 @@ class SquadProcessor(DataProcessor):
is_training
=
set_type
==
"train"
examples
=
[]
for
entry
in
tqdm
(
input_data
):
title
=
entry
[
'
title
'
]
title
=
entry
[
"
title
"
]
for
paragraph
in
entry
[
"paragraphs"
]:
context_text
=
paragraph
[
"context"
]
for
qa
in
paragraph
[
"qas"
]:
...
...
@@ -415,7 +474,7 @@ class SquadProcessor(DataProcessor):
start_position_character
=
None
answer_text
=
None
answers
=
[]
if
"is_impossible"
in
qa
:
is_impossible
=
qa
[
"is_impossible"
]
else
:
...
...
@@ -424,8 +483,8 @@ class SquadProcessor(DataProcessor):
if
not
is_impossible
:
if
is_training
:
answer
=
qa
[
"answers"
][
0
]
answer_text
=
answer
[
'
text
'
]
start_position_character
=
answer
[
'
answer_start
'
]
answer_text
=
answer
[
"
text
"
]
start_position_character
=
answer
[
"
answer_start
"
]
else
:
answers
=
qa
[
"answers"
]
...
...
@@ -437,12 +496,13 @@ class SquadProcessor(DataProcessor):
start_position_character
=
start_position_character
,
title
=
title
,
is_impossible
=
is_impossible
,
answers
=
answers
answers
=
answers
,
)
examples
.
append
(
example
)
return
examples
class
SquadV1Processor
(
SquadProcessor
):
train_file
=
"train-v1.1.json"
dev_file
=
"dev-v1.1.json"
...
...
@@ -451,7 +511,7 @@ class SquadV1Processor(SquadProcessor):
class
SquadV2Processor
(
SquadProcessor
):
train_file
=
"train-v2.0.json"
dev_file
=
"dev-v2.0.json"
class
SquadExample
(
object
):
"""
...
...
@@ -468,21 +528,23 @@ class SquadExample(object):
is_impossible: False by default, set to True if the example has no possible answer.
"""
def
__init__
(
self
,
qas_id
,
question_text
,
context_text
,
answer_text
,
start_position_character
,
title
,
answers
=
[],
is_impossible
=
False
):
def
__init__
(
self
,
qas_id
,
question_text
,
context_text
,
answer_text
,
start_position_character
,
title
,
answers
=
[],
is_impossible
=
False
,
):
self
.
qas_id
=
qas_id
self
.
question_text
=
question_text
self
.
context_text
=
context_text
self
.
answer_text
=
answer_text
self
.
title
=
title
self
.
is_impossible
=
is_impossible
self
.
is_impossible
=
is_impossible
self
.
answers
=
answers
self
.
start_position
,
self
.
end_position
=
0
,
0
...
...
@@ -537,24 +599,23 @@ class SquadFeatures(object):
end_position: end of the answer token index
"""
def
__init__
(
self
,
input_ids
,
attention_mask
,
token_type_ids
,
cls_index
,
p_mask
,
example_index
,
unique_id
,
paragraph_len
,
token_is_max_context
,
tokens
,
token_to_orig_map
,
start_position
,
end_position
):
self
.
input_ids
=
input_ids
def
__init__
(
self
,
input_ids
,
attention_mask
,
token_type_ids
,
cls_index
,
p_mask
,
example_index
,
unique_id
,
paragraph_len
,
token_is_max_context
,
tokens
,
token_to_orig_map
,
start_position
,
end_position
,
):
self
.
input_ids
=
input_ids
self
.
attention_mask
=
attention_mask
self
.
token_type_ids
=
token_type_ids
self
.
cls_index
=
cls_index
...
...
@@ -580,12 +641,13 @@ class SquadResult(object):
start_logits: The logits corresponding to the start of the answer
end_logits: The logits corresponding to the end of the answer
"""
def
__init__
(
self
,
unique_id
,
start_logits
,
end_logits
,
start_top_index
=
None
,
end_top_index
=
None
,
cls_logits
=
None
):
self
.
start_logits
=
start_logits
self
.
end_logits
=
end_logits
self
.
unique_id
=
unique_id
if
start_top_index
:
self
.
start_top_index
=
start_top_index
self
.
end_top_index
=
end_top_index
self
.
cls_logits
=
cls_logits
\ No newline at end of file
self
.
cls_logits
=
cls_logits
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment