Commit 762ded9b authored by thomwolf's avatar thomwolf
Browse files

wip examples

parent 74429563
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Run BERT on SQuAD.""" """ Finetuning a question-answering Bert model on SQuAD."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """ Finetuning a classification model (Bert, XLM, XLNet,...) on GLUE."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Run BERT on SQuAD.""" """ Finetuning a question-answering model (Bert, XLM, XLNet,...) on SQuAD."""
from __future__ import absolute_import, division, print_function from __future__ import absolute_import, division, print_function
...@@ -21,7 +21,6 @@ import argparse ...@@ -21,7 +21,6 @@ import argparse
import logging import logging
import os import os
import random import random
import sys
from io import open from io import open
import numpy as np import numpy as np
...@@ -33,31 +32,35 @@ from tqdm import tqdm, trange ...@@ -33,31 +32,35 @@ from tqdm import tqdm, trange
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from pytorch_transformers import (BertForQuestionAnswering, XLNetForQuestionAnswering, from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
XLMForQuestionAnswering, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertForQuestionAnswering, BertTokenizer,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, XLM_PRETRAINED_MODEL_ARCHIVE_MAP) XLMConfig, XLMForQuestionAnswering,
from pytorch_transformers import (BertTokenizer, XLNetTokenizer, XLMTokenizer, XLNetConfig,
XLMTokenizer) XLNetForQuestionAnswering,
XLNetTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(m.keys()) for m in (BERT_PRETRAINED_MODEL_ARCHIVE_MAP, ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) \
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP, for conf in (BertConfig, XLNetConfig, XLMConfig)), ())
XLM_PRETRAINED_MODEL_ARCHIVE_MAP)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'bert': BertForQuestionAnswering, 'bert': (BertConfig, BertForQuestionAnswering, BertTokenizer),
'xlnet': XLNetForQuestionAnswering, 'xlnet': (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
'xlm': XLMForQuestionAnswering, 'xlm': (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
} }
TOKENIZER_CLASSES = { def set_seed(args):
'bert': BertTokenizer, random.seed(args.seed)
'xlnet': XLNetTokenizer, np.random.seed(args.seed)
'xlm': XLMTokenizer, torch.manual_seed(args.seed)
} if args.n_gpu > 0:
torch.cuda.manual_seed_all(args.seed)
def train(args, train_dataset, model): def train(args, train_dataset, model):
""" Train the model """ """ Train the model """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment