GPTJ_QSL.py 791 Bytes
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import mlperf_loadgen as lg
from dataset import Dataset


class GPTJ_QSL:
    def __init__(self, dataset_path: str, max_examples: int):
        self.dataset_path = dataset_path
        self.max_examples = max_examples

        # creating data object for QSL
        self.data_object = Dataset(
            self.dataset_path, total_count_override=self.max_examples
        )

        # construct QSL from python binding
        self.qsl = lg.ConstructQSL(
            self.data_object.count,
            self.data_object.perf_count,
            self.data_object.LoadSamplesToRam,
            self.data_object.UnloadSamplesFromRam,
        )

        print("Finished constructing QSL.")


def get_GPTJ_QSL(dataset_path: str, max_examples: int):
    return GPTJ_QSL(dataset_path, max_examples)