"tests/test_zero/test_shard_model_v2.py" did not exist on "08eccfe681c7f2f6b47b1e9a162252ad0068b6af"
Unverified Commit 89fd10a1 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[chat] add zero2 cpu strategy for sft training (#3520)

parent 990d4c3e
......@@ -35,6 +35,8 @@ def train(args):
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2':
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
elif args.strategy == 'colossalai_zero2_cpu':
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
else:
raise ValueError(f'Unsupported strategy "{args.strategy}"')
......@@ -168,7 +170,7 @@ def train(args):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2', 'colossalai_zero2_cpu'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment