wrapper.py 533 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
from transformers.models.opt.modeling_opt import OPTAttention

from .opt_attn import XOPTAttention


def convert_to_xformer_model(model: nn.Module) -> nn.Module:
    for module in model.modules():
        if isinstance(module, OPTAttention):
            module.__class__ = XOPTAttention
    return model


def recover_from_xformer_model(model: nn.Module) -> nn.Module:
    for module in model.modules():
        if isinstance(module, XOPTAttention):
            module.__class__ = OPTAttention
    return model