Commit 576174f0 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix bug in MSA chunking code

parent de60b410
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -93,13 +93,16 @@ class MSAAttention(nn.Module): ...@@ -93,13 +93,16 @@ class MSAAttention(nn.Module):
use_memory_efficient_kernel: bool, use_memory_efficient_kernel: bool,
chunk_size: int, chunk_size: int,
) -> torch.Tensor: ) -> torch.Tensor:
mha = partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel
)
return chunk_layer( return chunk_layer(
self.mha, mha,
{ {
"q_x": m, "q_x": m,
"kv_x": m, "kv_x": m,
"biases": biases, "biases": biases,
"use_memory_efficient_kernel": use_memory_efficient_kernel,
}, },
chunk_size=chunk_size, chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]) no_batch_dims=len(m.shape[:-2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment