Unverified Commit 95900916 authored by Wesley Gifford's avatar Wesley Gifford Committed by GitHub
Browse files

Fixes for PatchTST Config (#27777)



* Remove config reference and pass num_patches for PatchTSTforPrediction

* ensure return_dict is properly set

---------
Co-authored-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>
parent cf62539a
...@@ -1546,7 +1546,7 @@ class PatchTSTForClassification(PatchTSTPreTrainedModel): ...@@ -1546,7 +1546,7 @@ class PatchTSTForClassification(PatchTSTPreTrainedModel):
PATCHTST_START_DOCSTRING, PATCHTST_START_DOCSTRING,
) )
class PatchTSTPredictionHead(nn.Module): class PatchTSTPredictionHead(nn.Module):
def __init__(self, config: PatchTSTConfig, distribution_output=None): def __init__(self, config: PatchTSTConfig, num_patches, distribution_output=None):
super().__init__() super().__init__()
self.share_projection = config.share_projection self.share_projection = config.share_projection
...@@ -1556,7 +1556,7 @@ class PatchTSTPredictionHead(nn.Module): ...@@ -1556,7 +1556,7 @@ class PatchTSTPredictionHead(nn.Module):
if self.pooling_type or self.use_cls_token: if self.pooling_type or self.use_cls_token:
head_dim = config.d_model head_dim = config.d_model
else: else:
head_dim = config.d_model * config.num_patches head_dim = config.d_model * num_patches
if not self.share_projection: if not self.share_projection:
# if each channel has its own head # if each channel has its own head
...@@ -1662,7 +1662,9 @@ class PatchTSTForPrediction(PatchTSTPreTrainedModel): ...@@ -1662,7 +1662,9 @@ class PatchTSTForPrediction(PatchTSTPreTrainedModel):
else: else:
raise ValueError(f"Unknown distribution output {config.distribution_output}") raise ValueError(f"Unknown distribution output {config.distribution_output}")
self.head = PatchTSTPredictionHead(config, self.distribution_output) self.head = PatchTSTPredictionHead(
config, self.model.patchifier.num_patches, distribution_output=self.distribution_output
)
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -1736,6 +1738,7 @@ class PatchTSTForPrediction(PatchTSTPreTrainedModel): ...@@ -1736,6 +1738,7 @@ class PatchTSTForPrediction(PatchTSTPreTrainedModel):
past_observed_mask=past_observed_mask, past_observed_mask=past_observed_mask,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
return_dict=True,
) )
# get output head # get output head
y_hat = self.head(model_output.last_hidden_state) y_hat = self.head(model_output.last_hidden_state)
...@@ -1962,10 +1965,10 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel): ...@@ -1962,10 +1965,10 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
past_observed_mask=past_observed_mask, past_observed_mask=past_observed_mask,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_attentions=output_attentions, output_attentions=output_attentions,
return_dict=return_dict, return_dict=True,
) )
# get output head. y_hat is of shape [bs x num_targets] or tuple of this shape # get output head. y_hat is of shape [bs x num_targets] or tuple of this shape
y_hat = self.head(model_output[0]) y_hat = self.head(model_output.last_hidden_state)
loss = None loss = None
if target_values is not None: if target_values is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment