Unverified Commit 9eae4aa5 authored by Eli Simhayev's avatar Eli Simhayev Committed by GitHub
Browse files

[Time-Series] fix past_observed_mask type (#22076)

added > 0.5 to `past_observed_mask`
parent 559a45d1
...@@ -117,7 +117,7 @@ class InformerModelTester: ...@@ -117,7 +117,7 @@ class InformerModelTester:
past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features]) past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
past_values = floats_tensor([self.batch_size, _past_length]) past_values = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length]) past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5
# decoder inputs # decoder inputs
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features]) future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])
......
...@@ -114,7 +114,7 @@ class TimeSeriesTransformerModelTester: ...@@ -114,7 +114,7 @@ class TimeSeriesTransformerModelTester:
past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features]) past_time_features = floats_tensor([self.batch_size, _past_length, config.num_time_features])
past_values = floats_tensor([self.batch_size, _past_length]) past_values = floats_tensor([self.batch_size, _past_length])
past_observed_mask = floats_tensor([self.batch_size, _past_length]) past_observed_mask = floats_tensor([self.batch_size, _past_length]) > 0.5
# decoder inputs # decoder inputs
future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features]) future_time_features = floats_tensor([self.batch_size, config.prediction_length, config.num_time_features])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment