Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2a2832ce
Unverified
Commit
2a2832ce
authored
Aug 29, 2019
by
erenup
Committed by
GitHub
Aug 29, 2019
Browse files
Merge pull request #1 from erenup/run_multiple_choice
roberta, xlnet for multiple choice
parents
e384ae2b
942d3f4b
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1080 additions
and
1 deletion
+1080
-1
examples/single_model_scripts/run_multiple_choice.py
examples/single_model_scripts/run_multiple_choice.py
+542
-0
examples/single_model_scripts/utils_multiple_choice.py
examples/single_model_scripts/utils_multiple_choice.py
+452
-0
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-1
pytorch_transformers/modeling_roberta.py
pytorch_transformers/modeling_roberta.py
+40
-0
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+44
-0
No files found.
examples/single_model_scripts/run_multiple_choice.py
0 → 100644
View file @
2a2832ce
This diff is collapsed.
Click to expand it.
examples/single_model_scripts/utils_multiple_choice.py
0 → 100644
View file @
2a2832ce
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT classification fine-tuning: utilities to work with GLUE tasks """
from
__future__
import
absolute_import
,
division
,
print_function
import
logging
import
os
import
sys
from
io
import
open
import
json
import
csv
import
glob
import
tqdm
logger
=
logging
.
getLogger
(
__name__
)
class
InputExample
(
object
):
"""A single training/test example for multiple choice"""
def
__init__
(
self
,
example_id
,
question
,
contexts
,
endings
,
label
=
None
):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self
.
example_id
=
example_id
self
.
question
=
question
self
.
contexts
=
contexts
self
.
endings
=
endings
self
.
label
=
label
class
InputFeatures
(
object
):
def
__init__
(
self
,
example_id
,
choices_features
,
label
):
self
.
example_id
=
example_id
self
.
choices_features
=
[
{
'input_ids'
:
input_ids
,
'input_mask'
:
input_mask
,
'segment_ids'
:
segment_ids
}
for
_
,
input_ids
,
input_mask
,
segment_ids
in
choices_features
]
self
.
label
=
label
class
DataProcessor
(
object
):
"""Base class for data converters for sequence classification data sets."""
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
raise
NotImplementedError
()
def
get_dev_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the dev set."""
raise
NotImplementedError
()
def
get_test_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the dev set."""
raise
NotImplementedError
()
def
get_labels
(
self
):
"""Gets the list of labels for this data set."""
raise
NotImplementedError
()
class
RaceProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} train"
.
format
(
data_dir
))
high
=
os
.
path
.
join
(
data_dir
,
'train/high'
)
middle
=
os
.
path
.
join
(
data_dir
,
'train/middle'
)
high
=
self
.
_read_txt
(
high
)
middle
=
self
.
_read_txt
(
middle
)
return
self
.
_create_examples
(
high
+
middle
,
'train'
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} dev"
.
format
(
data_dir
))
high
=
os
.
path
.
join
(
data_dir
,
'dev/high'
)
middle
=
os
.
path
.
join
(
data_dir
,
'dev/middle'
)
high
=
self
.
_read_txt
(
high
)
middle
=
self
.
_read_txt
(
middle
)
return
self
.
_create_examples
(
high
+
middle
,
'dev'
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} test"
.
format
(
data_dir
))
high
=
os
.
path
.
join
(
data_dir
,
'test/high'
)
middle
=
os
.
path
.
join
(
data_dir
,
'test/middle'
)
high
=
self
.
_read_txt
(
high
)
middle
=
self
.
_read_txt
(
middle
)
return
self
.
_create_examples
(
high
+
middle
,
'test'
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
,
"2"
,
"3"
]
def
_read_txt
(
self
,
input_dir
):
lines
=
[]
files
=
glob
.
glob
(
input_dir
+
"/*txt"
)
for
file
in
tqdm
.
tqdm
(
files
,
desc
=
"read files"
):
with
open
(
file
,
'r'
,
encoding
=
'utf-8'
)
as
fin
:
data_raw
=
json
.
load
(
fin
)
data_raw
[
"race_id"
]
=
file
lines
.
append
(
data_raw
)
return
lines
def
_create_examples
(
self
,
lines
,
set_type
):
"""Creates examples for the training and dev sets."""
examples
=
[]
for
(
_
,
data_raw
)
in
enumerate
(
lines
):
race_id
=
"%s-%s"
%
(
set_type
,
data_raw
[
"race_id"
])
article
=
data_raw
[
"article"
]
for
i
in
range
(
len
(
data_raw
[
"answers"
])):
truth
=
str
(
ord
(
data_raw
[
'answers'
][
i
])
-
ord
(
'A'
))
question
=
data_raw
[
'questions'
][
i
]
options
=
data_raw
[
'options'
][
i
]
examples
.
append
(
InputExample
(
example_id
=
race_id
,
question
=
question
,
contexts
=
[
article
,
article
,
article
,
article
],
endings
=
[
options
[
0
],
options
[
1
],
options
[
2
],
options
[
3
]],
label
=
truth
))
return
examples
class
SwagProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} train"
.
format
(
data_dir
))
return
self
.
_create_examples
(
self
.
_read_csv
(
os
.
path
.
join
(
data_dir
,
"train.csv"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} dev"
.
format
(
data_dir
))
return
self
.
_create_examples
(
self
.
_read_csv
(
os
.
path
.
join
(
data_dir
,
"val.csv"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} test"
.
format
(
data_dir
))
return
self
.
_create_examples
(
self
.
_read_csv
(
os
.
path
.
join
(
data_dir
,
"test.csv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
,
"2"
,
"3"
]
def
_read_csv
(
self
,
input_file
):
with
open
(
input_file
,
'r'
,
encoding
=
'utf-8'
)
as
f
:
reader
=
csv
.
reader
(
f
)
lines
=
[]
for
line
in
reader
:
if
sys
.
version_info
[
0
]
==
2
:
line
=
list
(
unicode
(
cell
,
'utf-8'
)
for
cell
in
line
)
lines
.
append
(
line
)
return
lines
def
_create_examples
(
self
,
lines
,
type
):
"""Creates examples for the training and dev sets."""
if
type
==
"train"
and
lines
[
0
][
-
1
]
!=
'label'
:
raise
ValueError
(
"For training, the input file must contain a label column."
)
examples
=
[
InputExample
(
example_id
=
line
[
2
],
question
=
line
[
5
],
# in the swag dataset, the
# common beginning of each
# choice is stored in "sent2".
contexts
=
[
line
[
4
],
line
[
4
],
line
[
4
],
line
[
4
]],
endings
=
[
line
[
7
],
line
[
8
],
line
[
9
],
line
[
10
]],
label
=
line
[
11
]
)
for
line
in
lines
[
1
:]
# we skip the line with the column names
]
return
examples
class
ArcProcessor
(
DataProcessor
):
"""Processor for the MRPC data set (GLUE version)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} train"
.
format
(
data_dir
))
return
self
.
_create_examples
(
self
.
_read_json
(
os
.
path
.
join
(
data_dir
,
"train.jsonl"
)),
"train"
)
def
get_dev_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} dev"
.
format
(
data_dir
))
return
self
.
_create_examples
(
self
.
_read_json
(
os
.
path
.
join
(
data_dir
,
"dev.jsonl"
)),
"dev"
)
def
get_test_examples
(
self
,
data_dir
):
logger
.
info
(
"LOOKING AT {} test"
.
format
(
data_dir
))
return
self
.
_create_examples
(
self
.
_read_json
(
os
.
path
.
join
(
data_dir
,
"test.jsonl"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
,
"2"
,
"3"
]
def
_read_json
(
self
,
input_file
):
with
open
(
input_file
,
'r'
,
encoding
=
'utf-8'
)
as
fin
:
lines
=
fin
.
readlines
()
return
lines
def
_create_examples
(
self
,
lines
,
type
):
"""Creates examples for the training and dev sets."""
def
normalize
(
truth
):
if
truth
in
"ABCD"
:
return
ord
(
truth
)
-
ord
(
"A"
)
elif
truth
in
"1234"
:
return
int
(
truth
)
-
1
else
:
logger
.
info
(
"truth ERROR! %s"
,
str
(
truth
))
return
None
examples
=
[]
three_choice
=
0
four_choice
=
0
five_choice
=
0
other_choices
=
0
for
line
in
tqdm
.
tqdm
(
lines
,
desc
=
"read arc data"
):
data_raw
=
json
.
loads
(
line
.
strip
(
"
\n
"
))
if
len
(
data_raw
[
"question"
][
"choices"
])
==
3
:
three_choice
+=
1
continue
elif
len
(
data_raw
[
"question"
][
"choices"
])
==
5
:
five_choice
+=
1
continue
elif
len
(
data_raw
[
"question"
][
"choices"
])
!=
4
:
other_choices
+=
1
continue
four_choice
+=
1
truth
=
str
(
normalize
(
data_raw
[
"answerKey"
]))
assert
truth
!=
"None"
question_choices
=
data_raw
[
"question"
]
question
=
question_choices
[
"stem"
]
id
=
data_raw
[
"id"
]
options
=
question_choices
[
"choices"
]
if
len
(
options
)
==
4
:
examples
.
append
(
InputExample
(
example_id
=
id
,
question
=
question
,
contexts
=
[
options
[
0
][
"para"
].
replace
(
"_"
,
""
),
options
[
1
][
"para"
].
replace
(
"_"
,
""
),
options
[
2
][
"para"
].
replace
(
"_"
,
""
),
options
[
3
][
"para"
].
replace
(
"_"
,
""
)],
endings
=
[
options
[
0
][
"text"
],
options
[
1
][
"text"
],
options
[
2
][
"text"
],
options
[
3
][
"text"
]],
label
=
truth
))
if
type
==
"train"
:
assert
len
(
examples
)
>
1
assert
examples
[
0
].
label
is
not
None
logger
.
info
(
"len examples: %s}"
,
str
(
len
(
examples
)))
logger
.
info
(
"Three choices: %s"
,
str
(
three_choice
))
logger
.
info
(
"Five choices: %s"
,
str
(
five_choice
))
logger
.
info
(
"Other choices: %s"
,
str
(
other_choices
))
logger
.
info
(
"four choices: %s"
,
str
(
four_choice
))
return
examples
def
convert_examples_to_features
(
examples
,
label_list
,
max_seq_length
,
tokenizer
,
cls_token_at_end
=
False
,
cls_token
=
'[CLS]'
,
cls_token_segment_id
=
1
,
sep_token
=
'[SEP]'
,
sequence_a_segment_id
=
0
,
sequence_b_segment_id
=
1
,
sep_token_extra
=
False
,
pad_token_segment_id
=
0
,
pad_on_left
=
False
,
pad_token
=
0
,
mask_padding_with_zero
=
True
):
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
- False (Default, BERT/XLM pattern): [CLS] + A + [SEP] + B + [SEP]
- True (XLNet/GPT pattern): A + [SEP] + B + [SEP] + [CLS]
`cls_token_segment_id` define the segment id associated to the CLS token (0 for BERT, 2 for XLNet)
"""
label_map
=
{
label
:
i
for
i
,
label
in
enumerate
(
label_list
)}
features
=
[]
for
(
ex_index
,
example
)
in
tqdm
.
tqdm
(
enumerate
(
examples
),
desc
=
"convert examples to features"
):
if
ex_index
%
10000
==
0
:
logger
.
info
(
"Writing example %d of %d"
%
(
ex_index
,
len
(
examples
)))
choices_features
=
[]
for
ending_idx
,
(
context
,
ending
)
in
enumerate
(
zip
(
example
.
contexts
,
example
.
endings
)):
tokens_a
=
tokenizer
.
tokenize
(
context
)
tokens_b
=
None
if
example
.
question
.
find
(
"_"
)
!=
-
1
:
tokens_b
=
tokenizer
.
tokenize
(
example
.
question
.
replace
(
"_"
,
ending
))
else
:
tokens_b
=
tokenizer
.
tokenize
(
example
.
question
)
tokens_b
+=
[
sep_token
]
if
sep_token_extra
:
tokens_b
+=
[
sep_token
]
tokens_b
+=
tokenizer
.
tokenize
(
ending
)
special_tokens_count
=
4
if
sep_token_extra
else
3
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_seq_length
-
special_tokens_count
)
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
#
# Where "type_ids" are used to indicate whether this is the first
# sequence or the second sequence. The embedding vectors for `type=0` and
# `type=1` were learned during pre-training and are added to the wordpiece
# embedding vector (and position vector). This is not *strictly* necessary
# since the [SEP] token unambiguously separates the sequences, but it makes
# it easier for the model to learn the concept of sequences.
#
# For classification tasks, the first vector (corresponding to [CLS]) is
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens
=
tokens_a
+
[
sep_token
]
if
sep_token_extra
:
# roberta uses an extra separator b/w pairs of sentences
tokens
+=
[
sep_token
]
segment_ids
=
[
sequence_a_segment_id
]
*
len
(
tokens
)
if
tokens_b
:
tokens
+=
tokens_b
+
[
sep_token
]
segment_ids
+=
[
sequence_b_segment_id
]
*
(
len
(
tokens_b
)
+
1
)
if
cls_token_at_end
:
tokens
=
tokens
+
[
cls_token
]
segment_ids
=
segment_ids
+
[
cls_token_segment_id
]
else
:
tokens
=
[
cls_token
]
+
tokens
segment_ids
=
[
cls_token_segment_id
]
+
segment_ids
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask
=
[
1
if
mask_padding_with_zero
else
0
]
*
len
(
input_ids
)
# Zero-pad up to the sequence length.
padding_length
=
max_seq_length
-
len
(
input_ids
)
if
pad_on_left
:
input_ids
=
([
pad_token
]
*
padding_length
)
+
input_ids
input_mask
=
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
+
input_mask
segment_ids
=
([
pad_token_segment_id
]
*
padding_length
)
+
segment_ids
else
:
input_ids
=
input_ids
+
([
pad_token
]
*
padding_length
)
input_mask
=
input_mask
+
([
0
if
mask_padding_with_zero
else
1
]
*
padding_length
)
segment_ids
=
segment_ids
+
([
pad_token_segment_id
]
*
padding_length
)
assert
len
(
input_ids
)
==
max_seq_length
assert
len
(
input_mask
)
==
max_seq_length
assert
len
(
segment_ids
)
==
max_seq_length
choices_features
.
append
((
tokens
,
input_ids
,
input_mask
,
segment_ids
))
label
=
label_map
[
example
.
label
]
if
ex_index
<
2
:
logger
.
info
(
"*** Example ***"
)
logger
.
info
(
"race_id: {}"
.
format
(
example
.
example_id
))
for
choice_idx
,
(
tokens
,
input_ids
,
input_mask
,
segment_ids
)
in
enumerate
(
choices_features
):
logger
.
info
(
"choice: {}"
.
format
(
choice_idx
))
logger
.
info
(
"tokens: {}"
.
format
(
' '
.
join
(
tokens
)))
logger
.
info
(
"input_ids: {}"
.
format
(
' '
.
join
(
map
(
str
,
input_ids
))))
logger
.
info
(
"input_mask: {}"
.
format
(
' '
.
join
(
map
(
str
,
input_mask
))))
logger
.
info
(
"segment_ids: {}"
.
format
(
' '
.
join
(
map
(
str
,
segment_ids
))))
logger
.
info
(
"label: {}"
.
format
(
label
))
features
.
append
(
InputFeatures
(
example_id
=
example
.
example_id
,
choices_features
=
choices_features
,
label
=
label
)
)
return
features
def
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_length
):
"""Truncates a sequence pair in place to the maximum length."""
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
break
# if len(tokens_a) > len(tokens_b):
# tokens_a.pop()
# else:
# tokens_b.pop()
tokens_a
.
pop
()
processors
=
{
"race"
:
RaceProcessor
,
"swag"
:
SwagProcessor
,
"arc"
:
ArcProcessor
}
GLUE_TASKS_NUM_LABELS
=
{
"race"
,
4
,
"swag"
,
4
,
"arc"
,
4
}
pytorch_transformers/__init__.py
View file @
2a2832ce
...
...
@@ -31,7 +31,7 @@ from .modeling_gpt2 import (GPT2Config, GPT2PreTrainedModel, GPT2Model,
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_xlnet
import
(
XLNetConfig
,
XLNetPreTrainedModel
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
,
XLNetForMultipleChoice
,
load_tf_weights_in_xlnet
,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_xlm
import
(
XLMConfig
,
XLMPreTrainedModel
,
XLMModel
,
...
...
@@ -39,6 +39,7 @@ from .modeling_xlm import (XLMConfig, XLMPreTrainedModel , XLMModel,
XLMForQuestionAnswering
,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP
,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_roberta
import
(
RobertaConfig
,
RobertaForMaskedLM
,
RobertaModel
,
RobertaForSequenceClassification
,
RobertaForMultipleChoice
,
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
,
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_utils
import
(
WEIGHTS_NAME
,
CONFIG_NAME
,
TF_WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_layer
,
Conv1D
)
...
...
pytorch_transformers/modeling_roberta.py
View file @
2a2832ce
...
...
@@ -329,6 +329,46 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
return
outputs
# (loss), logits, (hidden_states), (attentions)
class
RobertaForMultipleChoice
(
BertPreTrainedModel
):
config_class
=
RobertaConfig
pretrained_model_archive_map
=
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"roberta"
def
__init__
(
self
,
config
):
super
(
RobertaForMultipleChoice
,
self
).
__init__
(
config
)
self
.
roberta
=
RobertaModel
(
config
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
1
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
position_ids
=
None
,
head_mask
=
None
):
num_choices
=
input_ids
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
if
position_ids
is
not
None
else
None
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
outputs
=
self
.
roberta
(
flat_input_ids
,
position_ids
=
flat_position_ids
,
token_type_ids
=
flat_token_type_ids
,
attention_mask
=
flat_attention_mask
,
head_mask
=
head_mask
)
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
outputs
=
(
reshaped_logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
)
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), reshaped_logits, (hidden_states), (attentions)
class
RobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
...
...
pytorch_transformers/modeling_xlnet.py
View file @
2a2832ce
...
...
@@ -1143,6 +1143,50 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return
outputs
# return (loss), logits, mems, (hidden states), (attentions)
class
XLNetForMultipleChoice
(
XLNetPreTrainedModel
):
r
"""
"""
def
__init__
(
self
,
config
):
super
(
XLNetForMultipleChoice
,
self
).
__init__
(
config
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
1
)
self
.
apply
(
self
.
init_weights
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
labels
=
None
,
head_mask
=
None
):
num_choices
=
input_ids
.
shape
[
1
]
flat_input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
flat_token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
if
token_type_ids
is
not
None
else
None
flat_attention_mask
=
attention_mask
.
view
(
-
1
,
attention_mask
.
size
(
-
1
))
if
attention_mask
is
not
None
else
None
flat_input_mask
=
input_mask
.
view
(
-
1
,
input_mask
.
size
(
-
1
))
if
input_mask
is
not
None
else
None
transformer_outputs
=
self
.
transformer
(
flat_input_ids
,
token_type_ids
=
flat_token_type_ids
,
input_mask
=
flat_input_mask
,
attention_mask
=
flat_attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
head_mask
=
head_mask
)
output
=
transformer_outputs
[
0
]
output
=
self
.
sequence_summary
(
output
)
logits
=
self
.
logits_proj
(
output
)
reshaped_logits
=
logits
.
view
(
-
1
,
num_choices
)
outputs
=
(
reshaped_logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
reshaped_logits
,
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, mems, (hidden states), (attentions)
@
add_start_docstrings
(
"""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """
,
XLNET_START_DOCSTRING
,
XLNET_INPUTS_DOCSTRING
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment