Commit 762ded9b authored by thomwolf's avatar thomwolf
Browse files

wip examples

parent 74429563
......@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on SQuAD."""
""" Finetuning a question-answering Bert model on SQuAD."""
from __future__ import absolute_import, division, print_function
......
......@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""BERT finetuning runner."""
""" Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE."""
from __future__ import absolute_import, division, print_function
......
......@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run BERT on SQuAD."""
""" Finetuning a question-answering model (Bert, XLM, XLNet,...) on SQuAD."""
from __future__ import absolute_import, division, print_function
......@@ -21,7 +21,6 @@ import argparse
import logging
import os
import random
import sys
from io import open
import numpy as np
......@@ -33,31 +32,35 @@ from tqdm import tqdm, trange
from tensorboardX import SummaryWriter
from pytorch_transformers import (BertForQuestionAnswering, XLNetForQuestionAnswering,
XLMForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP)
from pytorch_transformers import (BertTokenizer, XLNetTokenizer,
XLMTokenizer)
from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
BertForQuestionAnswering, BertTokenizer,
XLMConfig, XLMForQuestionAnswering,
XLMTokenizer, XLNetConfig,
XLNetForQuestionAnswering,
XLNetTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP,
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
MODEL_CLASSES = {
'bert': BertForQuestionAnswering,
'xlnet': XLNetForQuestionAnswering,
'xlm': XLMForQuestionAnswering,
'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
}
TOKENIZER_CLASSES = {
'bert': BertTokenizer,
'xlnet': XLNetTokenizer,
'xlm': XLMTokenizer,
}
def set_seed(args):
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def train(args, train_dataset, model):
""" Train the model """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment