"lib/bindings/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "c70de37fcb50559ecd051df140db38250077c245"
Commit f63f2f6e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Move TorchScript utils

parent d3acabd1
...@@ -3,6 +3,11 @@ from typing import Optional, Sequence ...@@ -3,6 +3,11 @@ from typing import Optional, Sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Attention, GlobalAttention
def script_primitives_(model):
script_submodules_(model, [Attention, GlobalAttention])
def script_submodules_( def script_submodules_(
model: nn.Module, model: nn.Module,
...@@ -13,12 +18,14 @@ def script_submodules_( ...@@ -13,12 +18,14 @@ def script_submodules_(
list to recursively scripted equivalents in place. To script the entire list to recursively scripted equivalents in place. To script the entire
model, just call torch.jit.script on it directly. model, just call torch.jit.script on it directly.
When types is None, all submodules are scripted.
Args: Args:
model: A torch.nn.Module model: A torch.nn.Module
types: A list of types of submodules to script types: A list of types of submodules to script
""" """
for name, child in model.named_children(): for name, child in model.named_children():
if(any(isinstance(child, t) for t in types)): if(types is None or any(isinstance(child, t) for t in types)):
setattr(model, name, torch.jit.script(child)) setattr(model, name, torch.jit.script(child))
else: else:
script_submodules_(child, types) script_submodules_(child, types)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment