import hipdnn import torch def build_deform_attention_bwd_graph( hipdnn_handle, torch_tensor_value, torch_tensor_spatial_shapes, torch_tensor_level_start_index, torch_tensor_sampling_locations, torch_tensor_attention_weights, torch_tensor_grad_output, hipdnn_data_type, ): # Create graph graph = hipdnn.pygraph( handle=hipdnn_handle, io_data_type=hipdnn_data_type, intermediate_data_type=hipdnn.data_type.FLOAT, compute_data_type=hipdnn.data_type.FLOAT, name="deform_attention_bwd", ) # Create hipdnn tensors hipdnn_tensor_value = graph.tensor_like(torch_tensor_value) hipdnn_tensor_spatial_shapes = graph.tensor_like(torch_tensor_spatial_shapes) hipdnn_tensor_level_start_index = graph.tensor_like(torch_tensor_level_start_index) hipdnn_tensor_sampling_locations = graph.tensor_like(torch_tensor_sampling_locations) hipdnn_tensor_attention_weights = graph.tensor_like(torch_tensor_attention_weights) hipdnn_tensor_grad_output = graph.tensor_like(torch_tensor_grad_output) # Create deform attn op hipdnn_tensor_grad_value, hipdnn_tensor_grad_sampling_loc, hipdnn_tensor_grad_attn_weight = ( graph.deform_attn_dgrad( value=hipdnn_tensor_value, spatial_shapes=hipdnn_tensor_spatial_shapes, level_start_index=hipdnn_tensor_level_start_index, sampling_locations=hipdnn_tensor_sampling_locations, attention_weights=hipdnn_tensor_attention_weights, grad_output=hipdnn_tensor_grad_output, name="deform_attn_dgrad", ) ) hipdnn_tensor_grad_value.set_output(True) hipdnn_tensor_grad_sampling_loc.set_output(True) hipdnn_tensor_grad_attn_weight.set_output(True) graph.build(hipdnn_handle) return ( graph, hipdnn_tensor_value, hipdnn_tensor_spatial_shapes, hipdnn_tensor_level_start_index, hipdnn_tensor_sampling_locations, hipdnn_tensor_attention_weights, hipdnn_tensor_grad_output, hipdnn_tensor_grad_value, hipdnn_tensor_grad_sampling_loc, hipdnn_tensor_grad_attn_weight, ) if __name__ == "__main__": # Input dimensions n = 2 # batch size n_heads = 2 embed_dims_per_head = 32 embed_dims = n_heads * embed_dims_per_head n_levels = 2 n_points = 2 n_queries = 32 spatial_shapes_cpu = torch.randint(low=1, high=16, size=(n_levels, 2), dtype=torch.int64) # calculate n_keys based on spatial_shapes_cpu n_keys = spatial_shapes_cpu.prod(dim=1).sum() # calculate level_start_index based on spatial_shapes_cpu count_per_level = spatial_shapes_cpu.prod(dim=1) level_start_index_cpu = torch.zeros_like(count_per_level) level_start_index_cpu[1:] = torch.cumsum(count_per_level[:-1], dim=0) hipdnn_data_type = hipdnn.data_type.FLOAT torch_data_type = torch.float32 torch_tensor_value = torch.rand( n, n_keys, n_heads, embed_dims_per_head, dtype=torch_data_type, device="cuda" ) torch_tensor_spatial_shapes = spatial_shapes_cpu.to("cuda") torch_tensor_level_start_index = level_start_index_cpu.to("cuda") torch_tensor_sampling_locations = torch.rand( n, n_queries, n_heads, n_levels, n_points, 2, dtype=torch_data_type, device="cuda" ) torch_tensor_attention_weights = torch.rand( n, n_queries, n_heads, n_levels, n_points, dtype=torch_data_type, device="cuda" ) torch_tensor_grad_output = torch.rand( n, n_queries, embed_dims, dtype=torch_data_type, device="cuda" ) hipdnn_handle = hipdnn.create_handle() ( graph, hipdnn_tensor_value, hipdnn_tensor_spatial_shapes, hipdnn_tensor_level_start_index, hipdnn_tensor_sampling_locations, hipdnn_tensor_attention_weights, hipdnn_tensor_grad_output, hipdnn_tensor_grad_value, hipdnn_tensor_grad_sampling_loc, hipdnn_tensor_grad_attn_weight, ) = build_deform_attention_bwd_graph( hipdnn_handle, torch_tensor_value, torch_tensor_spatial_shapes, torch_tensor_level_start_index, torch_tensor_sampling_locations, torch_tensor_attention_weights, torch_tensor_grad_output, hipdnn_data_type, ) torch_tensor_grad_value = torch.empty( hipdnn_tensor_grad_value.get_dim(), dtype=torch_data_type, device="cuda" ) torch_tensor_grad_sampling_loc = torch.empty( hipdnn_tensor_grad_sampling_loc.get_dim(), dtype=torch_data_type, device="cuda" ) torch_tensor_grad_attn_weight = torch.empty( hipdnn_tensor_grad_attn_weight.get_dim(), dtype=torch_data_type, device="cuda" ) variant_pack = { hipdnn_tensor_value: torch_tensor_value.data_ptr(), hipdnn_tensor_spatial_shapes: torch_tensor_spatial_shapes.data_ptr(), hipdnn_tensor_level_start_index: torch_tensor_level_start_index.data_ptr(), hipdnn_tensor_sampling_locations: torch_tensor_sampling_locations.data_ptr(), hipdnn_tensor_attention_weights: torch_tensor_attention_weights.data_ptr(), hipdnn_tensor_grad_output: torch_tensor_grad_output.data_ptr(), hipdnn_tensor_grad_value: torch_tensor_grad_value.data_ptr(), hipdnn_tensor_grad_sampling_loc: torch_tensor_grad_sampling_loc.data_ptr(), hipdnn_tensor_grad_attn_weight: torch_tensor_grad_attn_weight.data_ptr(), } workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") graph.exec(variant_pack=variant_pack, workspace=workspace.data_ptr()) print("Deform attention bwd graph execution complete.")