Unverified Commit ebf84f07 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix PyTorch Perceiver `PerceiverFourierPositionEncoding` with fp16 (#21787)

* fix perceiver fp16

* hopefully fix tests
parent 831f3144
......@@ -2201,7 +2201,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
pos_emb = self.output_position_encodings(batch_size)
elif self.position_encoding_type == "fourier":
pos_emb = self.output_position_encodings(
self.output_index_dims, batch_size=batch_size, device=inputs.device, pos=pos
self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos
)
# Optionally project them to a target dimension.
......@@ -2215,7 +2215,9 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
if self.position_encoding_type == "trainable":
pos_emb = self.output_position_encodings(batch_size)
elif self.position_encoding_type == "fourier":
pos_emb = self.output_position_encodings(index_dims, batch_size, device=inputs.device)
pos_emb = self.output_position_encodings(
index_dims, batch_size, device=inputs.device, dtype=inputs.dtype
)
# Optionally project them to a target dimension.
pos_emb = self.positions_projection(pos_emb)
......@@ -2816,7 +2818,12 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
return encoding_size
def forward(
self, index_dims: List[int], batch_size: int, device, pos: torch.FloatTensor = None
self,
index_dims: List[int],
batch_size: int,
device: torch.device,
dtype: torch.dtype,
pos: torch.FloatTensor = None,
) -> torch.FloatTensor:
pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
fourier_pos_enc = generate_fourier_features(
......@@ -2825,7 +2832,7 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
max_resolution=self.max_resolution,
concat_pos=self.concat_pos,
sine_only=self.sine_only,
).to(device)
).to(device=device, dtype=dtype)
return fourier_pos_enc
......@@ -3156,7 +3163,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
if self.position_encoding_type == "trainable":
pos_enc = self.position_embeddings(batch_size)
elif self.position_encoding_type == "fourier":
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device)
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
# Optionally project them to a target dimension.
pos_enc = self.positions_projection(pos_enc)
......@@ -3324,7 +3331,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
if self.position_encoding_type == "trainable":
pos_enc = self.position_embeddings(batch_size)
elif self.position_encoding_type == "fourier":
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device)
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
# Optionally project them to a target dimension.
pos_enc = self.positions_projection(pos_enc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment