"vscode:/vscode.git/clone" did not exist on "e1b69c13c957ba6521daa5b1054740980dc6fc21"
Unverified Commit b09adff7 authored by Yuanchen's avatar Yuanchen Committed by GitHub
Browse files

[chat]fix sft training for bloom, gpt and opt (#3418)

fix sft training for bloom, gpt and opt 
parent 638a07a7
......@@ -33,3 +33,6 @@ class BLOOMLM(LM):
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
......@@ -33,3 +33,6 @@ class GPTLM(LM):
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
......@@ -33,3 +33,6 @@ class OPTLM(LM):
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment