"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "72d587d4c603da1d15f4f53d5d861f1a95bc67ad"
Commit f63f2f6e authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Move TorchScript utils

parent d3acabd1
......@@ -3,6 +3,11 @@ from typing import Optional, Sequence
import torch
import torch.nn as nn
from openfold.model.primitives import Attention, GlobalAttention
def script_primitives_(model):
script_submodules_(model, [Attention, GlobalAttention])
def script_submodules_(
model: nn.Module,
......@@ -13,12 +18,14 @@ def script_submodules_(
list to recursively scripted equivalents in place. To script the entire
model, just call torch.jit.script on it directly.
When types is None, all submodules are scripted.
Args:
model: A torch.nn.Module
types: A list of types of submodules to script
"""
for name, child in model.named_children():
if(any(isinstance(child, t) for t in types)):
if(types is None or any(isinstance(child, t) for t in types)):
setattr(model, name, torch.jit.script(child))
else:
script_submodules_(child, types)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment