from stepvideo.diffusion.video_pipeline import StepVideoPipeline import torch.distributed as dist import torch from stepvideo.config import parse_args from stepvideo.parallel import initialize_parall_group, get_parallel_group from stepvideo.utils import setup_seed if __name__ == "__main__": args = parse_args() initialize_parall_group(ring_degree=args.ring_degree, ulysses_degree=args.ulysses_degree) local_rank = get_parallel_group().local_rank device = torch.device(f"cuda:{local_rank}") setup_seed(args.seed) pipeline = StepVideoPipeline.from_pretrained(args.model_dir).to(dtype=torch.bfloat16, device="cpu") pipeline.transformer = pipeline.transformer.to(device) pipeline.setup_pipeline(args) prompt = args.prompt videos = pipeline( prompt=prompt, first_image=args.first_image_path, num_frames=args.num_frames, height=args.height, width=args.width, num_inference_steps = args.infer_steps, guidance_scale=args.cfg_scale, time_shift=args.time_shift, pos_magic=args.pos_magic, neg_magic=args.neg_magic, output_file_name=args.output_file_name or prompt[:50], motion_score=args.motion_score, ) dist.destroy_process_group()