Unverified Commit 0583a8d1 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Make CogVideoX RoPE implementation consistent (#9963)

* update cogvideox rope implementation

* apply suggestions from review
parent 7d0b9c4d
...@@ -444,21 +444,34 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -444,21 +444,34 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1 p_t = self.transformer.config.patch_size_t
base_size_width = self.transformer.config.sample_width // p base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid( if p_t is None:
(grid_height, grid_width), base_size_width, base_size_height # CogVideoX 1.0
) grid_crops_coords = get_resize_crop_region_for_grid(
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( (grid_height, grid_width), base_size_width, base_size_height
embed_dim=self.transformer.config.attention_head_dim, )
crops_coords=grid_crops_coords, freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
grid_size=(grid_height, grid_width), embed_dim=self.transformer.config.attention_head_dim,
temporal_size=base_num_frames, crops_coords=grid_crops_coords,
) grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device) freqs_sin = freqs_sin.to(device=device)
......
...@@ -490,21 +490,34 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -490,21 +490,34 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1 p_t = self.transformer.config.patch_size_t
base_size_width = self.transformer.config.sample_width // p base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid( if p_t is None:
(grid_height, grid_width), base_size_width, base_size_height # CogVideoX 1.0
) grid_crops_coords = get_resize_crop_region_for_grid(
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( (grid_height, grid_width), base_size_width, base_size_height
embed_dim=self.transformer.config.attention_head_dim, )
crops_coords=grid_crops_coords, freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
grid_size=(grid_height, grid_width), embed_dim=self.transformer.config.attention_head_dim,
temporal_size=base_num_frames, crops_coords=grid_crops_coords,
) grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device) freqs_sin = freqs_sin.to(device=device)
......
...@@ -528,6 +528,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -528,6 +528,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self.transformer.unfuse_qkv_projections() self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False self.fusing_transformer = False
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
def _prepare_rotary_positional_embeddings( def _prepare_rotary_positional_embeddings(
self, self,
height: int, height: int,
...@@ -541,11 +542,11 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -541,11 +542,11 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t p_t = self.transformer.config.patch_size_t
if p_t is None: base_size_width = self.transformer.config.sample_width // p
# CogVideoX 1.0 I2V base_size_height = self.transformer.config.sample_height // p
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
) )
...@@ -556,9 +557,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -556,9 +557,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
temporal_size=num_frames, temporal_size=num_frames,
) )
else: else:
# CogVideoX 1.5 I2V # CogVideoX 1.5
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
......
...@@ -520,21 +520,34 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -520,21 +520,34 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1 p_t = self.transformer.config.patch_size_t
base_size_width = self.transformer.config.sample_width // p base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
grid_crops_coords = get_resize_crop_region_for_grid( if p_t is None:
(grid_height, grid_width), base_size_width, base_size_height # CogVideoX 1.0
) grid_crops_coords = get_resize_crop_region_for_grid(
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( (grid_height, grid_width), base_size_width, base_size_height
embed_dim=self.transformer.config.attention_head_dim, )
crops_coords=grid_crops_coords, freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
grid_size=(grid_height, grid_width), embed_dim=self.transformer.config.attention_head_dim,
temporal_size=base_num_frames, crops_coords=grid_crops_coords,
) grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device) freqs_sin = freqs_sin.to(device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment