Unverified Commit 8abe4930 authored by Eli Simhayev's avatar Eli Simhayev Committed by GitHub
Browse files

[Time-Series] informer model (#21099)

* added informer to gitignore

* added informer to gitignore

* WIP informer2020

* added checking that instantiate works

* added config using gluonTS by kashif

* WIP config

* adding informeConfig. need to remove FeatureEmbedder

* done InformerConfig, but need to change the names

* Done informer model init. working on enc-dec

* added things to address, after reading again enc-dec in the paper

* done modeling - checking initialization work

* added informer to gitignore

* WIP informer2020

* added checking that instantiate works

* added config using gluonTS by kashif

* WIP config

* adding informeConfig. need to remove FeatureEmbedder

* done InformerConfig, but need to change the names

* Done informer model init. working on enc-dec

* added things to address, after reading again enc-dec in the paper

* done modeling - checking initialization work

* moved enc-dec init to InformerEncoder/Decoder init

* added 'init_std' to config, now model init works!

* WIP conversion script, and added code sources

* WIP conversion script: loading original informer pth works

* WIP conversion script: change defaults in the config

* WIP conversion script: supporting Informer input embedding

* WIP conversion script: added parameters for the informer embed

* WIP conversion script: change dim_feedforward=2048

* WIP conversion script: remove unused args for loading checkpoint

* just cleaning up

* DataEmbedding removed, after thinking with Kashif

* working on forward pass

* WIP forward pass: trying to establish working batch for forward pass

* cleaning and finalizing

* adding HF names and docs

* init after cleaning works

* WIP in tests

* added docs for the informer specific args

* fix style

* undo change

* cleaning informer, now need to work only enc-dec

* initial enc-dec classes

* added encoder and decoder

* added todo

* add todos for conv_layers

* added decoder docs from vanilla

* added encoder docs from vanilla

* remove encoder decoder from the original informer

* removed AttentionLayer from the original paper

* removed TriangularCausalMask, same as decoder_attention_mask

* initial sparse attention

* use conv_layers

* fixed test_config test

* fix parenthesis when itearting zip(layers, conv_layers)

* error found in prob attention, added sizes as comments

* fix sizes

* added proposal for q_reduce indexing, and remove unused

* WIP ProbMask, and changed factor=2 for testing

* remove unused libs for this PR for creating the env

* fix checking the attn_weights.size() after bmm

* Q_reduce: changed from torch.gather to simple slicing

* WIP calculate final attn_output

* finish adding v_aggregated, attn_output ready

* changed tgt_len to u in attention_mask, need to fix the size error

* comment attention_mask for encoder, and fix if cond for v_agg

* added ProbMask support (wip), removed old original code

* finished ProbMask 😃



* Revert "remove unused libs for this PR for creating the env"

This reverts commit 11a081e09e92771e51a5d2758d53a9afb59547f0.

* fixes

* make style

* fix initial tests

* fix more tests

* dry

* make style

* remove unused files

* style

* added integration tests

* fix num_static_real_features

* fix header

* remove unused function

* fix example

* fix docs

* Update src/transformers/models/informer/configuration_informer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/modeling_informer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/configuration_informer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/configuration_informer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/configuration_informer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/informer/configuration_informer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* fixes for reviewer

* use prediction_length from model

* fix style

* fixed informer.mdx

* added to index

* updated readme

* undo

* make fix-copies

* typo

* fix copy

* added Informer to toctree

* in order

* fixed comments

* remove unneeded new lines in docs

* make static real and cat optional

* fix use of distil conv layers

* fixed integration test

* added checkpoint for convlayer

* make fix-copies

* updated from time series model

* make fix-copies

* copy decoder

* fix unit tests

* updated scaling config

* fix integration tests

* IGNORE_NON_TESTED

* IGNORE_NON_AUTO_CONFIGURED

* IGNORE_NON_AUTO_CONFIGURED

* updated check configs

* fix formatting

* undo change from time series

* prediction_length should not be None

* aliign with the blog: prettify ProbSparse and change attention_factor  to sampling_factor

* make style

* make fix-copies

* niels CR: update contributed by

* niels CR: update configuration_informer.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* niels CR: update kashif -> huggingface
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* niels CR: `sampling_factor` only relevant when `attention_type`=prob

* make style

* fixed U_part: added multiplication by `L_Q`

* fixed bug: remove `is not None` from `if config.distil`

* fixed test: `decoder_seq_length` to `encoder_seq_length` in cross_attentions check

* fix integration tests

* updated model hub

* do not shift as in training

* undo

* fix make-copies

* make fix-copies

* added `if prediction_length is None`

* changed `ProbSparseAttention` to `InformerProbSparseAttention`

* changed `V_sum` -> `v_mean_dim_time`

* changed `ConvLayer` to `InformerConvLayer` and fixed `super()`

* TimeSeriesTansformer->Informer in decoder's Copied from

* more descriptive in ProbSparse

* make style

* fix coped from

* Revert "added `if prediction_length is None`"

This reverts commit b4cbddfa05e3bd739b79569cd3c3b89e316f2451.

* fixed indent

* use InformerSinusoidalPositionalEmbedding

* make fix-style

* fix from #21860

* fix name

* make fix-copies

* use time series utils

* fix dec num_heads

* docstring

* added time series util doc

* _import_structure

* formatting

* changes from review

* make style

* fix docs

* fix doc

* removed NegativeLogLikelihood

---------
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
parent dde718e7
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Time series distributional output classes and utilities.
"""
from typing import Callable, Dict, Optional, Tuple
import torch
from torch import nn
from torch.distributions import (
AffineTransform,
Distribution,
Independent,
NegativeBinomial,
Normal,
StudentT,
TransformedDistribution,
)
class AffineTransformed(TransformedDistribution):
def __init__(self, base_distribution: Distribution, loc=None, scale=None, event_dim=0):
self.scale = 1.0 if scale is None else scale
self.loc = 0.0 if loc is None else loc
super().__init__(base_distribution, [AffineTransform(loc=self.loc, scale=self.scale, event_dim=event_dim)])
@property
def mean(self):
"""
Returns the mean of the distribution.
"""
return self.base_dist.mean * self.scale + self.loc
@property
def variance(self):
"""
Returns the variance of the distribution.
"""
return self.base_dist.variance * self.scale**2
@property
def stddev(self):
"""
Returns the standard deviation of the distribution.
"""
return self.variance.sqrt()
class ParameterProjection(nn.Module):
def __init__(
self, in_features: int, args_dim: Dict[str, int], domain_map: Callable[..., Tuple[torch.Tensor]], **kwargs
) -> None:
super().__init__(**kwargs)
self.args_dim = args_dim
self.proj = nn.ModuleList([nn.Linear(in_features, dim) for dim in args_dim.values()])
self.domain_map = domain_map
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
params_unbounded = [proj(x) for proj in self.proj]
return self.domain_map(*params_unbounded)
class LambdaLayer(nn.Module):
def __init__(self, function):
super().__init__()
self.function = function
def forward(self, x, *args):
return self.function(x, *args)
class DistributionOutput:
distribution_class: type
in_features: int
args_dim: Dict[str, int]
def __init__(self, dim: int = 1) -> None:
self.dim = dim
self.args_dim = {k: dim * self.args_dim[k] for k in self.args_dim}
def _base_distribution(self, distr_args):
if self.dim == 1:
return self.distribution_class(*distr_args)
else:
return Independent(self.distribution_class(*distr_args), 1)
def distribution(
self,
distr_args,
loc: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> Distribution:
distr = self._base_distribution(distr_args)
if loc is None and scale is None:
return distr
else:
return AffineTransformed(distr, loc=loc, scale=scale, event_dim=self.event_dim)
@property
def event_shape(self) -> Tuple:
r"""
Shape of each individual event contemplated by the distributions that this object constructs.
"""
return () if self.dim == 1 else (self.dim,)
@property
def event_dim(self) -> int:
r"""
Number of event dimensions, i.e., length of the `event_shape` tuple, of the distributions that this object
constructs.
"""
return len(self.event_shape)
@property
def value_in_support(self) -> float:
r"""
A float that will have a valid numeric value when computing the log-loss of the corresponding distribution. By
default 0.0. This value will be used when padding data series.
"""
return 0.0
def get_parameter_projection(self, in_features: int) -> nn.Module:
r"""
Return the parameter projection layer that maps the input to the appropriate parameters of the distribution.
"""
return ParameterProjection(
in_features=in_features,
args_dim=self.args_dim,
domain_map=LambdaLayer(self.domain_map),
)
def domain_map(self, *args: torch.Tensor):
r"""
Converts arguments to the right shape and domain. The domain depends on the type of distribution, while the
correct shape is obtained by reshaping the trailing axis in such a way that the returned tensors define a
distribution of the right event_shape.
"""
raise NotImplementedError()
@staticmethod
def squareplus(x: torch.Tensor) -> torch.Tensor:
r"""
Helper to map inputs to the positive orthant by applying the square-plus operation. Reference:
https://twitter.com/jon_barron/status/1387167648669048833
"""
return (x + torch.sqrt(torch.square(x) + 4.0)) / 2.0
class StudentTOutput(DistributionOutput):
"""
Student-T distribution output class.
"""
args_dim: Dict[str, int] = {"df": 1, "loc": 1, "scale": 1}
distribution_class: type = StudentT
@classmethod
def domain_map(cls, df: torch.Tensor, loc: torch.Tensor, scale: torch.Tensor):
scale = cls.squareplus(scale)
df = 2.0 + cls.squareplus(df)
return df.squeeze(-1), loc.squeeze(-1), scale.squeeze(-1)
class NormalOutput(DistributionOutput):
"""
Normal distribution output class.
"""
args_dim: Dict[str, int] = {"loc": 1, "scale": 1}
distribution_class: type = Normal
@classmethod
def domain_map(cls, loc: torch.Tensor, scale: torch.Tensor):
scale = cls.squareplus(scale)
return loc.squeeze(-1), scale.squeeze(-1)
class NegativeBinomialOutput(DistributionOutput):
"""
Negative Binomial distribution output class.
"""
args_dim: Dict[str, int] = {"total_count": 1, "logits": 1}
distribution_class: type = NegativeBinomial
@classmethod
def domain_map(cls, total_count: torch.Tensor, logits: torch.Tensor):
total_count = cls.squareplus(total_count)
return total_count.squeeze(-1), logits.squeeze(-1)
def _base_distribution(self, distr_args) -> Distribution:
total_count, logits = distr_args
if self.dim == 1:
return self.distribution_class(total_count=total_count, logits=logits)
else:
return Independent(self.distribution_class(total_count=total_count, logits=logits), 1)
# Overwrites the parent class method. We cannot scale using the affine
# transformation since negative binomial should return integers. Instead
# we scale the parameters.
def distribution(
self, distr_args, loc: Optional[torch.Tensor] = None, scale: Optional[torch.Tensor] = None
) -> Distribution:
total_count, logits = distr_args
if scale is not None:
# See scaling property of Gamma.
logits += scale.log()
return self._base_distribution((total_count, logits))
...@@ -3430,6 +3430,30 @@ def load_tf_weights_in_imagegpt(*args, **kwargs): ...@@ -3430,6 +3430,30 @@ def load_tf_weights_in_imagegpt(*args, **kwargs):
requires_backends(load_tf_weights_in_imagegpt, ["torch"]) requires_backends(load_tf_weights_in_imagegpt, ["torch"])
INFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class InformerForPrediction(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class InformerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class InformerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST = None JUKEBOX_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
This diff is collapsed.
...@@ -69,6 +69,10 @@ SPECIAL_CASES_TO_ALLOW = { ...@@ -69,6 +69,10 @@ SPECIAL_CASES_TO_ALLOW = {
"RetriBertConfig": ["layer_norm_eps"], "RetriBertConfig": ["layer_norm_eps"],
# having default values other than `1e-5` - we can't fix them without breaking # having default values other than `1e-5` - we can't fix them without breaking
"TrajectoryTransformerConfig": ["layer_norm_eps"], "TrajectoryTransformerConfig": ["layer_norm_eps"],
# used internally to calculate the feature size
"InformerConfig": ["num_static_real_features", "num_time_features"],
# used internally to calculate the feature size
"TimeSeriesTransformerConfig": ["num_static_real_features", "num_time_features"],
} }
# TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure # TODO (ydshieh): Check the failing cases, try to fix them or move some cases to the above block once we are sure
...@@ -97,7 +101,6 @@ SPECIAL_CASES_TO_ALLOW.update( ...@@ -97,7 +101,6 @@ SPECIAL_CASES_TO_ALLOW.update(
"SwitchTransformersConfig": True, "SwitchTransformersConfig": True,
"TableTransformerConfig": True, "TableTransformerConfig": True,
"TapasConfig": True, "TapasConfig": True,
"TimeSeriesTransformerConfig": True,
"TrajectoryTransformerConfig": True, "TrajectoryTransformerConfig": True,
"TransfoXLConfig": True, "TransfoXLConfig": True,
"UniSpeechConfig": True, "UniSpeechConfig": True,
......
...@@ -68,6 +68,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ ...@@ -68,6 +68,8 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
"TableTransformerDecoder", # Building part of bigger (tested) model. "TableTransformerDecoder", # Building part of bigger (tested) model.
"TimeSeriesTransformerEncoder", # Building part of bigger (tested) model. "TimeSeriesTransformerEncoder", # Building part of bigger (tested) model.
"TimeSeriesTransformerDecoder", # Building part of bigger (tested) model. "TimeSeriesTransformerDecoder", # Building part of bigger (tested) model.
"InformerEncoder", # Building part of bigger (tested) model.
"InformerDecoder", # Building part of bigger (tested) model.
"JukeboxVQVAE", # Building part of bigger (tested) model. "JukeboxVQVAE", # Building part of bigger (tested) model.
"JukeboxPrior", # Building part of bigger (tested) model. "JukeboxPrior", # Building part of bigger (tested) model.
"DeformableDetrEncoder", # Building part of bigger (tested) model. "DeformableDetrEncoder", # Building part of bigger (tested) model.
...@@ -208,6 +210,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ ...@@ -208,6 +210,7 @@ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
"EsmForProteinFolding", "EsmForProteinFolding",
"GPTSanJapaneseModel", "GPTSanJapaneseModel",
"TimeSeriesTransformerForPrediction", "TimeSeriesTransformerForPrediction",
"InformerForPrediction",
"JukeboxVQVAE", "JukeboxVQVAE",
"JukeboxPrior", "JukeboxPrior",
"PegasusXEncoder", "PegasusXEncoder",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment