main.py 4.62 KB
Newer Older
bailuo's avatar
readme  
bailuo committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def s3_file_exists(s3, bucket: str, key: str) -> bool:
    from botocore.exceptions import ClientError

    try:
        s3.head_object(Bucket=bucket, Key=key)
        return True
    except ClientError as e:
        if e.response["Error"]["Code"] != "404":
            raise e
    return False


def read_s3_file(s3, bucket: str, key: str) -> str:
    resp = s3.get_object(Bucket=bucket, Key=key)
    return resp["Body"].read().decode("utf-8")


def forecast_partition(
    i: int,
    version: str,
    bucket: str,
    prefix: str,
    n_series: int,
    freq: str = "D",
    h: int = 7,
) -> None:
    import logging
    import math
    import time

    import boto3
    import pandas as pd
    from nixtla import NixtlaClient as V2Client
    from nixtlats import NixtlaClient as V1Client
    from tqdm.auto import tqdm
    from utilsforecast.data import generate_series

    s3 = boto3.client("s3")
    # only process if we haven't saved the time
    time_key = f"{prefix}/times/{version}/{i}.txt"
    if s3_file_exists(s3, bucket, time_key):
        print(f"{i}-th partition already processed, skipping.")
        return
    logging.getLogger("nixtla").setLevel(logging.ERROR)
    logging.getLogger("nixtlats").setLevel(logging.ERROR)

    series = generate_series(
        n_series=n_series,
        freq=freq,
        min_length=100,
        max_length=200,
        seed=i,
    )
    series["unique_id"] = series["unique_id"].astype("uint32") + i * n_series

    start = time.perf_counter()
    if version == "v1":
        client = V1Client()
        # v1 is slower when partitioning, so we do this sequentially
        num_partitions = math.ceil(n_series / 50_000)
        uids = series["unique_id"].unique()
        n_ids = uids.size
        ids_per_part = math.ceil(n_ids / num_partitions)
        results = []
        for j in tqdm(range(0, n_ids, ids_per_part)):
            part_uids = uids[j : j + ids_per_part]
            part = series[series["unique_id"].isin(part_uids)]
            results.append(client.forecast(df=part, h=h, freq=freq))
        forecast = pd.concat(results)
    else:
        client = V2Client()
        num_partitions = math.ceil(n_series / 100_000)
        forecast = client.forecast(
            df=series, h=h, freq=freq, num_partitions=num_partitions
        )
    time_taken = "{:.2f}".format(time.perf_counter() - start)
    forecast.to_parquet(
        f"s3://{bucket}/{prefix}/output/{version}/{i}.parquet", index=False
    )
    s3.put_object(Bucket=bucket, Key=time_key, Body=time_taken)
    print(f"{i}: {time_taken}")


def generate_forecasts(
    version: str,
    bucket: str,
    prefix: str,
    n_partitions: int,
    series_per_partition: int,
    n_jobs: int,
) -> None:
    from concurrent.futures import ProcessPoolExecutor
    from functools import partial

    fn = partial(
        forecast_partition,
        version=version,
        bucket=bucket,
        prefix=prefix,
        n_series=series_per_partition,
    )
    with ProcessPoolExecutor(n_jobs) as executor:
        _ = executor.map(fn, range(n_partitions))


def read_times(
    s3, version: str, bucket: str, prefix: str, n_partitions: int
) -> list[str]:
    from concurrent.futures import ThreadPoolExecutor, as_completed

    from tqdm.auto import tqdm

    key = f"{prefix}/{version}_times.txt"
    if s3_file_exists(s3, bucket, key):
        return read_s3_file(s3, bucket, key).splitlines()
    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(
                read_s3_file, s3, bucket, f"{prefix}/times/{version}/{i}.txt"
            )
            for i in range(n_partitions)
        ]
        times = []
        for future in tqdm(as_completed(futures), total=len(futures)):
            times.append(future.result())
    s3.put_object(
        Bucket=bucket,
        Key=key,
        Body="\n".join(times),
    )
    return times


def main(
    bucket: str = "datasets-nixtla",
    prefix: str = "one-billion",
    n_partitions: int = 1_000,
    series_per_partition: int = 1_000_000,
    n_jobs: int = 5,
):
    import boto3
    import pandas as pd

    times = {}
    s3 = boto3.client("s3")
    for version in ("v1", "v2"):
        generate_forecasts(
            version=version,
            bucket=bucket,
            prefix=prefix,
            n_partitions=n_partitions,
            series_per_partition=series_per_partition,
            n_jobs=n_jobs,
        )
        times[version] = read_times(
            s3, version=version, bucket=bucket, prefix=prefix, n_partitions=n_partitions
        )
    pd.DataFrame(times).to_csv(f"s3://{bucket}/{prefix}/times.csv", index=False)


if __name__ == "__main__":
    import fire

    fire.Fire(main)