Commit 576174f0 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix bug in MSA chunking code

parent de60b410
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
import torch
import torch.nn as nn
......@@ -93,13 +93,16 @@ class MSAAttention(nn.Module):
use_memory_efficient_kernel: bool,
chunk_size: int,
) -> torch.Tensor:
mha = partial(
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel
)
return chunk_layer(
self.mha,
mha,
{
"q_x": m,
"kv_x": m,
"biases": biases,
"use_memory_efficient_kernel": use_memory_efficient_kernel,
},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment