Unverified Commit 9e0fd780 authored by Gunjan Chhablani's avatar Gunjan Chhablani Committed by GitHub
Browse files

Fix reference to tpu short seq length (#13686)

parent 6dc41d9f
...@@ -174,7 +174,7 @@ class FNetBasicFourierTransform(nn.Module): ...@@ -174,7 +174,7 @@ class FNetBasicFourierTransform(nn.Module):
"dft_mat_hidden", torch.tensor(linalg.dft(config.hidden_size), dtype=torch.complex64) "dft_mat_hidden", torch.tensor(linalg.dft(config.hidden_size), dtype=torch.complex64)
) )
self.register_buffer( self.register_buffer(
"dft_mat_seq", torch.tensor(linalg.dft(config.tpu_short_sequence_length), dtype=torch.complex64) "dft_mat_seq", torch.tensor(linalg.dft(config.tpu_short_seq_length), dtype=torch.complex64)
) )
self.fourier_transform = partial( self.fourier_transform = partial(
two_dim_matmul, matrix_dim_one=self.dft_mat_seq, matrix_dim_two=self.dft_mat_hidden two_dim_matmul, matrix_dim_one=self.dft_mat_seq, matrix_dim_two=self.dft_mat_hidden
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment