Commit 1113f97f authored by thomwolf's avatar thomwolf
Browse files

clean up glue example

parent 162ba383
......@@ -309,14 +309,7 @@ def main():
# define a new function to compute loss values for both output_modes
ouputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
loss =
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif output_mode == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), label_ids.view(-1))
loss = ouputs[0]
if n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu.
......@@ -423,15 +416,8 @@ def main():
label_ids = label_ids.to(device)
with torch.no_grad():
logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
# create eval loss and other metric required by the task
if output_mode == "classification":
loss_fct = CrossEntropyLoss()
tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
elif output_mode == "regression":
loss_fct = MSELoss()
tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
outputs = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
tmp_eval_loss, logits = outputs[:2]
eval_loss += tmp_eval_loss.mean().item()
nb_eval_steps += 1
......
This diff is collapsed.
......@@ -583,6 +583,7 @@ processors = {
output_modes = {
"cola": "classification",
"mnli": "classification",
"mnli-mm": "classification",
"mrpc": "classification",
"sst-2": "classification",
"sts-b": "regression",
......
......@@ -110,6 +110,24 @@ class PreTrainedTokenizer(object):
return tokenizer
def tokenize(self, text):
raise NotImplementedError
def convert_tokens_to_ids(self, tokens):
raise NotImplementedError
def convert_ids_to_tokens(self, ids):
raise NotImplementedError
def encode(self, text):
raise NotImplementedError
def decode(self, token_ids, *input, **kwargs):
raise NotImplementedError
def save_vocabulary(self, vocab_path):
raise NotImplementedError
def clean_up_tokenization(out_string):
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment