Commit ceef010a authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add offloading sanity checks

parent 4b410596
......@@ -12,8 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
import torch
import torch.nn as nn
from typing import Tuple, Sequence, Optional
......@@ -208,6 +208,7 @@ class EvoformerBlockCore(nn.Module):
if(_offload_inference and inplace_safe):
del m, z
assert(sys.getrefcount(input_tensors[1]) == 2)
input_tensors[1] = input_tensors[1].cpu()
torch.cuda.empty_cache()
m, z = input_tensors
......@@ -218,6 +219,7 @@ class EvoformerBlockCore(nn.Module):
if(_offload_inference and inplace_safe):
del m, z
assert(sys.getrefcount(input_tensors[0]) == 2)
input_tensors[0] = input_tensors[0].cpu()
input_tensors[1] = input_tensors[1].to(opm.device)
m, z = input_tensors
......@@ -300,6 +302,8 @@ class EvoformerBlockCore(nn.Module):
if(_offload_inference and inplace_safe):
device = z.device
del m, z
assert(sys.getrefcount(input_tensors[0]) == 2)
assert(sys.getrefcount(input_tensors[1]) == 2)
input_tensors[0] = input_tensors[0].to(device)
input_tensors[1] = input_tensors[1].to(device)
m, z = input_tensors
......
......@@ -15,6 +15,7 @@
from functools import reduce
import importlib
import math
import sys
from operator import mul
import torch
......@@ -307,6 +308,7 @@ class InvariantPointAttention(nn.Module):
b = self.linear_b(z[0])
if(_offload_inference):
assert(sys.getrefcount(z[0]) == 2)
z[0] = z[0].cpu()
# [*, H, N_res, N_res]
......@@ -651,6 +653,7 @@ class StructureModule(nn.Module):
z_reference_list = None
if(_offload_inference):
assert(sys.getrefcount(evoformer_output_dict["pair"]) == 2)
evoformer_output_dict["pair"] = evoformer_output_dict["pair"].cpu()
z_reference_list = [z]
z = None
......
......@@ -14,6 +14,7 @@
# limitations under the License.
from functools import partial
import math
import sys
from typing import Optional, List
import torch
......@@ -471,6 +472,8 @@ def embed_templates_offload(
_mask_trans=model.config._mask_trans,
)
assert(sys.getrefcount(t) == 2)
pair_embeds_cpu.append(t.cpu())
del t
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment