getitem_backward.py 2.85 KB
Newer Older
yanjl1's avatar
Initial  
yanjl1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import hipdnn
import torch


def build_getitem_backward_graph(
    hipdnn_handle, torch_tensor_dy, hipdnn_data_type, dx_dim, index_dims, torch_tensor_indeices
):
    graph = hipdnn.pygraph(
        handle=hipdnn_handle,
        io_data_type=hipdnn_data_type,
        intermediate_data_type=hipdnn.data_type.FLOAT,
        compute_data_type=hipdnn.data_type.FLOAT,
        name="getitem_backward_inference",
    )
    hipdnn_tensor_dy = graph.tensor_like(torch_tensor_dy)
    hipdnn_tensor_indeices = []

    for i in range(len(index_dims)):
        hipdnn_tensor_indeices.append(graph.tensor_like(torch_tensor_indeices[i]))

    dx, error = graph.getitem_backward(
        dy=hipdnn_tensor_dy,
        indices=hipdnn_tensor_indeices,
        dims=index_dims,
        offset=0,
        name="getitem_backward",
    )
    dx.set_output(True).set_dim(dx_dim)
    error.set_output(True)
    graph.build(hipdnn_handle)
    return (graph, hipdnn_tensor_dy, dx, error, hipdnn_tensor_indeices)


if __name__ == "__main__":

    dy_batch = 32
    dy_channel = 16
    dx_batch = 64
    dx_channel = 32
    heigth = 32
    width = 32

    dy_dim = [dy_batch, dy_channel, heigth, width]
    dx_dim = [dx_batch, dx_channel, heigth, width]
    # index dim
    index_dims = [1, 2]

    hipdnn_data_type = hipdnn.data_type.FLOAT
    torch_data_type = torch.float32

    torch_tensor_dy = torch.rand(dy_dim, dtype=torch_data_type, device="cuda")

    torch_tensor_indeices = []
    for i in range(len(index_dims)):
        torch_tensor_indeicesDim1 = torch.randint(
            0,
            dx_dim[index_dims[i]],
            (dy_batch, dy_channel, heigth, width),
            dtype=torch.int32,
            device="cuda",
        )
        torch_tensor_indeices.append(torch_tensor_indeicesDim1)

    hipdnn_handle = hipdnn.create_handle()

    graph, hipdnn_tensor_dy, hipdnn_tensor_dx, hipdnn_tensor_error, hipdnn_tensor_indeices = (
        build_getitem_backward_graph(
            hipdnn_handle,
            torch_tensor_dy,
            hipdnn_data_type,
            dx_dim,
            index_dims,
            torch_tensor_indeices,
        )
    )
    torch_tensor_dx = torch.empty(hipdnn_tensor_dx.get_dim(), dtype=torch_data_type, device="cuda")
    # error tensor must be int32
    torch_tensor_error = torch.empty(len(index_dims), dtype=torch.int32, device="cuda")
    variant_pack = {
        hipdnn_tensor_dy: torch_tensor_dy.data_ptr(),
        hipdnn_tensor_dx: torch_tensor_dx.data_ptr(),
        hipdnn_tensor_indeices[0]: torch_tensor_indeices[0].data_ptr(),
        hipdnn_tensor_indeices[1]: torch_tensor_indeices[1].data_ptr(),
        hipdnn_tensor_error: torch_tensor_error.data_ptr(),
    }
    workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda")
    graph.exec(variant_pack=variant_pack, workspace=workspace.data_ptr())
    print("getitem_backward graph execution complete.")