curate_list.py 3.84 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import torchvision
from torch.utils.data import DataLoader
import os
import sys
import argparse
from typing import Dict, Optional, Sequence, List
from dataclasses import dataclass, field
import clip
import multiprocessing as mp
from dataloader import *
import threading
import json
import pickle

parser = argparse.ArgumentParser('')
parser.add_argument('--dataset_name', type=str, default="video-dataset", metavar='DN',
                    help='dataset name for finding annotation files')
parser.add_argument('--trace_path', type=str, default="/pickle/file/format", metavar='TP',
                    help='path to a file which is a list containing paths of extracted traces')
parser.add_argument('--clip_score_dir', type=str, default="/path/to/dir/clip_filtered_scores/", metavar='CSD',
                    help='path to directory which contains the clip scores')
parser.add_argument('--output_path', type=str, default="/path/to/output/file", metavar='OVD',
                    help='path to output json file containing list of valid traces')

parser.add_argument('--min_score', type=float, default=0.25, metavar='MS',
                    help='number of frames to use per video')
parser.add_argument('--split_idx', type=int, default=0, metavar='SI',
                    help='index for splitting entire dataset over multiple GPUs')
parser.add_argument('--num_samples_per_segment', type=int, default=10400145, metavar='NS',
                    help='specify number of segments per GPU')
parser.add_argument('--num_workers', type=int, default=8, metavar='NW',
                    help='number of worker processes')
parser.add_argument('--batch_size', type=int, default=128, metavar='BS',
                    help='batch size')
parser.add_argument('--thread_num', type=int, default=72, metavar='TN',
                    help='number of threads')

valid_traces = []
full_set_traces = []

def main():
    global args
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    output_path = args.output_path
    all_traces = list(pickle.load(open(args.trace_path, 'rb'))) # This should be a list that contains the paths to all extracted traces

    print('all_traces: ', len(all_traces))
    print('')

    lock = threading.Lock()

    # Function that writes to the global set
    def add_to_set(tid, split_traces):
        print('split %s: ' % tid, len(split_traces))
        print('')
        for idx, trace in enumerate(split_traces):
            if tid == 0 and idx % 1000 == 0:
                print(idx)

            trace_path = trace[0] + '/' + trace[-1]
            score_path = os.path.join(args.clip_score_dir, trace[0], '%s.pth' % trace[-1])

            try:
                trace_score = torch.load(score_path, map_location='cpu').max()
                if trace_score >= args.min_score:
                    global valid_traces
                    with lock:  # Ensure that only one thread can modify the set at a time
                        valid_traces.append(trace_path)

                global full_set_traces
                with lock:
                    full_set_traces.append(trace_path)
            except:
                continue


    # Create threads
    per_process_video_num = len(all_traces) // args.thread_num

    threads = []
    for i in range(args.thread_num):
        if i == args.thread_num - 1:
            sub_files = all_traces[i * per_process_video_num :]
        else:
            sub_files = all_traces[i * per_process_video_num : (i + 1) * per_process_video_num]

        t = threading.Thread(target=add_to_set, args=(i, sub_files,))
        threads.append(t)
        t.start()

    # Wait for all threads to finish
    for t in threads:
        t.join()

    json.dump(valid_traces, open(output_path, 'w'))
    print('valid_traces: ', len(valid_traces))
    print('full_set_traces: ', len(full_set_traces))
    
if __name__ == "__main__":
    main()