Unverified Commit b09adff7 authored by Yuanchen's avatar Yuanchen Committed by GitHub
Browse files

[chat]fix sft training for bloom, gpt and opt (#3418)

fix sft training for bloom, gpt and opt 
parent 638a07a7
...@@ -33,3 +33,6 @@ class BLOOMLM(LM): ...@@ -33,3 +33,6 @@ class BLOOMLM(LM):
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias) super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
...@@ -33,3 +33,6 @@ class GPTLM(LM): ...@@ -33,3 +33,6 @@ class GPTLM(LM):
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias) super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
...@@ -33,3 +33,6 @@ class OPTLM(LM): ...@@ -33,3 +33,6 @@ class OPTLM(LM):
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias) super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment