import hipdnn import torch def build_ctc_loss_graph(hipdnn_handle, torch_tensor_probs, hipdnn_data_type): graph = hipdnn.pygraph( handle=hipdnn_handle, io_data_type=hipdnn_data_type, intermediate_data_type=hipdnn.data_type.FLOAT, compute_data_type=hipdnn.data_type.FLOAT, name="ctc_loss_inference", ) hipdnn_tensor_probs = graph.tensor_like(torch_tensor_probs) losses, gradients = graph.ctc_loss( probs=hipdnn_tensor_probs, blank_label_id=0, apply_softmax=False, algo=0, labels=[1, 2, 3, 4, 2, 3, 2], label_lengths=[1, 2, 1, 3], input_lengths=[4, 100, 100, 200], name="ctc_loss", ) losses.set_output(True) gradients.set_output(True) graph.build(hipdnn_handle) return (graph, hipdnn_tensor_probs, losses, gradients) if __name__ == "__main__": batch, max_time, num_classes = 4, 500, 5 hipdnn_data_type = hipdnn.data_type.FLOAT torch_data_type = torch.float32 torch_tensor_probs = torch.rand( max_time, batch, num_classes, dtype=torch_data_type, device="cuda" ) hipdnn_handle = hipdnn.create_handle() graph, hipdnn_tensor_probs, hipdnn_tensor_losses, hipdnn_tensor_gradients = ( build_ctc_loss_graph(hipdnn_handle, torch_tensor_probs, hipdnn_data_type) ) torch_tensor_losses = torch.empty(batch, dtype=torch_data_type, device="cuda") torch_tensor_gradients = torch.empty( batch, max_time, num_classes, dtype=torch_data_type, device="cuda" ) variant_pack = { hipdnn_tensor_probs: torch_tensor_probs.data_ptr(), hipdnn_tensor_losses: torch_tensor_losses.data_ptr(), hipdnn_tensor_gradients: torch_tensor_gradients.data_ptr(), } workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") graph.exec(variant_pack=variant_pack, workspace=workspace.data_ptr()) print("ctc_loss graph execution complete.")