import hipdnn import torch def build_bn_finalize_graph( hipdnn_handle, torch_tensor_sum, torch_tensor_sq_sum, torch_tensor_scale, torch_tensor_bias, torch_tensor_prev_running_mean, torch_tensor_prev_running_variance, torch_tensor_momentum, torch_tensor_epsilon, torch_tensor_accum_count, hipdnn_data_type, ): graph = hipdnn.pygraph( handle=hipdnn_handle, io_data_type=hipdnn_data_type, intermediate_data_type=hipdnn.data_type.FLOAT, compute_data_type=hipdnn.data_type.FLOAT, name="bn_finalize", ) hipdnn_tensor_sum = graph.tensor_like(torch_tensor_sum) hipdnn_tensor_sq_sum = graph.tensor_like(torch_tensor_sq_sum) hipdnn_tensor_scale = graph.tensor_like(torch_tensor_scale) hipdnn_tensor_bias = graph.tensor_like(torch_tensor_bias) hipdnn_tensor_prev_running_mean = graph.tensor_like(torch_tensor_prev_running_mean) hipdnn_tensor_prev_running_variance = graph.tensor_like(torch_tensor_prev_running_variance) hipdnn_tensor_momentum = graph.tensor_like(torch_tensor_momentum) hipdnn_tensor_momentum.set_value(0.001) hipdnn_tensor_epsilon = graph.tensor_like(torch_tensor_epsilon) hipdnn_tensor_epsilon.set_value(1e-5) hipdnn_tensor_accum_count = graph.tensor_like(torch_tensor_accum_count) hipdnn_tensor_accum_count.set_value(torch_tensor_accum_count.item()) ( hipdnn_tensor_eq_scale, hipdnn_tensor_eq_bias, hipdnn_tensor_mean, hipdnn_tensor_inv_variance, hipdnn_tensor_next_running_mean, hipdnn_tensor_next_running_variance, ) = graph.bn_finalize( sum=hipdnn_tensor_sum, sq_sum=hipdnn_tensor_sq_sum, scale=hipdnn_tensor_scale, bias=hipdnn_tensor_bias, epsilon=hipdnn_tensor_epsilon, accum_count=hipdnn_tensor_accum_count, prev_running_mean=hipdnn_tensor_prev_running_mean, prev_running_variance=hipdnn_tensor_prev_running_variance, momentum=hipdnn_tensor_momentum, name="bn_finalize_node", ) hipdnn_tensor_eq_scale.set_output(True) hipdnn_tensor_eq_bias.set_output(True) hipdnn_tensor_mean.set_output(True) hipdnn_tensor_inv_variance.set_output(True) hipdnn_tensor_next_running_mean.set_output(True) hipdnn_tensor_next_running_variance.set_output(True) graph.build(hipdnn_handle) return ( graph, hipdnn_tensor_sum, hipdnn_tensor_sq_sum, hipdnn_tensor_scale, hipdnn_tensor_bias, hipdnn_tensor_prev_running_mean, hipdnn_tensor_prev_running_variance, hipdnn_tensor_momentum, hipdnn_tensor_epsilon, hipdnn_tensor_accum_count, hipdnn_tensor_eq_scale, hipdnn_tensor_eq_bias, hipdnn_tensor_mean, hipdnn_tensor_inv_variance, hipdnn_tensor_next_running_mean, hipdnn_tensor_next_running_variance, ) if __name__ == "__main__": n = 1 c = 32 h = 1 w = 1 hipdnn_data_type = hipdnn.data_type.FLOAT torch_data_type = torch.float32 torch_tensor_sum = torch.rand(n, c, h, w, dtype=torch_data_type, device="cuda") torch_tensor_sq_sum = torch.rand(n, c, h, w, dtype=torch_data_type, device="cuda") torch_tensor_scale = torch.rand(n, c, h, w, dtype=torch_data_type, device="cuda") torch_tensor_bias = torch.rand(n, c, h, w, dtype=torch_data_type, device="cuda") torch_tensor_prev_running_mean = torch.rand(n, c, h, w, dtype=torch_data_type, device="cuda") torch_tensor_prev_running_variance = torch.rand( n, c, h, w, dtype=torch_data_type, device="cuda" ) torch_tensor_momentum = torch.full( (1, 1, 1, 1), 0.001, dtype=torch.float32, requires_grad=False, device="cuda" ) torch_tensor_epsilon = torch.full( (1, 1, 1, 1), 1e-5, dtype=torch.float32, requires_grad=False, device="cuda" ) torch_tensor_accum_count = torch.full( (1, 1, 1, 1), n * h * w, dtype=torch.int32, requires_grad=False, device="cuda" ) hipdnn_handle = hipdnn.create_handle() ( graph, hipdnn_tensor_sum, hipdnn_tensor_sq_sum, hipdnn_tensor_scale, hipdnn_tensor_bias, hipdnn_tensor_prev_running_mean, hipdnn_tensor_prev_running_variance, hipdnn_tensor_momentum, hipdnn_tensor_epsilon, hipdnn_tensor_accum_count, hipdnn_tensor_eq_scale, hipdnn_tensor_eq_bias, hipdnn_tensor_mean, hipdnn_tensor_inv_variance, hipdnn_tensor_next_running_mean, hipdnn_tensor_next_running_variance, ) = build_bn_finalize_graph( hipdnn_handle, torch_tensor_sum, torch_tensor_sq_sum, torch_tensor_scale, torch_tensor_bias, torch_tensor_prev_running_mean, torch_tensor_prev_running_variance, torch_tensor_momentum, torch_tensor_epsilon, torch_tensor_accum_count, hipdnn_data_type, ) torch_tensor_eq_scale = torch.empty( hipdnn_tensor_eq_scale.get_dim(), dtype=torch_data_type, device="cuda" ) torch_tensor_eq_bias = torch.empty( hipdnn_tensor_eq_bias.get_dim(), dtype=torch_data_type, device="cuda" ) torch_tensor_mean = torch.empty( hipdnn_tensor_mean.get_dim(), dtype=torch_data_type, device="cuda" ) torch_tensor_inv_variance = torch.empty( hipdnn_tensor_inv_variance.get_dim(), dtype=torch_data_type, device="cuda" ) torch_tensor_next_running_mean = torch.empty( hipdnn_tensor_next_running_mean.get_dim(), dtype=torch_data_type, device="cuda" ) torch_tensor_next_running_variance = torch.empty( hipdnn_tensor_next_running_variance.get_dim(), dtype=torch_data_type, device="cuda" ) variant_pack = { hipdnn_tensor_sum: torch_tensor_sum.data_ptr(), hipdnn_tensor_sq_sum: torch_tensor_sq_sum.data_ptr(), hipdnn_tensor_scale: torch_tensor_scale.data_ptr(), hipdnn_tensor_bias: torch_tensor_bias.data_ptr(), hipdnn_tensor_prev_running_mean: torch_tensor_prev_running_mean.data_ptr(), hipdnn_tensor_prev_running_variance: torch_tensor_prev_running_variance.data_ptr(), hipdnn_tensor_momentum: torch_tensor_momentum.data_ptr(), hipdnn_tensor_epsilon: torch_tensor_epsilon.data_ptr(), hipdnn_tensor_accum_count: torch_tensor_accum_count.data_ptr(), hipdnn_tensor_eq_scale: torch_tensor_eq_scale.data_ptr(), hipdnn_tensor_eq_bias: torch_tensor_eq_bias.data_ptr(), hipdnn_tensor_mean: torch_tensor_mean.data_ptr(), hipdnn_tensor_inv_variance: torch_tensor_inv_variance.data_ptr(), hipdnn_tensor_next_running_mean: torch_tensor_next_running_mean.data_ptr(), hipdnn_tensor_next_running_variance: torch_tensor_next_running_variance.data_ptr(), } workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") graph.exec(variant_pack=variant_pack, workspace=workspace.data_ptr()) print("Batch normalization finalize graph execution complete.")