tile_image.py 3.01 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import cv2
import torch
import numpy as np
from einops import rearrange

def block_image(image, block_size, overlap):
    image = rearrange(image, "C H W -> H W C")
    height, width, _ = image.shape
    block_images = []

    # 计算重叠的像素数
    overlap_pixels = int(block_size * overlap)

    # 逐行遍历图像
    for y in range(0, height, block_size - overlap_pixels):
        for x in range(0, width, block_size - overlap_pixels):
            # 确保块的尺寸一致,填充超出边界的部分
            block = np.zeros((block_size, block_size, 3), dtype=image.dtype)
            y_end = min(y + block_size, height)
            x_end = min(x + block_size, width)
            block[:y_end - y, :x_end - x] = image[y:y_end, x:x_end]
            block = rearrange(block, "H W C -> C H W")
            block_images.append(block)

    return block_images

def combine_blocks(blocks, image_shape, block_size, overlap):
    height, width, _ = image_shape
    overlap_pixels = int(block_size * overlap)
    reconstructed_image = torch.zeros((height, width, 3), dtype=torch.float32).cuda()
    weight_sum = torch.zeros((height, width, 3), dtype=torch.float32).cuda()

    # 生成高斯权重矩阵
    weights = _gaussian_weights(block_size, block_size, 1).squeeze().cpu().numpy()

    idx = 0
    for y in range(0, height, block_size - overlap_pixels):
        for x in range(0, width, block_size - overlap_pixels):
            y_end = min(y + block_size, height)
            x_end = min(x + block_size, width)
            block = torch.tensor(blocks[idx], dtype=torch.float32).cuda()

            # 为块生成相应的权重矩阵
            block = rearrange(block, "C H W -> H W C")
            block_height, block_width = block.shape[:2]
            weight = torch.tensor(weights[:block_height, :block_width], dtype=torch.float32).unsqueeze(-1).cuda()
            weight = weight.expand(-1, -1, 3)  # Expand weight to match the number of channels
            
            # Adjust the dimensions of weight if necessary
            reconstructed_image[y:y_end, x:x_end, :] += block[:y_end - y, :x_end - x] * weight[:y_end - y, :x_end - x]
            weight_sum[y:y_end, x:x_end, :] += weight[:y_end - y, :x_end - x]
            idx += 1

    weight_sum[weight_sum == 0] = 1.0
    reconstructed_image /= weight_sum

    return reconstructed_image

def _gaussian_weights(tile_width, tile_height, nbatches):
    """Generates a gaussian mask of weights for tile contributions"""
    var = 0.01
    midpoint_w = (tile_width - 1) / 2
    x_probs = [np.exp(-(x - midpoint_w) * (x - midpoint_w) / (tile_width * tile_width) / (2 * var)) / np.sqrt(2 * np.pi * var)
               for x in range(tile_width)]
    midpoint_h = (tile_height - 1) / 2
    y_probs = [np.exp(-(y - midpoint_h) * (y - midpoint_h) / (tile_height * tile_height) / (2 * var)) / np.sqrt(2 * np.pi * var)
               for y in range(tile_height)]

    weights = np.outer(y_probs, x_probs)
    return torch.tensor(weights, dtype=torch.float32).unsqueeze(0).unsqueeze(0).repeat(nbatches, 1, 1, 1)