"vscode:/vscode.git/clone" did not exist on "b406c4d2611d9425dabf927d3de0fdf7981de2cb"
Unverified Commit 14373821 authored by StevenTang1998's avatar StevenTang1998 Committed by GitHub
Browse files

Fix the loss calculation of ProphetNet (#13132)

* Fix the loss calculation of ProphetNet

* Fix the loss calculation of ProphetNet

Fix the loss calculation of ProphetNet and remove warning
parent 91ff480e
......@@ -1812,14 +1812,6 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
>>> last_hidden_states = outputs.last_hidden_state # main stream hidden states
>>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states
"""
if self.training:
logger.warning(
"There is a known issue with ProphetNet training/fine-tuning that hasn't been fixed yet:"
"https://github.com/huggingface/transformers/issues/9804. Please try to use an off-the-shelf"
"checkpoint from the model hub or fine-tune another architecture instead."
)
use_cache == use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
......@@ -2006,6 +1998,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
break
expend_targets[i, :, :] = labels
logits = logits.transpose(0, 1).contiguous()
lprobs = nn.functional.log_softmax(
logits.view(-1, logits.size(-1)),
dim=-1,
......@@ -2250,6 +2243,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
break
expend_targets[i, :, :] = labels
logits = logits.transpose(0, 1).contiguous()
lprobs = nn.functional.log_softmax(
logits.view(-1, logits.size(-1)),
dim=-1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment