import hipdnn import torch def build_getitem_backward_graph( hipdnn_handle, torch_tensor_dy, hipdnn_data_type, dx_dim, index_dims, torch_tensor_indeices ): graph = hipdnn.pygraph( handle=hipdnn_handle, io_data_type=hipdnn_data_type, intermediate_data_type=hipdnn.data_type.FLOAT, compute_data_type=hipdnn.data_type.FLOAT, name="getitem_backward_inference", ) hipdnn_tensor_dy = graph.tensor_like(torch_tensor_dy) hipdnn_tensor_indeices = [] for i in range(len(index_dims)): hipdnn_tensor_indeices.append(graph.tensor_like(torch_tensor_indeices[i])) dx, error = graph.getitem_backward( dy=hipdnn_tensor_dy, indices=hipdnn_tensor_indeices, dims=index_dims, offset=0, name="getitem_backward", ) dx.set_output(True).set_dim(dx_dim) error.set_output(True) graph.build(hipdnn_handle) return (graph, hipdnn_tensor_dy, dx, error, hipdnn_tensor_indeices) if __name__ == "__main__": dy_batch = 32 dy_channel = 16 dx_batch = 64 dx_channel = 32 heigth = 32 width = 32 dy_dim = [dy_batch, dy_channel, heigth, width] dx_dim = [dx_batch, dx_channel, heigth, width] # index dim index_dims = [1, 2] hipdnn_data_type = hipdnn.data_type.FLOAT torch_data_type = torch.float32 torch_tensor_dy = torch.rand(dy_dim, dtype=torch_data_type, device="cuda") torch_tensor_indeices = [] for i in range(len(index_dims)): torch_tensor_indeicesDim1 = torch.randint( 0, dx_dim[index_dims[i]], (dy_batch, dy_channel, heigth, width), dtype=torch.int32, device="cuda", ) torch_tensor_indeices.append(torch_tensor_indeicesDim1) hipdnn_handle = hipdnn.create_handle() graph, hipdnn_tensor_dy, hipdnn_tensor_dx, hipdnn_tensor_error, hipdnn_tensor_indeices = ( build_getitem_backward_graph( hipdnn_handle, torch_tensor_dy, hipdnn_data_type, dx_dim, index_dims, torch_tensor_indeices, ) ) torch_tensor_dx = torch.empty(hipdnn_tensor_dx.get_dim(), dtype=torch_data_type, device="cuda") # error tensor must be int32 torch_tensor_error = torch.empty(len(index_dims), dtype=torch.int32, device="cuda") variant_pack = { hipdnn_tensor_dy: torch_tensor_dy.data_ptr(), hipdnn_tensor_dx: torch_tensor_dx.data_ptr(), hipdnn_tensor_indeices[0]: torch_tensor_indeices[0].data_ptr(), hipdnn_tensor_indeices[1]: torch_tensor_indeices[1].data_ptr(), hipdnn_tensor_error: torch_tensor_error.data_ptr(), } workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") graph.exec(variant_pack=variant_pack, workspace=workspace.data_ptr()) print("getitem_backward graph execution complete.")