import torch import encode_reg_target_ext import time # ----------------------------- # reference implementation # ----------------------------- def encode_reg_target_ref(box_target, device=None): outputs = [] for box in box_target: output = torch.cat( [ box[..., [X, Y, Z]], box[..., [W, L, H]].log(), torch.sin(box[..., YAW]).unsqueeze(-1), torch.cos(box[..., YAW]).unsqueeze(-1), box[..., YAW + 1:], ], dim=-1, ) if device is not None: output = output.to(device=device) outputs.append(output) return outputs # ----------------------------- # Optimized HIP/C++ implementation # ----------------------------- def encode_reg_target_optimized(box_target_list, device=None): if len(box_target_list) == 0: return [] outputs = [] dev = device if device is not None else box_target_list[0].device box_target_list = [t.to(dev) if t.device != dev else t for t in box_target_list] for box in box_target_list: N, D = box.shape out = encode_reg_target_ext.encode_reg_t(box, N) outputs.append(out) return outputs # ----------------------------- # Main benchmark # ----------------------------- if __name__ == "__main__": # index definition X, Y, Z, W, L, H, SIN_YAW, COS_YAW, VX, VY, VZ = list(range(11)) YAW = 6 device = "cuda:0" box_target_len = 100 N_dim = 900 D_dim = 10 m_dims = [ 40, 12, 25, 111, 45, 84, 19, 14, 52, 13, 20, 19, 28, 28, 6, 62, 26, 56, 13, 33 ] box_target = [ torch.rand(m, D_dim, dtype=torch.float32, device=device) for m in m_dims ] num_warmup = 10 num_runs = 100 # ----------------------------- # Warmup # ----------------------------- for _ in range(num_warmup): encode_reg_target_ref(box_target, device) for _ in range(num_warmup): encode_reg_target_optimized(box_target, device) torch.cuda.synchronize() # ----------------------------- # Measure reference # ----------------------------- start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() for _ in range(num_runs): out_ref = encode_reg_target_ref(box_target, device) end_event.record() torch.cuda.synchronize() elapsed_time_ms_ref = start_event.elapsed_time(end_event) avg_time_ms_ref = elapsed_time_ms_ref / num_runs print(f"reference avg time: {avg_time_ms_ref:.4f} ms") # ----------------------------- # Measure optimized # ----------------------------- start_event.record() for _ in range(num_runs): out_optimized = encode_reg_target_optimized(box_target, device) end_event.record() torch.cuda.synchronize() elapsed_time_ms_opt = start_event.elapsed_time(end_event) avg_time_ms_opt = elapsed_time_ms_opt / num_runs print(f"Optimized kernel avg time: {avg_time_ms_opt:.4f} ms") # ----------------------------- # Accuracy check # ----------------------------- all_close = True max_diff = 0.0 mean_diff = 0.0 for ref, opt in zip(out_ref, out_optimized): diff = torch.abs(ref - opt) if not torch.allclose(ref, opt, atol=1e-5): all_close = False max_diff = max(max_diff, diff.max().item()) mean_diff += diff.mean().item() mean_diff /= len(out_ref) print("Accuracy check:") print("Pass" if all_close else "Fail") print(f"Max diff: {max_diff}") print(f"Mean diff: {mean_diff}") print("Test end.")