Unverified Commit 1892a3db authored by yihuiwen's avatar yihuiwen Committed by GitHub
Browse files

support gguf (#510)


Co-authored-by: default avataryihuiwen <yihuiwen@sensetime.com>
parent 58c9abb9
{
"infer_steps": 4,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"sample_guide_scale": 5,
"sample_shift": 5,
"enable_cfg": true,
"cpu_offload": true,
"offload_granularity": "model",
"dit_quantized": true,
"dit_quant_scheme": "gguf-Q4_K_S"
}
......@@ -4,6 +4,8 @@ from abc import ABCMeta, abstractmethod
import torch
from lightx2v.utils.envs import *
from lightx2v.utils.ggml_tensor import GGMLTensor
from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tensor
from lightx2v.utils.global_paras import CALIB
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
......@@ -969,21 +971,159 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
return output_tensor
class MMWeightGGUFTemplate(MMWeightQuantTemplate):
TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)
class MMWeightGGUFTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
def dequantize_func(self):
# TODO: implement dequantize_func
pass
def load(self, weight_dict):
assert not self.create_cuda_buffer, "GGUF Unsupported offload block"
self.weight = weight_dict[self.weight_name]
weight_shape = self.weight.shape
weight_dtype = self.weight.dtype
@MM_WEIGHT_REGISTER("W-gguf-Q4_K")
class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, lazy_load, lazy_load_file, is_post_adapter)
if isinstance(self.weight, GGMLTensor):
self.pin_weight = GGMLTensor.empty_pinned(weight_shape, orig_shape=self.weight.orig_shape, dtype=weight_dtype, gguf_type=self.weight.gguf_type)
self.pin_weight.copy_from(self.weight)
else:
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
if isinstance(self.bias, GGMLTensor):
self.pin_bias = GGMLTensor.empty_pinned(self.bias.shape, orig_shape=self.bias.orig_shape, dtype=self.bias.dtype, gguf_type=self.bias.gguf_type)
self.pin_bias.copy_from(self.bias)
else:
self.pin_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else:
self.bias = None
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
return destination
def get_weight(self, tensor, dtype):
if tensor is None:
return
device = tensor.device
weight = gguf_dequantize_tensor(tensor, dtype)
# prevent propagating custom tensor class
if isinstance(weight, GGMLTensor):
weight = torch.Tensor(weight)
return weight
def cast_bias_weight(self, input_tensor=None, dtype=None, device=None, bias_dtype=None):
if input_tensor is not None:
if dtype is None:
dtype = getattr(input_tensor, "dtype", torch.float32)
bias = None
if self.bias is not None:
bias = self.get_weight(self.bias, dtype)
weight = self.get_weight(self.weight, dtype)
return weight, bias
def apply(self, input_tensor):
weight, bias = self.cast_bias_weight(input_tensor)
return torch.nn.functional.linear(input_tensor, weight, bias)
@MM_WEIGHT_REGISTER("gguf-BF16")
class MMWeightGGUFBF16(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.BF16
@MM_WEIGHT_REGISTER("gguf-Q8_0")
class MMWeightGGUFQ80(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q8_0
@MM_WEIGHT_REGISTER("gguf-Q6_K")
class MMWeightGGUFQ6K(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_K_S")
class MMWeightGGUFQ5KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_K_M")
class MMWeightGGUFQ5KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_1")
class MMWeightGGUFQ51(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_1
@MM_WEIGHT_REGISTER("gguf-Q5_0")
class MMWeightGGUFQ50(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_0
@MM_WEIGHT_REGISTER("gguf-Q4_K_M")
class MMWeightGGUFQ4KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_0
@MM_WEIGHT_REGISTER("gguf-Q4_K_S")
class MMWeightGGUFQ4KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_K
@MM_WEIGHT_REGISTER("gguf-Q4_1")
class MMWeightGGUFQ41(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_1
@MM_WEIGHT_REGISTER("gguf-Q4_0")
class MMWeightGGUFQ40(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_0
@MM_WEIGHT_REGISTER("gguf-Q3_K_M")
class MMWeightGGUFQ3KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q3_K
@MM_WEIGHT_REGISTER("gguf-Q3_K_S")
class MMWeightGGUFQ3KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q2_K
@MM_WEIGHT_REGISTER("int4-g128-marlin")
......
......@@ -32,13 +32,9 @@ from lightx2v.models.networks.wan.weights.transformer_weights import (
)
from lightx2v.utils.custom_compiler import CompiledMethodsMixin, compiled_method
from lightx2v.utils.envs import *
from lightx2v.utils.ggml_tensor import load_gguf_sd_ckpt
from lightx2v.utils.utils import *
try:
import gguf
except ImportError:
gguf = None
class WanModel(CompiledMethodsMixin):
pre_weight_class = WanPreWeights
......@@ -76,6 +72,18 @@ class WanModel(CompiledMethodsMixin):
"mxfp6-mxfp8",
"mxfp8",
"int8-tmo",
"gguf-Q8_0",
"gguf-Q6_K",
"gguf-Q5_K_S",
"gguf-Q5_K_M",
"gguf-Q5_0",
"gguf-Q5_1",
"gguf-Q4_K_S",
"gguf-Q4_K_M",
"gguf-Q4_0",
"gguf-Q4_1",
"gguf-Q3_K_S",
"gguf-Q3_K_M",
]
self.device = device
self._init_infer_class()
......@@ -181,6 +189,17 @@ class WanModel(CompiledMethodsMixin):
else:
safetensors_path = self.model_path
if "gguf" in self.config.get("dit_quant_scheme", ""):
gguf_path = ""
if os.path.isdir(safetensors_path):
gguf_type = self.config.get("dit_quant_scheme").replace("gguf-", "")
gguf_files = list(filter(lambda x: gguf_type in x, glob.glob(os.path.join(safetensors_path, "*.gguf"))))
gguf_path = gguf_files[0]
else:
gguf_path = safetensors_path
weight_dict = self._load_gguf_ckpt(gguf_path)
return weight_dict
if os.path.isdir(safetensors_path):
safetensors_files = glob.glob(os.path.join(safetensors_path, "*.safetensors"))
else:
......@@ -192,6 +211,7 @@ class WanModel(CompiledMethodsMixin):
if self.config.get("adapter_model_path", None) is not None:
if self.config["adapter_model_path"] == safetensor_path:
continue
with safe_open(safetensor_path, framework="pt") as f:
logger.info(f"Loading weights from {safetensor_path}")
for k in f.keys():
......@@ -240,13 +260,9 @@ class WanModel(CompiledMethodsMixin):
return pre_post_weight_dict
def _load_gguf_ckpt(self):
gguf_path = self.dit_quantized_ckpt
logger.info(f"Loading gguf-quant dit model from {gguf_path}")
reader = gguf.GGUFReader(gguf_path)
for tensor in reader.tensors:
# TODO: implement _load_gguf_ckpt
pass
def _load_gguf_ckpt(self, gguf_path):
state_dict = load_gguf_sd_ckpt(gguf_path, to_device=self.device)
return state_dict
def _init_weights(self, weight_dict=None):
unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
......
from typing import Optional, Tuple, Union
import gguf
import numpy as np
import torch
from loguru import logger
TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16, gguf.GGMLQuantizationType.BF16)
class GGMLTensor:
def __init__(
self,
data: Union[torch.Tensor, np.ndarray, None] = None,
orig_shape: Tuple[int, ...] = None,
dtype: torch.dtype = None,
gguf_type: gguf.GGMLQuantizationType = None,
requires_grad: bool = False,
aligned: bool = True,
pin_memory: bool = False,
preallocated: bool = False,
):
super().__init__()
assert orig_shape is not None
assert gguf_type is not None
if isinstance(data, np.ndarray):
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
torch_data = torch.from_numpy(data)
else:
torch_data = data
if dtype is not None and torch_data.dtype != dtype:
torch_data = torch_data.to(dtype)
self.data = torch_data
self.gguf_type = gguf_type
self._orig_shape = orig_shape
self._aligned = aligned
self._pinned_memory = pin_memory
self._requires_grad = requires_grad
self._preallocated = preallocated
self._quantized = self._is_quantized_type(gguf_type)
self._q_type = self._get_quant_type_str(gguf_type)
if aligned:
self._make_aligned()
if pin_memory:
self._pin_memory()
def _is_quantized_type(self, gguf_type: gguf.GGMLQuantizationType) -> bool:
return gguf_type not in TORCH_COMPATIBLE_QTYPES
def _get_quant_type_str(self, gguf_type: gguf.GGMLQuantizationType) -> str:
type_mapping = {
gguf.GGMLQuantizationType.F32: "ggml_f32",
gguf.GGMLQuantizationType.F16: "ggml_f16",
gguf.GGMLQuantizationType.Q4_0: "ggml_q4_0",
gguf.GGMLQuantizationType.Q4_1: "ggml_q4_1",
gguf.GGMLQuantizationType.Q5_0: "ggml_q5_0",
gguf.GGMLQuantizationType.Q5_1: "ggml_q5_1",
gguf.GGMLQuantizationType.Q8_0: "ggml_q8_0",
gguf.GGMLQuantizationType.Q8_1: "ggml_q8_1",
gguf.GGMLQuantizationType.Q2_K: "ggml_q2_k",
gguf.GGMLQuantizationType.Q3_K: "ggml_q3_k",
gguf.GGMLQuantizationType.Q4_K: "ggml_q4_k",
gguf.GGMLQuantizationType.Q5_K: "ggml_q5_k",
gguf.GGMLQuantizationType.Q6_K: "ggml_q6_k",
gguf.GGMLQuantizationType.Q8_K: "ggml_q8_k",
}
return type_mapping.get(gguf_type, "unknown")
@classmethod
def empty_pinned(
cls, shape: Tuple[int, ...], orig_shape: Tuple[int, ...] = None, dtype: torch.dtype = torch.float32, gguf_type: gguf.GGMLQuantizationType = None, aligned: bool = True
) -> "GGMLTensor":
torch_data = torch.empty(shape, pin_memory=True, dtype=dtype)
return cls(data=torch_data, dtype=dtype, orig_shape=orig_shape, gguf_type=gguf_type, pin_memory=True, aligned=aligned, preallocated=True)
@classmethod
def empty_aligned(
cls, shape: Tuple[int, ...], orig_shape: Tuple[int, ...] = None, dtype: torch.dtype = torch.float32, gguf_type: gguf.GGMLQuantizationType = None, pin_memory: bool = False
) -> "GGMLTensor":
return cls(dtype=dtype, orig_shape=orig_shape, gguf_type=gguf_type, pin_memory=pin_memory, aligned=True, preallocated=True)
def copy_from(self, source: Union[torch.Tensor, "GGMLTensor"], transpose: bool = False, non_blocking: bool = False) -> "GGMLTensor":
if not self._preallocated:
raise RuntimeError("copy_from can only be used with preallocated tensors")
if transpose:
source_data = source.data.t().contiguous()
else:
source_data = source.data.contiguous()
if self.shape != source_data.shape:
raise ValueError(f"Shape mismatch: target {self.shape} vs source {source_data.shape}")
self.data.copy_(source_data)
return self
def copy_(self, target: Union[torch.Tensor, "GGMLTensor"], transpose: bool = False, non_blocking: bool = False) -> "GGMLTensor":
source_data = self.data
if transpose:
source_data = self.t().contiguous()
if isinstance(target, GGMLTensor):
target.copy_from(source_data, non_blocking=non_blocking)
else:
target.copy_(source_data)
return self
def t(self):
self.data = self.data.t()
return self
def _make_aligned(self, alignment: int = 32):
if not self.data.is_contiguous():
self.data = self.data.contiguous().data
ptr = self.data.data_ptr()
if ptr % alignment == 0:
return
if self._pinned_memory:
aligned_data = torch.empty(self.data.shape, dtype=self.data.dtype, device=self.data.device, pin_memory=True)
else:
aligned_data = torch.empty(self.data.shape, dtype=self.data.dtype, device=self.data.device)
aligned_data.copy_(self.data)
self.data = aligned_data.data
def _pin_memory(self) -> "GGMLTensor":
if self._pinned_memory or self.device.type != "cpu":
return self
pinned_data = self.data.pin_memory()
self.data = pinned_data.data
self._pinned_memory = True
return self
def to_torch(self) -> torch.Tensor:
return torch.as_tensor(self.data)
@property
def shape(self):
return self.data.shape
@property
def dtype(self):
return self.data.dtype
@property
def device(self):
return self.data.device
@property
def tensor_type(self) -> gguf.GGMLQuantizationType:
return self.gguf_type
@property
def quant_type(self) -> str:
return self._q_type
@property
def is_quantized(self) -> bool:
return self._quantized
@property
def orig_shape(self) -> Tuple[int, ...]:
return self._orig_shape
@property
def blocksize(self) -> Optional[int]:
_blocksize, _ = gguf.GGML_QUANT_SIZES[self.qtype]
return _blocksize
@property
def is_pinned(self) -> bool:
return self._pinned_memory
def memory_footprint(self) -> int:
if self._quantized:
return self.data.numel() * self.element_size()
else:
return self.data.numel() * self.element_size()
def __repr__(self) -> str:
return f"GGMLTensor(shape={self.data.shape}, orig_shape={self.orig_shape}, dtype={self.data.dtype}, quantized={self.is_quantized}, quant_type='{self.quant_type}', pinned={self.is_pinned})"
def cuda(self, device: Optional[Union[int, torch.device]] = None, non_blocking: bool = False) -> "GGMLTensor":
if device is None:
self.data = self.data.cuda(non_blocking=non_blocking)
else:
self.data = self.data.cuda(device=device, non_blocking=non_blocking)
return self
def cpu(self, pin_memory: bool = False) -> "GGMLTensor":
self.data = self.data.cpu()
return self
def to(self, *args, **kwargs) -> "GGMLTensor":
self.data = self.data.to(*args, **kwargs)
return self
def load_gguf_sd_ckpt(gguf_path, return_arch=False, to_device: Optional[Union[int, torch.device]] = None):
import warnings
logger.info(f"Loading gguf-quant dit model from {gguf_path}")
reader = gguf.GGUFReader(gguf_path)
state_dict = {}
for tensor in reader.tensors:
tensor_name = tensor.name
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="The given NumPy array is not writable")
torch_tensor = torch.from_numpy(tensor.data) # mmap
shape = get_orig_shape(reader, tensor_name)
if shape is None:
shape = torch.Size(tuple(int(v) for v in reversed(tensor.shape)))
if tensor.tensor_type in TORCH_COMPATIBLE_QTYPES:
state_dict[tensor.name] = torch_tensor.to(to_device)
else:
state_dict[tensor.name] = GGMLTensor(
data=torch_tensor,
gguf_type=tensor.tensor_type,
orig_shape=shape,
aligned=True,
pin_memory=False,
).to(to_device)
if return_arch:
arch = get_model_architecture(reader)
return state_dict, arch
return state_dict
def get_orig_shape(reader, tensor_name: str) -> Optional[Tuple[int, ...]]:
# TODO 这里正式上线的时候,需要更换
field_key = f"comfy.gguf.orig_shape.{tensor_name}"
field = reader.get_field(field_key)
if field is None:
return None
# Has original shape metadata, so we try to decode it.
if len(field.types) != 2 or field.types[0] != gguf.GGUFValueType.ARRAY or field.types[1] != gguf.GGUFValueType.INT32:
raise TypeError(f"Bad original shape metadata for {field_key}: Expected ARRAY of INT32, got {field.types}")
return torch.Size(tuple(int(field.parts[part_idx][0]) for part_idx in field.data))
def get_field(reader, field_name, field_type):
field = reader.get_field(field_name)
if field is None:
return None
elif isinstance(field_type, str):
# extra check here as this is used for checking arch string
if len(field.types) != 1 or field.types[0] != gguf.GGUFValueType.STRING:
raise TypeError(f"Bad type for GGUF {field_name} key: expected string, got {field.types!r}")
return str(field.parts[field.data[-1]], encoding="utf-8")
elif field_type in [int, float, bool]:
return field_type(field.parts[field.data[-1]])
else:
raise TypeError(f"Unknown field type {field_type}")
def get_model_architecture(reader) -> str:
arch_str = get_field(reader, "general.architecture", str)
return arch_str
def dequantize_tensor(tensor, dtype=None):
qtype = getattr(tensor, "gguf_type", None)
oshape = getattr(tensor, "orig_shape", tensor.data.shape)
if qtype in TORCH_COMPATIBLE_QTYPES:
return tensor.to(dtype)
elif qtype in dequantize_functions:
return dequantize(tensor.to_torch().data, qtype, oshape, dtype=dtype).to(dtype)
else:
# this is incredibly slow
tqdm.write(f"Falling back to numpy dequant for qtype: {qtype}")
new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype)
return torch.from_numpy(new).to(tensor.device, dtype=dtype)
def dequantize(data, qtype, oshape, dtype=None):
block_size, type_size = gguf.GGML_QUANT_SIZES[qtype]
dequantize_blocks = dequantize_functions[qtype]
rows = data.reshape((-1, data.shape[-1])).view(torch.uint8)
n_blocks = rows.numel() // type_size
blocks = rows.reshape((n_blocks, type_size))
blocks = dequantize_blocks(blocks, block_size, type_size, dtype)
return blocks.reshape(oshape)
def to_uint32(x):
x = x.view(torch.uint8).to(torch.int32)
return (x[:, 0] | x[:, 1] << 8 | x[:, 2] << 16 | x[:, 3] << 24).unsqueeze(1)
def split_block_dims(blocks, *args):
n_max = blocks.shape[1]
dims = list(args) + [n_max - sum(args)]
return torch.split(blocks, dims, dim=1)
def dequantize_blocks_BF16(blocks, block_size, type_size, dtype=None):
return (blocks.view(torch.int16).to(torch.int32) << 16).view(torch.float32)
def dequantize_blocks_Q8_0(blocks, block_size, type_size, dtype=None):
d, x = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
x = x.view(torch.int8)
return d * x
def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, m, qh, qs = split_block_dims(blocks, 2, 2, 4)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape((n_blocks, -1))
qs = ql | (qh << 4)
return (d * qs) + m
def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, qh, qs = split_block_dims(blocks, 2, 4)
d = d.view(torch.float16).to(dtype)
qh = to_uint32(qh)
qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32)
ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qh = (qh & 1).to(torch.uint8)
ql = (ql & 0x0F).reshape(n_blocks, -1)
qs = (ql | (qh << 4)).to(torch.int8) - 16
return d * qs
def dequantize_blocks_Q4_1(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, m, qs = split_block_dims(blocks, 2, 2)
d = d.view(torch.float16).to(dtype)
m = m.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1)
qs = (qs & 0x0F).reshape(n_blocks, -1)
return (d * qs) + m
def dequantize_blocks_Q4_0(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, qs = split_block_dims(blocks, 2)
d = d.view(torch.float16).to(dtype)
qs = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1)).to(torch.int8) - 8
return d * qs
# K Quants #
QK_K = 256
K_SCALE_SIZE = 12
def get_scale_min(scales):
n_blocks = scales.shape[0]
scales = scales.view(torch.uint8)
scales = scales.reshape((n_blocks, 3, 4))
d, m, m_d = torch.split(scales, scales.shape[-2] // 3, dim=-2)
sc = torch.cat([d & 0x3F, (m_d & 0x0F) | ((d >> 2) & 0x30)], dim=-1)
min = torch.cat([m & 0x3F, (m_d >> 4) | ((m >> 2) & 0x30)], dim=-1)
return (sc.reshape((n_blocks, 8)), min.reshape((n_blocks, 8)))
def dequantize_blocks_Q6_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
(
ql,
qh,
scales,
d,
) = split_block_dims(blocks, QK_K // 2, QK_K // 4, QK_K // 16)
scales = scales.view(torch.int8).to(dtype)
d = d.view(torch.float16).to(dtype)
d = (d * scales).reshape((n_blocks, QK_K // 16, 1))
ql = ql.reshape((n_blocks, -1, 1, 64)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qh = (qh & 0x03).reshape((n_blocks, -1, 32))
q = (ql | (qh << 4)).to(torch.int8) - 32
q = q.reshape((n_blocks, QK_K // 16, -1))
return (d * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q5_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, dmin, scales, qh, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE, QK_K // 8)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qh = qh.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1))
ql = (ql & 0x0F).reshape((n_blocks, -1, 32))
qh = (qh & 0x01).reshape((n_blocks, -1, 32))
q = ql | (qh << 4)
return (d * q - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q4_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
d, dmin, scales, qs = split_block_dims(blocks, 2, 2, K_SCALE_SIZE)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
sc, m = get_scale_min(scales)
d = (d * sc).reshape((n_blocks, -1, 1))
dm = (dmin * m).reshape((n_blocks, -1, 1))
qs = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 1, 2, 1))
qs = (qs & 0x0F).reshape((n_blocks, -1, 32))
return (d * qs - dm).reshape((n_blocks, QK_K))
def dequantize_blocks_Q3_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
hmask, qs, scales, d = split_block_dims(blocks, QK_K // 8, QK_K // 4, 12)
d = d.view(torch.float16).to(dtype)
lscales, hscales = scales[:, :8], scales[:, 8:]
lscales = lscales.reshape((n_blocks, 1, 8)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape((1, 2, 1))
lscales = lscales.reshape((n_blocks, 16))
hscales = hscales.reshape((n_blocks, 1, 4)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 4, 1))
hscales = hscales.reshape((n_blocks, 16))
scales = (lscales & 0x0F) | ((hscales & 0x03) << 4)
scales = scales.to(torch.int8) - 32
dl = (d * scales).reshape((n_blocks, 16, 1))
ql = qs.reshape((n_blocks, -1, 1, 32)) >> torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qh = hmask.reshape(n_blocks, -1, 1, 32) >> torch.tensor([i for i in range(8)], device=d.device, dtype=torch.uint8).reshape((1, 1, 8, 1))
ql = ql.reshape((n_blocks, 16, QK_K // 16)) & 3
qh = (qh.reshape((n_blocks, 16, QK_K // 16)) & 1) ^ 1
q = ql.to(torch.int8) - (qh << 2).to(torch.int8)
return (dl * q).reshape((n_blocks, QK_K))
def dequantize_blocks_Q2_K(blocks, block_size, type_size, dtype=None):
n_blocks = blocks.shape[0]
scales, qs, d, dmin = split_block_dims(blocks, QK_K // 16, QK_K // 4, 2)
d = d.view(torch.float16).to(dtype)
dmin = dmin.view(torch.float16).to(dtype)
# (n_blocks, 16, 1)
dl = (d * (scales & 0xF)).reshape((n_blocks, QK_K // 16, 1))
ml = (dmin * (scales >> 4)).reshape((n_blocks, QK_K // 16, 1))
shift = torch.tensor([0, 2, 4, 6], device=d.device, dtype=torch.uint8).reshape((1, 1, 4, 1))
qs = (qs.reshape((n_blocks, -1, 1, 32)) >> shift) & 3
qs = qs.reshape((n_blocks, QK_K // 16, 16))
qs = dl * qs - ml
return qs.reshape((n_blocks, -1))
dequantize_functions = {
gguf.GGMLQuantizationType.BF16: dequantize_blocks_BF16,
gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0,
gguf.GGMLQuantizationType.Q5_1: dequantize_blocks_Q5_1,
gguf.GGMLQuantizationType.Q5_0: dequantize_blocks_Q5_0,
gguf.GGMLQuantizationType.Q4_1: dequantize_blocks_Q4_1,
gguf.GGMLQuantizationType.Q4_0: dequantize_blocks_Q4_0,
gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K,
gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K,
gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K,
gguf.GGMLQuantizationType.Q3_K: dequantize_blocks_Q3_K,
gguf.GGMLQuantizationType.Q2_K: dequantize_blocks_Q2_K,
}
if __name__ == "__main__":
sd = load_gguf_sd_ckpt("/home/SENSETIME/yihuiwen/yihuiwen/workspace/models/city96/Wan2.1-I2V-14B-720P-gguf/wan2.1-i2v-14b-720p-Q4_K_S.gguf", return_arch=False)
for k, s in sd.items():
print(k)
if isinstance(s, GGMLTensor):
dequantize_tensor(s, torch.float32)
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/quantization/gguf/wan_i2v_q4_k.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_i2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment