import hipdnn import torch def build_kthvalue_graph(hipdnn_handle, torch_tensor_input, hipdnn_data_type): graph = hipdnn.pygraph( handle=hipdnn_handle, io_data_type=hipdnn_data_type, intermediate_data_type=hipdnn.data_type.FLOAT, compute_data_type=hipdnn.data_type.FLOAT, name="kthvalue_inference", ) hipdnn_tensor_input = graph.tensor_like(torch_tensor_input) output, indices = graph.kthvalue( input=hipdnn_tensor_input, k=2, dim=1, keep_dim=False, name="kthvalue" ) output.set_output(True).set_dim([4]) indices.set_output(True).set_dim([4]) graph.build(hipdnn_handle) return (graph, hipdnn_tensor_input, output, indices) if __name__ == "__main__": batch, dim = 4, 10 hipdnn_data_type = hipdnn.data_type.FLOAT torch_data_type = torch.float32 torch_tensor_input = torch.rand(batch, dim, dtype=torch_data_type, device="cuda") hipdnn_handle = hipdnn.create_handle() graph, hipdnn_tensor_input, hipdnn_tensor_output, hipdnn_tensor_indices = build_kthvalue_graph( hipdnn_handle, torch_tensor_input, hipdnn_data_type ) torch_tensor_output = torch.empty(batch, dtype=torch_data_type, device="cuda") torch_tensor_indices = torch.empty(batch, dtype=torch.int64, device="cuda") variant_pack = { hipdnn_tensor_input: torch_tensor_input.data_ptr(), hipdnn_tensor_output: torch_tensor_output.data_ptr(), hipdnn_tensor_indices: torch_tensor_indices.data_ptr(), } workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") graph.exec(variant_pack=variant_pack, workspace=workspace.data_ptr()) print("kthvalue graph execution complete.")