Unverified Commit 0ac52d6f authored by hlky's avatar hlky Committed by GitHub
Browse files

Use `torch` in `get_2d_rotary_pos_embed` (#10155)

* Use `torch` in `get_2d_rotary_pos_embed`

* Add deprecation
parent ba6fd6eb
...@@ -1008,6 +1008,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline): ...@@ -1008,6 +1008,8 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
self.transformer.inner_dim // self.transformer.num_heads, self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords, grid_crops_coords,
(grid_height, grid_width), (grid_height, grid_width),
device=device,
output_type="pt",
) )
style = torch.tensor([0], device=device) style = torch.tensor([0], device=device)
......
...@@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro( ...@@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): def get_2d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
):
"""
RoPE for image tokens with 2d structure.
Args:
embed_dim: (`int`):
The embedding dimension size
crops_coords (`Tuple[int]`)
The top-left and bottom-right coordinates of the crop.
grid_size (`Tuple[int]`):
The grid size of the positional embedding.
use_real (`bool`):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
device: (`torch.device`, **optional**):
The device used to create tensors.
Returns:
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
"""
if output_type == "np":
deprecation_message = (
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
" `from_numpy` is no longer required."
" Pass `output_type='pt' to use the new version now."
)
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
return _get_2d_rotary_pos_embed_np(
embed_dim=embed_dim,
crops_coords=crops_coords,
grid_size=grid_size,
use_real=use_real,
)
start, stop = crops_coords
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
grid_h = torch.linspace(
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
)
grid_w = torch.linspace(
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0) # [2, W, H]
grid = grid.reshape([2, 1, *grid.shape[1:]])
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
return pos_embed
def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
""" """
RoPE for image tokens with 2d structure. RoPE for image tokens with 2d structure.
......
...@@ -925,7 +925,11 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline): ...@@ -925,7 +925,11 @@ class HunyuanDiTControlNetPipeline(DiffusionPipeline):
base_size = 512 // 8 // self.transformer.config.patch_size base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed( image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
) )
style = torch.tensor([0], device=device) style = torch.tensor([0], device=device)
......
...@@ -798,7 +798,11 @@ class HunyuanDiTPipeline(DiffusionPipeline): ...@@ -798,7 +798,11 @@ class HunyuanDiTPipeline(DiffusionPipeline):
base_size = 512 // 8 // self.transformer.config.patch_size base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed( image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
) )
style = torch.tensor([0], device=device) style = torch.tensor([0], device=device)
......
...@@ -818,7 +818,11 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin): ...@@ -818,7 +818,11 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
base_size = 512 // 8 // self.transformer.config.patch_size base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed( image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width) self.transformer.inner_dim // self.transformer.num_heads,
grid_crops_coords,
(grid_height, grid_width),
device=device,
output_type="pt",
) )
style = torch.tensor([0], device=device) style = torch.tensor([0], device=device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment