def get_args(): import argparse parser = argparse.ArgumentParser(description="Test Operator") parser.add_argument( "--profile", action="store_true", help="Whether profile tests", ) parser.add_argument( "--cpu", action="store_true", help="Run CPU test", ) parser.add_argument( "--nvidia", action="store_true", help="Run NVIDIA GPU test", ) parser.add_argument( "--cambricon", action="store_true", help="Run Cambricon MLU test", ) parser.add_argument( "--ascend", action="store_true", help="Run ASCEND NPU test", ) return parser.parse_args() def synchronize_device(torch_device): import torch if torch_device == "cuda": torch.cuda.synchronize() elif torch_device == "npu": torch.npu.synchronize() elif torch_device == "mlu": torch.mlu.synchronize()