patched_decoder.py 14.9 KB
Newer Older
suily's avatar
suily committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pax ML model for patched time-series decoder.

The file implements Residual MLPs, Patched Decoder layers and PAX ML models.
"""

import dataclasses
from typing import Optional, Tuple

import einshape as es
from jax import lax
import jax.numpy as jnp
from praxis import base_layer
from praxis import layers
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis.layers import activations
from praxis.layers import embedding_softmax
from praxis.layers import linears
from praxis.layers import normalizations
from praxis.layers import stochastics
from praxis.layers import transformers


# PAX shortcuts
NestedMap = py_utils.NestedMap
JTensor = pytypes.JTensor

LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
template_field = base_layer.template_field


PAD_VAL = 1123581321.0
DEFAULT_QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

# NestedMap keys
_INPUT_TS = "input_ts"
_INPUT_PADDING = "input_padding"
_OUTPUT_TS = "output_ts"
_FREQ = "freq"
_OUTPUT_TOKENS = "output_tokens"
_STATS = "stats"


# Small numerical value.
_TOLERANCE = 1e-7


def _shift_padded_seq(mask: JTensor, seq: JTensor) -> JTensor:
  """Shifts rows of seq based on the first 0 in each row of the mask."""
  num = seq.shape[1]

  # Find the index of the first 0 in each row of the mask
  first_zero_idx = jnp.argmin(mask, axis=1)

  # Create a range array for indexing
  idx_range = jnp.arange(num)

  def shift_row(carry, x):
    seq_row, shift = x
    shifted_idx = (idx_range - shift) % num
    shifted_row = seq_row[shifted_idx]
    return carry, shifted_row

  # Use lax.scan to shift each row of seq based on the corresponding
  # first_zero_idx.
  _, shifted_seq = lax.scan(shift_row, None, (seq, first_zero_idx))

  return shifted_seq


class ResidualBlock(base_layer.BaseLayer):
  """Simple feedforward block with residual connection.

  Attributes:
    input_dims: input dimension.
    hidden_dims: hidden dimension.
    output_dims: output dimension.
    dropout_prob: dropout probability.
    layer_norm: whether to use layer norm or not.
    dropout_tpl: config for dropout.
    ln_tpl: config for layer norm.
    act_tpl: config for activation in hidden layer.
  """

  input_dims: int = 0
  hidden_dims: int = 0
  output_dims: int = 0
  dropout_prob: float = 0.0
  layer_norm: bool = False
  dropout_tpl: LayerTpl = template_field(stochastics.Dropout)
  ln_tpl: LayerTpl = template_field(normalizations.LayerNorm)
  act_tpl: LayerTpl = template_field(activations.Swish)

  def setup(self):
    lnorm_tpl = self.ln_tpl.clone()
    lnorm_tpl.dim = self.output_dims
    self.create_child("ln_layer", lnorm_tpl)

    dropout_tpl = self.dropout_tpl.clone()
    dropout_tpl.keep_prob = 1.0 - self.dropout_prob
    self.create_child("dropout", dropout_tpl)

    self.create_child(
        "hidden_layer",
        pax_fiddle.Config(
            linears.FeedForward,
            input_dims=self.input_dims,
            output_dims=self.hidden_dims,
            activation_tpl=self.act_tpl.clone(),
        ),
    )

    self.create_child(
        "output_layer",
        pax_fiddle.Config(
            linears.FeedForward,
            input_dims=self.hidden_dims,
            output_dims=self.output_dims,
            activation_tpl=pax_fiddle.Config(activations.Identity),
        ),
    )

    self.create_child(
        "residual_layer",
        pax_fiddle.Config(
            linears.FeedForward,
            input_dims=self.input_dims,
            output_dims=self.output_dims,
            activation_tpl=pax_fiddle.Config(activations.Identity),
        ),
    )

  def __call__(self, inputs: JTensor) -> JTensor:
    hidden = self.hidden_layer(inputs)
    output = self.output_layer(hidden)
    output = self.dropout(output)
    residual = self.residual_layer(inputs)
    if self.layer_norm:
      return self.ln_layer(output + residual)
    else:
      return output + residual


def _masked_mean_std(
    inputs: JTensor, padding: JTensor
) -> Tuple[JTensor, JTensor]:
  """Calculates mean and standard deviation of arr across axis 1.

  It should exclude values where pad is 1.

  Args:
    inputs: A JAX array of shape [b, n, p].
    padding: A JAX array of shape [b, n, p] with values 0 or 1.

  Returns:
    A tuple containing the mean and standard deviation of arr. We return the
    statistics of the first patch with more than three non-padded values.
  """
  # Selecting the first pad with more than 3 unpadded values.
  pad_sum = jnp.sum(1 - padding, axis=2)

  def _get_patch_index(arr: JTensor):
    indices = jnp.argmax(arr >= 3, axis=1)
    row_sum = (arr >= 3).sum(axis=1)
    return jnp.where(row_sum == 0, arr.shape[1] - 1, indices)

  patch_indices = _get_patch_index(pad_sum)
  bidxs = jnp.arange(inputs.shape[0])

  arr = inputs[bidxs, patch_indices, :]
  pad = padding[bidxs, patch_indices, :]

  # Create a mask where P is 0
  mask = 1 - pad

  # Calculate the number of valid elements
  num_valid_elements = jnp.sum(mask, axis=1)

  num_valid_elements = jnp.where(num_valid_elements == 0, 1, num_valid_elements)

  # Calculate the masked sum and squared sum of M
  masked_sum = jnp.sum(arr * mask, axis=1)
  masked_squared_sum = jnp.sum((arr * mask) ** 2, axis=1)

  # Calculate the masked mean and standard deviation
  masked_mean = masked_sum / num_valid_elements
  masked_var = masked_squared_sum / num_valid_elements - masked_mean**2
  masked_var = jnp.where(masked_var < 0.0, 0.0, masked_var)
  masked_std = jnp.sqrt(masked_var)

  return masked_mean, masked_std


def _create_quantiles() -> list[float]:
  """Returns the quantiles for forecasting."""
  return DEFAULT_QUANTILES


class PatchedTimeSeriesDecoder(base_layer.BaseLayer):
  """Patch decoder layer for time-series foundation model.

  Attributes:
    patch_len: length of input patches.
    horizon_len: length of output patches. Referred to as `output_patch_len`
      during inference.
    model_dims: model dimension of stacked transformer layer.
    hidden_dims: hidden dimensions in fully connected layers.
    quantiles: list of quantiles for non prob model.
    residual_block_tpl: config for residual block.
    stacked_transformer_params_tpl: config for stacked transformer.
    use_freq: whether to use frequency encoding.

  In all of what followed, except specified otherwise, B is batch size, T is
  sequence length of time-series. N is the number of input patches that can be
  obtained from T. P is the input patch length and H is the horizon length. Q is
  number of output logits. D is model dimension.
  """

  patch_len: int = 0
  horizon_len: int = 0
  model_dims: int = 0
  hidden_dims: int = 0
  quantiles: list[float] = dataclasses.field(default_factory=_create_quantiles)
  residual_block_tpl: LayerTpl = template_field(ResidualBlock)
  stacked_transformer_params_tpl: LayerTpl = template_field(
      transformers.StackedTransformer
  )
  use_freq: bool = True

  def setup(self) -> None:
    """Construct the model."""
    num_outputs = len(self.quantiles) + 1

    stl = self.stacked_transformer_params_tpl.clone()
    stl.model_dims = self.model_dims
    stl.hidden_dims = self.hidden_dims
    stl.mask_self_attention = True

    self.create_child("stacked_transformer_layer", stl)

    input_resl = self.residual_block_tpl.clone()
    ff_in_dims = 2 * self.patch_len
    input_resl.input_dims = ff_in_dims
    input_resl.hidden_dims = self.hidden_dims
    input_resl.output_dims = self.model_dims
    self.create_child(
        "input_ff_layer",
        input_resl,
    )

    horizon_resl = self.residual_block_tpl.clone()
    horizon_resl.input_dims = self.model_dims
    horizon_resl.hidden_dims = self.hidden_dims
    horizon_resl.output_dims = self.horizon_len * num_outputs
    self.create_child(
        "horizon_ff_layer",
        horizon_resl,
    )

    self.create_child(
        "position_emb",
        pax_fiddle.Config(
            layers.PositionalEmbedding, embedding_dims=self.model_dims
        ),
    )

    if self.use_freq:
      self.create_child(
          "freq_emb",
          pax_fiddle.Config(
              embedding_softmax.Embedding,
              num_classes=3,
              input_dims=self.model_dims,
          ),
      )

  def transform_decode_state(
      self, transform_fn: base_layer.DecodeStateTransformFn
  ) -> None:
    """Transforms all decode state variables based on transform_fn."""
    self.stacked_transformer_layer.transform_decode_state(transform_fn)

  def _forward_transform(
      self, inputs: JTensor, patched_pads: JTensor
  ) -> Tuple[JTensor, Tuple[JTensor, JTensor]]:
    """Input is of shape [B, N, P]."""
    mu, sigma = _masked_mean_std(inputs, patched_pads)
    sigma = jnp.where(sigma < _TOLERANCE, 1.0, sigma)
    # Normalize each patch.
    outputs = (inputs - mu[:, None, None]) / sigma[:, None, None]
    outputs = jnp.where(
        jnp.abs(inputs - PAD_VAL) < _TOLERANCE, PAD_VAL, outputs
    )
    return outputs, (mu, sigma)

  def _reverse_transform(
      self, outputs: JTensor, stats: Tuple[JTensor, JTensor]
  ) -> JTensor:
    """Output is of shape [B, N, P, Q]."""
    mu, sigma = stats
    return outputs * sigma[:, None, None, None] + mu[:, None, None, None]

  def _preprocess_input(
      self,
      input_ts: JTensor,
      input_padding: JTensor,
      pos_emb: Optional[JTensor] = None,
  ) -> Tuple[JTensor, JTensor, Optional[Tuple[JTensor, JTensor]], JTensor]:
    """Preprocess input for stacked transformer."""
    # Reshape into patches.
    patched_inputs = es.jax_einshape("b(np)->bnp", input_ts, p=self.patch_len)
    input_padding = jnp.where(
        jnp.abs(input_ts - PAD_VAL) < _TOLERANCE, 1, input_padding
    )
    patched_pads = es.jax_einshape(
        "b(np)->bnp", input_padding, p=self.patch_len
    )
    patched_inputs, stats = self._forward_transform(
        patched_inputs, patched_pads
    )
    # B x N x D
    patched_inputs = patched_inputs * (1.0 - patched_pads)
    concat_inputs = jnp.concatenate([patched_inputs, patched_pads], axis=-1)
    model_input = self.input_ff_layer(concat_inputs)
    # A patch should not be padded even if there is at least one zero.
    patched_padding = jnp.min(patched_pads, axis=-1)

    if pos_emb is None:
      position_emb = self.position_emb(seq_length=model_input.shape[1])
    else:
      position_emb = pos_emb
    if self.do_eval:
      if position_emb.shape[0] != model_input.shape[0]:
        position_emb = jnp.repeat(position_emb, model_input.shape[0], axis=0)
      position_emb = _shift_padded_seq(patched_padding, position_emb)
    model_input += position_emb

    return model_input, patched_padding, stats, patched_inputs

  def _postprocess_output(
      self,
      model_output: JTensor,
      num_outputs: int,
      stats: Tuple[JTensor, JTensor],
  ) -> JTensor:
    """Postprocess output of stacked transformer."""
    # B x N x (H.Q)
    output_ts = self.horizon_ff_layer(model_output)
    output_ts = es.jax_einshape(
        "bn(hq)->bnhq", output_ts, q=num_outputs, h=self.horizon_len
    )
    return self._reverse_transform(output_ts, stats)

  def __call__(self, inputs: NestedMap) -> NestedMap:
    """PatchTST call.

    Args:
      inputs: A NestedMap containing (1) input_ts: input sequence of shape [B,
        T] where T must be multiple of patch_length; (2) input_padding: that
        contains padding map.

    Returns:
      A nested map with two keys:
      (1) 'output_tokens' of shape [B, N, D].
      (2) 'output_ts' of shape [B, N, H, Q]
      (3) 'stats' a Tuple of statistics for renormalization.
    """
    input_ts, input_padding = inputs[_INPUT_TS], inputs[_INPUT_PADDING]
    num_outputs = len(self.quantiles) + 1
    model_input, patched_padding, stats, _ = self._preprocess_input(
        input_ts=input_ts,
        input_padding=input_padding,
    )
    if self.use_freq:
      freq = inputs[_FREQ].astype(jnp.int32)
      f_emb = self.freq_emb(freq)  # B x 1 x D
      f_emb = jnp.repeat(f_emb, model_input.shape[1], axis=1)
      model_input += f_emb
    model_output = self.stacked_transformer_layer(model_input, patched_padding)

    output_ts = self._postprocess_output(model_output, num_outputs, stats)
    return NestedMap(
        {_OUTPUT_TOKENS: model_output, _OUTPUT_TS: output_ts, _STATS: stats}
    )

  def decode(
      self,
      inputs: NestedMap,
      horizon_len: int,
      output_patch_len: Optional[int] = None,
      max_len: int = 512,
  ) -> tuple[JTensor, JTensor]:
    """Auto-regressive decoding without caching.

    Args:
      inputs: input time-series and paddings. Time-series shape B x C, padding
        shape shape B x (C + H) where H is the prediction length.
      horizon_len: prediction length.
      output_patch_len: output length to be fetched from one step of
        auto-regressive decoding.
      max_len: maximum training context length.

    Returns:
      Tuple of two forecasting results:
      - Point (mean) output predictions as a tensor with shape B x H.
      - Full predictions (mean and quantiles) as a tensor with shape
        B x H x (1 + # quantiles).
    """
    final_out = inputs[_INPUT_TS]
    inp_time_len = final_out.shape[1]
    paddings = inputs[_INPUT_PADDING]
    if self.use_freq:
      freq = inputs[_FREQ].astype(jnp.int32)
    else:
      freq = jnp.zeros([final_out.shape[0], 1], dtype=jnp.int32)
    full_outputs = []
    if paddings.shape[1] != final_out.shape[1] + horizon_len:
      raise ValueError(
          "Length of paddings must match length of input + horizon_len:"
          f" {paddings.shape[1]} != {final_out.shape[1]} + {horizon_len}"
      )
    if output_patch_len is None:
      output_patch_len = self.horizon_len
    num_decode_patches = (
        horizon_len + output_patch_len - 1
    ) // output_patch_len
    for _ in range(num_decode_patches):
      current_padding = paddings[:, 0 : final_out.shape[1]]
      input_ts = final_out[:, -max_len:]
      input_padding = current_padding[:, -max_len:]
      model_input = NestedMap(
          input_ts=input_ts,
          input_padding=input_padding,
          freq=freq,
      )
      fprop_outputs = self(model_input)[_OUTPUT_TS]
      # (full batch, last patch, output_patch_len, index of mean forecast = 0)
      new_ts = fprop_outputs[:, -1, :output_patch_len, 0]
      # (full batch, last patch, output_patch_len, all output indices)
      full_outputs.append(fprop_outputs[:, -1, :output_patch_len, :])
      final_out = jnp.concatenate([final_out, new_ts], axis=-1)

    return (
        final_out[:, inp_time_len : inp_time_len + horizon_len],
        jnp.concatenate(full_outputs, axis=1)[:, 0:horizon_len, :],
    )