"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "8d99bffbdcb1db3496fa64c92fe6fe4009b524e1"
Commit ea9dbea9 authored by thomwolf's avatar thomwolf
Browse files

update GPT2 loss computation for more flexbility

parent ce863365
...@@ -336,6 +336,7 @@ class GPT2MultipleChoiceHead(nn.Module): ...@@ -336,6 +336,7 @@ class GPT2MultipleChoiceHead(nn.Module):
# (bsz, num_choices, 1, hidden_size) # (bsz, num_choices, 1, hidden_size)
multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2) multiple_choice_h = hidden_states.gather(2, mc_token_ids).squeeze(2)
# (bsz, num_choices, hidden_size) # (bsz, num_choices, hidden_size)
multiple_choice_h = self.dropout(multiple_choice_h.transpose(1, 2)).transpose(1, 2)
multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1) multiple_choice_logits = self.linear(multiple_choice_h).squeeze(-1)
# (bsz, num_choices) # (bsz, num_choices)
return multiple_choice_logits return multiple_choice_logits
...@@ -665,9 +666,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -665,9 +666,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[:, :-1].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[:, 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
# Flatten the tokens # Flatten the tokens
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
...@@ -746,11 +746,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -746,11 +746,10 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
losses = [] losses = []
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[:, :-1].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[:, 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
losses.append(loss_fct(shift_logits.view(-1, losses.append(loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)))
shift_logits.size(-1)), shift_labels.view(-1)))
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1))) losses.append(loss_fct(mc_logits.view(-1, mc_logits.size(-1)), mc_labels.view(-1)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment