Unverified Commit 739d6ec7 authored by Junsong Chen's avatar Junsong Chen Committed by GitHub
Browse files

add a timestep scale for sana-sprint teacher model (#11150)

parent 1ddf3f3a
...@@ -326,6 +326,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -326,6 +326,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
Whether to use elementwise affinity in the normalization layer. Whether to use elementwise affinity in the normalization layer.
norm_eps (`float`, defaults to `1e-6`): norm_eps (`float`, defaults to `1e-6`):
The epsilon value for the normalization layer. The epsilon value for the normalization layer.
qk_norm (`str`, *optional*, defaults to `None`):
The normalization to use for the query and key.
timestep_scale (`float`, defaults to `1.0`):
The scale to use for the timesteps.
""" """
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
...@@ -355,6 +359,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -355,6 +359,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
guidance_embeds: bool = False, guidance_embeds: bool = False,
guidance_embeds_scale: float = 0.1, guidance_embeds_scale: float = 0.1,
qk_norm: Optional[str] = None, qk_norm: Optional[str] = None,
timestep_scale: float = 1.0,
) -> None: ) -> None:
super().__init__() super().__init__()
......
...@@ -938,6 +938,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin): ...@@ -938,6 +938,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
timestep = timestep * self.transformer.config.timestep_scale
# predict noise model_output # predict noise model_output
noise_pred = self.transformer( noise_pred = self.transformer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment