Unverified Commit 4c4b323c authored by hlky's avatar hlky Committed by GitHub
Browse files

Use `torch` in `get_3d_rotary_pos_embed`/`_allegro` (#10161)

Use torch in get_3d_rotary_pos_embed/_allegro
parent 22d3a826
......@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings(
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=device,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
......
......@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings(
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=device,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
......
......@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed(
use_real: bool = True,
grid_type: str = "linspace",
max_size: Optional[Tuple[int, int]] = None,
device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
RoPE for video tokens with 3D structure.
......@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed(
if grid_type == "linspace":
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
grid_h = torch.linspace(
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
)
grid_w = torch.linspace(
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
)
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
grid_t = torch.linspace(
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
)
elif grid_type == "slice":
max_h, max_w = max_size
grid_size_h, grid_size_w = grid_size
grid_h = np.arange(max_h, dtype=np.float32)
grid_w = np.arange(max_w, dtype=np.float32)
grid_t = np.arange(temporal_size, dtype=np.float32)
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
else:
raise ValueError("Invalid value passed for `grid_type`.")
......@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed(
dim_w = embed_dim // 8 * 3
# Temporal frequencies
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
# Spatial frequencies for height and width
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
......@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro(
temporal_size,
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
theta: int = 10000,
device: Optional[torch.device] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# TODO(aryan): docs
start, stop = crops_coords
grid_size_h, grid_size_w = grid_size
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
grid_t = torch.linspace(
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
)
grid_h = torch.linspace(
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
)
grid_w = torch.linspace(
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
)
# Compute dimensions for each axis
dim_t = embed_dim // 3
......
......@@ -623,20 +623,17 @@ class AllegroPipeline(DiffusionPipeline):
self.transformer.config.interpolation_scale_h,
self.transformer.config.interpolation_scale_w,
),
device=device,
)
grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long)
grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long)
grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long)
grid_t = grid_t.to(dtype=torch.long)
grid_h = grid_h.to(dtype=torch.long)
grid_w = grid_w.to(dtype=torch.long)
pos = torch.cartesian_prod(grid_t, grid_h, grid_w)
pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous()
grid_t, grid_h, grid_w = pos
freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device))
freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device))
freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device))
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
@property
......
......@@ -459,6 +459,7 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=device,
)
else:
# CogVideoX 1.5
......@@ -471,10 +472,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
device=device,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
......
......@@ -505,6 +505,7 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=device,
)
else:
# CogVideoX 1.5
......@@ -517,10 +518,9 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
device=device,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
......
......@@ -555,6 +555,7 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=device,
)
else:
# CogVideoX 1.5
......@@ -567,10 +568,9 @@ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
device=device,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
......
......@@ -529,6 +529,7 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=device,
)
else:
# CogVideoX 1.5
......@@ -541,10 +542,9 @@ class CogVideoXVideoToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin)
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
device=device,
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
@property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment