Unverified Commit 95900916 authored by Wesley Gifford's avatar Wesley Gifford Committed by GitHub
Browse files

Fixes for PatchTST Config (#27777)



* Remove config reference and pass num_patches for PatchTSTforPrediction

* ensure return_dict is properly set

---------
Co-authored-by: default avatarWesley M. Gifford <wmgifford@us.ibm.com>
parent cf62539a
......@@ -1546,7 +1546,7 @@ class PatchTSTForClassification(PatchTSTPreTrainedModel):
PATCHTST_START_DOCSTRING,
)
class PatchTSTPredictionHead(nn.Module):
def __init__(self, config: PatchTSTConfig, distribution_output=None):
def __init__(self, config: PatchTSTConfig, num_patches, distribution_output=None):
super().__init__()
self.share_projection = config.share_projection
......@@ -1556,7 +1556,7 @@ class PatchTSTPredictionHead(nn.Module):
if self.pooling_type or self.use_cls_token:
head_dim = config.d_model
else:
head_dim = config.d_model * config.num_patches
head_dim = config.d_model * num_patches
if not self.share_projection:
# if each channel has its own head
......@@ -1662,7 +1662,9 @@ class PatchTSTForPrediction(PatchTSTPreTrainedModel):
else:
raise ValueError(f"Unknown distribution output {config.distribution_output}")
self.head = PatchTSTPredictionHead(config, self.distribution_output)
self.head = PatchTSTPredictionHead(
config, self.model.patchifier.num_patches, distribution_output=self.distribution_output
)
# Initialize weights and apply final processing
self.post_init()
......@@ -1736,6 +1738,7 @@ class PatchTSTForPrediction(PatchTSTPreTrainedModel):
past_observed_mask=past_observed_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=True,
)
# get output head
y_hat = self.head(model_output.last_hidden_state)
......@@ -1962,10 +1965,10 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
past_observed_mask=past_observed_mask,
output_hidden_states=output_hidden_states,
output_attentions=output_attentions,
return_dict=return_dict,
return_dict=True,
)
# get output head. y_hat is of shape [bs x num_targets] or tuple of this shape
y_hat = self.head(model_output[0])
y_hat = self.head(model_output.last_hidden_state)
loss = None
if target_values is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment