Unverified Commit 8d2f953f authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Time series Informer] fix dtype of cumsum (#25431)

* fix dtype of cumsum

* add comment
parent bc3e20dc
...@@ -647,7 +647,8 @@ class InformerProbSparseAttention(nn.Module): ...@@ -647,7 +647,8 @@ class InformerProbSparseAttention(nn.Module):
# calculate context for updating the attn_output, based on: # calculate context for updating the attn_output, based on:
# https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74 # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74
if self.is_decoder: if self.is_decoder:
context = value_states.cumsum(dim=-2) # cast to float32 before operation to avoid overflow
context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype)
else: else:
v_mean_dim_time = value_states.mean(dim=-2) v_mean_dim_time = value_states.mean(dim=-2)
context = ( context = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment