Unverified Commit f72c7c22 authored by Wesley Gifford's avatar Wesley Gifford Committed by GitHub
Browse files

PatchtTST and PatchTSMixer fixes (#28083)

* 🐛

 fix .max bug

* remove prediction_length from regression output dimensions

* fix parameter names, fix output names, update tests

* ensure shape for PatchTST

* ensure output shape for PatchTSMixer

* update model, batch, and expected for regression distribution test

* update test expected
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>

* Update tests/models/patchtst/test_modeling_patchtst.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtst/test_modeling_patchtst.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtst/test_modeling_patchtst.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* standardize on patch_length
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/patchtsmixer/test_modeling_patchtsmixer.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Make arguments more explicit
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>

* adjust prepared inputs
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>

---------
Signed-off-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>
Co-authored-by: default avatarKashif Rasul <kashif.rasul@gmail.com>
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 3a08cc48
...@@ -40,7 +40,7 @@ class PatchTSMixerConfig(PretrainedConfig): ...@@ -40,7 +40,7 @@ class PatchTSMixerConfig(PretrainedConfig):
Args: Args:
context_length (`int`, *optional*, defaults to 32): context_length (`int`, *optional*, defaults to 32):
The context/history length for the input sequence. The context/history length for the input sequence.
patch_len (`int`, *optional*, defaults to 8): patch_length (`int`, *optional*, defaults to 8):
The patch length for the input sequence. The patch length for the input sequence.
num_input_channels (`int`, *optional*, defaults to 1): num_input_channels (`int`, *optional*, defaults to 1):
Number of input variates. For Univariate, set it to 1. Number of input variates. For Univariate, set it to 1.
...@@ -51,7 +51,7 @@ class PatchTSMixerConfig(PretrainedConfig): ...@@ -51,7 +51,7 @@ class PatchTSMixerConfig(PretrainedConfig):
The number of samples to generate in parallel for probabilistic forecast. The number of samples to generate in parallel for probabilistic forecast.
d_model (`int`, *optional*, defaults to 8): d_model (`int`, *optional*, defaults to 8):
Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-5X of Hidden dimension of the model. Recommended to set it as a multiple of patch_length (i.e. 2-5X of
patch_len). Larger value indicates more complex model. patch_length). Larger value indicates more complex model.
expansion_factor (`int`, *optional*, defaults to 2): expansion_factor (`int`, *optional*, defaults to 2):
Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model. Expansion factor to use inside MLP. Recommended range is 2-5. Larger value indicates more complex model.
num_layers (`int`, *optional*, defaults to 3): num_layers (`int`, *optional*, defaults to 3):
...@@ -155,7 +155,7 @@ class PatchTSMixerConfig(PretrainedConfig): ...@@ -155,7 +155,7 @@ class PatchTSMixerConfig(PretrainedConfig):
self, self,
# Time series specific configuration # Time series specific configuration
context_length: int = 32, context_length: int = 32,
patch_len: int = 8, patch_length: int = 8,
num_input_channels: int = 1, num_input_channels: int = 1,
patch_stride: int = 8, patch_stride: int = 8,
num_parallel_samples: int = 100, num_parallel_samples: int = 100,
...@@ -198,7 +198,7 @@ class PatchTSMixerConfig(PretrainedConfig): ...@@ -198,7 +198,7 @@ class PatchTSMixerConfig(PretrainedConfig):
): ):
self.num_input_channels = num_input_channels self.num_input_channels = num_input_channels
self.context_length = context_length self.context_length = context_length
self.patch_length = patch_len self.patch_length = patch_length
self.patch_stride = patch_stride self.patch_stride = patch_stride
self.d_model = d_model self.d_model = d_model
self.expansion_factor = expansion_factor self.expansion_factor = expansion_factor
...@@ -209,7 +209,7 @@ class PatchTSMixerConfig(PretrainedConfig): ...@@ -209,7 +209,7 @@ class PatchTSMixerConfig(PretrainedConfig):
self.norm_mlp = norm_mlp self.norm_mlp = norm_mlp
self.scaling = scaling self.scaling = scaling
self.head_dropout = head_dropout self.head_dropout = head_dropout
self.num_patches = (max(context_length, patch_len) - patch_len) // patch_stride + 1 self.num_patches = (max(context_length, patch_length) - patch_length) // patch_stride + 1
self.mask_type = mask_type self.mask_type = mask_type
self.random_mask_ratio = random_mask_ratio self.random_mask_ratio = random_mask_ratio
self.num_forecast_mask_patches = num_forecast_mask_patches self.num_forecast_mask_patches = num_forecast_mask_patches
......
...@@ -888,7 +888,7 @@ def forecast_masking( ...@@ -888,7 +888,7 @@ def forecast_masking(
Parameters: Parameters:
inputs (`torch.Tensor`): inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_len)` Input of shape `(bs, num_channels, num_patch, patch_length)`
num_forecast_mask_patches (`list`): num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5]. Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*): unmasked_channel_indices (`list`, *optional*):
...@@ -1864,15 +1864,15 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel): ...@@ -1864,15 +1864,15 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
def forward( def forward(
self, self,
past_values: torch.Tensor, past_values: torch.Tensor,
future_values: torch.Tensor = None, target_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_loss: bool = True, return_loss: bool = True,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> PatchTSMixerForTimeSeriesClassificationOutput: ) -> PatchTSMixerForTimeSeriesClassificationOutput:
r""" r"""
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting, target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the values of the time series, that serve as labels for the model. The `target_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task. required for a pretraining task.
...@@ -1912,8 +1912,8 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel): ...@@ -1912,8 +1912,8 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels] y_hat = self.head(model_output.last_hidden_state) # tensor [batch_size x n_labels]
if future_values is not None and return_loss is True: if target_values is not None and return_loss is True:
loss_val = loss(y_hat, future_values) loss_val = loss(y_hat, target_values)
else: else:
loss_val = None loss_val = None
...@@ -1942,7 +1942,7 @@ class PatchTSMixerForRegressionOutput(ModelOutput): ...@@ -1942,7 +1942,7 @@ class PatchTSMixerForRegressionOutput(ModelOutput):
Output type of [`PatchTSMixerForRegressionOutput`]. Output type of [`PatchTSMixerForRegressionOutput`].
Args: Args:
prediction_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`): regression_outputs (`torch.FloatTensor` of shape `(batch_size, num_targets)`):
Prediction output from the regression head. Prediction output from the regression head.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`): last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_input_channels, num_patches, d_model)`):
Backbone embeddings before passing through the head. Backbone embeddings before passing through the head.
...@@ -1953,7 +1953,7 @@ class PatchTSMixerForRegressionOutput(ModelOutput): ...@@ -1953,7 +1953,7 @@ class PatchTSMixerForRegressionOutput(ModelOutput):
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
prediction_outputs: torch.FloatTensor = None regression_outputs: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
...@@ -2054,15 +2054,15 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel): ...@@ -2054,15 +2054,15 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
def forward( def forward(
self, self,
past_values: torch.Tensor, past_values: torch.Tensor,
future_values: torch.Tensor = None, target_values: torch.Tensor = None,
output_hidden_states: Optional[bool] = False, output_hidden_states: Optional[bool] = False,
return_loss: bool = True, return_loss: bool = True,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> PatchTSMixerForRegressionOutput: ) -> PatchTSMixerForRegressionOutput:
r""" r"""
future_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting, target_values (`torch.FloatTensor` of shape `(batch_size, target_len, num_input_channels)` for forecasting,
`(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target `(batch_size, num_targets)` for regression, or `(batch_size,)` for classification, *optional*): Target
values of the time series, that serve as labels for the model. The `future_values` is what the values of the time series, that serve as labels for the model. The `target_values` is what the
Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT Transformer needs during training to learn to output, given the `past_values`. Note that, this is NOT
required for a pretraining task. required for a pretraining task.
...@@ -2106,16 +2106,18 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel): ...@@ -2106,16 +2106,18 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets] y_hat = self.head(model_output.last_hidden_state) # [batch_size x num_targets]
if future_values is not None and return_loss is True: if target_values is not None and return_loss is True:
if self.distribution_output: if self.distribution_output:
if self.distribution_output == "negative_binomial" and torch.any(future_values < 0): if self.distribution_output == "negative_binomial" and torch.any(target_values < 0):
raise Exception("future_values cannot be negative for negative_binomial distribution.") raise Exception("target_values cannot be negative for negative_binomial distribution.")
distribution = self.distribution_output.distribution(y_hat) distribution = self.distribution_output.distribution(y_hat)
loss_val = loss(distribution, future_values) # y_hat should be a 2-tuple, each with dimension [bs, num_targets]
y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat])
loss_val = loss(distribution, target_values)
# take average of the loss # take average of the loss
loss_val = weighted_average(loss_val) loss_val = weighted_average(loss_val)
else: else:
loss_val = loss(y_hat, future_values) loss_val = loss(y_hat, target_values)
else: else:
loss_val = None loss_val = None
...@@ -2132,7 +2134,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel): ...@@ -2132,7 +2134,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
return PatchTSMixerForRegressionOutput( return PatchTSMixerForRegressionOutput(
loss=loss_val, loss=loss_val,
prediction_outputs=y_hat, # tensor [batch_size x num_targets] regression_outputs=y_hat, # tensor [batch_size x num_targets]
last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model] last_hidden_state=model_output.last_hidden_state, # [batch_size x nvars x num_patch x d_model]
hidden_states=model_output.hidden_states, hidden_states=model_output.hidden_states,
) )
...@@ -2146,7 +2148,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel): ...@@ -2146,7 +2148,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
Args: Args:
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`): past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_input_channels)`):
Past values of the time series that serves as context in order to predict the future. Past values of the time series that serves as context in order to predict the target values.
Return: Return:
[`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size, [`SamplePatchTSMixerRegressionOutput`] where the outputs `sequences` tensor will have shape `(batch_size,
...@@ -2158,17 +2160,18 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel): ...@@ -2158,17 +2160,18 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
# get model output # get model output
outputs = self( outputs = self(
past_values=past_values, past_values=past_values,
future_values=None, target_values=None,
output_hidden_states=False, output_hidden_states=False,
) )
# get distribution # get distribution
distribution = self.distribution_output.distribution(outputs.prediction_outputs) distribution = self.distribution_output.distribution(outputs.regression_outputs)
# get samples # get samples
samples = [ samples = [
distribution.sample() for _ in range(num_parallel_samples) distribution.sample() for _ in range(num_parallel_samples)
] # samples: list of [batch_size x num_targets] ] # samples: list of [batch_size x num_targets]
# stack tensors # stack tensors
samples = torch.stack(samples, dim=1) # [batch_size x num_samples x num_targets] # [batch_size x num_samples x num_targets]
samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
return SamplePatchTSMixerRegressionOutput(sequences=samples) return SamplePatchTSMixerRegressionOutput(sequences=samples)
...@@ -289,7 +289,7 @@ def forecast_masking( ...@@ -289,7 +289,7 @@ def forecast_masking(
Parameters: Parameters:
inputs (`torch.Tensor`): inputs (`torch.Tensor`):
Input of shape `(bs, num_channels, num_patch, patch_len)` Input of shape `(bs, num_channels, num_patch, patch_length)`
num_forecast_mask_patches (`list`): num_forecast_mask_patches (`list`):
Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5]. Number of patches to be masked at the end of each batch sample. e.g. 4 or [3, 5].
unmasked_channel_indices (`list`, *optional*): unmasked_channel_indices (`list`, *optional*):
...@@ -1430,7 +1430,7 @@ class PatchTSTClassificationHead(nn.Module): ...@@ -1430,7 +1430,7 @@ class PatchTSTClassificationHead(nn.Module):
pooled_embedding = embedding.mean(dim=2) pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max": elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model] # pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2) pooled_embedding = embedding.max(dim=2).values
else: else:
raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet") raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
# pooled_embedding: bs x num_channels * d_model # pooled_embedding: bs x num_channels * d_model
...@@ -1602,7 +1602,7 @@ class PatchTSTPredictionHead(nn.Module): ...@@ -1602,7 +1602,7 @@ class PatchTSTPredictionHead(nn.Module):
pooled_embedding = embedding.mean(dim=2) pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max": elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model] # pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2) pooled_embedding = embedding.max(dim=2).values
else: else:
# pooled_embedding: [bs x num_channels x num_patches x d_model] # pooled_embedding: [bs x num_channels x num_patches x d_model]
pooled_embedding = embedding pooled_embedding = embedding
...@@ -1866,7 +1866,7 @@ class PatchTSTRegressionHead(nn.Module): ...@@ -1866,7 +1866,7 @@ class PatchTSTRegressionHead(nn.Module):
pooled_embedding = embedding.mean(dim=2) pooled_embedding = embedding.mean(dim=2)
elif self.pooling_type == "max": elif self.pooling_type == "max":
# pooled_embedding: [bs x num_channels x d_model] # pooled_embedding: [bs x num_channels x d_model]
pooled_embedding = embedding.max(dim=2) pooled_embedding = embedding.max(dim=2).values
else: else:
raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet") raise ValueError(f"pooling operator {self.pooling_type} is not implemented yet")
# flatten the input # flatten the input
...@@ -1899,11 +1899,11 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel): ...@@ -1899,11 +1899,11 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
self.distribution_output = None self.distribution_output = None
else: else:
if config.distribution_output == "student_t": if config.distribution_output == "student_t":
self.distribution_output = StudentTOutput(dim=config.prediction_length * config.num_targets) self.distribution_output = StudentTOutput(dim=config.num_targets)
elif config.distribution_output == "normal": elif config.distribution_output == "normal":
self.distribution_output = NormalOutput(dim=config.prediction_length * config.num_targets) self.distribution_output = NormalOutput(dim=config.num_targets)
elif config.distribution_output == "negative_binomial": elif config.distribution_output == "negative_binomial":
self.distribution_output = NegativeBinomialOutput(dim=config.prediction_length * config.num_targets) self.distribution_output = NegativeBinomialOutput(dim=config.num_targets)
else: else:
raise ValueError(f"Unknown distribution output {config.distribution_output}") raise ValueError(f"Unknown distribution output {config.distribution_output}")
...@@ -1974,6 +1974,8 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel): ...@@ -1974,6 +1974,8 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
if target_values is not None: if target_values is not None:
if self.distribution_output: if self.distribution_output:
distribution = self.distribution_output.distribution(y_hat) distribution = self.distribution_output.distribution(y_hat)
# y_hat should be a 2-tuple, each with dimension [bs, num_targets]
y_hat = tuple([item.view(-1, self.config.num_targets) for item in y_hat])
loss = nll(distribution, target_values) loss = nll(distribution, target_values)
# take average of the loss # take average of the loss
loss = weighted_average(loss) loss = weighted_average(loss)
...@@ -1982,6 +1984,7 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel): ...@@ -1982,6 +1984,7 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
loss = loss(y_hat, target_values) loss = loss(y_hat, target_values)
if not return_dict: if not return_dict:
# hidden_states, attentions, mask
outputs = (y_hat,) + model_output[1:-3] outputs = (y_hat,) + model_output[1:-3]
outputs = (loss,) + outputs if loss is not None else outputs outputs = (loss,) + outputs if loss is not None else outputs
return outputs return outputs
...@@ -2030,5 +2033,5 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel): ...@@ -2030,5 +2033,5 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
# get samples: list of [bs x num_targets] # get samples: list of [bs x num_targets]
samples = [distribution.sample() for _ in range(num_parallel_samples)] samples = [distribution.sample() for _ in range(num_parallel_samples)]
# samples: [bs x num_samples x num_targets] # samples: [bs x num_samples x num_targets]
samples = torch.stack(samples, dim=1) samples = torch.stack(samples, dim=1).view(-1, num_parallel_samples, self.config.num_targets)
return SamplePatchTSTOutput(sequences=samples) return SamplePatchTSTOutput(sequences=samples)
...@@ -191,11 +191,8 @@ class PatchTSMixerModelTester: ...@@ -191,11 +191,8 @@ class PatchTSMixerModelTester:
# [bs x context_length x n_vars] # [bs x context_length x n_vars]
past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels]) past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels])
future_values = floats_tensor([self.batch_size, config.prediction_length, self.num_input_channels])
inputs_dict = { inputs_dict = {
"past_values": past_values, "past_values": past_values,
"future_values": future_values,
} }
return inputs_dict return inputs_dict
...@@ -256,21 +253,25 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test ...@@ -256,21 +253,25 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
# if classification model: if model_class == PatchTSMixerForPrediction:
if model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
rng = random.Random(self.model_tester.seed_number) rng = random.Random(self.model_tester.seed_number)
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng) labels = floats_tensor(
# inputs_dict["labels"] = labels [
self.model_tester.batch_size,
self.model_tester.prediction_length,
self.model_tester.num_input_channels,
],
rng=rng,
)
inputs_dict["future_values"] = labels inputs_dict["future_values"] = labels
# inputs_dict.pop("future_values") elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
rng = random.Random(self.model_tester.seed_number)
labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
inputs_dict["target_values"] = labels
elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING): elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
rng = random.Random(self.model_tester.seed_number) rng = random.Random(self.model_tester.seed_number)
labels = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng) labels = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
# inputs_dict["labels"] = labels inputs_dict["target_values"] = labels
inputs_dict["future_values"] = labels
# inputs_dict.pop("future_values")
elif model_class in [PatchTSMixerModel, PatchTSMixerForPretraining]:
inputs_dict.pop("future_values")
inputs_dict["output_hidden_states"] = True inputs_dict["output_hidden_states"] = True
return inputs_dict return inputs_dict
...@@ -409,28 +410,37 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test ...@@ -409,28 +410,37 @@ class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Test
# signature.parameters is an OrderedDict => so arg_names order is deterministic # signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
expected_arg_names_with_target = [
"past_values",
"observed_mask",
"future_values",
"output_hidden_states",
"return_loss",
]
expected_arg_names_without_target = [
"past_values",
"observed_mask",
"output_hidden_states",
]
expected_arg_names = expected_arg_names_with_target
if model_class == PatchTSMixerForPretraining: if model_class == PatchTSMixerForPretraining:
expected_arg_names = expected_arg_names_without_target + ["return_loss"] expected_arg_names = [
if model_class == PatchTSMixerModel: "past_values",
expected_arg_names = expected_arg_names_without_target "observed_mask",
if model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values( "output_hidden_states",
"return_loss",
]
elif model_class == PatchTSMixerModel:
expected_arg_names = [
"past_values",
"observed_mask",
"output_hidden_states",
]
elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
): ):
expected_arg_names.remove("observed_mask") expected_arg_names = [
"past_values",
"target_values",
"output_hidden_states",
"return_loss",
]
else:
# PatchTSMixerForPrediction
expected_arg_names = [
"past_values",
"observed_mask",
"future_values",
"output_hidden_states",
"return_loss",
]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
...@@ -686,20 +696,27 @@ class PatchTSMixerFunctionalTests(unittest.TestCase): ...@@ -686,20 +696,27 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
else: else:
target_output = target_input target_output = target_input
ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1, -1) ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1, -1)
ground_truth_arg = "future_values"
output_predictions_arg = "prediction_outputs"
elif task == "classification": elif task == "classification":
mdl = PatchTSMixerForTimeSeriesClassification(config) mdl = PatchTSMixerForTimeSeriesClassification(config)
target_input = self.__class__.correct_classification_classes target_input = self.__class__.correct_classification_classes
target_output = self.__class__.correct_classification_output target_output = self.__class__.correct_classification_output
ground_truth_arg = "target_values"
output_predictions_arg = "prediction_outputs"
elif task == "regression": elif task == "regression":
mdl = PatchTSMixerForRegression(config) mdl = PatchTSMixerForRegression(config)
target_input = self.__class__.correct_regression_output target_input = self.__class__.correct_regression_output
target_output = self.__class__.correct_regression_output target_output = self.__class__.correct_regression_output
ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1) ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1)
ground_truth_arg = "target_values"
output_predictions_arg = "regression_outputs"
elif task == "pretrain": elif task == "pretrain":
mdl = PatchTSMixerForPretraining(config) mdl = PatchTSMixerForPretraining(config)
target_input = None target_input = None
target_output = self.__class__.correct_pretrain_output target_output = self.__class__.correct_pretrain_output
ground_truth_arg = None
output_predictions_arg = "prediction_outputs"
else: else:
print("invalid task") print("invalid task")
...@@ -710,15 +727,18 @@ class PatchTSMixerFunctionalTests(unittest.TestCase): ...@@ -710,15 +727,18 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
else: else:
output = mdl( output = mdl(
self.__class__.data, self.__class__.data,
future_values=target_input, **{
output_hidden_states=output_hidden_states, ground_truth_arg: target_input,
"output_hidden_states": output_hidden_states,
},
) )
if isinstance(output.prediction_outputs, tuple): prediction_outputs = getattr(output, output_predictions_arg)
for t in output.prediction_outputs: if isinstance(prediction_outputs, tuple):
for t in prediction_outputs:
self.assertEqual(t.shape, target_output.shape) self.assertEqual(t.shape, target_output.shape)
else: else:
self.assertEqual(output.prediction_outputs.shape, target_output.shape) self.assertEqual(prediction_outputs.shape, target_output.shape)
self.assertEqual(output.last_hidden_state.shape, enc_output.shape) self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
...@@ -980,7 +1000,7 @@ class PatchTSMixerFunctionalTests(unittest.TestCase): ...@@ -980,7 +1000,7 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForTimeSeriesClassification(config) mdl = PatchTSMixerForTimeSeriesClassification(config)
output = mdl( output = mdl(
self.__class__.data, self.__class__.data,
future_values=self.__class__.correct_classification_classes, target_values=self.__class__.correct_classification_classes,
) )
self.assertEqual( self.assertEqual(
output.prediction_outputs.shape, output.prediction_outputs.shape,
...@@ -994,7 +1014,7 @@ class PatchTSMixerFunctionalTests(unittest.TestCase): ...@@ -994,7 +1014,7 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForTimeSeriesClassification(config) mdl = PatchTSMixerForTimeSeriesClassification(config)
output = mdl( output = mdl(
self.__class__.data, self.__class__.data,
future_values=self.__class__.correct_classification_classes, target_values=self.__class__.correct_classification_classes,
return_dict=False, return_dict=False,
) )
if isinstance(output, tuple): if isinstance(output, tuple):
...@@ -1017,9 +1037,9 @@ class PatchTSMixerFunctionalTests(unittest.TestCase): ...@@ -1017,9 +1037,9 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
def test_regression_full(self): def test_regression_full(self):
config = PatchTSMixerConfig(**self.__class__.params) config = PatchTSMixerConfig(**self.__class__.params)
mdl = PatchTSMixerForRegression(config) mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output) output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual( self.assertEqual(
output.prediction_outputs.shape, output.regression_outputs.shape,
self.__class__.correct_regression_output.shape, self.__class__.correct_regression_output.shape,
) )
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape) self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
...@@ -1030,13 +1050,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase): ...@@ -1030,13 +1050,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
mdl = PatchTSMixerForRegression(config) mdl = PatchTSMixerForRegression(config)
output = mdl( output = mdl(
self.__class__.data, self.__class__.data,
future_values=self.__class__.correct_regression_output, target_values=self.__class__.correct_regression_output,
return_dict=False, return_dict=False,
) )
if isinstance(output, tuple): if isinstance(output, tuple):
output = PatchTSMixerForRegressionOutput(*output) output = PatchTSMixerForRegressionOutput(*output)
self.assertEqual( self.assertEqual(
output.prediction_outputs.shape, output.regression_outputs.shape,
self.__class__.correct_regression_output.shape, self.__class__.correct_regression_output.shape,
) )
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape) self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
...@@ -1049,13 +1069,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase): ...@@ -1049,13 +1069,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
config = PatchTSMixerConfig(**params) config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForRegression(config) mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output) output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual( self.assertEqual(
output.prediction_outputs[0].shape, output.regression_outputs[0].shape,
self.__class__.correct_regression_output.shape, self.__class__.correct_regression_output.shape,
) )
self.assertEqual( self.assertEqual(
output.prediction_outputs[1].shape, output.regression_outputs[1].shape,
self.__class__.correct_regression_output.shape, self.__class__.correct_regression_output.shape,
) )
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape) self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
...@@ -1075,13 +1095,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase): ...@@ -1075,13 +1095,13 @@ class PatchTSMixerFunctionalTests(unittest.TestCase):
config = PatchTSMixerConfig(**params) config = PatchTSMixerConfig(**params)
mdl = PatchTSMixerForRegression(config) mdl = PatchTSMixerForRegression(config)
output = mdl(self.__class__.data, future_values=self.__class__.correct_regression_output) output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
self.assertEqual( self.assertEqual(
output.prediction_outputs[0].shape, output.regression_outputs[0].shape,
self.__class__.correct_regression_output.shape, self.__class__.correct_regression_output.shape,
) )
self.assertEqual( self.assertEqual(
output.prediction_outputs[1].shape, output.regression_outputs[1].shape,
self.__class__.correct_regression_output.shape, self.__class__.correct_regression_output.shape,
) )
self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape) self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
......
...@@ -367,19 +367,19 @@ class PatchTSTModelIntegrationTests(unittest.TestCase): ...@@ -367,19 +367,19 @@ class PatchTSTModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE))
def test_regression_generation(self): def test_regression_generation(self):
model = PatchTSTForRegression.from_pretrained("namctin/patchtst_etth1_regression").to(torch_device) model = PatchTSTForRegression.from_pretrained("ibm/patchtst-etth1-regression-distribution").to(torch_device)
batch = prepare_batch(file="test-batch.pt") batch = prepare_batch(repo_id="ibm/patchtst-etth1-test-data", file="regression_distribution_batch.pt")
torch.manual_seed(0) torch.manual_seed(0)
model.eval()
with torch.no_grad(): with torch.no_grad():
outputs = model.generate(past_values=batch["past_values"].to(torch_device)) outputs = model.generate(past_values=batch["past_values"].to(torch_device))
expected_shape = torch.Size((64, model.config.num_parallel_samples, model.config.num_targets)) expected_shape = torch.Size((64, model.config.num_parallel_samples, model.config.num_targets))
self.assertEqual(outputs.sequences.shape, expected_shape) self.assertEqual(outputs.sequences.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[0.3228, 0.4320, 0.4591, 0.4066, -0.3461, 0.3094, -0.8426]], [[-0.08046409], [-0.06570087], [-0.28218266], [-0.20636195], [-0.11787311]],
device=torch_device, device=torch_device,
) )
mean_prediction = outputs.sequences.mean(dim=1) mean_prediction = outputs.sequences.mean(dim=1)
self.assertTrue(torch.allclose(mean_prediction[-5:], expected_slice, rtol=TOLERANCE))
self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, rtol=TOLERANCE))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment