Unverified Commit 1d37f420 authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

Modify the implementation of retrieve_timesteps in CogView4-Control. (#11125)



* 1

* change to channel 1

* cogview4 control training

* add CacheMixin

* 1

* remove initial_input_channels change for val

* 1

* update

* use 3.5

* new loss

* 1

* use imagetoken

* for megatron convert

* 1

* train con and uc

* 2

* remove guidance_scale

* Update pipeline_cogview4_control.py

* fix

* use cogview4 pipeline with timestep

* update shift_factor

* remove the uncond

* add max length

* change convert and use GLMModel instead of GLMForCasualLM

* fix

* [cogview4] Add attention mask support to transformer model

* [fix] Add attention mask for padded token

* update

* remove padding type

* Update train_control_cogview4.py

* resolve conflicts with #10981

* add control convert

* use control format

* fix

* add missing import

* update with cogview4 formate

* make style

* Update pipeline_cogview4_control.py

* Update pipeline_cogview4_control.py

* remove

* Update pipeline_cogview4_control.py

* put back

* Apply style fixes

---------
Co-authored-by: default avatarOleehyO <leehy0357@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 0213179b
...@@ -68,7 +68,7 @@ def calculate_shift( ...@@ -68,7 +68,7 @@ def calculate_shift(
return mu return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Copied from diffusers.pipelines.cogview4.pipeline_cogview4.retrieve_timesteps
def retrieve_timesteps( def retrieve_timesteps(
scheduler, scheduler,
num_inference_steps: Optional[int] = None, num_inference_steps: Optional[int] = None,
...@@ -100,10 +100,19 @@ def retrieve_timesteps( ...@@ -100,10 +100,19 @@ def retrieve_timesteps(
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps. second element is the number of inference steps.
""" """
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if timesteps is not None and sigmas is not None: if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") if not accepts_timesteps and not accepts_sigmas:
if timesteps is not None: raise ValueError(
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif timesteps is not None and sigmas is None:
if not accepts_timesteps: if not accepts_timesteps:
raise ValueError( raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
...@@ -112,9 +121,8 @@ def retrieve_timesteps( ...@@ -112,9 +121,8 @@ def retrieve_timesteps(
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps timesteps = scheduler.timesteps
num_inference_steps = len(timesteps) num_inference_steps = len(timesteps)
elif sigmas is not None: elif timesteps is None and sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_sigmas:
if not accept_sigmas:
raise ValueError( raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler." f" sigmas schedules. Please check whether you are using the correct scheduler."
...@@ -515,8 +523,8 @@ class CogView4ControlPipeline(DiffusionPipeline): ...@@ -515,8 +523,8 @@ class CogView4ControlPipeline(DiffusionPipeline):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead Whether or not to return a [`~pipelines.pipeline_CogView4.CogView4PipelineOutput`] instead of a plain
of a plain tuple. tuple.
attention_kwargs (`dict`, *optional*): attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
...@@ -532,7 +540,6 @@ class CogView4ControlPipeline(DiffusionPipeline): ...@@ -532,7 +540,6 @@ class CogView4ControlPipeline(DiffusionPipeline):
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `224`): max_sequence_length (`int`, defaults to `224`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results. Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
Examples: Examples:
Returns: Returns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment