import hipdnn import torch def build_transformer_adamw_graph( hipdnn_handle, torch_tensor_params, torch_tensor_grads, torch_tensor_exp_avgs, torch_tensor_exp_avg_sqs, hipdnn_data_type, ): # Create graph graph = hipdnn.pygraph( handle=hipdnn_handle, io_data_type=hipdnn_data_type, intermediate_data_type=hipdnn.data_type.FLOAT, compute_data_type=hipdnn.data_type.FLOAT, name="adamw", ) # Create hipdnn tensors hipdnn_tensor_params = graph.tensor_like(torch_tensor_params) hipdnn_tensor_grads = graph.tensor_like(torch_tensor_grads) hipdnn_tensor_exp_avgs = graph.tensor_like(torch_tensor_exp_avgs) hipdnn_tensor_exp_avg_sqs = graph.tensor_like(torch_tensor_exp_avg_sqs) # Create adamw op graph.adamw( params=hipdnn_tensor_params, grads=hipdnn_tensor_grads, exp_avgs=hipdnn_tensor_exp_avgs, exp_avg_sqs=hipdnn_tensor_exp_avg_sqs, is_transformeradamw=True, ) graph.build(hipdnn_handle) return ( graph, hipdnn_tensor_params, hipdnn_tensor_grads, hipdnn_tensor_exp_avgs, hipdnn_tensor_exp_avg_sqs, ) if __name__ == "__main__": # Input dimensions batch, channels, height, width = 1, 2, 3, 4 hipdnn_data_type = hipdnn.data_type.FLOAT torch_data_type = torch.float32 torch_tensor_params = torch.rand( batch, channels, height, width, dtype=torch_data_type, device="cuda" ) torch_tensor_grads = torch.rand( batch, channels, height, width, dtype=torch_data_type, device="cuda" ) torch_tensor_exp_avgs = torch.rand( batch, channels, height, width, dtype=torch_data_type, device="cuda" ) torch_tensor_exp_avg_sqs = torch.rand( batch, channels, height, width, dtype=torch_data_type, device="cuda" ) hipdnn_handle = hipdnn.create_handle() ( graph, hipdnn_tensor_params, hipdnn_tensor_grads, hipdnn_tensor_exp_avgs, hipdnn_tensor_exp_avg_sqs, ) = build_transformer_adamw_graph( hipdnn_handle, torch_tensor_params, torch_tensor_grads, torch_tensor_exp_avgs, torch_tensor_exp_avg_sqs, hipdnn_data_type, ) variant_pack = { hipdnn_tensor_params: torch_tensor_params.data_ptr(), hipdnn_tensor_grads: torch_tensor_grads.data_ptr(), hipdnn_tensor_exp_avgs: torch_tensor_exp_avgs.data_ptr(), hipdnn_tensor_exp_avg_sqs: torch_tensor_exp_avg_sqs.data_ptr(), } workspace = torch.empty(graph.get_workspace_size(), dtype=torch.uint8, device="cuda") graph.exec(variant_pack=variant_pack, workspace=workspace.data_ptr()) print("Transformer adamw graph execution complete.")