Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4812a5a7
"test/vscode:/vscode.git/clone" did not exist on "0617528632fe266427e1ee6cf5037e3fca06e538"
Commit
4812a5a7
authored
Sep 16, 2019
by
erenup
Browse files
add doc string
parent
6e1ac34e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
158 additions
and
36 deletions
+158
-36
examples/single_model_scripts/run_multiple_choice.py
examples/single_model_scripts/run_multiple_choice.py
+6
-7
examples/single_model_scripts/utils_multiple_choice.py
examples/single_model_scripts/utils_multiple_choice.py
+34
-25
pytorch_transformers/modeling_roberta.py
pytorch_transformers/modeling_roberta.py
+69
-2
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+49
-2
No files found.
examples/single_model_scripts/run_multiple_choice.py
View file @
4812a5a7
...
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for multiple choice (Bert,
XLM
, XLNet)."""
""" Finetuning the library models for multiple choice (Bert,
Roberta
, XLNet)."""
from
__future__
import
absolute_import
,
division
,
print_function
...
...
@@ -44,7 +44,7 @@ from utils_multiple_choice import (convert_examples_to_features, processors)
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
XLNetConfig
)),
())
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
XLNetConfig
,
RobertaConfig
)),
())
MODEL_CLASSES
=
{
'bert'
:
(
BertConfig
,
BertForMultipleChoice
,
BertTokenizer
),
...
...
@@ -208,7 +208,6 @@ def train(args, train_dataset, model, tokenizer):
def
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
""
,
test
=
False
):
# Loop to handle MNLI double evaluation (matched, mis-matched)
eval_task_names
=
(
args
.
task_name
,)
eval_outputs_dirs
=
(
args
.
output_dir
,)
...
...
@@ -259,7 +258,7 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
result
=
{
"eval_acc"
:
acc
,
"eval_loss"
:
eval_loss
}
results
.
update
(
result
)
output_eval_file
=
os
.
path
.
join
(
eval_output_dir
,
"is_test_"
+
str
(
test
)
+
"_eval_results.txt"
)
output_eval_file
=
os
.
path
.
join
(
eval_output_dir
,
"is_test_"
+
str
(
test
)
.
lower
()
+
"_eval_results.txt"
)
with
open
(
output_eval_file
,
"w"
)
as
writer
:
logger
.
info
(
"***** Eval results {} *****"
.
format
(
str
(
prefix
)
+
" is test:"
+
str
(
test
)))
...
...
@@ -522,9 +521,9 @@ def main():
if
not
args
.
do_train
:
args
.
output_dir
=
args
.
model_name_or_path
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
#can not use this to do test!!
just for different paras
checkpoints
=
list
(
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
'/**/'
+
WEIGHTS_NAME
,
recursive
=
True
)))
logging
.
getLogger
(
"pytorch_transformers.modeling_utils"
).
setLevel
(
logging
.
WARN
)
# Reduce logging
#
if args.eval_all_checkpoints: #
can not use this to do test!!
#
checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True)))
#
logging.getLogger("pytorch_transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
global_step
=
checkpoint
.
split
(
'-'
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
...
...
examples/single_model_scripts/utils_multiple_choice.py
View file @
4812a5a7
...
...
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" BERT
classification
fine-tuning: utilities to work with
GLUE tasks
"""
""" BERT
multiple choice
fine-tuning: utilities to work with
multiple choice tasks of reading comprehension
"""
from
__future__
import
absolute_import
,
division
,
print_function
...
...
@@ -38,11 +38,10 @@ class InputExample(object):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
example_id: Unique id for the example.
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
question: string. The untokenized text of the second sequence (qustion).
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
...
...
@@ -73,7 +72,7 @@ class InputFeatures(object):
class
DataProcessor
(
object
):
"""Base class for data converters for
sequence classification
data sets."""
"""Base class for data converters for
multiple choice
data sets."""
def
get_train_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the train set."""
...
...
@@ -84,7 +83,7 @@ class DataProcessor(object):
raise
NotImplementedError
()
def
get_test_examples
(
self
,
data_dir
):
"""Gets a collection of `InputExample`s for the
dev
set."""
"""Gets a collection of `InputExample`s for the
test
set."""
raise
NotImplementedError
()
def
get_labels
(
self
):
...
...
@@ -93,7 +92,7 @@ class DataProcessor(object):
class
RaceProcessor
(
DataProcessor
):
"""Processor for the
MRPC
data set
(GLUE version)
."""
"""Processor for the
RACE
data set."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -152,13 +151,13 @@ class RaceProcessor(DataProcessor):
InputExample
(
example_id
=
race_id
,
question
=
question
,
contexts
=
[
article
,
article
,
article
,
article
],
contexts
=
[
article
,
article
,
article
,
article
],
# this is not efficient but convenient
endings
=
[
options
[
0
],
options
[
1
],
options
[
2
],
options
[
3
]],
label
=
truth
))
return
examples
class
SwagProcessor
(
DataProcessor
):
"""Processor for the
MRPC
data set
(GLUE version)
."""
"""Processor for the
SWAG
data set."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -172,9 +171,12 @@ class SwagProcessor(DataProcessor):
def
get_test_examples
(
self
,
data_dir
):
"""See base class."""
logger
.
info
(
"LOOKING AT {} test"
.
format
(
data_dir
))
logger
.
info
(
"LOOKING AT {} dev"
.
format
(
data_dir
))
raise
ValueError
(
"For swag testing, the input file does not contain a label column. It can not be tested in current code"
"setting!"
)
return
self
.
_create_examples
(
self
.
_read_csv
(
os
.
path
.
join
(
data_dir
,
"test.csv"
)),
"test"
)
def
get_labels
(
self
):
"""See base class."""
return
[
"0"
,
"1"
,
"2"
,
"3"
]
...
...
@@ -213,7 +215,7 @@ class SwagProcessor(DataProcessor):
class
ArcProcessor
(
DataProcessor
):
"""Processor for the
MRP
C data set (
GLUE version
)."""
"""Processor for the
AR
C data set (
request from allennlp
)."""
def
get_train_examples
(
self
,
data_dir
):
"""See base class."""
...
...
@@ -242,6 +244,7 @@ class ArcProcessor(DataProcessor):
def
_create_examples
(
self
,
lines
,
type
):
"""Creates examples for the training and dev sets."""
#There are two types of labels. They should be normalized
def
normalize
(
truth
):
if
truth
in
"ABCD"
:
return
ord
(
truth
)
-
ord
(
"A"
)
...
...
@@ -256,6 +259,7 @@ class ArcProcessor(DataProcessor):
four_choice
=
0
five_choice
=
0
other_choices
=
0
# we deleted example which has more than or less than four choices
for
line
in
tqdm
.
tqdm
(
lines
,
desc
=
"read arc data"
):
data_raw
=
json
.
loads
(
line
.
strip
(
"
\n
"
))
if
len
(
data_raw
[
"question"
][
"choices"
])
==
3
:
...
...
@@ -274,7 +278,6 @@ class ArcProcessor(DataProcessor):
question
=
question_choices
[
"stem"
]
id
=
data_raw
[
"id"
]
options
=
question_choices
[
"choices"
]
if
len
(
options
)
==
4
:
examples
.
append
(
InputExample
(
...
...
@@ -328,13 +331,16 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
tokens_a
=
tokenizer
.
tokenize
(
context
)
tokens_b
=
None
if
example
.
question
.
find
(
"_"
)
!=
-
1
:
#this is for cloze question
tokens_b
=
tokenizer
.
tokenize
(
example
.
question
.
replace
(
"_"
,
ending
))
else
:
tokens_b
=
tokenizer
.
tokenize
(
example
.
question
)
tokens_b
+=
[
sep_token
]
if
sep_token_extra
:
tokens_b
+=
[
sep_token
]
tokens_b
+=
tokenizer
.
tokenize
(
ending
)
tokens_b
=
tokenizer
.
tokenize
(
example
.
question
+
" "
+
ending
)
# you can add seq token between quesiotn and ending. This does not make too much difference.
# tokens_b = tokenizer.tokenize(example.question)
# tokens_b += [sep_token]
# if sep_token_extra:
# tokens_b += [sep_token]
# tokens_b += tokenizer.tokenize(ending)
special_tokens_count
=
4
if
sep_token_extra
else
3
_truncate_seq_pair
(
tokens_a
,
tokens_b
,
max_seq_length
-
special_tokens_count
)
...
...
@@ -427,15 +433,18 @@ def _truncate_seq_pair(tokens_a, tokens_b, max_length):
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
# However, since we'd better not to remove tokens of options and questions, you can choose to use a bigger
# length or only pop from context
while
True
:
total_length
=
len
(
tokens_a
)
+
len
(
tokens_b
)
if
total_length
<=
max_length
:
break
# if len(tokens_a) > len(tokens_b):
# tokens_a.pop()
# else:
# tokens_b.pop()
if
len
(
tokens_a
)
>
len
(
tokens_b
):
tokens_a
.
pop
()
else
:
logger
.
info
(
'Attention! you are removing from question + options. Try to use a bigger max seq length!'
)
tokens_b
.
pop
()
processors
=
{
...
...
pytorch_transformers/modeling_roberta.py
View file @
4812a5a7
...
...
@@ -294,7 +294,7 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
Examples::
tokenizer = RoertaTokenizer.from_pretrained('roberta-base')
tokenizer = Ro
b
ertaTokenizer.from_pretrained('roberta-base')
model = RobertaForSequenceClassification.from_pretrained('roberta-base')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
...
...
@@ -333,8 +333,75 @@ class RobertaForSequenceClassification(BertPreTrainedModel):
return
outputs
# (loss), logits, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Roberta Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """
,
ROBERTA_START_DOCSTRING
,
ROBERTA_INPUTS_DOCSTRING
)
class
RobertaForMultipleChoice
(
BertPreTrainedModel
):
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
To match pre-training, RoBerta input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaForMultipleChoice.from_pretrained('roberta-base')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
config_class
=
RobertaConfig
pretrained_model_archive_map
=
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"roberta"
...
...
pytorch_transformers/modeling_xlnet.py
View file @
4812a5a7
...
...
@@ -1152,9 +1152,56 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return
outputs
# return (loss), logits, mems, (hidden states), (attentions)
@
add_start_docstrings
(
"""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """
,
XLNET_START_DOCSTRING
,
XLNET_INPUTS_DOCSTRING
)
class
XLNetForMultipleChoice
(
XLNetPreTrainedModel
):
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to scores.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
model = XLNetForMultipleChoice.from_pretrained('xlnet-base-cased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, classification_scores = outputs[:2]
"""
def
__init__
(
self
,
config
):
...
...
@@ -1251,7 +1298,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
Examples::
tokenizer = XL
M
Tokenizer.from_pretrained('xl
m-mlm-en-2048
')
tokenizer =
XL
Net
Tokenizer.from_pretrained('xl
net-large-cased
')
model = XLMForQuestionAnswering.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
start_positions = torch.tensor([1])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment