"vscode:/vscode.git/clone" did not exist on "abd8e77ad861a75a149cff488cc2949bc760f0f5"
Unverified Commit 89fd10a1 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[chat] add zero2 cpu strategy for sft training (#3520)

parent 990d4c3e
...@@ -35,6 +35,8 @@ def train(args): ...@@ -35,6 +35,8 @@ def train(args):
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda') strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2': elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda') strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2_cpu':
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
else: else:
raise ValueError(f'Unsupported strategy "{args.strategy}"') raise ValueError(f'Unsupported strategy "{args.strategy}"')
...@@ -168,7 +170,7 @@ def train(args): ...@@ -168,7 +170,7 @@ def train(args):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--strategy', parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'], choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='naive') default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom') parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None) parser.add_argument('--pretrain', type=str, default=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment