utils.py 1.3 KB
Newer Older
Muyang Li's avatar
Muyang Li committed
1
2
import typing as tp

3
4
import torch

Muyang Li's avatar
Muyang Li committed
5
from ...utils import ceil_divide, load_state_dict_in_safetensors
6
7


Muyang Li's avatar
Muyang Li committed
8
def is_nunchaku_format(lora: str | dict[str, torch.Tensor]) -> bool:
9
10
11
12
13
14
    if isinstance(lora, str):
        tensors = load_state_dict_in_safetensors(lora, device="cpu")
    else:
        tensors = lora

    for k in tensors.keys():
Muyang Li's avatar
Muyang Li committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
        if ".mlp_fc" in k or "mlp_context_fc1" in k:
            return True
    return False


def pad(
    tensor: tp.Optional[torch.Tensor],
    divisor: int | tp.Sequence[int],
    dim: int | tp.Sequence[int],
    fill_value: float | int = 0,
) -> torch.Tensor | None:
    if isinstance(divisor, int):
        if divisor <= 1:
            return tensor
    elif all(d <= 1 for d in divisor):
        return tensor
    if tensor is None:
        return None
    shape = list(tensor.shape)
    if isinstance(dim, int):
        assert isinstance(divisor, int)
        shape[dim] = ceil_divide(shape[dim], divisor) * divisor
    else:
        if isinstance(divisor, int):
            divisor = [divisor] * len(dim)
        for d, div in zip(dim, divisor, strict=True):
            shape[d] = ceil_divide(shape[d], div) * div
    result = torch.full(shape, fill_value, dtype=tensor.dtype, device=tensor.device)
    result[[slice(0, extent) for extent in tensor.shape]] = tensor
    return result