"docs/vscode:/vscode.git/clone" did not exist on "3095ee9dab739f212a8753b5be4e1a72ba42e28e"
Commit 511bce58 authored by thomwolf's avatar thomwolf
Browse files

update new token classification model

parent 258eb500
...@@ -932,8 +932,8 @@ class BertForTokenClassification(PreTrainedBertModel): ...@@ -932,8 +932,8 @@ class BertForTokenClassification(PreTrainedBertModel):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
pooled_output = self.dropout(sequence_output) sequence_output = self.dropout(sequence_output)
logits = self.classifier(pooled_output) logits = self.classifier(sequence_output)
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment