Unverified Commit 86294d3c authored by co63oc's avatar co63oc Committed by GitHub
Browse files

Fix typos in docs and comments (#11416)



* Fix typos in docs and comments

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent d70f8ee1
...@@ -394,7 +394,7 @@ if __name__ == "__main__": ...@@ -394,7 +394,7 @@ if __name__ == "__main__":
help="Scheduler type to use. Use 'scm' for Sana Sprint models.", help="Scheduler type to use. Use 'scm' for Sana Sprint models.",
) )
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipelien elemets in one.") parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.") parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -984,7 +984,7 @@ def renderer(*, args, checkpoint_map_location): ...@@ -984,7 +984,7 @@ def renderer(*, args, checkpoint_map_location):
return renderer_model return renderer_model
# prior model will expect clip_mean and clip_std, whic are missing from the state_dict # prior model will expect clip_mean and clip_std, which are missing from the state_dict
PRIOR_EXPECTED_MISSING_KEYS = ["clip_mean", "clip_std"] PRIOR_EXPECTED_MISSING_KEYS = ["clip_mean", "clip_std"]
......
...@@ -55,8 +55,8 @@ for key in orig_state_dict.keys(): ...@@ -55,8 +55,8 @@ for key in orig_state_dict.keys():
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else: else:
state_dict[key] = orig_state_dict[key] state_dict[key] = orig_state_dict[key]
deocder = WuerstchenDiffNeXt() decoder = WuerstchenDiffNeXt()
deocder.load_state_dict(state_dict) decoder.load_state_dict(state_dict)
# Prior # Prior
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"] orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
...@@ -94,7 +94,7 @@ prior_pipeline = WuerstchenPriorPipeline( ...@@ -94,7 +94,7 @@ prior_pipeline = WuerstchenPriorPipeline(
prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior") prior_pipeline.save_pretrained("warp-ai/wuerstchen-prior")
decoder_pipeline = WuerstchenDecoderPipeline( decoder_pipeline = WuerstchenDecoderPipeline(
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=decoder, scheduler=scheduler
) )
decoder_pipeline.save_pretrained("warp-ai/wuerstchen") decoder_pipeline.save_pretrained("warp-ai/wuerstchen")
...@@ -103,7 +103,7 @@ wuerstchen_pipeline = WuerstchenCombinedPipeline( ...@@ -103,7 +103,7 @@ wuerstchen_pipeline = WuerstchenCombinedPipeline(
# Decoder # Decoder
text_encoder=gen_text_encoder, text_encoder=gen_text_encoder,
tokenizer=gen_tokenizer, tokenizer=gen_tokenizer,
decoder=deocder, decoder=decoder,
scheduler=scheduler, scheduler=scheduler,
vqgan=vqmodel, vqgan=vqmodel,
# Prior # Prior
......
...@@ -243,7 +243,7 @@ class GroupOffloadingHook(ModelHook): ...@@ -243,7 +243,7 @@ class GroupOffloadingHook(ModelHook):
class LazyPrefetchGroupOffloadingHook(ModelHook): class LazyPrefetchGroupOffloadingHook(ModelHook):
r""" r"""
A hook, used in conjuction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module. A hook, used in conjunction with GroupOffloadingHook, that applies lazy prefetching to groups of torch.nn.Module.
This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer This hook is used to determine the order in which the layers are executed during the forward pass. Once the layer
invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows invocation order is known, assignments of the next_group attribute for prefetching can be made, which allows
prefetching groups in the correct order. prefetching groups in the correct order.
......
...@@ -90,7 +90,7 @@ class PeftInputAutocastDisableHook(ModelHook): ...@@ -90,7 +90,7 @@ class PeftInputAutocastDisableHook(ModelHook):
that the inputs are casted to the computation dtype correctly always. However, there are two goals we are that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
hoping to achieve: hoping to achieve:
1. Making forward implementations independent of device/dtype casting operations as much as possible. 1. Making forward implementations independent of device/dtype casting operations as much as possible.
2. Peforming inference without losing information from casting to different precisions. With the current 2. Performing inference without losing information from casting to different precisions. With the current
PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
......
...@@ -819,7 +819,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): ...@@ -819,7 +819,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
if zero_status_pe: if zero_status_pe:
logger.info( logger.info(
"The `position_embedding` LoRA params are all zeros which make them ineffective. " "The `position_embedding` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible." "So, we will purge them out of the current state dict to make loading possible."
) )
else: else:
...@@ -835,7 +835,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): ...@@ -835,7 +835,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
if zero_status_t5: if zero_status_t5:
logger.info( logger.info(
"The `t5xxl` LoRA params are all zeros which make them ineffective. " "The `t5xxl` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible." "So, we will purge them out of the current state dict to make loading possible."
) )
else: else:
logger.info( logger.info(
...@@ -850,7 +850,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): ...@@ -850,7 +850,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
if zero_status_diff_b: if zero_status_diff_b:
logger.info( logger.info(
"The `diff_b` LoRA params are all zeros which make them ineffective. " "The `diff_b` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible." "So, we will purge them out of the current state dict to make loading possible."
) )
else: else:
logger.info( logger.info(
...@@ -866,7 +866,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict): ...@@ -866,7 +866,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
if zero_status_diff: if zero_status_diff:
logger.info( logger.info(
"The `diff` LoRA params are all zeros which make them ineffective. " "The `diff` LoRA params are all zeros which make them ineffective. "
"So, we will purge them out of the curret state dict to make loading possible." "So, we will purge them out of the current state dict to make loading possible."
) )
else: else:
logger.info( logger.info(
...@@ -1237,7 +1237,7 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict): ...@@ -1237,7 +1237,7 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
f"double_blocks.{i}.txt_attn.norm.key_norm.scale" f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
) )
# single transfomer blocks # single transformer blocks
for i in range(num_single_layers): for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}." block_prefix = f"single_transformer_blocks.{i}."
......
...@@ -2413,7 +2413,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2413,7 +2413,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
) -> bool: ) -> bool:
""" """
Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and Control LoRA expands the shape of the input layer from (3072, 64) to (3072, 128). This method handles that and
generalizes things a bit so that any parameter that needs expansion receives appropriate treatement. generalizes things a bit so that any parameter that needs expansion receives appropriate treatment.
""" """
state_dict = {} state_dict = {}
if lora_state_dict is not None: if lora_state_dict is not None:
......
...@@ -330,7 +330,7 @@ class PeftAdapterMixin: ...@@ -330,7 +330,7 @@ class PeftAdapterMixin:
new_sd[k] = v new_sd[k] = v
return new_sd return new_sd
# To handle scenarios where we cannot successfully set state dict. If it's unsucessful, # To handle scenarios where we cannot successfully set state dict. If it's unsuccessful,
# we should also delete the `peft_config` associated to the `adapter_name`. # we should also delete the `peft_config` associated to the `adapter_name`.
try: try:
if hotswap: if hotswap:
...@@ -344,7 +344,7 @@ class PeftAdapterMixin: ...@@ -344,7 +344,7 @@ class PeftAdapterMixin:
config=lora_config, config=lora_config,
) )
except Exception as e: except Exception as e:
logger.error(f"Hotswapping {adapter_name} was unsucessful with the following error: \n{e}") logger.error(f"Hotswapping {adapter_name} was unsuccessful with the following error: \n{e}")
raise raise
# the hotswap function raises if there are incompatible keys, so if we reach this point we can set # the hotswap function raises if there are incompatible keys, so if we reach this point we can set
# it to None # it to None
...@@ -379,7 +379,7 @@ class PeftAdapterMixin: ...@@ -379,7 +379,7 @@ class PeftAdapterMixin:
module.delete_adapter(adapter_name) module.delete_adapter(adapter_name)
self.peft_config.pop(adapter_name) self.peft_config.pop(adapter_name)
logger.error(f"Loading {adapter_name} was unsucessful with the following error: \n{e}") logger.error(f"Loading {adapter_name} was unsuccessful with the following error: \n{e}")
raise raise
warn_msg = "" warn_msg = ""
...@@ -712,7 +712,7 @@ class PeftAdapterMixin: ...@@ -712,7 +712,7 @@ class PeftAdapterMixin:
if self.lora_scale != 1.0: if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale) module.scale_layer(self.lora_scale)
# For BC with prevous PEFT versions, we need to check the signature # For BC with previous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument. # of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters) supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs: if "adapter_names" in supported_merge_kwargs:
......
...@@ -453,7 +453,7 @@ class FromSingleFileMixin: ...@@ -453,7 +453,7 @@ class FromSingleFileMixin:
logger.warning( logger.warning(
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n" "Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
"This may lead to errors if the model components are not correctly inferred. \n" "This may lead to errors if the model components are not correctly inferred. \n"
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n" "To avoid this warning, please explicitly pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n" "e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with " "or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
"the necessary config files.\n" "the necessary config files.\n"
......
...@@ -2278,7 +2278,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): ...@@ -2278,7 +2278,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
f"double_blocks.{i}.txt_attn.proj.bias" f"double_blocks.{i}.txt_attn.proj.bias"
) )
# single transfomer blocks # single transformer blocks
for i in range(num_single_layers): for i in range(num_single_layers):
block_prefix = f"single_transformer_blocks.{i}." block_prefix = f"single_transformer_blocks.{i}."
# norm.linear <- single_blocks.0.modulation.lin # norm.linear <- single_blocks.0.modulation.lin
...@@ -2872,7 +2872,7 @@ def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): ...@@ -2872,7 +2872,7 @@ def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
def convert_lumina2_to_diffusers(checkpoint, **kwargs): def convert_lumina2_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {} converted_state_dict = {}
# Original Lumina-Image-2 has an extra norm paramter that is unused # Original Lumina-Image-2 has an extra norm parameter that is unused
# We just remove it here # We just remove it here
checkpoint.pop("norm_final.weight", None) checkpoint.pop("norm_final.weight", None)
......
...@@ -123,7 +123,7 @@ class SD3Transformer2DLoadersMixin: ...@@ -123,7 +123,7 @@ class SD3Transformer2DLoadersMixin:
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj") key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
updated_state_dict[key] = value updated_state_dict[key] = value
# Image projetion parameters # Image projection parameters
embed_dim = updated_state_dict["proj_in.weight"].shape[1] embed_dim = updated_state_dict["proj_in.weight"].shape[1]
output_dim = updated_state_dict["proj_out.weight"].shape[0] output_dim = updated_state_dict["proj_out.weight"].shape[0]
hidden_dim = updated_state_dict["proj_in.weight"].shape[0] hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
......
...@@ -734,17 +734,17 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -734,17 +734,17 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
unet (`UNet2DConditionModel`): unet (`UNet2DConditionModel`):
The UNet model we want to control. The UNet model we want to control.
controlnet (`ControlNetXSAdapter`): controlnet (`ControlNetXSAdapter`):
The ConntrolNet-XS adapter with which the UNet will be fused. If none is given, a new ConntrolNet-XS The ControlNet-XS adapter with which the UNet will be fused. If none is given, a new ControlNet-XS
adapter will be created. adapter will be created.
size_ratio (float, *optional*, defaults to `None`): size_ratio (float, *optional*, defaults to `None`):
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`): ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details, Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
where this parameter is called `block_out_channels`. where this parameter is called `block_out_channels`.
time_embedding_mix (`float`, *optional*, defaults to None): time_embedding_mix (`float`, *optional*, defaults to None):
Used to contruct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details. Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`): ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
Passed to the `init` of the new controlent if no controlent was given. Passed to the `init` of the new controlnet if no controlnet was given.
""" """
if controlnet is None: if controlnet is None:
controlnet = ControlNetXSAdapter.from_unet( controlnet = ControlNetXSAdapter.from_unet(
......
...@@ -97,7 +97,7 @@ def get_3d_sincos_pos_embed( ...@@ -97,7 +97,7 @@ def get_3d_sincos_pos_embed(
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
spatial dimensions (height and width). spatial dimensions (height and width).
temporal_size (`int`): temporal_size (`int`):
The temporal dimension of postional embeddings (number of frames). The temporal dimension of positional embeddings (number of frames).
spatial_interpolation_scale (`float`, defaults to 1.0): spatial_interpolation_scale (`float`, defaults to 1.0):
Scale factor for spatial grid interpolation. Scale factor for spatial grid interpolation.
temporal_interpolation_scale (`float`, defaults to 1.0): temporal_interpolation_scale (`float`, defaults to 1.0):
...@@ -169,7 +169,7 @@ def _get_3d_sincos_pos_embed_np( ...@@ -169,7 +169,7 @@ def _get_3d_sincos_pos_embed_np(
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
spatial dimensions (height and width). spatial dimensions (height and width).
temporal_size (`int`): temporal_size (`int`):
The temporal dimension of postional embeddings (number of frames). The temporal dimension of positional embeddings (number of frames).
spatial_interpolation_scale (`float`, defaults to 1.0): spatial_interpolation_scale (`float`, defaults to 1.0):
Scale factor for spatial grid interpolation. Scale factor for spatial grid interpolation.
temporal_interpolation_scale (`float`, defaults to 1.0): temporal_interpolation_scale (`float`, defaults to 1.0):
......
...@@ -30,7 +30,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -30,7 +30,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
_supports_gradient_checkpointing = True _supports_gradient_checkpointing = True
""" """
A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, offical code: A 3D Transformer model for video-like data, paper: https://arxiv.org/abs/2401.03048, official code:
https://github.com/Vchitect/Latte https://github.com/Vchitect/Latte
Parameters: Parameters:
...@@ -216,7 +216,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin): ...@@ -216,7 +216,7 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
) )
num_patches = height * width num_patches = height * width
hidden_states = self.pos_embed(hidden_states) # alrady add positional embeddings hidden_states = self.pos_embed(hidden_states) # already add positional embeddings
added_cond_kwargs = {"resolution": None, "aspect_ratio": None} added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
timestep, embedded_timestep = self.adaln_single( timestep, embedded_timestep = self.adaln_single(
......
...@@ -43,7 +43,7 @@ class LuminaNextDiTBlock(nn.Module): ...@@ -43,7 +43,7 @@ class LuminaNextDiTBlock(nn.Module):
num_kv_heads (`int`): num_kv_heads (`int`):
Number of attention heads in key and value features (if using GQA), or set to None for the same as query. Number of attention heads in key and value features (if using GQA), or set to None for the same as query.
multiple_of (`int`): The number of multiple of ffn layer. multiple_of (`int`): The number of multiple of ffn layer.
ffn_dim_multiplier (`float`): The multipier factor of ffn layer dimension. ffn_dim_multiplier (`float`): The multiplier factor of ffn layer dimension.
norm_eps (`float`): The eps for norm layer. norm_eps (`float`): The eps for norm layer.
qk_norm (`bool`): normalization for query and key. qk_norm (`bool`): normalization for query and key.
cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states. cross_attention_dim (`int`): Cross attention embedding dimension of the input text prompt hidden_states.
......
...@@ -154,7 +154,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -154,7 +154,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
# of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This # of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This
# is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below. # is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below.
# This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it # This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it
# without running proper depcrecation cycles for the {down,mid,up} blocks which are a # without running proper deprecation cycles for the {down,mid,up} blocks which are a
# part of the public API. # part of the public API.
num_attention_heads = attention_head_dim num_attention_heads = attention_head_dim
......
...@@ -131,7 +131,7 @@ class AmusedPipeline(DiffusionPipeline): ...@@ -131,7 +131,7 @@ class AmusedPipeline(DiffusionPipeline):
generation deterministic. generation deterministic.
latents (`torch.IntTensor`, *optional*): latents (`torch.IntTensor`, *optional*):
Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image
gneration. If not provided, the starting latents will be completely masked. generation. If not provided, the starting latents will be completely masked.
prompt_embeds (`torch.Tensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument. A single vector from the provided, text embeddings are generated from the `prompt` input argument. A single vector from the
......
...@@ -373,7 +373,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -373,7 +373,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
*e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from
`negative_prompt` input argument. `negative_prompt` input argument.
generated_prompt_embeds (`torch.Tensor`, *optional*): generated_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs, Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs,
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
argument. argument.
negative_generated_prompt_embeds (`torch.Tensor`, *optional*): negative_generated_prompt_embeds (`torch.Tensor`, *optional*):
...@@ -394,7 +394,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -394,7 +394,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
attention_mask (`torch.LongTensor`): attention_mask (`torch.LongTensor`):
Attention mask to be applied to the `prompt_embeds`. Attention mask to be applied to the `prompt_embeds`.
generated_prompt_embeds (`torch.Tensor`): generated_prompt_embeds (`torch.Tensor`):
Text embeddings generated from the GPT2 langauge model. Text embeddings generated from the GPT2 language model.
Example: Example:
...@@ -904,7 +904,7 @@ class AudioLDM2Pipeline(DiffusionPipeline): ...@@ -904,7 +904,7 @@ class AudioLDM2Pipeline(DiffusionPipeline):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
generated_prompt_embeds (`torch.Tensor`, *optional*): generated_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings from the GPT2 langauge model. Can be used to easily tweak text inputs, Pre-generated text embeddings from the GPT2 language model. Can be used to easily tweak text inputs,
*e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input
argument. argument.
negative_generated_prompt_embeds (`torch.Tensor`, *optional*): negative_generated_prompt_embeds (`torch.Tensor`, *optional*):
......
...@@ -138,7 +138,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): ...@@ -138,7 +138,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
def get_query_embeddings(self, input_image, src_subject): def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = [] rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects): for prompt, tgt_subject in zip(prompts, tgt_subjects):
......
...@@ -149,7 +149,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): ...@@ -149,7 +149,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
def get_query_embeddings(self, input_image, src_subject): def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False) return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it # from the original Blip Diffusion code, specifies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20): def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = [] rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects): for prompt, tgt_subject in zip(prompts, tgt_subjects):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment