import json from dataclasses import dataclass, replace from typing import Callable, NamedTuple, Protocol, TypeVar import safetensors import torch ModelType = TypeVar("ModelType") @dataclass(frozen=True, slots=True) class ContentReplacement: """ Represents a content replacement operation. Used to replace a specific content with a replacement in a state dict key. """ content: str replacement: str @dataclass(frozen=True, slots=True) class ContentMatching: """ Represents a content matching operation. Used to match a specific prefix and suffix in a state dict key. """ prefix: str = "" suffix: str = "" class KeyValueOperationResult(NamedTuple): """ Represents the result of a key-value operation. Contains the new key and value after the operation has been applied. """ new_key: str new_value: torch.Tensor class KeyValueOperation(Protocol): """ Protocol for key-value operations. Used to apply operations to a specific key and value in a state dict. """ def __call__(self, tensor_key: str, tensor_value: torch.Tensor) -> list[KeyValueOperationResult]: ... @dataclass(frozen=True, slots=True) class SDKeyValueOperation: """ Represents a key-value operation. Used to apply operations to a specific key and value in a state dict. """ key_matcher: ContentMatching kv_operation: KeyValueOperation @dataclass(frozen=True, slots=True) class SDOps: """Immutable class representing state dict key operations.""" name: str mapping: tuple[ContentReplacement | ContentMatching | SDKeyValueOperation, ...] = () # Immutable tuple of (key, value) pairs def with_replacement(self, content: str, replacement: str) -> "SDOps": """Create a new SDOps instance with the specified replacement added to the mapping.""" new_mapping = (*self.mapping, ContentReplacement(content, replacement)) return replace(self, mapping=new_mapping) def with_matching(self, prefix: str = "", suffix: str = "") -> "SDOps": """Create a new SDOps instance with the specified prefix and suffix matching added to the mapping.""" new_mapping = (*self.mapping, ContentMatching(prefix, suffix)) return replace(self, mapping=new_mapping) def with_kv_operation( self, operation: KeyValueOperation, key_prefix: str = "", key_suffix: str = "", ) -> "SDOps": """Create a new SDOps instance with the specified value operation added to the mapping.""" key_matcher = ContentMatching(key_prefix, key_suffix) sd_kv_operation = SDKeyValueOperation(key_matcher, operation) new_mapping = (*self.mapping, sd_kv_operation) return replace(self, mapping=new_mapping) def apply_to_key(self, key: str) -> str | None: """Apply the mapping to the given name.""" matchers = [content for content in self.mapping if isinstance(content, ContentMatching)] valid = any(key.startswith(f.prefix) and key.endswith(f.suffix) for f in matchers) if not valid: return None for replacement in self.mapping: if not isinstance(replacement, ContentReplacement): continue if replacement.content in key: key = key.replace(replacement.content, replacement.replacement) return key def apply_to_key_value(self, key: str, value: torch.Tensor) -> list[KeyValueOperationResult]: """Apply the value operation to the given name and associated value.""" for operation in self.mapping: if not isinstance(operation, SDKeyValueOperation): continue if key.startswith(operation.key_matcher.prefix) and key.endswith(operation.key_matcher.suffix): return operation.kv_operation(key, value) return [KeyValueOperationResult(key, value)] class ModuleOps(NamedTuple): """ Defines a named operation for matching and mutating PyTorch modules. Used to selectively transform modules in a model (e.g., replacing layers with quantized versions). """ name: str matcher: Callable[[torch.nn.Module], bool] mutator: Callable[[torch.nn.Module], torch.nn.Module] class ModelConfigurator(Protocol[ModelType]): """Protocol for model loader classes that instantiates models from a configuration dictionary.""" @classmethod def from_config(cls, config: dict) -> ModelType: ... @dataclass(frozen=True) class StateDict: """ Immutable container for a PyTorch state dictionary. Contains: - sd: Dictionary of tensors (weights, buffers, etc.) - device: Device where tensors are stored - size: Total memory footprint in bytes - dtype: Set of tensor dtypes present """ sd: dict device: torch.device size: int dtype: set[torch.dtype] def footprint(self) -> tuple[int, torch.device]: return self.size, self.device class StateDictLoader(Protocol): """ Protocol for loading state dictionaries from various sources. Implementations must provide: - metadata: Extract model metadata from a single path - load: Load state dict from path(s) and apply SDOps transformations """ def metadata(self, path: str) -> dict: """Load metadata from path""" def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: """Load state dict from path or paths (for sharded model storage) and apply sd_ops""" class SafetensorsStateDictLoader(StateDictLoader): """ Loads weights from safetensors files without metadata support. Use this for loading raw weight files. For model files that include configuration metadata, use SafetensorsModelStateDictLoader instead. """ def metadata(self, path: str) -> dict: raise NotImplementedError("Not implemented") def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: """ Load state dict from path or paths (for sharded model storage) and apply sd_ops """ sd = {} size = 0 dtype = set() device = device or torch.device("cpu") model_paths = path if isinstance(path, list) else [path] for shard_path in model_paths: with safetensors.safe_open(shard_path, framework="pt", device=str(device)) as f: safetensor_keys = f.keys() for name in safetensor_keys: expected_name = name if sd_ops is None else sd_ops.apply_to_key(name) if expected_name is None: continue value = f.get_tensor(name).to(device=device, non_blocking=True, copy=False) key_value_pairs = ((expected_name, value),) if sd_ops is not None: key_value_pairs = sd_ops.apply_to_key_value(expected_name, value) for key, value in key_value_pairs: size += value.nbytes dtype.add(value.dtype) sd[key] = value return StateDict(sd=sd, device=device, size=size, dtype=dtype) class SafetensorsModelStateDictLoader(StateDictLoader): """ Loads weights and configuration metadata from safetensors model files. Unlike SafetensorsStateDictLoader, this loader can read model configuration from the safetensors file metadata via the metadata() method. """ def __init__(self, weight_loader: SafetensorsStateDictLoader | None = None): self.weight_loader = weight_loader if weight_loader is not None else SafetensorsStateDictLoader() def metadata(self, path: str) -> dict: with safetensors.safe_open(path, framework="pt") as f: return json.loads(f.metadata()["config"]) def load(self, path: str | list[str], sd_ops: SDOps | None = None, device: torch.device | None = None) -> StateDict: return self.weight_loader.load(path, sd_ops, device) # Predefined SDOps instances LTXV_LORA_COMFY_RENAMING_MAP = SDOps("LTXV_LORA_COMFY_PREFIX_MAP").with_matching().with_replacement("diffusion_model.", "") LTXV_LORA_COMFY_TARGET_MAP = ( SDOps("LTXV_LORA_COMFY_TARGET_MAP").with_matching().with_replacement("diffusion_model.", "").with_replacement(".lora_A.weight", ".weight").with_replacement(".lora_B.weight", ".weight") )