Unverified Commit 2da5d81d authored by zhang-yi-chi's avatar zhang-yi-chi Committed by GitHub
Browse files

[chat] fix train_prompts.py gemini strategy bug (#3666)

* fix gemini strategy bug

* add comment

* add comment

* better solution
parent d5566488
......@@ -36,6 +36,7 @@ def main(args):
if args.rm_path is not None:
state_dict = torch.load(args.rm_path, map_location='cpu')
with strategy.model_init_context():
# configure model
if args.model == 'gpt2':
initial_model = GPTActor(pretrained=args.pretrain)
......@@ -74,7 +75,6 @@ def main(args):
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
if args.model == 'gpt2':
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'bloom':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment