Commit ceef010a authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add offloading sanity checks

parent 4b410596
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
import sys
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Sequence, Optional from typing import Tuple, Sequence, Optional
...@@ -208,6 +208,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -208,6 +208,7 @@ class EvoformerBlockCore(nn.Module):
if(_offload_inference and inplace_safe): if(_offload_inference and inplace_safe):
del m, z del m, z
assert(sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu() input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache() torch.cuda.empty_cache()
m, z = input_tensors m, z = input_tensors
...@@ -218,6 +219,7 @@ class EvoformerBlockCore(nn.Module): ...@@ -218,6 +219,7 @@ class EvoformerBlockCore(nn.Module):
if(_offload_inference and inplace_safe): if(_offload_inference and inplace_safe):
del m, z del m, z
assert(sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].cpu() input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device) input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors m, z = input_tensors
...@@ -300,6 +302,8 @@ class EvoformerBlockCore(nn.Module): ...@@ -300,6 +302,8 @@ class EvoformerBlockCore(nn.Module):
if(_offload_inference and inplace_safe): if(_offload_inference and inplace_safe):
device = z.device device = z.device
del m, z del m, z
assert(sys.getrefcount(input_tensors[0]) == 2)
assert(sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device) input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device) input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors m, z = input_tensors
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from functools import reduce from functools import reduce
import importlib import importlib
import math import math
import sys
from operator import mul from operator import mul
import torch import torch
...@@ -307,6 +308,7 @@ class InvariantPointAttention(nn.Module): ...@@ -307,6 +308,7 @@ class InvariantPointAttention(nn.Module):
b = self.linear_b(z[0]) b = self.linear_b(z[0])
if(_offload_inference): if(_offload_inference):
assert(sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu() z[0] = z[0].cpu()
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
...@@ -651,6 +653,7 @@ class StructureModule(nn.Module): ...@@ -651,6 +653,7 @@ class StructureModule(nn.Module):
z_reference_list = None z_reference_list = None
if(_offload_inference): if(_offload_inference):
assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu() evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z] z_reference_list = [z]
z = None z = None
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from functools import partial from functools import partial
import math import math
import sys
from typing import Optional, List from typing import Optional, List
import torch import torch
...@@ -471,6 +472,8 @@ def embed_templates_offload( ...@@ -471,6 +472,8 @@ def embed_templates_offload(
_mask_trans=model.config._mask_trans, _mask_trans=model.config._mask_trans,
) )
assert(sys.getrefcount(t) == 2)
pair_embeds_cpu.append(t.cpu()) pair_embeds_cpu.append(t.cpu())
del t del t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment