t5_encoder.py 2.3 KB
Newer Older
1
2
import json
import logging
3
import os
4
from pathlib import Path
5
6

import torch
7
from accelerate import init_empty_weights
8
from torch import nn
9
from transformers import T5Config, T5EncoderModel
10

11
from ...utils import load_state_dict_in_safetensors
muyangli's avatar
muyangli committed
12
13
from .linear import W4Linear

14
15
16
17
18
19
20
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()

# Configure logging
logging.basicConfig(level=getattr(logging, log_level, logging.INFO), format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

21
22
23

class NunchakuT5EncoderModel(T5EncoderModel):
    @classmethod
24
25
    def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
        pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
26
        state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
27
28

        # Load the config file
29
30
31
        config = json.loads(metadata["config"])
        config = T5Config(**config)

32
33
34
35
36
37
        # Initialize model on 'meta' device (no memory allocation for weights)
        with init_empty_weights():
            t5_encoder = T5EncoderModel(config).to(kwargs.get("torch_dtype", torch.bfloat16))

        t5_encoder.eval()

38
        # Load the model weights from the safetensors file
39
40
41
42
43
        named_modules = {}
        for name, module in t5_encoder.named_modules():
            assert isinstance(name, str)
            if isinstance(module, nn.Linear):
                if f"{name}.qweight" in state_dict:
44
                    logger.debug(f"Switching {name} to W4Linear")
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
                    qmodule = W4Linear.from_linear(module, group_size=128, init_only=True)
                    # modeling_t5.py: T5DenseGatedActDense needs dtype of weight
                    qmodule.weight = torch.empty([1], dtype=module.weight.dtype, device=module.weight.device)

                    parent_name, child_name = name.rsplit(".", 1)
                    setattr(named_modules[parent_name], child_name, qmodule)
            else:
                named_modules[name] = module

        device = kwargs.get("device", "cuda")
        if isinstance(device, str):
            device = torch.device(device)
        t5_encoder.to_empty(device=device)
        t5_encoder.load_state_dict(state_dict, strict=True)

60
        return t5_encoder