Unverified Commit 3611fc90 authored by Nate Brake's avatar Nate Brake Committed by GitHub
Browse files

compute_loss in trainer failing to label shift for PEFT model when label...


compute_loss in trainer failing to label shift for PEFT model when label smoothing enabled. (#25044)

* added PeftModelForCausalLM to MODEL_FOR_CAUSAL_LM_MAPPING_NAMES dict

* check for PEFT model in compute_loss section

---------
Co-authored-by: default avatarNathan Brake <nbrake3@mmm.com>
parent a03d13c8
...@@ -2677,7 +2677,11 @@ class Trainer: ...@@ -2677,7 +2677,11 @@ class Trainer:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
if labels is not None: if labels is not None:
if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): if is_peft_available() and isinstance(model, PeftModel):
model_name = unwrap_model(model.base_model)._get_name()
else:
model_name = unwrap_model(model)._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True) loss = self.label_smoother(outputs, labels, shift_labels=True)
else: else:
loss = self.label_smoother(outputs, labels) loss = self.label_smoother(outputs, labels)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment