Unverified Commit a88bc828 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[chatgpt] disable shard init for colossalai (#2767)

parent d6d6dec1
import warnings
from typing import Optional
import torch
......@@ -23,6 +24,7 @@ class ColossalAIStrategy(DDPStrategy):
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
seed(int): The seed for the random number generator.
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
......@@ -50,7 +52,7 @@ class ColossalAIStrategy(DDPStrategy):
self,
stage: int = 3,
seed: int = 42,
shard_init: bool = True, # only for stage 3
shard_init: bool = False, # only for stage 3
placement_policy: str = 'cuda',
pin_memory: bool = True, # only for stage 3
force_outputs_fp32: bool = False, # only for stage 3
......@@ -72,6 +74,10 @@ class ColossalAIStrategy(DDPStrategy):
super().__init__(seed)
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
self.stage = stage
# TODO(ver217): support shard_init when using from_pretrained()
if shard_init:
warnings.warn(f'Shard init is not supported yet. Ignore.')
shard_init = False
self.shard_init = shard_init
self.gemini_config = dict(device=get_current_device(),
placement_policy=placement_policy,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment