Unverified Commit dd9d483d authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Trainer (#3800)

* doc

* [tests] Add sample files for a regression task

* [HUGE] Trainer

* Feedback from @sshleifer

* Feedback from @thomwolf + logging tweak

* [file_utils] when downloading concurrently, get_from_cache will use the cached file for subsequent processes

* [glue] Use default max_seq_length of 128 like before

* [glue] move DataTrainingArguments around

* [ner] Change interface of InputExample, and align run_{tf,pl}

* Re-align the pl scripts a little bit

* ner

* [ner] Add integration test

* Fix language_modeling with API tweak

* [ci] Tweak loss target

* Don't break console output

* amp.initialize: model must be on right device before

* [multiple-choice] update for Trainer

* Re-align to 827d6d6e
parent eb5601b0
Gleich O
darauf O
entwirft O
er O
seine O
Selbstdarstellung O
" O
Ecce B-OTH
homo I-OTH
" O
in O
enger O
Auseinandersetzung O
mit O
diesem O
Bild O
Jesu B-PER
. O
1980 O
kam O
der O
Crown B-OTH
als O
Versuch O
von O
Toyota B-ORG
, O
sich O
in O
der O
Oberen O
Mittelklasse O
zu O
etablieren O
, O
auch O
nach O
Deutschland B-LOC
. O
– O
4:26 O
# O
Sometime B-OTH
Ago/La I-OTH
Fiesta I-OTH
– O
23:18 O
Alle O
Stücke O
wurden O
von O
Corea B-PER
komponiert O
mit O
Ausnahme O
der O
einleitenden O
Improvisation O
zu O
Sometime B-OTH
Ago I-OTH
. O
Bis O
2013 O
steigen O
die O
Mittel O
aus O
dem O
EU-Budget B-ORGpart
auf O
rund O
120 O
Millionen O
Euro B-OTH
. O
Daraus O
entwickelte O
sich O
im O
Rokoko B-OTH
die O
Sitte O
des O
gemeinsamen O
Weinens O
im O
Theater O
, O
das O
die O
Standesgrenzen O
innerhalb O
des O
Publikums O
überbrücken O
sollte O
. O
Die O
Spinne O
hatte O
sie O
mit O
Seidenfäden O
an O
ihrem O
Schwanz O
gefesselt O
und O
nach O
oben O
gezogen O
. O
In O
Deutschland B-LOC
ist O
nach O
StGB O
eine O
Anwerbung O
für O
die O
Fremdenlegion O
strafbar O
. O
Am O
Donnerstag O
wird O
sich O
zeigen O
, O
ob O
die O
Idee O
der O
DLR-Forscher B-ORGpart
funktioniert O
. O
Der O
sechste O
Lauf O
der O
ADAC B-ORG
GT I-ORG
Mastersstand O
ganz O
klar O
im O
Mittelpunkt O
des O
Motorsport-Wochenendes O
auf O
dem O
Eurospeedway B-ORG
Lausitz I-ORG
. O
Nach O
den O
schwächeren O
Vorgaben O
der O
Wall B-ORG
Street I-ORG
vom O
Vortag O
setzten O
die O
deutschen B-LOCderiv
Standardwerte O
ihren O
Konsolidierungskurs O
fort O
. O
Kolb B-PER
war O
seit O
1986 O
im O
Turnverein O
als O
Leiter O
tätig O
, O
darunter O
elf O
Jahre O
als O
Hauptleiter O
in O
der O
Männerriege O
. O
B-LOC
B-LOCderiv
B-LOCpart
B-ORG
B-ORGderiv
B-ORGpart
B-OTH
B-OTHderiv
B-OTHpart
B-PER
B-PERderiv
B-PERpart
I-LOC
I-LOCderiv
I-LOCpart
I-ORG
I-ORGderiv
I-ORGpart
I-OTH
I-OTHderiv
I-OTHpart
I-PER
I-PERderiv
I-PERpart
O
Schartau B-PER
sagte O
dem O
" O
Tagesspiegel B-ORG
" O
vom O
Freitag O
, O
Fischer B-PER
sei O
" O
in O
einer O
Weise O
aufgetreten O
, O
die O
alles O
andere O
als O
überzeugend O
war O
" O
. O
Firmengründer O
Wolf B-PER
Peter I-PER
Bree I-PER
arbeitete O
Anfang O
der O
siebziger O
Jahre O
als O
Möbelvertreter O
, O
als O
er O
einen O
fliegenden O
Händler O
aus O
dem O
Libanon B-LOC
traf O
. O
Ob O
sie O
dabei O
nach O
dem O
Runden O
Tisch O
am O
23. O
April O
in O
Berlin B-LOC
durch O
ein O
pädagogisches O
Konzept O
unterstützt O
wird O
, O
ist O
allerdings O
zu O
bezweifeln O
. O
Bayern B-ORG
München I-ORG
ist O
wieder O
alleiniger O
Top- O
Favorit O
auf O
den O
Gewinn O
der O
deutschen B-LOCderiv
Fußball-Meisterschaft O
. O
Dabei O
hätte O
der O
tapfere O
Schlussmann O
allen O
Grund O
gehabt O
, O
sich O
viel O
früher O
aufzuregen O
. O
ARD-Programmchef B-ORGpart
Günter B-PER
Struve I-PER
war O
wegen O
eines O
vierwöchigen O
Urlaubs O
für O
eine O
Stellungnahme O
nicht O
erreichbar O
. O
Alternativ O
sollten O
sich O
die O
Restaurantbetreiber O
aus O
Sicht O
der O
Solingerin B-LOCderiv
zu O
längeren O
Öffnungszeiten O
verpflichten O
, O
um O
wartende O
Kunden O
aufzunehmen O
. O
Die O
Deutsche B-ORG
Flugsicherung I-ORG
( O
DFS B-ORG
) O
beschloss O
ein O
Flugverbot O
für O
alle O
internationalen O
Flughäfen O
mit O
Ausnahme O
der O
beiden O
Berliner B-LOCderiv
Flughäfen O
bis O
2.00 O
Uhr O
nachts O
. O
New O
Small O
Family O
mit O
E-Motor O
: O
Studie O
E-Up O
! O
Eine O
Schwachstelle O
war O
beispielsweise O
der O
Spiegelkasten O
. O
Denn O
durch O
den O
Einsatz O
moderner O
Fahrzeugtechnik O
( O
Dieseltriebwagen O
) O
und O
schalldämmender O
Fenster O
entsteht O
keine O
Einschränkung O
der O
Wohnqualität O
. O
index genre filename year old_index source1 source2 sentence1 sentence2 score
0 main-captions MSRvid 2012test 0000 none none A man with a hard hat is dancing. A man wearing a hard hat is dancing. 5.000
1 main-captions MSRvid 2012test 0002 none none A young child is riding a horse. A child is riding a horse. 4.750
2 main-captions MSRvid 2012test 0003 none none A man is feeding a mouse to a snake. The man is feeding a mouse to the snake. 5.000
3 main-captions MSRvid 2012test 0007 none none A woman is playing the guitar. A man is playing guitar. 2.400
4 main-captions MSRvid 2012test 0008 none none A woman is playing the flute. A man is playing a flute. 2.750
5 main-captions MSRvid 2012test 0010 none none A woman is cutting an onion. A man is cutting onions. 2.615
6 main-captions MSRvid 2012test 0015 none none A man is erasing a chalk board. The man is erasing the chalk board. 5.000
7 main-captions MSRvid 2012test 0023 none none A woman is carrying a boy. A woman is carrying her baby. 2.333
8 main-captions MSRvid 2012test 0027 none none Three men are playing guitars. Three men are on stage playing guitars. 3.750
index genre filename year old_index source1 source2 sentence1 sentence2 score
0 main-captions MSRvid 2012test 0001 none none A plane is taking off. An air plane is taking off. 5.000
1 main-captions MSRvid 2012test 0004 none none A man is playing a large flute. A man is playing a flute. 3.800
2 main-captions MSRvid 2012test 0005 none none A man is spreading shreded cheese on a pizza. A man is spreading shredded cheese on an uncooked pizza. 3.800
3 main-captions MSRvid 2012test 0006 none none Three men are playing chess. Two men are playing chess. 2.600
4 main-captions MSRvid 2012test 0009 none none A man is playing the cello. A man seated is playing the cello. 4.250
5 main-captions MSRvid 2012test 0011 none none Some men are fighting. Two men are fighting. 4.250
6 main-captions MSRvid 2012test 0012 none none A man is smoking. A man is skating. 0.500
7 main-captions MSRvid 2012test 0013 none none The man is playing the piano. The man is playing the guitar. 1.600
8 main-captions MSRvid 2012test 0014 none none A man is playing on a guitar and singing. A woman is playing an acoustic guitar and singing. 2.200
......@@ -8,7 +8,6 @@ import pytorch_lightning as pl
import torch
from transformers import (
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
AdamW,
AutoConfig,
AutoModel,
......@@ -20,15 +19,11 @@ from transformers import (
AutoTokenizer,
get_linear_schedule_with_warmup,
)
from transformers.modeling_auto import MODEL_MAPPING
logger = logging.getLogger(__name__)
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
MODEL_MODES = {
"base": AutoModel,
"sequence-classification": AutoModelForSequenceClassification,
......@@ -51,28 +46,25 @@ class BaseTransformer(pl.LightningModule):
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
"Initialize a model."
super(BaseTransformer, self).__init__()
super().__init__()
self.hparams = hparams
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
self.hparams.model_type = self.hparams.model_type.lower()
config = AutoConfig.from_pretrained(
self.config = AutoConfig.from_pretrained(
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
**({"num_labels": num_labels} if num_labels is not None else {}),
cache_dir=cache_dir,
**config_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer = AutoTokenizer.from_pretrained(
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
do_lower_case=self.hparams.do_lower_case,
cache_dir=cache_dir,
)
model = MODEL_MODES[mode].from_pretrained(
self.model = MODEL_MODES[mode].from_pretrained(
self.hparams.model_name_or_path,
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
config=config,
config=self.config,
cache_dir=cache_dir,
)
self.config, self.tokenizer, self.model = config, tokenizer, model
def is_logger(self):
return self.trainer.proc_rank <= 0
......@@ -148,19 +140,12 @@ class BaseTransformer(pl.LightningModule):
@staticmethod
def add_model_specific_args(parser, root_dir):
parser.add_argument(
"--model_type",
default=None,
type=str,
required=True,
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
......@@ -177,9 +162,6 @@ class BaseTransformer(pl.LightningModule):
type=str,
help="Where do you want to store the pre-trained models downloaded from s3",
)
parser.add_argument(
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
)
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
......@@ -252,8 +234,6 @@ def add_generic_args(parser, root_dir):
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
......@@ -261,15 +241,6 @@ def generic_train(model: BaseTransformer, args: argparse.Namespace):
# init model
set_seed(args)
# Setup distant debugging if needed
if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd
print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
......
......@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
""" Multiple choice fine-tuning: utilities to work with multiple choice tasks of reading comprehension """
import csv
......@@ -21,48 +21,124 @@ import glob
import json
import logging
import os
from typing import List
from dataclasses import dataclass
from enum import Enum
from typing import List, Optional
import torch
import tqdm
from torch.utils.data.dataset import Dataset
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, torch_distributed_zero_first
logger = logging.getLogger(__name__)
class InputExample(object):
"""A single training/test example for multiple choice"""
def __init__(self, example_id, question, contexts, endings, label=None):
"""Constructs a InputExample.
Args:
example_id: Unique id for the example.
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
question: string. The untokenized text of the second sequence (question).
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.example_id = example_id
self.question = question
self.contexts = contexts
self.endings = endings
self.label = label
class InputFeatures(object):
def __init__(self, example_id, choices_features, label):
self.example_id = example_id
self.choices_features = [
{"input_ids": input_ids, "input_mask": input_mask, "segment_ids": segment_ids}
for input_ids, input_mask, segment_ids in choices_features
]
self.label = label
@dataclass(frozen=True)
class InputExample:
"""
A single training/test example for multiple choice
Args:
example_id: Unique id for the example.
question: string. The untokenized text of the second sequence (question).
contexts: list of str. The untokenized text of the first sequence (context of corresponding question).
endings: list of str. multiple choice's options. Its length must be equal to contexts' length.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
example_id: str
question: str
contexts: List[str]
endings: List[str]
label: Optional[str]
@dataclass(frozen=True)
class InputFeatures:
"""
A single set of features of data.
Property names are the same names as the corresponding inputs to a model.
"""
example_id: str
input_ids: List[List[int]]
attention_mask: Optional[List[List[int]]]
token_type_ids: Optional[List[List[int]]]
label: Optional[int]
class Split(Enum):
train = "train"
dev = "dev"
test = "test"
class MultipleChoiceDataset(Dataset):
"""
This will be superseded by a framework-agnostic approach
soon.
"""
features: List[InputFeatures]
def __init__(
self,
data_dir: str,
tokenizer: PreTrainedTokenizer,
task: str,
max_seq_length: Optional[int] = None,
overwrite_cache=False,
mode: Split = Split.train,
local_rank=-1,
):
processor = processors[task]()
cached_features_file = os.path.join(
data_dir,
"cached_{}_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length), task,),
)
with torch_distributed_zero_first(local_rank):
# Make sure only the first process in distributed training processes the dataset,
# and the others will use the cache.
if os.path.exists(cached_features_file) and not overwrite_cache:
logger.info(f"Loading features from cached file {cached_features_file}")
self.features = torch.load(cached_features_file)
else:
logger.info(f"Creating features from dataset file at {data_dir}")
label_list = processor.get_labels()
if mode == Split.dev:
examples = processor.get_dev_examples(data_dir)
elif mode == Split.test:
examples = processor.get_test_examples(data_dir)
else:
examples = processor.get_train_examples(data_dir)
logger.info("Training examples: %s", len(examples))
# TODO clean up all this to leverage built-in features of tokenizers
self.features = convert_examples_to_features(
examples,
label_list,
max_seq_length,
tokenizer,
pad_on_left=bool(tokenizer.padding_side == "left"),
pad_token=tokenizer.pad_token_id,
pad_token_segment_id=tokenizer.pad_token_type_id,
)
if local_rank in [-1, 0]:
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(self.features, cached_features_file)
class DataProcessor(object):
def __len__(self):
return len(self.features)
def __getitem__(self, i) -> InputFeatures:
return self.features[i]
class DataProcessor:
"""Base class for data converters for multiple choice data sets."""
def get_train_examples(self, data_dir):
......@@ -311,7 +387,7 @@ def convert_examples_to_features(
for (ex_index, example) in tqdm.tqdm(enumerate(examples), desc="convert examples to features"):
if ex_index % 10000 == 0:
logger.info("Writing example %d of %d" % (ex_index, len(examples)))
choices_features = []
choices_inputs = []
for ending_idx, (context, ending) in enumerate(zip(example.contexts, example.endings)):
text_a = context
if example.question.find("_") != -1:
......@@ -321,7 +397,7 @@ def convert_examples_to_features(
text_b = example.question + " " + ending
inputs = tokenizer.encode_plus(
text_a, text_b, add_special_tokens=True, max_length=max_length, return_token_type_ids=True
text_a, text_b, add_special_tokens=True, max_length=max_length, pad_to_max_length=True,
)
if "num_truncated_tokens" in inputs and inputs["num_truncated_tokens"] > 0:
logger.info(
......@@ -330,41 +406,31 @@ def convert_examples_to_features(
"you need to try to use a bigger max seq length!"
)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
choices_inputs.append(inputs)
# Zero-pad up to the sequence length.
padding_length = max_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
attention_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + attention_mask
token_type_ids = ([pad_token_segment_id] * padding_length) + token_type_ids
else:
input_ids = input_ids + ([pad_token] * padding_length)
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
label = label_map[example.label]
assert len(input_ids) == max_length
assert len(attention_mask) == max_length
assert len(token_type_ids) == max_length
choices_features.append((input_ids, attention_mask, token_type_ids))
input_ids = [x["input_ids"] for x in choices_inputs]
attention_mask = (
[x["attention_mask"] for x in choices_inputs] if "attention_mask" in choices_inputs[0] else None
)
token_type_ids = (
[x["token_type_ids"] for x in choices_inputs] if "token_type_ids" in choices_inputs[0] else None
)
label = label_map[example.label]
features.append(
InputFeatures(
example_id=example.example_id,
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
label=label,
)
)
if ex_index < 2:
logger.info("*** Example ***")
logger.info("race_id: {}".format(example.example_id))
for choice_idx, (input_ids, attention_mask, token_type_ids) in enumerate(choices_features):
logger.info("choice: {}".format(choice_idx))
logger.info("input_ids: {}".format(" ".join(map(str, input_ids))))
logger.info("attention_mask: {}".format(" ".join(map(str, attention_mask))))
logger.info("token_type_ids: {}".format(" ".join(map(str, token_type_ids))))
logger.info("label: {}".format(label))
features.append(InputFeatures(example_id=example.example_id, choices_features=choices_features, label=label,))
for f in features[:2]:
logger.info("*** Example ***")
logger.info("feature: %s" % f)
return features
......
......@@ -31,6 +31,8 @@ from .benchmark_utils import (
start_memory_tracing,
stop_memory_tracing,
)
# Configurations
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, AutoConfig
from .configuration_bart import BartConfig
......@@ -46,8 +48,6 @@ from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, Open
from .configuration_roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig
from .configuration_t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config
from .configuration_transfo_xl import TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, TransfoXLConfig
# Configurations
from .configuration_utils import PretrainedConfig
from .configuration_xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig
from .configuration_xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig
......@@ -121,6 +121,8 @@ from .pipelines import (
TranslationPipeline,
pipeline,
)
# Tokenizers
from .tokenization_albert import AlbertTokenizer
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from .tokenization_bart import BartTokenizer, MBartTokenizer
......@@ -136,8 +138,6 @@ from .tokenization_openai import OpenAIGPTTokenizer, OpenAIGPTTokenizerFast
from .tokenization_roberta import RobertaTokenizer, RobertaTokenizerFast
from .tokenization_t5 import T5Tokenizer
from .tokenization_transfo_xl import TransfoXLCorpus, TransfoXLTokenizer, TransfoXLTokenizerFast
# Tokenizers
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_xlm import XLMTokenizer
from .tokenization_xlm_roberta import XLMRobertaTokenizer
......@@ -162,6 +162,7 @@ if is_torch_available():
AutoModelForQuestionAnswering,
AutoModelWithLMHead,
AutoModelForTokenClassification,
AutoModelForMultipleChoice,
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
MODEL_MAPPING,
MODEL_FOR_PRETRAINING_MAPPING,
......@@ -169,6 +170,7 @@ if is_torch_available():
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
)
from .modeling_bert import (
......@@ -320,6 +322,10 @@ if is_torch_available():
get_linear_schedule_with_warmup,
)
# Trainer
from .trainer import Trainer, set_seed, torch_distributed_zero_first, EvalPrediction
from .data.data_collator import DefaultDataCollator, DataCollator, DataCollatorForLanguageModeling
from .data.datasets import GlueDataset, TextDataset, LineByLineTextDataset, GlueDataTrainingArguments
# TensorFlow
if is_tf_available():
......
......@@ -87,7 +87,7 @@ class PretrainedConfig(object):
self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop("num_labels", 2)
self.id2label = kwargs.pop("id2label", {i: "LABEL_{}".format(i) for i in range(self.num_labels)})
self.id2label = kwargs.pop("id2label", {i: f"LABEL_{i}" for i in range(self.num_labels)})
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
......
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, NewType, Tuple
import torch
from torch.nn.utils.rnn import pad_sequence
from ..tokenization_utils import PreTrainedTokenizer
class DataCollator(ABC):
"""
A `DataCollator` is responsible for batching
and pre-processing samples of data as requested by the training loop.
"""
@abstractmethod
def collate_batch(self) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
pass
InputDataClass = NewType("InputDataClass", Any)
@dataclass
class DefaultDataCollator(DataCollator):
"""
Very simple data collator that:
- simply collates batches of dict-like objects
- Performs special handling for potential keys named:
- `label`: handles a single value (int or float) per object
- `label_ids`: handles a list of values per object
- does not do any additional preprocessing
i.e., Property names of the input object will be used as corresponding inputs to the model.
See glue and ner for example of how it's useful.
"""
def collate_batch(self, features: List[InputDataClass]) -> Dict[str, torch.Tensor]:
# In this method we'll make the assumption that all `features` in the batch
# have the same attributes.
# So we will look at the first element as a proxy for what attributes exist
# on the whole batch.
first = features[0]
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if hasattr(first, "label") and first.label is not None:
if type(first.label) is int:
labels = torch.tensor([f.label for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label for f in features], dtype=torch.float)
batch = {"labels": labels}
elif hasattr(first, "label_ids") and first.label_ids is not None:
if type(first.label_ids[0]) is int:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.long)
else:
labels = torch.tensor([f.label_ids for f in features], dtype=torch.float)
batch = {"labels": labels}
else:
batch = {}
# Handling of all other possible attributes.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in vars(first).items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
batch[k] = torch.tensor([getattr(f, k) for f in features], dtype=torch.long)
return batch
@dataclass
class DataCollatorForLanguageModeling(DataCollator):
"""
Data collator used for language modeling.
- collates batches of tensors, honoring their tokenizer's pad_token
- preprocesses batches for masked language modeling
"""
tokenizer: PreTrainedTokenizer
mlm: bool = True
mlm_probability: float = 0.15
def collate_batch(self, examples: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
batch = self._tensorize_batch(examples)
if self.mlm:
inputs, labels = self.mask_tokens(batch)
return {"input_ids": inputs, "masked_lm_labels": labels}
else:
return {"input_ids": batch, "labels": batch}
def _tensorize_batch(self, examples: List[torch.Tensor]) -> torch.Tensor:
length_of_first = examples[0].size(0)
are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
if are_tensors_same_length:
return torch.stack(examples, dim=0)
else:
if self.tokenizer._pad_token is None:
raise ValueError(
"You are attempting to pad samples but the tokenizer you are using"
f" ({self.tokenizer.__class__.__name__}) does not have one."
)
return pad_sequence(examples, batch_first=True, padding_value=self.tokenizer.pad_token_id)
def mask_tokens(self, inputs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
"""
if self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer."
)
labels = inputs.clone()
# We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
return inputs, labels
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from .glue import GlueDataset, GlueDataTrainingArguments
from .language_modeling import LineByLineTextDataset, TextDataset
This diff is collapsed.
This diff is collapsed.
......@@ -17,6 +17,7 @@
import logging
import os
from enum import Enum
from typing import List, Optional, Union
from ...file_utils import is_tf_available
......@@ -153,6 +154,11 @@ def _glue_convert_examples_to_features(
return features
class OutputMode(Enum):
classification = "classification"
regression = "regression"
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -5,8 +5,7 @@ from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
from transformers.hf_argparser import HfArgumentParser
from transformers.training_args import TrainingArguments
from transformers import HfArgumentParser, TrainingArguments
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment