Unverified Commit 0583a8d1 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Make CogVideoX RoPE implementation consistent (#9963)

* update cogvideox rope implementation

* apply suggestions from review
parent 7d0b9c4d
...@@ -444,12 +444,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -444,12 +444,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1 p_t = self.transformer.config.patch_size_t
base_size_width = self.transformer.config.sample_width // p base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
) )
...@@ -457,7 +458,19 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -457,7 +458,19 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames, temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
) )
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
......
...@@ -490,12 +490,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -490,12 +490,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1 p_t = self.transformer.config.patch_size_t
base_size_width = self.transformer.config.sample_width // p base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
) )
...@@ -503,7 +504,19 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -503,7 +504,19 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames, temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
) )
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
......
...@@ -528,6 +528,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -528,6 +528,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
self.transformer.unfuse_qkv_projections() self.transformer.unfuse_qkv_projections()
self.fusing_transformer = False self.fusing_transformer = False
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
def _prepare_rotary_positional_embeddings( def _prepare_rotary_positional_embeddings(
self, self,
height: int, height: int,
...@@ -541,11 +542,11 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -541,11 +542,11 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t p_t = self.transformer.config.patch_size_t
if p_t is None:
# CogVideoX 1.0 I2V
base_size_width = self.transformer.config.sample_width // p base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p base_size_height = self.transformer.config.sample_height // p
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
) )
...@@ -556,9 +557,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -556,9 +557,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
temporal_size=num_frames, temporal_size=num_frames,
) )
else: else:
# CogVideoX 1.5 I2V # CogVideoX 1.5
base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
......
...@@ -520,12 +520,13 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -520,12 +520,13 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size) grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
p = self.transformer.config.patch_size p = self.transformer.config.patch_size
p_t = self.transformer.config.patch_size_t or 1 p_t = self.transformer.config.patch_size_t
base_size_width = self.transformer.config.sample_width // p base_size_width = self.transformer.config.sample_width // p
base_size_height = self.transformer.config.sample_height // p base_size_height = self.transformer.config.sample_height // p
base_num_frames = (num_frames + p_t - 1) // p_t
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid( grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height (grid_height, grid_width), base_size_width, base_size_height
) )
...@@ -533,7 +534,19 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -533,7 +534,19 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
embed_dim=self.transformer.config.attention_head_dim, embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames, temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
) )
freqs_cos = freqs_cos.to(device=device) freqs_cos = freqs_cos.to(device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment