Unverified Commit 552ff244 authored by Komal Kumar's avatar Komal Kumar Committed by GitHub
Browse files

Fixed base model class name extraction from PeftModels (#27162)

* Fixed base model class name extraction from PeftModels

* Changes to first unwrap the model then extract the base model name

* Changed base_model to base_model.model to stay consistent with peft model abstractions
parent 49912168
...@@ -646,7 +646,7 @@ class Trainer: ...@@ -646,7 +646,7 @@ class Trainer:
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel): if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.get_input_embeddings() embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else: else:
embeddings = unwrapped_model.get_input_embeddings() embeddings = unwrapped_model.get_input_embeddings()
...@@ -667,7 +667,7 @@ class Trainer: ...@@ -667,7 +667,7 @@ class Trainer:
unwrapped_model = unwrap_model(model) unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel): if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.get_input_embeddings() embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else: else:
embeddings = unwrapped_model.get_input_embeddings() embeddings = unwrapped_model.get_input_embeddings()
...@@ -2752,10 +2752,11 @@ class Trainer: ...@@ -2752,10 +2752,11 @@ class Trainer:
self._past = outputs[self.args.past_index] self._past = outputs[self.args.past_index]
if labels is not None: if labels is not None:
if is_peft_available() and isinstance(model, PeftModel): unwrapped_model = unwrap_model(model)
model_name = unwrap_model(model.base_model)._get_name() if is_peft_available() and isinstance(unwrapped_model, PeftModel):
model_name = unwrapped_model.base_model.model._get_name()
else: else:
model_name = unwrap_model(model)._get_name() model_name = unwrapped_model._get_name()
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
loss = self.label_smoother(outputs, labels, shift_labels=True) loss = self.label_smoother(outputs, labels, shift_labels=True)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment