Unverified Commit f72c7c22 authored by Wesley Gifford's avatar Wesley Gifford Committed by GitHub
Browse files

PatchtTST and PatchTSMixer fixes (#28083)

* 🐛

 fix .max bug

* remove prediction_length from regression output dimensions

* fix parameter names, fix output names, update tests

* ensure shape for PatchTST

* ensure output shape for PatchTSMixer

* update model, batch, and expected for regression distribution test

* update test expected
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>

* Update tests/models/patchtst/test_modeling_patchtst.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtst/test_modeling_patchtst.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtst/test_modeling_patchtst.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* standardize on patch_length
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Make arguments more explicit
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>

* adjust prepared inputs
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>

---------
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 3a08cc48
......@@ -40,7 +40,7 @@ class PatchTSMixerConfig(PretrainedConfig):
Args:
context_length (`int`, *optional*, defaults to 32):
The context/history length for the input sequence.
patch_len (`int`, *optional*, defaults to 8):
patch_length (`int`, *optional*, defaults to 8):
The patch length for the input sequence.
num_input_channels (`int`, *optional*, defaults to 1):
Number of input variates. For Univariate, set it to 1.
......@@ -51,7 +51,7 @@ class PatchTSMixerConfig(PretrainedConfig):
The number of samples to generate in parallel for probabilistic forecast.
d_model (`int`, *optional*, defaults to 8):
Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-5X of
patch_len). Larger value indicates more complex model.
patch_length). Larger value indicates more complex model.
expansion_factor (`int`, *optional*, defaults to 2):
Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model.
num_layers (`int`, *optional*, defaults to 3):
......@@ -155,7 +155,7 @@ class PatchTSMixerConfig(PretrainedConfig):
self,
# Time series specific configuration
context_length: int = 32,
patch_len: int = 8,
patch_length: int = 8,
num_input_channels: int = 1,
patch_stride: int = 8,
num_parallel_samples: int = 100,
......@@ -198,7 +198,7 @@ class PatchTSMixerConfig(PretrainedConfig):
):
self.num_input_channels = num_input_channels
self.context_length = context_length
self.patch_length = patch_len
self.patch_length = patch_length
self.patch_stride = patch_stride
self.d_model = d_model
self.expansion_factor = expansion_factor
......@@ -209,7 +209,7 @@ class PatchTSMixerConfig(PretrainedConfig):
self.norm_mlp = norm_mlp
self.scaling = scaling
self.head_dropout = head_dropout
self.num_patches = (max(context_length, patch_len) - patch_len) // patch_stride + 1
self.num_patches = (max(context_length, patch_length) - patch_length) // patch_stride + 1
self.mask_type = mask_type
self.random_mask_ratio = random_mask_ratio
self.num_forecast_mask_patches = num_forecast_mask_patches
......
......@@ -888,7 +888,7 @@ def forecast_masking(
Parameters:
inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_len)`
Input of shape `(bs, num_channels, num_patch, patch_length)`
num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*):
......@@ -1864,15 +1864,15 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
def forward(
self,
past_values: torch.Tensor,
future_values: torch.Tensor = None,
target_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForTimeSeriesClassificationOutput:
r"""
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the
values of the time series, that serve as labels for the model. The `target_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task.
......@@ -1912,8 +1912,8 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels]
if future_values is not None and return_loss is True:
loss_val = loss(y_hat, future_values)
if target_values is not None and return_loss is True:
loss_val = loss(y_hat, target_values)
else:
loss_val = None
......@@ -1942,7 +1942,7 @@ class PatchTSMixerForRegressionOutput(ModelOutput):
Output type of [`PatchTSMixerForRegressionOutput`].
Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
Prediction output from the regression head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head.
......@@ -1953,7 +1953,7 @@ class PatchTSMixerForRegressionOutput(ModelOutput):
"""
loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None
regression_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
......@@ -2054,15 +2054,15 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
def forward(
self,
past_values: torch.Tensor,
future_values: torch.Tensor = None,
target_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False,
return_loss: bool = True,
return_dict: Optional[bool] = None,
) -> PatchTSMixerForRegressionOutput:
r"""
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the
values of the time series, that serve as labels for the model. The `target_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task.
......@@ -2106,16 +2106,18 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets]
if future_values is not None and return_loss is True:
if target_values is not None and return_loss is True:
if self.distribution_output:
if self.distribution_output == "negative_binomial" and torch.any(future_values < 0):
raise Exception("future_values cannot be negative for negative_binomial distribution.")
if self.distribution_output == "negative_binomial" and torch.any(target_values < 0):
raise Exception("target_values cannot be negative for negative_binomial distribution.")
distribution = self.distribution_output.distribution(y_hat)
loss_val = loss(distribution, future_values)
# y_hat should be a 2-tuple, each with dimension [bs, num_targets]
y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat])
loss_val = loss(distribution, target_values)
# take average of the loss
loss_val = weighted_average(loss_val)
else:
loss_val = loss(y_hat, future_values)
loss_val = loss(y_hat, target_values)
else:
loss_val = None
......@@ -2132,7 +2134,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
return PatchTSMixerForRegressionOutput(
loss=loss_val,
prediction_outputs=y_hat, # tensor [batch_size x num_targets]
regression_outputs=y_hat, # tensor [batch_size x num_targets]
last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model]
hidden_states=model_output.hidden_states,
)
......@@ -2146,7 +2148,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Past values of the time series that serves as context in order to predict the future.
Past values of the time series that serves as context in order to predict the target values.
Return:
[`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
......@@ -2158,17 +2160,18 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
# get model output
outputs = self(
past_values=past_values,
future_values=None,
target_values=None,
output_hidden_states=False,
)
# get distribution
distribution = self.distribution_output.distribution(outputs.prediction_outputs)
distribution = self.distribution_output.distribution(outputs.regression_outputs)
# get samples
samples = [
distribution.sample() for _ in range(num_parallel_samples)
] # samples: list of [batch_size x num_targets]
# stack tensors
samples = torch.stack(samples, dim=1) # [batch_size x num_samples x num_targets]
# [batch_size x num_samples x num_targets]
samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
return SamplePatchTSMixerRegressionOutput(sequences=samples)
......@@ -289,7 +289,7 @@ def forecast_masking(
Parameters:
inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_len)`
Input of shape `(bs, num_channels, num_patch, patch_length)`
num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*):
......@@ -1430,7 +1430,7 @@ class PatchTSTClassificationHead(nn.Module):
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2)
pooled_embedding = embedding.max(dim=2).values
else:
raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
# pooled_embedding: bs x num_channels * d_model
......@@ -1602,7 +1602,7 @@ class PatchTSTPredictionHead(nn.Module):
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2)
pooled_embedding = embedding.max(dim=2).values
else:
# pooled_embedding: [bs x num_channels x num_patches x d_model]
pooled_embedding = embedding
......@@ -1866,7 +1866,7 @@ class PatchTSTRegressionHead(nn.Module):
pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2)
pooled_embedding = embedding.max(dim=2).values
else:
raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
# flatten the input
......@@ -1899,11 +1899,11 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
self.distribution_output = None
else:
if config.distribution_output == "student_t":
self.distribution_output = StudentTOutput(dim=config.prediction_length * config.num_targets)
self.distribution_output = StudentTOutput(dim=config.num_targets)
elif config.distribution_output == "normal":
self.distribution_output = NormalOutput(dim=config.prediction_length * config.num_targets)
self.distribution_output = NormalOutput(dim=config.num_targets)
elif config.distribution_output == "negative_binomial":
self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length * config.num_targets)
self.distribution_output = NegativeBinomialOutput(dim=config.num_targets)
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
......@@ -1974,6 +1974,8 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
if target_values is not None:
if self.distribution_output:
distribution = self.distribution_output.distribution(y_hat)
# y_hat should be a 2-tuple, each with dimension [bs, num_targets]
y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat])
loss = nll(distribution, target_values)
# take average of the loss
loss = weighted_average(loss)
......@@ -1982,6 +1984,7 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
loss = loss(y_hat, target_values)
if not return_dict:
# hidden_states, attentions, mask
outputs = (y_hat,) + model_output[1:-3]
outputs = (loss,) + outputs if loss is not None else outputs
return outputs
......@@ -2030,5 +2033,5 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
# get samples: list of [bs x num_targets]
samples = [distribution.sample() for _ in range(num_parallel_samples)]
# samples: [bs x num_samples x num_targets]
samples = torch.stack(samples, dim=1)
samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
return SamplePatchTSTOutput(sequences=samples)
......@@ -191,11 +191,8 @@ class PatchTSMixerModelTester:
# [bs x context_length x n_vars]
past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels])
future_values = floats_tensor([self.batch_size, config.prediction_length, self.num_input_channels])
inputs_dict = {
"past_values": past_values,
"future_values": future_values,
}
return inputs_dict
......@@ -256,21 +253,25 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
# if classification model:
if model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
if model_class == PatchTSMixerForPrediction:
rng = random.Random(self.model_tester.seed_number)
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
# inputs_dict["labels"] = labels
labels = floats_tensor(
[
self.model_tester.batch_size,
self.model_tester.prediction_length,
self.model_tester.num_input_channels,
],
rng=rng,
)
inputs_dict["future_values"] = labels
# inputs_dict.pop("future_values")
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
inputs_dict["target_values"] = labels
elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
labels = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
# inputs_dict["labels"] = labels
inputs_dict["future_values"] = labels
# inputs_dict.pop("future_values")
elif model_class in [PatchTSMixerModel, PatchTSMixerForPretraining]:
inputs_dict.pop("future_values")
inputs_dict["target_values"] = labels
inputs_dict["output_hidden_states"] = True
return inputs_dict
......@@ -409,28 +410,37 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names_with_target = [
if model_class == PatchTSMixerForPretraining:
expected_arg_names = [
"past_values",
"observed_mask",
"future_values",
"output_hidden_states",
"return_loss",
]
expected_arg_names_without_target = [
elif model_class == PatchTSMixerModel:
expected_arg_names = [
"past_values",
"observed_mask",
"output_hidden_states",
]
expected_arg_names = expected_arg_names_with_target
if model_class == PatchTSMixerForPretraining:
expected_arg_names = expected_arg_names_without_target + ["return_loss"]
if model_class == PatchTSMixerModel:
expected_arg_names = expected_arg_names_without_target
if model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
):
expected_arg_names.remove("observed_mask")
expected_arg_names = [
"past_values",
"target_values",
"output_hidden_states",
"return_loss",
]
else:
# PatchTSMixerForPrediction
expected_arg_names = [
"past_values",
"observed_mask",
"future_values",
"output_hidden_states",
"return_loss",
]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
......@@ -686,20 +696,27 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
else:
target_output = target_input
ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1, -1)
ground_truth_arg = "future_values"
output_predictions_arg = "prediction_outputs"
elif task == "classification":
mdl = PatchTSMixerForTimeSeriesClassification(config)
target_input = self.__class__.correct_classification_classes
target_output = self.__class__.correct_classification_output
ground_truth_arg = "target_values"
output_predictions_arg = "prediction_outputs"
elif task == "regression":
mdl = PatchTSMixerForRegression(config)
target_input = self.__class__.correct_regression_output
target_output = self.__class__.correct_regression_output
ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1)
ground_truth_arg = "target_values"
output_predictions_arg = "regression_outputs"
elif task == "pretrain":
mdl = PatchTSMixerForPretraining(config)
target_input = None
target_output = self.__class__.correct_pretrain_output
ground_truth_arg = None
output_predictions_arg = "prediction_outputs"
else:
print("invalid task")
......@@ -710,15 +727,18 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
else:
output = mdl(
self.__class__.data,
future_values=target_input,
output_hidden_states=output_hidden_states,
**{
ground_truth_arg: target_input,
"output_hidden_states": output_hidden_states,
},
)
if isinstance(output.prediction_outputs, tuple):
for t in output.prediction_outputs:
prediction_outputs = getattr(output, output_predictions_arg)
if isinstance(prediction_outputs, tuple):
for t in prediction_outputs:
self.assertEqual(t.shape, target_output.shape)
else:
self.assertEqual(output.prediction_outputs.shape, target_output.shape)
self.assertEqual(prediction_outputs.shape, target_output.shape)
self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
......@@ -980,7 +1000,7 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForTimeSeriesClassification(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_classification_classes,
target_values=self.__class__.correct_classification_classes,
)
self.assertEqual(
output.prediction_outputs.shape,
......@@ -994,7 +1014,7 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForTimeSeriesClassification(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_classification_classes,
target_values=self.__class__.correct_classification_classes,
return_dict=False,
)
if isinstance(output, tuple):
......@@ -1017,9 +1037,9 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
def test_regression_full(self):
config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs.shape,
output.regression_outputs.shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
......@@ -1030,13 +1050,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForRegression(config)
output = mdl(
self.__class__.data,
future_values=self.__class__.correct_regression_output,
target_values=self.__class__.correct_regression_output,
return_dict=False,
)
if isinstance(output, tuple):
output = PatchTSMixerForRegressionOutput(*output)
self.assertEqual(
output.prediction_outputs.shape,
output.regression_outputs.shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
......@@ -1049,13 +1069,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs[0].shape,
output.regression_outputs[0].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(
output.prediction_outputs[1].shape,
output.regression_outputs[1].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
......@@ -1075,13 +1095,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output)
output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual(
output.prediction_outputs[0].shape,
output.regression_outputs[0].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(
output.prediction_outputs[1].shape,
output.regression_outputs[1].shape,
self.__class__.correct_regression_output.shape,
)
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
......
......@@ -367,19 +367,19 @@ class PatchTSTModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE))
def test_regression_generation(self):
model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression").to(torch_device)
batch = prepare_batch(file="test-batch.pt")
model = PatchTSTForRegression.from_pretrained("ibm/patchtst-etth1-regression-distribution").to(torch_device)
batch = prepare_batch(repo_id="ibm/patchtst-etth1-test-data", file="regression_distribution_batch.pt")
torch.manual_seed(0)
model.eval()
with torch.no_grad():
outputs = model.generate(past_values=batch["past_values"].to(torch_device))
expected_shape = torch.Size((64, model.config.num_parallel_samples, model.config.num_targets))
self.assertEqual(outputs.sequences.shape, expected_shape)
expected_slice = torch.tensor(
[[0.3228, 0.4320, 0.4591, 0.4066, -0.3461, 0.3094, -0.8426]],
[[-0.08046409], [-0.06570087], [-0.28218266], [-0.20636195], [-0.11787311]],
device=torch_device,
)
mean_prediction = outputs.sequences.mean(dim=1)
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, rtol=TOLERANCE))
self.assertTrue(torch.allclose(mean_prediction[-5:], expected_slice, rtol=TOLERANCE))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment