Unverified Commit 14373821 authored by StevenTang1998's avatar StevenTang1998 Committed by GitHub
Browse files

Fix the loss calculation of ProphetNet (#13132)

* Fix the loss calculation of ProphetNet

* Fix the loss calculation of ProphetNet

Fix the loss calculation of ProphetNet and remove warning
parent 91ff480e
...@@ -1812,14 +1812,6 @@ class ProphetNetModel(ProphetNetPreTrainedModel): ...@@ -1812,14 +1812,6 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
>>> last_hidden_states = outputs.last_hidden_state # main stream hidden states >>> last_hidden_states = outputs.last_hidden_state # main stream hidden states
>>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states >>> last_hidden_states_ngram = outputs.last_hidden_state_ngram # predict hidden states
""" """
if self.training:
logger.warning(
"There is a known issue with ProphetNet training/fine-tuning that hasn't been fixed yet:"
"https://github.com/huggingface/transformers/issues/9804. Please try to use an off-the-shelf"
"checkpoint from the model hub or fine-tune another architecture instead."
)
use_cache == use_cache if use_cache is not None else self.config.use_cache use_cache == use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -2006,6 +1998,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -2006,6 +1998,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
break break
expend_targets[i, :, :] = labels expend_targets[i, :, :] = labels
logits = logits.transpose(0, 1).contiguous()
lprobs = nn.functional.log_softmax( lprobs = nn.functional.log_softmax(
logits.view(-1, logits.size(-1)), logits.view(-1, logits.size(-1)),
dim=-1, dim=-1,
...@@ -2250,6 +2243,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -2250,6 +2243,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
break break
expend_targets[i, :, :] = labels expend_targets[i, :, :] = labels
logits = logits.transpose(0, 1).contiguous()
lprobs = nn.functional.log_softmax( lprobs = nn.functional.log_softmax(
logits.view(-1, logits.size(-1)), logits.view(-1, logits.size(-1)),
dim=-1, dim=-1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment