import numpy as np
import argparse
import os

def parse_args():
    parser = argparse.ArgumentParser(description='Compare two embedding files and calculate absolute differences.')
    parser.add_argument('--gpu_embeddings', type=str, required=True, 
                        help='Path to the GPU embeddings file (.npy)')
    parser.add_argument('--dcu_embeddings', type=str, required=True, 
                        help='Path to the DCU embeddings file (.npy)')
    return parser.parse_args()

def main(args):
    script_dir = os.path.dirname(os.path.abspath(__file__))
    
    embeddings_1 = np.load(args.gpu_embeddings)
    embeddings_2 = np.load(args.dcu_embeddings)

    if embeddings_1.shape != embeddings_2.shape:
        raise ValueError("两个嵌入文件的形状不匹配！")

    abs_diff = np.abs(embeddings_1 - embeddings_2)

    mean_abs_diff = np.mean(abs_diff, axis=1)

    print(f"abs_diff:\n{abs_diff}")
    print(f"mean_abs_diff:\n{mean_abs_diff}")

if __name__ == "__main__":
    args = parse_args()
    main(args)