deform_attention.py 4.23 KB
Newer Older
yanjl1's avatar
Initial  
yanjl1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import hipdnn
import torch


def build_deform_attention_graph(
    hipdnn_handle,
    torch_tensor_value,
    torch_tensor_spatial_shapes,
    torch_tensor_level_start_index,
    torch_tensor_sampling_locations,
    torch_tensor_attention_weights,
    hipdnn_data_type,
):
    # Create graph
    graph = hipdnn.pygraph(
        handle=hipdnn_handle,
        io_data_type=hipdnn_data_type,
        intermediate_data_type=hipdnn.data_type.FLOAT,
        compute_data_type=hipdnn.data_type.FLOAT,
        name="deform_attention",
    )

    # Create hipdnn tensors
    hipdnn_tensor_value = graph.tensor_like(torch_tensor_value)
    hipdnn_tensor_spatial_shapes = graph.tensor_like(torch_tensor_spatial_shapes)
    hipdnn_tensor_level_start_index = graph.tensor_like(torch_tensor_level_start_index)
    hipdnn_tensor_sampling_locations = graph.tensor_like(torch_tensor_sampling_locations)
    hipdnn_tensor_attention_weights = graph.tensor_like(torch_tensor_attention_weights)

    # Create deform attn op
    hipdnn_tensor_y = graph.deform_attn_fprop(
        value=hipdnn_tensor_value,
        spatial_shapes=hipdnn_tensor_spatial_shapes,
        level_start_index=hipdnn_tensor_level_start_index,
        sampling_locations=hipdnn_tensor_sampling_locations,
        attention_weights=hipdnn_tensor_attention_weights,
        name="deform_attn_fprop",
    )
    hipdnn_tensor_y.set_output(True)
    graph.build(hipdnn_handle)

    return (
        graph,
        hipdnn_tensor_value,
        hipdnn_tensor_spatial_shapes,
        hipdnn_tensor_level_start_index,
        hipdnn_tensor_sampling_locations,
        hipdnn_tensor_attention_weights,
        hipdnn_tensor_y,
    )


if __name__ == "__main__":
    # Input dimensions
    n = 2  # batch size
    n_heads = 2
    embed_dims_per_head = 32
    embed_dims = n_heads * embed_dims_per_head
    n_levels = 2
    n_points = 2
    n_queries = 32

    spatial_shapes_cpu = torch.randint(low=1, high=16, size=(n_levels, 2), dtype=torch.int64)
    # calculate n_keys based on spatial_shapes_cpu
    n_keys = spatial_shapes_cpu.prod(dim=1).sum()
    # calculate level_start_index based on spatial_shapes_cpu
    count_per_level = spatial_shapes_cpu.prod(dim=1)
    level_start_index_cpu = torch.zeros_like(count_per_level)
    level_start_index_cpu[1:] = torch.cumsum(count_per_level[:-1], dim=0)

    hipdnn_data_type = hipdnn.data_type.FLOAT
    torch_data_type = torch.float32

    torch_tensor_value = torch.rand(
        n, n_keys, n_heads, embed_dims_per_head, dtype=torch_data_type, device="cuda"
    )
    torch_tensor_spatial_shapes = spatial_shapes_cpu.to("cuda")
    torch_tensor_level_start_index = level_start_index_cpu.to("cuda")
    torch_tensor_sampling_locations = torch.rand(
        n, n_queries, n_heads, n_levels, n_points, 2, dtype=torch_data_type, device="cuda"
    )
    torch_tensor_attention_weights = torch.rand(
        n, n_queries, n_heads, n_levels, n_points, dtype=torch_data_type, device="cuda"
    )

    hipdnn_handle = hipdnn.create_handle()

    (
        graph,
        hipdnn_tensor_value,
        hipdnn_tensor_spatial_shapes,
        hipdnn_tensor_level_start_index,
        hipdnn_tensor_sampling_locations,
        hipdnn_tensor_attention_weights,
        hipdnn_tensor_y,
    ) = build_deform_attention_graph(
        hipdnn_handle,
        torch_tensor_value,
        torch_tensor_spatial_shapes,
        torch_tensor_level_start_index,
        torch_tensor_sampling_locations,
        torch_tensor_attention_weights,
        hipdnn_data_type,
    )

    torch_tensor_y = torch.empty(hipdnn_tensor_y.get_dim(), dtype=torch_data_type, device="cuda")
    variant_pack = {
        hipdnn_tensor_value: torch_tensor_value.data_ptr(),
        hipdnn_tensor_spatial_shapes: torch_tensor_spatial_shapes.data_ptr(),
        hipdnn_tensor_level_start_index: torch_tensor_level_start_index.data_ptr(),
        hipdnn_tensor_sampling_locations: torch_tensor_sampling_locations.data_ptr(),
        hipdnn_tensor_attention_weights: torch_tensor_attention_weights.data_ptr(),
        hipdnn_tensor_y: torch_tensor_y.data_ptr(),
    }
    workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda")

    graph.exec(variant_pack=variant_pack, workspace=workspace.data_ptr())
    print("Deform attention graph execution complete.")