"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "33f936154db0bc7080960316b4ddb291e9555bf7"
Unverified Commit 4c4b323c authored by hlky's avatar hlky Committed by GitHub
Browse files

Use `torch` in `get_3d_rotary_pos_embed`/`_allegro` (#10161)

Use torch in get_3d_rotary_pos_embed/_allegro
parent 22d3a826
...@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings( ...@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings(
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=num_frames,
device=device,
) )
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
......
...@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings( ...@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings(
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=num_frames,
device=device,
) )
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
......
...@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed( ...@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed(
use_real: bool = True, use_real: bool = True,
grid_type: str = "linspace", grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None, max_size: Optional[Tuple[int, int]] = None,
device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
""" """
RoPE for video tokens with 3D structure. RoPE for video tokens with 3D structure.
...@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed( ...@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed(
if grid_type == "linspace": if grid_type == "linspace":
start, stop = crops_coords start, stop = crops_coords
grid_size_h, grid_size_w = grid_size grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) grid_h = torch.linspace(
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
grid_t = np.arange(temporal_size, dtype=np.float32) )
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) grid_w = torch.linspace(
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
)
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
grid_t = torch.linspace(
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
)
elif grid_type == "slice": elif grid_type == "slice":
max_h, max_w = max_size max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32) grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
grid_w = np.arange(max_w, dtype=np.float32) grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
grid_t = np.arange(temporal_size, dtype=np.float32) grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
else: else:
raise ValueError("Invalid value passed for `grid_type`.") raise ValueError("Invalid value passed for `grid_type`.")
...@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed( ...@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed(
dim_w = embed_dim // 8 * 3 dim_w = embed_dim // 8 * 3
# Temporal frequencies # Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True) freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
# Spatial frequencies for height and width # Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True) freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True) freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor # BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w): def combine_time_height_width(freqs_t, freqs_h, freqs_w):
...@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro( ...@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro(
temporal_size, temporal_size,
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0), interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
theta: int = 10000, theta: int = 10000,
device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(aryan): docs # TODO(aryan): docs
start, stop = crops_coords start, stop = crops_coords
grid_size_h, grid_size_w = grid_size grid_size_h, grid_size_w = grid_size
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) grid_t = torch.linspace(
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) 0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) )
grid_h = torch.linspace(
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
)
grid_w = torch.linspace(
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
)
# Compute dimensions for each axis # Compute dimensions for each axis
dim_t = embed_dim // 3 dim_t = embed_dim // 3
......
...@@ -623,20 +623,17 @@ class AllegroPipeline(DiffusionPipeline): ...@@ -623,20 +623,17 @@ class AllegroPipeline(DiffusionPipeline):
self.transformer.config.interpolation_scale_h, self.transformer.config.interpolation_scale_h,
self.transformer.config.interpolation_scale_w, self.transformer.config.interpolation_scale_w,
), ),
device=device,
) )
grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long) grid_t = grid_t.to(dtype=torch.long)
grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long) grid_h = grid_h.to(dtype=torch.long)
grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long) grid_w = grid_w.to(dtype=torch.long)
pos = torch.cartesian_prod(grid_t, grid_h, grid_w) pos = torch.cartesian_prod(grid_t, grid_h, grid_w)
pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous() pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous()
grid_t, grid_h, grid_w = pos grid_t, grid_h, grid_w = pos
freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device))
freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device))
freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device))
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w) return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
@property @property
......
...@@ -459,6 +459,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -459,6 +459,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=num_frames,
device=device,
) )
else: else:
# CogVideoX 1.5 # CogVideoX 1.5
...@@ -471,10 +472,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -471,10 +472,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
temporal_size=base_num_frames, temporal_size=base_num_frames,
grid_type="slice", grid_type="slice",
max_size=(base_size_height, base_size_width), max_size=(base_size_height, base_size_width),
device=device,
) )
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
@property @property
......
...@@ -505,6 +505,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -505,6 +505,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=num_frames,
device=device,
) )
else: else:
# CogVideoX 1.5 # CogVideoX 1.5
...@@ -517,10 +518,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin): ...@@ -517,10 +518,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
temporal_size=base_num_frames, temporal_size=base_num_frames,
grid_type="slice", grid_type="slice",
max_size=(base_size_height, base_size_width), max_size=(base_size_height, base_size_width),
device=device,
) )
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
@property @property
......
...@@ -555,6 +555,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -555,6 +555,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=num_frames,
device=device,
) )
else: else:
# CogVideoX 1.5 # CogVideoX 1.5
...@@ -567,10 +568,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -567,10 +568,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
temporal_size=base_num_frames, temporal_size=base_num_frames,
grid_type="slice", grid_type="slice",
max_size=(base_size_height, base_size_width), max_size=(base_size_height, base_size_width),
device=device,
) )
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
@property @property
......
...@@ -529,6 +529,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -529,6 +529,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
crops_coords=grid_crops_coords, crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width), grid_size=(grid_height, grid_width),
temporal_size=num_frames, temporal_size=num_frames,
device=device,
) )
else: else:
# CogVideoX 1.5 # CogVideoX 1.5
...@@ -541,10 +542,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin) ...@@ -541,10 +542,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
temporal_size=base_num_frames, temporal_size=base_num_frames,
grid_type="slice", grid_type="slice",
max_size=(base_size_height, base_size_width), max_size=(base_size_height, base_size_width),
device=device,
) )
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment