Unverified Commit 2da5d81d authored by zhang-yi-chi's avatar zhang-yi-chi Committed by GitHub
Browse files

[chat] fix train_prompts.py gemini strategy bug (#3666)

* fix gemini strategy bug

* add comment

* add comment

* better solution
parent d5566488
...@@ -36,45 +36,45 @@ def main(args): ...@@ -36,45 +36,45 @@ def main(args):
if args.rm_path is not None: if args.rm_path is not None:
state_dict = torch.load(args.rm_path, map_location='cpu') state_dict = torch.load(args.rm_path, map_location='cpu')
# configure model with strategy.model_init_context():
if args.model == 'gpt2': # configure model
initial_model = GPTActor(pretrained=args.pretrain) if args.model == 'gpt2':
elif args.model == 'bloom': initial_model = GPTActor(pretrained=args.pretrain)
initial_model = BLOOMActor(pretrained=args.pretrain) elif args.model == 'bloom':
elif args.model == 'opt': initial_model = BLOOMActor(pretrained=args.pretrain)
initial_model = OPTActor(pretrained=args.pretrain) elif args.model == 'opt':
elif args.model == 'llama': initial_model = OPTActor(pretrained=args.pretrain)
initial_model = LlamaActor(pretrained=args.pretrain) elif args.model == 'llama':
elif args.model == 'roberta': initial_model = LlamaActor(pretrained=args.pretrain)
initial_model = RoBERTaActor(pretrained=args.pretrain) elif args.model == 'roberta':
else: initial_model = RoBERTaActor(pretrained=args.pretrain)
raise ValueError(f'Unsupported actor model "{args.model}"') else:
raise ValueError(f'Unsupported actor model "{args.model}"')
if args.rm_model == None: if args.rm_model == None:
rm_model_name = args.model rm_model_name = args.model
else: else:
rm_model_name = args.rm_model rm_model_name = args.rm_model
if rm_model_name == 'gpt2':
reward_model = GPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'bloom':
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'opt':
reward_model = OPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'llama':
reward_model = LlamaRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'roberta':
reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
if args.rm_path is not None: if rm_model_name == 'gpt2':
reward_model.load_state_dict(state_dict) reward_model = GPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'bloom':
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'opt':
reward_model = OPTRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'llama':
reward_model = LlamaRM(pretrained=args.rm_pretrain)
elif rm_model_name == 'roberta':
reward_model = RoBERTaRM(pretrained=args.rm_pretrain)
else:
raise ValueError(f'Unsupported reward model "{rm_model_name}"')
initial_model.to(torch.float16).to(torch.cuda.current_device()) if args.rm_path is not None:
reward_model.to(torch.float16).to(torch.cuda.current_device()) reward_model.load_state_dict(state_dict)
initial_model.to(torch.float16).to(torch.cuda.current_device())
reward_model.to(torch.float16).to(torch.cuda.current_device())
with strategy.model_init_context():
if args.model == 'gpt2': if args.model == 'gpt2':
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank) actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
elif args.model == 'bloom': elif args.model == 'bloom':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment