Commit 40dbda68 authored by thomwolf's avatar thomwolf
Browse files

updating classification example

parent 7388c83b
...@@ -228,10 +228,10 @@ def main(): ...@@ -228,10 +228,10 @@ def main():
# Prepare data loader # Prepare data loader
train_examples = processor.get_train_examples(args.data_dir) train_examples = processor.get_train_examples(args.data_dir)
cached_train_features_file = args.data_dir + '_{0}_{1}_{2}'.format( cached_train_features_file = os.path.join(args.data_dir, 'train_{0}_{1}_{2}'.format(
list(filter(None, args.bert_model.split('/'))).pop(), list(filter(None, args.bert_model.split('/'))).pop(),
str(args.max_seq_length), str(args.max_seq_length),
str(task_name)) str(task_name)))
try: try:
with open(cached_train_features_file, "rb") as reader: with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader) train_features = pickle.load(reader)
...@@ -311,7 +311,7 @@ def main(): ...@@ -311,7 +311,7 @@ def main():
input_ids, input_mask, segment_ids, label_ids = batch input_ids, input_mask, segment_ids, label_ids = batch
# define a new function to compute loss values for both output_modes # define a new function to compute loss values for both output_modes
logits = model(input_ids, segment_ids, input_mask, labels=None) logits = model(input_ids, segment_ids, input_mask)
if output_mode == "classification": if output_mode == "classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
...@@ -380,6 +380,22 @@ def main(): ...@@ -380,6 +380,22 @@ def main():
### Evaluation ### Evaluation
if args.do_eval: if args.do_eval:
eval_examples = processor.get_dev_examples(args.data_dir) eval_examples = processor.get_dev_examples(args.data_dir)
cached_train_features_file = os.path.join(args.data_dir, 'dev_{0}_{1}_{2}'.format(
list(filter(None, args.bert_model.split('/'))).pop(),
str(args.max_seq_length),
str(task_name)))
try:
with open(cached_train_features_file, "rb") as reader:
train_features = pickle.load(reader)
except:
train_features = convert_examples_to_features(
train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
logger.info(" Saving train features into cached file %s", cached_train_features_file)
with open(cached_train_features_file, "wb") as writer:
pickle.dump(train_features, writer)
eval_features = convert_examples_to_features( eval_features = convert_examples_to_features(
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode) eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
logger.info("***** Running evaluation *****") logger.info("***** Running evaluation *****")
...@@ -414,7 +430,7 @@ def main(): ...@@ -414,7 +430,7 @@ def main():
label_ids = label_ids.to(device) label_ids = label_ids.to(device)
with torch.no_grad(): with torch.no_grad():
logits = model(input_ids, segment_ids, input_mask, labels=None) logits = model(input_ids, segment_ids, input_mask)
# create eval loss and other metric required by the task # create eval loss and other metric required by the task
if output_mode == "classification": if output_mode == "classification":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment