Commit 2d07f945 authored by VictorSanh's avatar VictorSanh
Browse files

fix error with torch.no_grad and loss computation

parent 6b8d2270
......@@ -238,8 +238,7 @@ def bertForSequenceClassification(*args, **kwargs):
seq_classif_logits = model(tokens_tensor, segments_tensors)
# Or get the sequence classification loss
>>> labels = torch.tensor([1])
>>> with torch.no_grad():
seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels)
>>> seq_classif_loss = model(tokens_tensor, segments_tensors, labels=labels)
"""
model = BertForSequenceClassification.from_pretrained(*args, **kwargs)
return model
......@@ -273,8 +272,7 @@ def bertForMultipleChoice(*args, **kwargs):
multiple_choice_logits = model(tokens_tensor, segments_tensors)
# Or get the multiple choice loss
>>> labels = torch.tensor([1])
>>> with torch.no_grad():
multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels)
>>> multiple_choice_loss = model(tokens_tensor, segments_tensors, labels=labels)
"""
model = BertForMultipleChoice.from_pretrained(*args, **kwargs)
return model
......@@ -306,8 +304,7 @@ def bertForQuestionAnswering(*args, **kwargs):
start_logits, end_logits = model(tokens_tensor, segments_tensors)
# Or get the total loss which is the sum of the CrossEntropy loss for the start and end token positions
>>> start_positions, end_positions = torch.tensor([12]), torch.tensor([14])
>>> with torch.no_grad():
multiple_choice_loss = model(tokens_tensor, segments_tensors, start_positions=start_positions, end_positions=end_positions)
>>> multiple_choice_loss = model(tokens_tensor, segments_tensors, start_positions=start_positions, end_positions=end_positions)
"""
model = BertForQuestionAnswering.from_pretrained(*args, **kwargs)
return model
......@@ -344,8 +341,7 @@ def bertForTokenClassification(*args, **kwargs):
classif_logits = model(tokens_tensor, segments_tensors)
# Or get the token classification loss
>>> labels = torch.tensor([[0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1, 0]])
>>> with torch.no_grad():
classif_loss = model(tokens_tensor, segments_tensors, labels=labels)
>>> classif_loss = model(tokens_tensor, segments_tensors, labels=labels)
"""
model = BertForTokenClassification.from_pretrained(*args, **kwargs)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment