"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "95792a948e68b8dc89a68bb9cc5bb7fc0a8a3e9c"
Unverified Commit 1ae132a0 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Reformer] Axial Pos Emb Improve mem usage reformer (#5209)

* improve mem handling

* improve mem for pos ax encodings
parent 51441040
...@@ -154,9 +154,14 @@ class AxialPositionEmbeddings(nn.Module): ...@@ -154,9 +154,14 @@ class AxialPositionEmbeddings(nn.Module):
self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length, self.axial_pos_shape, sequence_length, self.least_common_mult_chunk_length,
) )
# reshape axial encodings and use only until sequence_length # compute how many columns are needed
position_encodings = torch.cat(broadcasted_weights, dim=-1) required_pos_encodings_columns = -(-sequence_length // self.axial_pos_shape[1])
position_encodings = position_encodings.view(batch_size, -1, position_encodings.shape[-1])[
# cut to columns that are needed
position_encodings = torch.cat(
[weight[:, :required_pos_encodings_columns] for weight in broadcasted_weights], dim=-1
)
position_encodings = torch.reshape(position_encodings, (batch_size, -1, position_encodings.shape[-1]))[
:, :sequence_length :, :sequence_length
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment