timesfm.py 20.7 KB
Newer Older
suily's avatar
suily committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TimesFM forecast API for inference."""

import logging
import multiprocessing
from os import path
import time
from typing import Any, Literal, Optional, Sequence

import einshape as es
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from huggingface_hub import snapshot_download
from paxml import checkpoints
from paxml import tasks_lib
from praxis import base_hyperparams
from praxis import base_layer
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis.layers import normalizations
from praxis.layers import transformers
from . import patched_decoder # TODO:报错
from utilsforecast.processing import make_future_dataframe

instantiate = base_hyperparams.instantiate
NestedMap = py_utils.NestedMap
JTensor = pytypes.JTensor


def process_group(key, group, value_name, forecast_context_len):
  group = group.tail(forecast_context_len)
  return np.array(group[value_name], dtype=np.float32), key


def moving_average(arr, window_size):
  """Calculates the moving average using NumPy's convolution function."""
  # Pad with zeros to handle initial window positions
  arr_padded = np.pad(arr, (window_size - 1, 0), "constant")
  smoothed_arr = (
      np.convolve(arr_padded, np.ones(window_size), "valid") / window_size
  )
  return [smoothed_arr, arr - smoothed_arr]


def freq_map(freq: str):
  """Returns the frequency map for the given frequency string."""
  freq = str.upper(freq)
  if (
      freq.endswith("H")
      or freq.endswith("T")
      or freq.endswith("MIN")
      or freq.endswith("D")
      or freq.endswith("B")
      or freq.endswith("U")
  ):
    return 0
  elif freq.endswith(("W", "M", "MS")):
    return 1
  elif freq.endswith("Y") or freq.endswith("Q"):
    return 2
  else:
    raise ValueError(f"Invalid frequency: {freq}")


class TimesFm:
  """TimesFM forecast API for inference.

  This class is the scaffolding for calling TimesFM forecast. To properly use:
    1. Create an instance with the correct hyperparameters of a TimesFM model.
    2. Call `load_from_checkpoint` to load a compatible checkpoint.
    3. Call `forecast` for inference.

  Given the model size, this API does not shard the model weights for SPMD. All
  parallelism happens on the data dimension.

  Compilation happens during the first time `forecast` is called and uses the
  `per_core_batch_size` to set and freeze the input signature. Subsequent calls
  to `forecast` reflect the actual inference latency.

  Attributes:
    per_core_batch_size: Batch size on each core for data parallelism.
    backend: One of "cpu", "gpu" or "tpu".
    num_devices: Number of cores provided the backend.
    global_batch_size: per_core_batch_size * num_devices. Each batch of
      inference task will be padded with respect to global_batch_size to
      minimize latency.
    context_len: Largest context length the model allows for each decode call.
      This technically can be any large, but practically should set to the
      context length the checkpoint was trained with.
    horizon_len: Forecast horizon.
    input_patch_len: Input patch len.
    output_patch_len: Output patch len. How many timepoints is taken from a
      single step of autoregressive decoding. Can be set as the training horizon
      of the checkpoint.
    mesh_shape: Shape of the data parallelism mesh.
    mesh_name: Names of the data parallelism mesh.
    model_p: Configuration of the TimesFM model deduced from the hparams.
  """

  def _logging(self, s):
    if self._verbose:
      print(s)

  def __init__(
      self,
      context_len: int,
      horizon_len: int,
      input_patch_len: int,
      output_patch_len: int,
      num_layers: int,
      model_dims: int,
      per_core_batch_size: int = 32,
      backend: Literal["cpu", "gpu", "tpu"] = "cpu",
      quantiles: Sequence[float] | None = None,
      verbose: bool = True,
  ) -> None:
    """Initializes the TimesFM forecast API.

    Args:
      context_len: Largest context length the model allows for each decode call.
        This technically can be any large, but practically should set to the
        context length the checkpoint was trained with.
      horizon_len: Forecast horizon.
      input_patch_len: Input patch len.
      output_patch_len: Output patch len. How many timepoints is taken from a
        single step of autoregressive decoding. Can be set as the training
        horizon of the checkpoint.
      num_layers: Number of transformer layers.
      model_dims: Model dimension.
      per_core_batch_size: Batch size on each core for data parallelism.
      backend: One of "cpu", "gpu" or "tpu".
      quantiles: list of output quantiles supported by the model.
      verbose: Whether to print logging messages.
    """
    self.per_core_batch_size = per_core_batch_size
    self.backend = backend
    self.num_devices = jax.local_device_count(self.backend)
    self.global_batch_size = self.per_core_batch_size * self.num_devices

    self.context_len = context_len
    self.horizon_len = horizon_len
    self.input_patch_len = input_patch_len
    self.output_patch_len = output_patch_len

    self.mesh_shape = [1, self.num_devices, 1]
    self.mesh_name = ["replica", "data", "mdl"]
    if quantiles is None:
      quantiles = patched_decoder.DEFAULT_QUANTILES

    self.model_p = pax_fiddle.Config(
        patched_decoder.PatchedTimeSeriesDecoder,
        name="patched_decoder",
        horizon_len=self.output_patch_len,
        patch_len=input_patch_len,
        model_dims=model_dims,
        hidden_dims=model_dims,
        residual_block_tpl=pax_fiddle.Config(patched_decoder.ResidualBlock),
        quantiles=quantiles,
        use_freq=True,
        stacked_transformer_params_tpl=pax_fiddle.Config(
            transformers.StackedTransformer,
            num_heads=16,
            num_layers=num_layers,
            transformer_layer_params_tpl=pax_fiddle.Config(
                transformers.Transformer,
                ln_tpl=pax_fiddle.Config(
                    normalizations.RmsNorm,
                ),
            ),
        ),
    )

    self._key1, self._key2 = jax.random.split(jax.random.PRNGKey(42))
    self._model = None
    self._train_state = None
    self._pmapped_decode = None
    self._verbose = verbose
    self._eval_context = base_layer.JaxContext.HParams(do_eval=True)
    try:
      multiprocessing.set_start_method("spawn")
    except RuntimeError:
      print("Multiprocessing context has already been set.")

  def _get_sample_inputs(self):
    return {
        "input_ts": jnp.zeros(
            (
                self.per_core_batch_size,
                self.context_len + self.output_patch_len,
            ),
            dtype=jnp.float32,
        ),
        "input_padding": jnp.zeros(
            (
                self.per_core_batch_size,
                self.context_len + self.output_patch_len,
            ),
            dtype=jnp.float32,
        ),
        "freq": jnp.zeros(
            (
                self.per_core_batch_size,
                1,
            ),
            dtype=jnp.int32,
        ),
    }

  def load_from_checkpoint(   # 加载检查点、初始化模型
      self,
      checkpoint_path: Optional[str] = None,  # 模型地址
      repo_id: str = "google/timesfm-1.0-200m",   # 用于模型未下载时下载模型(没具体操作)
      checkpoint_type: checkpoints.CheckpointType = checkpoints.CheckpointType.FLAX,
      step: int | None = None,
  ) -> None:
    """Loads a checkpoint and compiles the decoder.

    Args:
      checkpoint_path: Optional path to the checkpoint directory.
      repo_id: Hugging Face Hub repo id.
      checkpoint_type: type of PAX checkpoint
      step: step of the checkpoint to load. If `None`, load latest checkpoint.
    """
    # Download the checkpoint from Hugging Face Hub if not given
    if checkpoint_path is None:
      checkpoint_path = path.join(snapshot_download(repo_id), "checkpoints")

    #  Initialize the model weights.
    self._logging("Constructing model weights.")
    start_time = time.time()
    self._model = instantiate(self.model_p)
    var_weight_hparams = self._model.abstract_init_with_metadata(
        self._get_sample_inputs(), do_eval=True
    )
    train_state_partition_specs = tasks_lib.create_state_partition_specs(
        var_weight_hparams,
        mesh_shape=self.mesh_shape,
        mesh_axis_names=self.mesh_name,
        discard_opt_states=True,
        learners=None,
    )
    train_state_local_shapes = tasks_lib.create_state_unpadded_shapes(
        var_weight_hparams,
        discard_opt_states=True,
        learners=None,
    )
    self._logging(
        f"Constructed model weights in {time.time() - start_time:.2f} seconds."
    )
    # Load the model weights.
    self._logging(f"Restoring checkpoint from {checkpoint_path}.")
    start_time = time.time()
    self._train_state = checkpoints.restore_checkpoint(
        train_state_local_shapes,
        checkpoint_dir=checkpoint_path,
        checkpoint_type=checkpoint_type,
        state_specs=train_state_partition_specs,
        step=step,
    )
    self._logging(
        f"Restored checkpoint in {time.time() - start_time:.2f} seconds."
    )

    # Initialize and jit the decode fn.
    def _decode(inputs):
      assert self._model is not None
      assert self._train_state is not None
      return self._model.apply(
          self._train_state.mdl_vars,
          inputs,
          horizon_len=self.horizon_len,
          output_patch_len=self.output_patch_len,
          max_len=self.context_len,
          rngs={
              base_layer.PARAMS: self._key1,
              base_layer.RANDOM: self._key2,
          },
          method=self._model.decode,
      )

    self._logging("Jitting decoding.")
    start_time = time.time()
    self._pmapped_decode = jax.pmap(
        _decode,   # 此处应用了训练好的模型
        axis_name="batch",
        devices=jax.devices(self.backend),
        backend=self.backend,
        axis_size=self.num_devices,
    )
    with base_layer.JaxContext.new_context(hparams=self._eval_context):
      _ = self._pmapped_decode(
          NestedMap({
              "input_ts": jnp.zeros(
                  (
                      self.num_devices,
                      self.per_core_batch_size,
                      self.context_len,
                  ),
                  dtype=jnp.float32,
              ),
              "input_padding": jnp.zeros(
                  (
                      self.num_devices,
                      self.per_core_batch_size,
                      self.context_len + self.horizon_len,
                  ),
                  dtype=jnp.float32,
              ),
              "date_features": None,
              "freq": jnp.zeros(
                  (self.num_devices, self.per_core_batch_size, 1),
                  dtype=jnp.int32,
              ),
          })
      )
    self._logging(f"Jitted decoding in {time.time() - start_time:.2f} seconds.")

  def _preprocess(
      self, inputs: Sequence[np.array], freq: Sequence[int]
  ) -> tuple[np.array, np.array, int]:
    """Formats and pads raw inputs to feed into the model.

    This function both pads each time series to match the context length, and
    pads the inputs to meet the SPMD shape requirement.

    Args:
      inputs: A list of 1d JTensors. Each JTensor is the context time series of
        a single forecast task.
      freq: list of frequencies

    Returns:
    A tuple of:
    - the padded input time series to meet the model required context.
    - the padding indicator.
    - the number of padded examples for SPMD so that each core has the same
        number (a multiple of `batch_size`) of examples.
    """

    input_ts, input_padding, inp_freq = [], [], []

    pmap_pad = (
        (len(inputs) - 1) // self.global_batch_size + 1
    ) * self.global_batch_size - len(inputs)

    for i, ts in enumerate(inputs):
      input_len = ts.shape[0]
      padding = np.zeros(shape=(input_len + self.horizon_len,), dtype=float)
      if input_len < self.context_len:
        num_front_pad = self.context_len - input_len
        ts = np.concatenate(
            [np.zeros(shape=(num_front_pad,), dtype=float), ts], axis=0
        )
        padding = np.concatenate(
            [np.ones(shape=(num_front_pad,), dtype=float), padding], axis=0
        )
      elif input_len > self.context_len:
        ts = ts[-self.context_len :]
        padding = padding[-(self.context_len + self.horizon_len) :]

      input_ts.append(ts)
      input_padding.append(padding)
      inp_freq.append(freq[i])

    # Padding the remainder batch.
    for _ in range(pmap_pad):
      input_ts.append(input_ts[-1])
      input_padding.append(input_padding[-1])
      inp_freq.append(inp_freq[-1])

    return (
        np.stack(input_ts, axis=0),
        np.stack(input_padding, axis=0),
        np.array(inp_freq).astype(np.int32).reshape(-1, 1),
        pmap_pad,
    )

  def forecast(   # inference
      self,
      inputs: Sequence[Any],
      freq: Sequence[int] | None = None,
      window_size: int | None = None,
      forecast_context_len: int | None = None,
  ) -> tuple[JTensor, JTensor]:
    """Forecasts on a list of time series.

    Args:
      inputs: list of time series forecast contexts. Each context time series
        should be in a format convertible to JTensor by `jnp.array`.
      freq: frequency of each context time series. 0 for high frequency
        (default), 1 for medium, and 2 for low. Notice this is different from
        the `freq` required by `forecast_on_df`.
      window_size: window size of trend + residual decomposition. If None then
        we do not do decomposition.
      forecast_context_len: optional max context length.

    Returns:
    A tuple for JTensors:
    - the mean forecast of size (# inputs, # forecast horizon),
    - the full forecast (mean + quantiles) of size
        (# inputs,  # forecast horizon, 1 + # quantiles).

    Raises:
    ValueError: If the checkpoint is not properly loaded.
    """
    if not self._train_state or not self._model:
      raise ValueError(
          "Checkpoint not loaded. Call `load_from_checkpoint` before"
          " `forecast`."
      )
    if forecast_context_len is None:
      forecast_context_len = self.context_len
    inputs = [np.array(ts)[-forecast_context_len:] for ts in inputs]
    inp_min = np.min([np.min(ts) for ts in inputs])

    if window_size is not None:
      new_inputs = []
      for ts in inputs:
        new_inputs.extend(moving_average(ts, window_size))
      inputs = new_inputs

    if freq is None:
      logging.info("No frequency provided via `freq`. Default to high (0).")
      freq = [0] * len(inputs)

    input_ts, input_padding, inp_freq, pmap_pad = self._preprocess(inputs, freq)
    with base_layer.JaxContext.new_context(hparams=self._eval_context):
      mean_outputs = []
      full_outputs = []
      assert input_ts.shape[0] % self.global_batch_size == 0
      for i in range(input_ts.shape[0] // self.global_batch_size):
        input_ts_in = jnp.array(
            input_ts[
                i * self.global_batch_size : (i + 1) * self.global_batch_size
            ]
        )
        input_padding_in = jnp.array(
            input_padding[
                i * self.global_batch_size : (i + 1) * self.global_batch_size
            ],
        )
        inp_freq_in = jnp.array(
            inp_freq[
                i * self.global_batch_size : (i + 1) * self.global_batch_size, :
            ],
            dtype=jnp.int32,
        )
        pmapped_inputs = NestedMap({
            "input_ts": es.jax_einshape(
                "(db)...->db...",
                input_ts_in,
                d=self.num_devices,
            ),
            "input_padding": es.jax_einshape(
                "(db)...->db...",
                input_padding_in,
                d=self.num_devices,
            ),
            "date_features": None,
            "freq": es.jax_einshape(
                "(db)...->db...",
                inp_freq_in,
                d=self.num_devices,
            ),
        })
        mean_output, full_output = self._pmapped_decode(pmapped_inputs)
        mean_output = es.jax_einshape(
            "db...->(db)...", mean_output, d=self.num_devices
        )
        full_output = es.jax_einshape(
            "db...->(db)...", full_output, d=self.num_devices
        )
        mean_output = np.array(mean_output)
        full_output = np.array(full_output)
        mean_outputs.append(mean_output)
        full_outputs.append(full_output)

    mean_outputs = np.concatenate(mean_outputs, axis=0)
    full_outputs = np.concatenate(full_outputs, axis=0)

    if pmap_pad > 0:
      mean_outputs = mean_outputs[:-pmap_pad, ...]
      full_outputs = full_outputs[:-pmap_pad, ...]

    if window_size is not None:
      mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...]
      full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...]
    if inp_min >= 0:
      mean_outputs = np.maximum(mean_outputs, 0.0)
      full_outputs = np.maximum(full_outputs, 0.0)
    return mean_outputs, full_outputs

  def forecast_on_df(
      self,
      inputs: pd.DataFrame,
      freq: str,
      forecast_context_len: int = 0,
      value_name: str = "values",
      model_name: str = "timesfm",
      window_size: int | None = None,
      num_jobs: int = 1,
  ) -> pd.DataFrame:
    """Forecasts on a list of time series.

    Args:
      inputs: A pd.DataFrame of all time series. The dataframe should have a
        `unique_id` column for identifying the time series, a `ds` column for
        timestamps and a value column for the time series values.
      freq: string valued `freq` of data. Notice this is different from the
        `freq` required by `forecast`. See `freq_map` for allowed values.
      forecast_context_len: If provided none zero, we take the last
        `forecast_context_len` time-points from each series as the forecast
        context instead of the `context_len` set by the model.
      value_name: The name of the value column.
      model_name: name of the model to be written into future df.
      window_size: window size of trend + residual decomposition. If None then
        we do not do decomposition.
      num_jobs: number of parallel processes to use for dataframe processing.

    Returns:
      Future forecasts dataframe.
    """
    if not (
        "unique_id" in inputs.columns
        and "ds" in inputs.columns
        and value_name in inputs.columns
    ):
      raise ValueError(
          f"DataFrame must have unique_id, ds and {value_name} columns."
      )
    if not forecast_context_len:
      forecast_context_len = self.context_len
    logging.info("Preprocessing dataframe.")
    df_sorted = inputs.sort_values(by=["unique_id", "ds"])
    new_inputs = []
    uids = []
    if num_jobs == 1:
      print("Processing dataframe with single process.")
      for key, group in df_sorted.groupby("unique_id"):
        inp, uid = process_group(
            key,
            group,
            value_name,
            forecast_context_len,
        )
        new_inputs.append(inp)
        uids.append(uid)
    else:
      if num_jobs == -1:
        num_jobs = multiprocessing.cpu_count()
      print("Processing dataframe with multiple processes.")
      with multiprocessing.Pool(processes=num_jobs) as pool:
        results = pool.starmap(
            process_group,
            [
                (key, group, value_name, forecast_context_len)
                for key, group in df_sorted.groupby("unique_id")
            ],
        )
      new_inputs, uids = zip(*results)
    print("Finished preprocessing dataframe.")
    freq_inps = [freq_map(freq)] * len(new_inputs)
    _, full_forecast = self.forecast(
        new_inputs, freq=freq_inps, window_size=window_size
    )
    print("Finished forecasting.")
    fcst_df = make_future_dataframe(
        uids=uids,
        last_times=df_sorted.groupby("unique_id")["ds"].tail(1),
        h=self.horizon_len,
        freq=freq,
    )
    fcst_df[model_name] = full_forecast[:, 0 : self.horizon_len, 0].reshape(
        -1, 1
    )

    if self._model.quantiles is not None:
      for i, q in enumerate(self._model.quantiles):
        q_col = f"{model_name}-q-{q}"
        fcst_df[q_col] = full_forecast[:, 0 : self.horizon_len, 1 + i].reshape(
            -1, 1
        )
        if q == 0.5:
          fcst_df[model_name] = fcst_df[q_col]
    logging.info("Finished creating output dataframe.")
    return fcst_df