Unverified Commit ebf84f07 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Fix PyTorch Perceiver `PerceiverFourierPositionEncoding` with fp16 (#21787)

* fix perceiver fp16

* hopefully fix tests
parent 831f3144
...@@ -2201,7 +2201,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): ...@@ -2201,7 +2201,7 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
pos_emb = self.output_position_encodings(batch_size) pos_emb = self.output_position_encodings(batch_size)
elif self.position_encoding_type == "fourier": elif self.position_encoding_type == "fourier":
pos_emb = self.output_position_encodings( pos_emb = self.output_position_encodings(
self.output_index_dims, batch_size=batch_size, device=inputs.device, pos=pos self.output_index_dims, batch_size=batch_size, device=inputs.device, dtype=inputs.dtype, pos=pos
) )
# Optionally project them to a target dimension. # Optionally project them to a target dimension.
...@@ -2215,7 +2215,9 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder): ...@@ -2215,7 +2215,9 @@ class PerceiverBasicDecoder(PerceiverAbstractDecoder):
if self.position_encoding_type == "trainable": if self.position_encoding_type == "trainable":
pos_emb = self.output_position_encodings(batch_size) pos_emb = self.output_position_encodings(batch_size)
elif self.position_encoding_type == "fourier": elif self.position_encoding_type == "fourier":
pos_emb = self.output_position_encodings(index_dims, batch_size, device=inputs.device) pos_emb = self.output_position_encodings(
index_dims, batch_size, device=inputs.device, dtype=inputs.dtype
)
# Optionally project them to a target dimension. # Optionally project them to a target dimension.
pos_emb = self.positions_projection(pos_emb) pos_emb = self.positions_projection(pos_emb)
...@@ -2816,7 +2818,12 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding): ...@@ -2816,7 +2818,12 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
return encoding_size return encoding_size
def forward( def forward(
self, index_dims: List[int], batch_size: int, device, pos: torch.FloatTensor = None self,
index_dims: List[int],
batch_size: int,
device: torch.device,
dtype: torch.dtype,
pos: torch.FloatTensor = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
pos = _check_or_build_spatial_positions(pos, index_dims, batch_size) pos = _check_or_build_spatial_positions(pos, index_dims, batch_size)
fourier_pos_enc = generate_fourier_features( fourier_pos_enc = generate_fourier_features(
...@@ -2825,7 +2832,7 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding): ...@@ -2825,7 +2832,7 @@ class PerceiverFourierPositionEncoding(PerceiverAbstractPositionEncoding):
max_resolution=self.max_resolution, max_resolution=self.max_resolution,
concat_pos=self.concat_pos, concat_pos=self.concat_pos,
sine_only=self.sine_only, sine_only=self.sine_only,
).to(device) ).to(device=device, dtype=dtype)
return fourier_pos_enc return fourier_pos_enc
...@@ -3156,7 +3163,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor): ...@@ -3156,7 +3163,7 @@ class PerceiverImagePreprocessor(AbstractPreprocessor):
if self.position_encoding_type == "trainable": if self.position_encoding_type == "trainable":
pos_enc = self.position_embeddings(batch_size) pos_enc = self.position_embeddings(batch_size)
elif self.position_encoding_type == "fourier": elif self.position_encoding_type == "fourier":
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device) pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
# Optionally project them to a target dimension. # Optionally project them to a target dimension.
pos_enc = self.positions_projection(pos_enc) pos_enc = self.positions_projection(pos_enc)
...@@ -3324,7 +3331,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor): ...@@ -3324,7 +3331,7 @@ class PerceiverAudioPreprocessor(AbstractPreprocessor):
if self.position_encoding_type == "trainable": if self.position_encoding_type == "trainable":
pos_enc = self.position_embeddings(batch_size) pos_enc = self.position_embeddings(batch_size)
elif self.position_encoding_type == "fourier": elif self.position_encoding_type == "fourier":
pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device) pos_enc = self.position_embeddings(index_dims, batch_size, device=inputs.device, dtype=inputs.dtype)
# Optionally project them to a target dimension. # Optionally project them to a target dimension.
pos_enc = self.positions_projection(pos_enc) pos_enc = self.positions_projection(pos_enc)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment