Unverified Commit 4ed07528 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Time Series] use mean scaler when scaling is a boolean True (#24237)

* use mean scaler when scaling is boolean True

* remove debug
parent 695928e1
......@@ -1495,7 +1495,7 @@ class AutoformerModel(AutoformerPreTrainedModel):
def __init__(self, config: AutoformerConfig):
super().__init__(config)
if config.scaling == "mean" or config.scaling:
if config.scaling == "mean" or config.scaling is True:
self.scaler = AutoformerMeanScaler(dim=1, keepdim=True)
elif config.scaling == "std":
self.scaler = AutoformerStdScaler(dim=1, keepdim=True)
......
......@@ -1504,7 +1504,7 @@ class InformerModel(InformerPreTrainedModel):
def __init__(self, config: InformerConfig):
super().__init__(config)
if config.scaling == "mean" or config.scaling:
if config.scaling == "mean" or config.scaling is True:
self.scaler = InformerMeanScaler(dim=1, keepdim=True)
elif config.scaling == "std":
self.scaler = InformerStdScaler(dim=1, keepdim=True)
......
......@@ -1229,7 +1229,7 @@ class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel):
def __init__(self, config: TimeSeriesTransformerConfig):
super().__init__(config)
if config.scaling == "mean" or config.scaling:
if config.scaling == "mean" or config.scaling is True:
self.scaler = TimeSeriesMeanScaler(dim=1, keepdim=True)
elif config.scaling == "std":
self.scaler = TimeSeriesStdScaler(dim=1, keepdim=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment