• dg845's avatar
    Add Consistency Models Pipeline (#3492) · aed7499a
    dg845 authored
    
    
    * initial commit
    
    * Improve consistency models sampling implementation.
    
    * Add CMStochasticIterativeScheduler, which implements the multi-step sampler (stochastic_iterative_sampler) in the original code, and make further improvements to sampling.
    
    * Add Unet blocks for consistency models
    
    * Add conversion script for Unet
    
    * Fix bug in new unet blocks
    
    * Fix attention weight loading
    
    * Make design improvements to ConsistencyModelPipeline and CMStochasticIterativeScheduler and add initial version of tests.
    
    * make style
    
    * Make small random test UNet class conditional and set resnet_time_scale_shift to 'scale_shift' to better match consistency model checkpoints.
    
    * Add support for converting a test UNet and non-class-conditional UNets to the consistency models conversion script.
    
    * make style
    
    * Change num_class_embeds to 1000 to better match the original consistency models implementation.
    
    * Add support for distillation in pipeline_consistency_models.py.
    
    * Improve consistency model tests:
    	- Get small testing checkpoints from hub
    	- Modify tests to take into account "distillation" parameter of ConsistencyModelPipeline
    	- Add onestep, multistep tests for distillation and distillation + class conditional
    	- Add expected image slices for onestep tests
    
    * make style
    
    * Improve ConsistencyModelPipeline:
    	- Add initial support for class-conditional generation
    	- Fix initial sigma for onestep generation
    	- Fix some sigma shape issues
    
    * make style
    
    * Improve ConsistencyModelPipeline:
    	- add latents __call__ argument and prepare_latents method
    	- add check_inputs method
    	- add initial docstrings for ConsistencyModelPipeline.__call__
    
    * make style
    
    * Fix bug when randomly generating class labels for class-conditional generation.
    
    * Switch CMStochasticIterativeScheduler to configuring a sigma schedule and make related changes to the pipeline and tests.
    
    * Remove some unused code and make style.
    
    * Fix small bug in CMStochasticIterativeScheduler.
    
    * Add expected slices for multistep sampling tests and make them pass.
    
    * Work on consistency model fast tests:
    	- in pipeline, call self.scheduler.scale_model_input before denoising
    	- get expected slices for Euler and Heun scheduler tests
    	- make Euler test pass
    	- mark Heun test as expected fail because it doesn't support prediction_type "sample" yet
    	- remove DPM and Euler Ancestral tests because they don't support use_karras_sigmas
    
    * make style
    
    * Refactor conversion script to make it easier to add more model architectures to convert in the future.
    
    * Work on ConsistencyModelPipeline tests:
    	- Fix device bug when handling class labels in ConsistencyModelPipeline.__call__
    	- Add slow tests for onestep and multistep sampling and make them pass
    	- Refactor fast tests
    	- Refactor ConsistencyModelPipeline.__init__
    
    * make style
    
    * Remove the add_noise and add_noise_to_input methods from CMStochasticIterativeScheduler for now.
    
    * Run python utils/check_copies.py --fix_and_overwrite
    python utils/check_dummies.py --fix_and_overwrite to make dummy objects for new pipeline and scheduler.
    
    * Make fast tests from PipelineTesterMixin pass.
    
    * make style
    
    * Refactor consistency models pipeline and scheduler:
    	- Remove support for Karras schedulers (only support CMStochasticIterativeScheduler)
    	- Move sigma manipulation, input scaling, denoising from pipeline to scheduler
    	- Make corresponding changes to tests and ensure they pass
    
    * make style
    
    * Add docstrings and further refactor pipeline and scheduler.
    
    * make style
    
    * Add initial version of the consistency models documentation.
    
    * Refactor custom timesteps logic following DDPMScheduler/IFPipeline and temporarily add torch 2.0 SDPA kernel selection logic for debugging.
    
    * make style
    
    * Convert current slow tests to use fp16 and flash attention.
    
    * make style
    
    * Add slow tests for normal attention on cuda device.
    
    * make style
    
    * Fix attention weights loading
    
    * Update consistency model fast tests for new test checkpoints with attention fix.
    
    * make style
    
    * apply suggestions
    
    * Add add_noise method to CMStochasticIterativeScheduler (copied from EulerDiscreteScheduler).
    
    * Conversion script now outputs pipeline instead of UNet and add support for LSUN-256 models and different schedulers.
    
    * When both timesteps and num_inference_steps are supplied, raise warning instead of error (timesteps take precedence).
    
    * make style
    
    * Add remaining diffusers model checkpoints for models in the original consistency model release and update usage example.
    
    * apply suggestions from review
    
    * make style
    
    * fix attention naming
    
    * Add tests for CMStochasticIterativeScheduler.
    
    * make style
    
    * Make CMStochasticIterativeScheduler tests pass.
    
    * make style
    
    * Override test_step_shape in CMStochasticIterativeSchedulerTest instead of modifying it in SchedulerCommonTest.
    
    * make style
    
    * rename some models
    
    * Improve API
    
    * rename some models
    
    * Remove duplicated block
    
    * Add docstring and make torch compile work
    
    * More fixes
    
    * Fixes
    
    * Apply suggestions from code review
    
    * Apply suggestions from code review
    
    * add more docstring
    
    * update consistency conversion script
    
    ---------
    Co-authored-by: default avatarayushmangal <ayushmangal@microsoft.com>
    Co-authored-by: default avatarAyush Mangal <43698245+ayushtues@users.noreply.github.com>
    Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
    aed7499a
dummy_pt_objects.py 19.1 KB