"tests/t5/test_modeling_tf_t5.py" did not exist on "0b933584473c7e7a1d1e231ffa3e95b71c3e2139"
Unverified Commit df06fb1f authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

Time series transformer: input projection and Std scaler (#21020)



* added loc and scale outputs from scalers

* fix typo

* fix tests

* fixed formatting

* initial StdScaler

* move scaling to optional str

* calculate std feature for scalers

* undid change as it does not help

* added StdScaler with weights

* added input projection layer and d_model hyperparam

* use linear proj

* add back layernorm_embedding

* add sin-cos pos embeddings

* updated scalers

* formatting

* fix type

* fixed test

* fix repeated_past_values cal.

* fix when keepdim=false

* fix default_scale

* backward compatibility of scaling config

* update integration test expected output

* fix style

* fix docs

* use the actual num_static_real_features in feature_dim cal

* clarified docs

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* prediction_length is not optional

* fix for reviewer

* Update src/transformers/models/time_series_transformer/configuration_time_series_transformer.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* get rid of un-needed new lines

* fix doc

* remove unneeded new lines

* fix style

* static_categorical_features and static_real_features are optional

* fix integration test

* Update src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fixing docs for multivariate setting

* documentation for generate

---------
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent bb5a2f2f
......@@ -14,7 +14,7 @@
# limitations under the License.
""" Time Series Transformer model configuration"""
from typing import List, Optional
from typing import List, Optional, Union
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -56,8 +56,9 @@ class TimeSeriesTransformerConfig(PretrainedConfig):
input_size (`int`, *optional*, defaults to 1):
The size of the target variable which by default is 1 for univariate targets. Would be > 1 in case of
multivariate targets.
scaling (`bool`, *optional* defaults to `True`):
Whether to scale the input targets.
scaling (`string` or `bool`, *optional* defaults to `"mean"`):
Whether to scale the input targets via "mean" scaler, "std" scaler or no scaler if `None`. If `True`, the
scaler is set to "mean".
lags_sequence (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 5, 6, 7]`):
The lags of the input time series as covariates often dictated by the frequency. Default is `[1, 2, 3, 4,
5, 6, 7]`.
......@@ -77,6 +78,8 @@ class TimeSeriesTransformerConfig(PretrainedConfig):
The dimension of the embedding for each of the static categorical features. Should be a list of integers,
having the same length as `num_static_categorical_features`. Cannot be `None` if
`num_static_categorical_features` is > 0.
d_model (`int`, *optional*, defaults to 64):
Dimensionality of the transformer layers.
encoder_layers (`int`, *optional*, defaults to 2):
Number of encoder layers.
decoder_layers (`int`, *optional*, defaults to 2):
......@@ -132,13 +135,13 @@ class TimeSeriesTransformerConfig(PretrainedConfig):
def __init__(
self,
input_size: int = 1,
prediction_length: Optional[int] = None,
context_length: Optional[int] = None,
distribution_output: str = "student_t",
loss: str = "nll",
input_size: int = 1,
lags_sequence: List[int] = [1, 2, 3, 4, 5, 6, 7],
scaling: bool = True,
scaling: Optional[Union[str, bool]] = "mean",
num_dynamic_real_features: int = 0,
num_static_categorical_features: int = 0,
num_static_real_features: int = 0,
......@@ -153,6 +156,7 @@ class TimeSeriesTransformerConfig(PretrainedConfig):
decoder_layers: int = 2,
is_encoder_decoder: bool = True,
activation_function: str = "gelu",
d_model: int = 64,
dropout: float = 0.1,
encoder_layerdrop: float = 0.1,
decoder_layerdrop: float = 0.1,
......@@ -182,7 +186,7 @@ class TimeSeriesTransformerConfig(PretrainedConfig):
)
self.cardinality = cardinality
else:
self.cardinality = [1]
self.cardinality = [0]
if embedding_dimension and num_static_categorical_features > 0:
if len(embedding_dimension) != num_static_categorical_features:
raise ValueError(
......@@ -194,7 +198,8 @@ class TimeSeriesTransformerConfig(PretrainedConfig):
self.num_parallel_samples = num_parallel_samples
# Transformer architecture configuration
self.d_model = input_size * len(lags_sequence) + self._number_of_features
self.feature_size = input_size * len(lags_sequence) + self._number_of_features
self.d_model = d_model
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.encoder_ffn_dim = encoder_ffn_dim
......@@ -224,6 +229,6 @@ class TimeSeriesTransformerConfig(PretrainedConfig):
sum(self.embedding_dimension)
+ self.num_dynamic_real_features
+ self.num_time_features
+ max(1, self.num_static_real_features) # there is at least one dummy static real feature
+ self.input_size # the log(scale)
+ self.num_static_real_features
+ self.input_size * 2 # the log1p(abs(loc)) and log(scale) features
)
......@@ -19,6 +19,7 @@ import random
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from torch.distributions import (
......@@ -255,6 +256,39 @@ class FeatureEmbedder(nn.Module):
)
class StdScaler(nn.Module):
"""
Standardize features by calculating the mean and scaling along some given dimension `dim`, and then normalizes it
by subtracting from the mean and dividing by the standard deviation.
Args:
dim (`int`):
Dimension along which to calculate the mean and standard deviation.
keepdim (`bool`, *optional*, defaults to `False`):
Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it.
minimum_scale (`float`, *optional*, defaults to 1e-5):
Default scale that is used for elements that are constantly zero along dimension `dim`.
"""
def __init__(self, dim: int, keepdim: bool = False, minimum_scale: float = 1e-5):
super().__init__()
if not dim > 0:
raise ValueError("Cannot compute scale along dim = 0 (batch dimension), please provide dim > 0")
self.dim = dim
self.keepdim = keepdim
self.minimum_scale = minimum_scale
@torch.no_grad()
def forward(self, data: torch.Tensor, weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
denominator = weights.sum(self.dim, keepdim=self.keepdim)
denominator = denominator.clamp_min(1.0)
loc = (data * weights).sum(self.dim, keepdim=self.keepdim) / denominator
variance = (((data - loc) * weights) ** 2).sum(self.dim, keepdim=self.keepdim) / denominator
scale = torch.sqrt(variance + self.minimum_scale)
return (data - loc) / scale, loc, scale
class MeanScaler(nn.Module):
"""
Computes a scaling factor as the weighted average absolute value along dimension `dim`, and scales the data
......@@ -265,48 +299,49 @@ class MeanScaler(nn.Module):
Dimension along which to compute the scale.
keepdim (`bool`, *optional*, defaults to `False`):
Controls whether to retain dimension `dim` (of length 1) in the scale tensor, or suppress it.
default_scale (`float`, *optional*, defaults to `None`):
Default scale that is used for elements that are constantly zero. If `None`, we use the scale of the batch.
minimum_scale (`float`, *optional*, defaults to 1e-10):
Default scale that is used for elements that are constantly zero along dimension `dim`.
Default minimum possible scale that is used for any item.
"""
def __init__(self, dim: int, keepdim: bool = False, minimum_scale: float = 1e-10):
def __init__(
self, dim: int = -1, keepdim: bool = True, default_scale: Optional[float] = None, minimum_scale: float = 1e-10
):
super().__init__()
if not dim > 0:
raise ValueError("Cannot compute scale along dim = 0 (batch dimension), please provide dim > 0")
self.dim = dim
self.keepdim = keepdim
self.register_buffer("minimum_scale", torch.tensor(minimum_scale))
def forward(self, data: torch.Tensor, weights: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# these will have shape (N, C)
total_weight = weights.sum(dim=self.dim)
weighted_sum = (data.abs() * weights).sum(dim=self.dim)
# first compute a global scale per-dimension
total_observed = total_weight.sum(dim=0)
denominator = torch.max(total_observed, torch.ones_like(total_observed))
default_scale = weighted_sum.sum(dim=0) / denominator
# then compute a per-item, per-dimension scale
denominator = torch.max(total_weight, torch.ones_like(total_weight))
scale = weighted_sum / denominator
# use per-batch scale when no element is observed
# or when the sequence contains only zeros
scale = (
torch.max(
self.minimum_scale,
torch.where(
weighted_sum > torch.zeros_like(weighted_sum),
scale,
default_scale * torch.ones_like(total_weight),
),
)
.detach()
.unsqueeze(dim=self.dim)
)
self.minimum_scale = minimum_scale
self.default_scale = default_scale
@torch.no_grad()
def forward(self, data: torch.Tensor, observed_indicator: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# shape: (N, [C], T=1)
ts_sum = (data * observed_indicator).abs().sum(self.dim, keepdim=True)
num_observed = observed_indicator.sum(self.dim, keepdim=True)
scale = ts_sum / torch.clamp(num_observed, min=1)
# If `default_scale` is provided, we use it, otherwise we use the scale
# of the batch.
if self.default_scale is None:
batch_sum = ts_sum.sum(dim=0)
batch_observations = torch.clamp(num_observed.sum(0), min=1)
default_scale = torch.squeeze(batch_sum / batch_observations)
else:
default_scale = self.default_scale * torch.ones_like(scale)
return data / scale, scale if self.keepdim else scale.squeeze(dim=self.dim)
# apply default scale where there are no observations
scale = torch.where(num_observed > 0, scale, default_scale)
# ensure the scale is at least `self.minimum_scale`
scale = torch.clamp(scale, min=self.minimum_scale)
scaled_data = data / scale
if not self.keepdim:
scale = scale.squeeze(dim=self.dim)
return scaled_data, torch.zeros_like(scale), scale
class NOPScaler(nn.Module):
......@@ -325,9 +360,12 @@ class NOPScaler(nn.Module):
self.dim = dim
self.keepdim = keepdim
def forward(self, data: torch.Tensor, observed_indicator: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
scale = torch.ones_like(data).mean(dim=self.dim, keepdim=self.keepdim)
return data, scale
def forward(
self, data: torch.Tensor, observed_indicator: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
scale = torch.ones_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
loc = torch.zeros_like(data, requires_grad=False).mean(dim=self.dim, keepdim=self.keepdim)
return data, loc, scale
def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
......@@ -394,6 +432,50 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->TimeSeries
class TimeSeriesSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None) -> None:
super().__init__(num_positions, embedding_dim)
self.weight = self._init_weight(self.weight)
@staticmethod
def _init_weight(out: nn.Parameter) -> nn.Parameter:
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
n_pos, dim = out.shape
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
out.requires_grad = False # set early to avoid an error in pytorch-1.8+
sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1
out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
out.detach_()
return out
@torch.no_grad()
def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor:
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
bsz, seq_len = input_ids_shape[:2]
positions = torch.arange(
past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device
)
return super().forward(positions)
class ValueEmbedding(nn.Module):
def __init__(self, feature_size, d_model):
super(ValueEmbedding, self).__init__()
self.value_projection = nn.Linear(in_features=feature_size, out_features=d_model, bias=False)
def forward(self, x):
return self.value_projection(x)
@dataclass
class Seq2SeqTimeSeriesModelOutput(ModelOutput):
"""
......@@ -443,9 +525,12 @@ class Seq2SeqTimeSeriesModelOutput(ModelOutput):
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
scale: (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
loc (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
Shift values of each time series' context window which is used to give the model inputs of the same
magnitude and then used to shift back to the original magnitude.
scale (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
Scaling values of each time series' context window which is used to give the model inputs of the same
magnitude and then used to rescale to the original scale.
magnitude and then used to rescale back to the original magnitude.
static_features: (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
Static features of each time series' in a batch which are copied to the covariates at inference time.
"""
......@@ -458,6 +543,7 @@ class Seq2SeqTimeSeriesModelOutput(ModelOutput):
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
loc: Optional[torch.FloatTensor] = None
scale: Optional[torch.FloatTensor] = None
static_features: Optional[torch.FloatTensor] = None
......@@ -510,9 +596,12 @@ class Seq2SeqTimeSeriesPredictionOutput(ModelOutput):
Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
self-attention heads.
scale: (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
loc (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
Shift values of each time series' context window which is used to give the model inputs of the same
magnitude and then used to shift back to the original magnitude.
scale (`torch.FloatTensor` of shape `(batch_size,)`, *optional*):
Scaling values of each time series' context window which is used to give the model inputs of the same
magnitude and then used to rescale to the original scale.
magnitude and then used to rescale back to the original magnitude.
static_features: (`torch.FloatTensor` of shape `(batch_size, feature size)`, *optional*):
Static features of each time series' in a batch which are copied to the covariates at inference time.
"""
......@@ -526,6 +615,7 @@ class Seq2SeqTimeSeriesPredictionOutput(ModelOutput):
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
loc: Optional[torch.FloatTensor] = None
scale: Optional[torch.FloatTensor] = None
static_features: Optional[torch.FloatTensor] = None
......@@ -889,6 +979,8 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, TimeSeriesSinusoidalPositionalEmbedding):
pass
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
......@@ -917,30 +1009,41 @@ TIME_SERIES_TRANSFORMER_START_DOCSTRING = r"""
TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING = r"""
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
Past values of the time series, that serve as context in order to predict the future. These values may
contain lags, i.e. additional values from the past which are added in order to serve as "extra context".
The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as
`static_categorical_features`, `static_real_features`, `past_time_features`).
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
Past values of the time series, that serve as context in order to predict the future. The sequence size of
this tensor must be larger than the `context_length` of the model, since the model will use the larger size
to construct lag features, i.e. additional values from the past which are added in order to serve as "extra
context".
The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no
`lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest
look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of
the past.
The sequence length here is equal to `context_length` + `max(config.lags_sequence)`.
The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as
`static_categorical_features`, `static_real_features`, `past_time_features` and lags).
Missing values need to be replaced with zeros.
Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`.
past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`, *optional*):
Optional time features, which the model internally will add to `past_values`. These could be things like
For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of
variates in the time series per time step.
past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`):
Required time features, which the model internally will add to `past_values`. These could be things like
"month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These
could also be so-called "age" features, which basically help the model know "at which point in life" a
time-series is. Age features have small values for distant past time steps and increase monotonically the
more we approach the current time step.
more we approach the current time step. Holiday features are also a good example of time features.
These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where
the position encodings are learned from scratch internally as parameters of the model, the Time Series
Transformer requires to provide additional time features.
Transformer requires to provide additional time features. The Time Series Transformer only learns
additional embeddings for `static_categorical_features`.
The Time Series Transformer only learns additional embeddings for `static_categorical_features`.
Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features
must but known at prediction time.
past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`.
past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in
`[0, 1]`:
......@@ -954,35 +1057,50 @@ TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING = r"""
Static categorical features are features which have the same value for all time steps (static over time).
A typical example of a static categorical feature is a time series ID.
static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*):
Optional static real features which the model will add to the values of the time series.
Static real features are features which have the same value for all time steps (static over time).
A typical example of a static real feature is promotion information.
future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)`):
future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*):
Future values of the time series, that serve as labels for the model. The `future_values` is what the
Transformer needs to learn to output, given the `past_values`.
Transformer needs during training to learn to output, given the `past_values`.
The sequence length here is equal to `prediction_length`.
See the demo notebook and code snippets for details.
Missing values need to be replaced with zeros.
Optionally, during training any missing values need to be replaced with zeros and indicated via the
`future_observed_mask`.
future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`, *optional*):
Optional time features, which the model internally will add to `future_values`. These could be things like
"month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These
could also be so-called "age" features, which basically help the model know "at which point in life" a
time-series is. Age features have small values for distant past time steps and increase monotonically the
more we approach the current time step.
For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of
variates in the time series per time step.
future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`):
Required time features for the prediction window, which the model internally will add to `future_values`.
These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as
Fourier features). These could also be so-called "age" features, which basically help the model know "at
which point in life" a time-series is. Age features have small values for distant past time steps and
increase monotonically the more we approach the current time step. Holiday features are also a good example
of time features.
These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where
the position encodings are learned from scratch internally as parameters of the model, the Time Series
Transformer requires to provide additional features.
Transformer requires to provide additional time features. The Time Series Transformer only learns
additional embeddings for `static_categorical_features`.
Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features
must but known at prediction time.
The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`.
future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*):
Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
The Time Series Transformer only learns additional embeddings for `static_categorical_features`.
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
This mask is used to filter out missing values for the final loss calculation.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on certain token indices. Mask values selected in `[0, 1]`:
......@@ -990,11 +1108,9 @@ TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
decoder_attention_mask (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
Mask to avoid performing attention on certain token indices. By default, a causal mask will be used, to
make sure the model can only look at previous inputs in order to predict the future.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in the encoder. Mask values selected in `[0, 1]`:
......@@ -1032,7 +1148,6 @@ TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING = r"""
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
......@@ -1062,10 +1177,12 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel):
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
embed_dim = config.d_model
self.value_embedding = ValueEmbedding(feature_size=config.feature_size, d_model=config.d_model)
self.embed_positions = TimeSeriesSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
)
self.layers = nn.ModuleList([TimeSeriesTransformerEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
......@@ -1114,8 +1231,10 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel):
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
hidden_states = inputs_embeds
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.value_embedding(inputs_embeds)
embed_pos = self.embed_positions(inputs_embeds.size())
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# expand attention_mask
......@@ -1193,6 +1312,10 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
self.dropout = config.dropout
self.layerdrop = config.decoder_layerdrop
self.value_embedding = ValueEmbedding(feature_size=config.feature_size, d_model=config.d_model)
self.embed_positions = TimeSeriesSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
)
self.layers = nn.ModuleList([TimeSeriesTransformerDecoderLayer(config) for _ in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
......@@ -1278,20 +1401,16 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
......@@ -1316,9 +1435,9 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
encoder_attention_mask = _expand_mask(encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
hidden_states = inputs_embeds
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = self.value_embedding(inputs_embeds)
embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length)
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
......@@ -1423,11 +1542,14 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__(config)
if config.scaling:
if config.scaling == "mean" or config.scaling:
self.scaler = MeanScaler(dim=1, keepdim=True)
elif config.scaling == "std":
self.scaler = StdScaler(dim=1, keepdim=True)
else:
self.scaler = NOPScaler(dim=1, keepdim=True)
if config.num_static_categorical_features > 0:
self.embedder = FeatureEmbedder(
cardinalities=config.cardinality,
embedding_dims=config.embedding_dimension,
......@@ -1483,8 +1605,8 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
self,
past_values: torch.Tensor,
past_time_features: torch.Tensor,
static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor,
static_categorical_features: Optional[torch.Tensor] = None,
static_real_features: Optional[torch.Tensor] = None,
past_observed_mask: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
future_time_features: Optional[torch.Tensor] = None,
......@@ -1508,12 +1630,12 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
context = past_values[:, -self.config.context_length :]
observed_context = past_observed_mask[:, -self.config.context_length :]
_, scale = self.scaler(context, observed_context)
_, loc, scale = self.scaler(context, observed_context)
inputs = (
torch.cat((past_values, future_values), dim=1) / scale
(torch.cat((past_values, future_values), dim=1) - loc) / scale
if future_values is not None
else past_values / scale
else (past_values - loc) / scale
)
inputs_length = (
......@@ -1533,34 +1655,29 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
else self.config.context_length
)
# embeddings
embedded_cat = self.embedder(static_categorical_features)
# static features
log_abs_loc = loc.abs().log1p() if self.config.input_size == 1 else loc.squeeze(1).abs().log1p()
log_scale = scale.log() if self.config.input_size == 1 else scale.squeeze(1).log()
static_feat = torch.cat((embedded_cat, static_real_features, log_scale), dim=1)
static_feat = torch.cat((log_abs_loc, log_scale), dim=1)
if static_real_features is not None:
static_feat = torch.cat((static_real_features, static_feat), dim=1)
if static_categorical_features is not None:
embedded_cat = self.embedder(static_categorical_features)
static_feat = torch.cat((embedded_cat, static_feat), dim=1)
expanded_static_feat = static_feat.unsqueeze(1).expand(-1, time_feat.shape[1], -1)
# all features
features = torch.cat((expanded_static_feat, time_feat), dim=-1)
# lagged features
lagged_sequence = self.get_lagged_subsequences(sequence=inputs, subsequences_length=subsequences_length)
lags_shape = lagged_sequence.shape
reshaped_lagged_sequence = lagged_sequence.reshape(lags_shape[0], lags_shape[1], -1)
transformer_inputs = torch.cat((reshaped_lagged_sequence, features), dim=-1)
return transformer_inputs, scale, static_feat
def enc_dec_outputs(self, transformer_inputs):
enc_input = transformer_inputs[:, : self.config.context_length, ...]
dec_input = transformer_inputs[:, self.config.context_length :, ...]
encoder_outputs = self.encoder(inputs_embeds=enc_input)
decoder_outputs = self.decoder(
inputs_embeds=dec_input, encoder_hidden_states=encoder_outputs.last_hidden_state
)
return encoder_outputs, decoder_outputs
return transformer_inputs, loc, scale, static_feat
def get_encoder(self):
return self.encoder
......@@ -1575,8 +1692,8 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
past_values: torch.Tensor,
past_time_features: torch.Tensor,
past_observed_mask: torch.Tensor,
static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor,
static_categorical_features: Optional[torch.Tensor] = None,
static_real_features: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
future_time_features: Optional[torch.Tensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
......@@ -1628,7 +1745,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
transformer_inputs, scale, static_feat = self.create_network_inputs(
transformer_inputs, loc, scale, static_feat = self.create_network_inputs(
past_values=past_values,
past_time_features=past_time_features,
past_observed_mask=past_observed_mask,
......@@ -1670,7 +1787,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
)
if not return_dict:
return decoder_outputs + encoder_outputs + (scale, static_feat)
return decoder_outputs + encoder_outputs + (loc, scale, static_feat)
return Seq2SeqTimeSeriesModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
......@@ -1681,6 +1798,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
loc=loc,
scale=scale,
static_features=static_feat,
)
......@@ -1724,11 +1842,11 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
return self.model.get_decoder()
@torch.jit.ignore
def output_distribution(self, params, scale=None, trailing_n=None) -> torch.distributions.Distribution:
def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> torch.distributions.Distribution:
sliced_params = params
if trailing_n is not None:
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distribution_output.distribution(sliced_params, scale=scale)
return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale)
@add_start_docstrings_to_model_forward(TIME_SERIES_TRANSFORMER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqTimeSeriesModelOutput, config_class=_CONFIG_FOR_DOC)
......@@ -1737,8 +1855,8 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
past_values: torch.Tensor,
past_time_features: torch.Tensor,
past_observed_mask: torch.Tensor,
static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor,
static_categorical_features: Optional[torch.Tensor] = None,
static_real_features: Optional[torch.Tensor] = None,
future_values: Optional[torch.Tensor] = None,
future_time_features: Optional[torch.Tensor] = None,
future_observed_mask: Optional[torch.Tensor] = None,
......@@ -1756,15 +1874,6 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
r"""
Returns:
future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
This mask is used to filter out missing values for the final loss calculation.
Examples:
```python
......@@ -1839,7 +1948,8 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
params = None
if future_values is not None:
params = self.output_params(outputs[0]) # outputs.last_hidden_state
distribution = self.output_distribution(params, outputs[-2]) # outputs.scale
# loc is 3rd last and scale is 2nd last output
distribution = self.output_distribution(params, loc=outputs[-3], scale=outputs[-2])
loss = self.loss(distribution, future_values)
......@@ -1867,6 +1977,7 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
loc=outputs.loc,
scale=outputs.scale,
static_features=outputs.static_features,
)
......@@ -1874,15 +1985,102 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
@torch.no_grad()
def generate(
self,
static_categorical_features: torch.Tensor,
static_real_features: torch.Tensor,
past_time_features: torch.Tensor,
past_values: torch.Tensor,
past_observed_mask: torch.Tensor,
future_time_features: Optional[torch.Tensor],
past_time_features: torch.Tensor,
future_time_features: torch.Tensor,
past_observed_mask: Optional[torch.Tensor] = None,
static_categorical_features: Optional[torch.Tensor] = None,
static_real_features: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> torch.Tensor:
) -> SampleTimeSeriesPredictionOutput:
r"""
Greedily generate sequences of sample predictions from a model with a probability distribution head.
Parameters:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
Past values of the time series, that serve as context in order to predict the future. The sequence size
of this tensor must be larger than the `context_length` of the model, since the model will use the
larger size to construct lag features, i.e. additional values from the past which are added in order to
serve as "extra context".
The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if
no `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest
look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length
of the past.
The `past_values` is what the Transformer encoder gets as input (with optional additional features,
such as `static_categorical_features`, `static_real_features`, `past_time_features` and lags).
Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`.
For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number
of variates in the time series per time step.
past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`):
Required time features, which the model internally will add to `past_values`. These could be things
like "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features).
These could also be so-called "age" features, which basically help the model know "at which point in
life" a time-series is. Age features have small values for distant past time steps and increase
monotonically the more we approach the current time step. Holiday features are also a good example of
time features.
These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT,
where the position encodings are learned from scratch internally as parameters of the model, the Time
Series Transformer requires to provide additional time features. The Time Series Transformer only
learns additional embeddings for `static_categorical_features`.
Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these
features must but known at prediction time.
The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`.
future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`):
Required time features for the prediction window, which the model internally will add to sampled
predictions. These could be things like "month of year", "day of the month", etc. encoded as vectors
(for instance as Fourier features). These could also be so-called "age" features, which basically help
the model know "at which point in life" a time-series is. Age features have small values for distant
past time steps and increase monotonically the more we approach the current time step. Holiday features
are also a good example of time features.
These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT,
where the position encodings are learned from scratch internally as parameters of the model, the Time
Series Transformer requires to provide additional time features. The Time Series Transformer only
learns additional embeddings for `static_categorical_features`.
Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these
features must but known at prediction time.
The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`.
past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*):
Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected
in `[0, 1]`:
- 1 for values that are **observed**,
- 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*):
Optional static categorical features for which the model will learn an embedding, which it will add to
the values of the time series.
Static categorical features are features which have the same value for all time steps (static over
time).
A typical example of a static categorical feature is a time series ID.
static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*):
Optional static real features which the model will add to the values of the time series.
Static real features are features which have the same value for all time steps (static over time).
A typical example of a static real feature is promotion information.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers.
Return:
[`SampleTimeSeriesPredictionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
number of samples, prediction_length)` or `(batch_size, number of samples, prediction_length, input_size)`
for multivariate predictions.
"""
outputs = self(
static_categorical_features=static_categorical_features,
static_real_features=static_real_features,
......@@ -1899,13 +2097,17 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
decoder = self.model.get_decoder()
enc_last_hidden = outputs.encoder_last_hidden_state
loc = outputs.loc
scale = outputs.scale
static_feat = outputs.static_features
num_parallel_samples = self.config.num_parallel_samples
repeated_loc = loc.repeat_interleave(repeats=num_parallel_samples, dim=0)
repeated_scale = scale.repeat_interleave(repeats=num_parallel_samples, dim=0)
repeated_past_values = past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) / repeated_scale
repeated_past_values = (
past_values.repeat_interleave(repeats=num_parallel_samples, dim=0) - repeated_loc
) / repeated_scale
expanded_static_feat = static_feat.unsqueeze(1).expand(-1, future_time_features.shape[1], -1)
features = torch.cat((expanded_static_feat, future_time_features), dim=-1)
......@@ -1932,10 +2134,12 @@ class TimeSeriesTransformerForPrediction(TimeSeriesTransformerPreTrainedModel):
dec_last_hidden = dec_output.last_hidden_state
params = self.parameter_projection(dec_last_hidden[:, -1:])
distr = self.output_distribution(params, scale=repeated_scale)
distr = self.output_distribution(params, loc=repeated_loc, scale=repeated_scale)
next_sample = distr.sample()
repeated_past_values = torch.cat((repeated_past_values, next_sample / repeated_scale), dim=1)
repeated_past_values = torch.cat(
(repeated_past_values, (next_sample - repeated_loc) / repeated_scale), dim=1
)
future_samples.append(next_sample)
concat_future_samples = torch.cat(future_samples, dim=1)
......
......@@ -55,7 +55,7 @@ class TimeSeriesTransformerModelTester:
embedding_dimension=5,
num_time_features=4,
is_training=True,
hidden_size=16,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=4,
......@@ -98,6 +98,7 @@ class TimeSeriesTransformerModelTester:
context_length=self.context_length,
lags_sequence=self.lags_sequence,
num_time_features=self.num_time_features,
num_static_real_features=1,
num_static_categorical_features=1,
cardinality=[self.cardinality],
embedding_dimension=[self.embedding_dimension],
......@@ -149,7 +150,7 @@ class TimeSeriesTransformerModelTester:
encoder.save_pretrained(tmpdirname)
encoder = TimeSeriesTransformerEncoder.from_pretrained(tmpdirname).to(torch_device)
transformer_inputs, _, _ = model.create_network_inputs(**inputs_dict)
transformer_inputs, _, _, _ = model.create_network_inputs(**inputs_dict)
enc_input = transformer_inputs[:, : config.context_length, ...]
dec_input = transformer_inputs[:, config.context_length :, ...]
......@@ -186,13 +187,18 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase):
def setUp(self):
self.model_tester = TimeSeriesTransformerModelTester(self)
self.config_tester = ConfigTester(self, config_class=TimeSeriesTransformerConfig, has_text_modality=False)
self.config_tester = ConfigTester(
self,
config_class=TimeSeriesTransformerConfig,
has_text_modality=False,
prediction_length=self.model_tester.prediction_length,
)
def test_config(self):
self.config_tester.run_common_tests()
def test_save_load_strict(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config, _ = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
......@@ -303,7 +309,7 @@ class TimeSeriesTransformerModelTest(ModelTesterMixin, unittest.TestCase):
)
out_len = len(outputs)
correct_outlen = 6
correct_outlen = 7
if "last_hidden_state" in outputs:
correct_outlen += 1
......@@ -389,13 +395,13 @@ class TimeSeriesTransformerModelIntegrationTests(unittest.TestCase):
static_real_features=batch["static_real_features"],
future_values=batch["future_values"],
future_time_features=batch["future_time_features"],
)[0]
).last_hidden_state
expected_shape = torch.Size((64, model.config.prediction_length, model.config.d_model))
expected_shape = torch.Size((64, model.config.context_length, model.config.d_model))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[-0.3125, -1.2884, -1.1118], [-0.5801, -1.4907, -0.7782], [0.0849, -1.6557, -0.9755]], device=torch_device
[[-0.6322, -1.5771, -0.9340], [-0.1011, -1.0263, -0.7208], [0.4979, -0.6487, -0.7189]], device=torch_device
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
......@@ -412,12 +418,12 @@ class TimeSeriesTransformerModelIntegrationTests(unittest.TestCase):
static_categorical_features=batch["static_categorical_features"],
static_real_features=batch["static_real_features"],
future_time_features=batch["future_time_features"],
)[1]
expected_shape = torch.Size((64, model.config.prediction_length, model.config.d_model))
).encoder_last_hidden_state
expected_shape = torch.Size((64, model.config.context_length, model.config.d_model))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[0.9127, -0.2056, -0.5259], [1.0572, 1.4104, -0.1964], [0.1358, 2.0348, 0.5739]], device=torch_device
[[0.8177, -1.7989, -0.3127], [1.6964, -1.0607, -0.1749], [1.8395, 0.1110, 0.0263]], device=torch_device
)
self.assertTrue(torch.allclose(output[0, :3, :3], expected_slice, atol=TOLERANCE))
......@@ -438,6 +444,6 @@ class TimeSeriesTransformerModelIntegrationTests(unittest.TestCase):
expected_shape = torch.Size((64, model.config.num_parallel_samples, model.config.prediction_length))
self.assertEqual(outputs.sequences.shape, expected_shape)
expected_slice = torch.tensor([2289.5203, 2778.3054, 4648.1313], device=torch_device)
expected_slice = torch.tensor([3883.5037, 4630.2251, 7562.1338], device=torch_device)
mean_prediction = outputs.sequences.mean(dim=1)
self.assertTrue(torch.allclose(mean_prediction[0, -3:], expected_slice, rtol=1e-1))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment