"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "32798bf242a6b15e91a6fadc444f8806b4e8bb46"
Unverified Commit 089f0f4c authored by Fazzie-Maqianli's avatar Fazzie-Maqianli Committed by GitHub
Browse files

update to latest colossalai (#1951)

parent aba2a65d
...@@ -15,8 +15,7 @@ from colossalai.context.parallel_mode import ParallelMode ...@@ -15,8 +15,7 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
from colossalai.nn.parallel.utils import convert_to_torch_module from colossalai.nn.parallel.utils import get_static_torch_model
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
...@@ -356,26 +355,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: ...@@ -356,26 +355,17 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
# Gemini + ZeRO DDP # Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"): def gemini_zero_dpp(model: torch.nn.Module, placememt_policy: str = "auto"):
from colossalai.nn.parallel import GeminiDDP from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP( model = GeminiDDP(
model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=32 model, device=get_current_device(), placement_policy=placememt_policy, pin_memory=True, search_range_mb=64
) )
return model return model
def main(args): def main(args):
# config for colossalai colossalai.launch_from_torch(config={})
config = {
"BATCH": args.train_batch_size,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"clip_grad_norm": args.max_grad_norm,
}
colossalai.launch_from_torch(config=config)
pg = ProcessGroup()
if args.seed is not None: if args.seed is not None:
gpc.set_seed(args.seed) gpc.set_seed(args.seed)
...@@ -472,7 +462,7 @@ def main(args): ...@@ -472,7 +462,7 @@ def main(args):
) )
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
with ColoInitContext(): with ColoInitContext(device=get_current_device()):
unet = UNet2DConditionModel.from_pretrained( unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, low_cpu_mem_usage=False
) )
...@@ -484,12 +474,19 @@ def main(args): ...@@ -484,12 +474,19 @@ def main(args):
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.scale_lr: if args.scale_lr:
args.learning_rate = args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * 2 args.learning_rate = (
args.learning_rate
* args.gradient_accumulation_steps
* args.train_batch_size
* gpc.get_world_size(ParallelMode.DATA)
)
unet = gemini_zero_dpp(unet, pg, args.placement) unet = gemini_zero_dpp(unet, args.placement)
# config optimizer for colossalai zero # config optimizer for colossalai zero
optimizer = GeminiAdamOptimizer(unet, lr=args.learning_rate, initial_scale=2**5) optimizer = GeminiAdamOptimizer(
unet, lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm
)
# load noise_scheduler # load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
...@@ -657,10 +654,11 @@ def main(args): ...@@ -657,10 +654,11 @@ def main(args):
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
torch.cuda.synchronize() torch.cuda.synchronize()
torch_unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0: if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=convert_to_torch_module(unet), unet=torch_unet,
revision=args.revision, revision=args.revision,
) )
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
...@@ -670,7 +668,7 @@ def main(args): ...@@ -670,7 +668,7 @@ def main(args):
break break
torch.cuda.synchronize() torch.cuda.synchronize()
unet = convert_to_torch_module(unet) unet = get_static_torch_model(unet)
if gpc.get_local_rank(ParallelMode.DATA) == 0: if gpc.get_local_rank(ParallelMode.DATA) == 0:
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment