Unverified Commit c9e27f0d authored by BlueRum's avatar BlueRum Committed by GitHub
Browse files

[chatgpt]fix lora bug (#2974)

* fix lora bug

* polish
parent 82149e9d
...@@ -23,7 +23,7 @@ class RewardModel(LoRAModule): ...@@ -23,7 +23,7 @@ class RewardModel(LoRAModule):
lora_rank: int = 0, lora_rank: int = 0,
lora_train_bias: str = 'none') -> None: lora_train_bias: str = 'none') -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias) super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.body = model self.model = model
if value_head is not None: if value_head is not None:
if value_head.out_features != 1: if value_head.out_features != 1:
raise ValueError("The value head of reward model's output dim should be 1!") raise ValueError("The value head of reward model's output dim should be 1!")
...@@ -34,7 +34,7 @@ class RewardModel(LoRAModule): ...@@ -34,7 +34,7 @@ class RewardModel(LoRAModule):
self.convert_to_lora() self.convert_to_lora()
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
outputs = self.body(sequences, attention_mask=attention_mask) outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs['last_hidden_state'] last_hidden_states = outputs['last_hidden_state']
values = self.value_head(last_hidden_states)[:, :-1] values = self.value_head(last_hidden_states)[:, :-1]
value = values.mean(dim=1).squeeze(1) # ensure shape is (B) value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
......
...@@ -44,6 +44,8 @@ class RewardModelTrainer(ABC): ...@@ -44,6 +44,8 @@ class RewardModelTrainer(ABC):
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size) self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size)
self.model = strategy.setup_model(model) self.model = strategy.setup_model(model)
if "DDP" in str(self.strategy):
self.model = self.model.module
self.loss_fn = PairWiseLoss() self.loss_fn = PairWiseLoss()
self.optimizer = strategy.setup_optimizer(optim, self.model) self.optimizer = strategy.setup_optimizer(optim, self.model)
...@@ -56,7 +58,7 @@ class RewardModelTrainer(ABC): ...@@ -56,7 +58,7 @@ class RewardModelTrainer(ABC):
# train # train
if use_lora > 0: if use_lora > 0:
print("Using Lora") print("Using Lora")
lora.mark_only_lora_as_trainable(self.model.body) lora.mark_only_lora_as_trainable(self.model.model)
else: else:
self.model.train() self.model.train()
......
...@@ -61,8 +61,8 @@ def train(args): ...@@ -61,8 +61,8 @@ def train(args):
# prepare for data and dataset # prepare for data and dataset
data = load_dataset(args.dataset) data = load_dataset(args.dataset)
train_data = data["train"].select(range(100)) train_data = data["train"]
eval_data = data['test'].select(range(5)) eval_data = data['test']
train_dataset = RewardDataset(train_data, tokenizer, max_len) train_dataset = RewardDataset(train_data, tokenizer, max_len)
eval_dataset = RewardDataset(eval_data, tokenizer, max_len) eval_dataset = RewardDataset(eval_data, tokenizer, max_len)
...@@ -93,7 +93,7 @@ if __name__ == '__main__': ...@@ -93,7 +93,7 @@ if __name__ == '__main__':
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='Dahoas/rm-static') parser.add_argument('--dataset', type=str, default='Dahoas/rm-static')
parser.add_argument('--save_path', type=str, default='rm_ckpt.pth') parser.add_argument('--save_path', type=str, default='rm_ckpt.pth')
parser.add_argument('--max_epochs', type=int, default=10) parser.add_argument('--max_epochs', type=int, default=1)
parser.add_argument('--batch_size', type=int, default=4) parser.add_argument('--batch_size', type=int, default=4)
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank") parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
args = parser.parse_args() args = parser.parse_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment