"docs/source/vscode:/vscode.git/clone" did not exist on "367a0dbd53cc1b826d986b166f3ac520d500db64"
Unverified Commit 4ed07528 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Time Series] use mean scaler when scaling is a boolean True (#24237)

* use mean scaler when scaling is boolean True

* remove debug
parent 695928e1
...@@ -1495,7 +1495,7 @@ class AutoformerModel(AutoformerPreTrainedModel): ...@@ -1495,7 +1495,7 @@ class AutoformerModel(AutoformerPreTrainedModel):
def __init__(self, config: AutoformerConfig): def __init__(self, config: AutoformerConfig):
super().__init__(config) super().__init__(config)
if config.scaling == "mean" or config.scaling: if config.scaling == "mean" or config.scaling is True:
self.scaler = AutoformerMeanScaler(dim=1, keepdim=True) self.scaler = AutoformerMeanScaler(dim=1, keepdim=True)
elif config.scaling == "std": elif config.scaling == "std":
self.scaler = AutoformerStdScaler(dim=1, keepdim=True) self.scaler = AutoformerStdScaler(dim=1, keepdim=True)
......
...@@ -1504,7 +1504,7 @@ class InformerModel(InformerPreTrainedModel): ...@@ -1504,7 +1504,7 @@ class InformerModel(InformerPreTrainedModel):
def __init__(self, config: InformerConfig): def __init__(self, config: InformerConfig):
super().__init__(config) super().__init__(config)
if config.scaling == "mean" or config.scaling: if config.scaling == "mean" or config.scaling is True:
self.scaler = InformerMeanScaler(dim=1, keepdim=True) self.scaler = InformerMeanScaler(dim=1, keepdim=True)
elif config.scaling == "std": elif config.scaling == "std":
self.scaler = InformerStdScaler(dim=1, keepdim=True) self.scaler = InformerStdScaler(dim=1, keepdim=True)
......
...@@ -1229,7 +1229,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): ...@@ -1229,7 +1229,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
def __init__(self, config: TimeSeriesTransformerConfig): def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__(config) super().__init__(config)
if config.scaling == "mean" or config.scaling: if config.scaling == "mean" or config.scaling is True:
self.scaler = TimeSeriesMeanScaler(dim=1, keepdim=True) self.scaler = TimeSeriesMeanScaler(dim=1, keepdim=True)
elif config.scaling == "std": elif config.scaling == "std":
self.scaler = TimeSeriesStdScaler(dim=1, keepdim=True) self.scaler = TimeSeriesStdScaler(dim=1, keepdim=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment