Commit e2778d0d authored by litzh's avatar litzh
Browse files

Initial commit

parents
Pipeline #3370 canceled with stages
from abc import ABCMeta, abstractmethod
import torch
from lightx2v.utils.registry_factory import CONV2D_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class Conv2dWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, stride, padding, dilation, groups):
self.weight_name = weight_name
self.bias_name = bias_name
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.config = {}
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
@CONV2D_WEIGHT_REGISTER("Default")
class Conv2dWeight(Conv2dWeightTemplate):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].to(AI_DEVICE)
self.bias = weight_dict[self.bias_name].to(AI_DEVICE) if self.bias_name is not None else None
def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
return input_tensor
def to_cpu(self, non_blocking=False):
self.weight = self.weight.cpu(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.to(AI_DEVICE, non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.to(AI_DEVICE, non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
from abc import ABCMeta, abstractmethod
import torch
from loguru import logger
from lightx2v.common.ops.utils import *
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class Conv3dWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1, lora_prefix="diffusion_model.blocks"):
self.weight_name = weight_name
self.bias_name = bias_name
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.config = {}
self.lora_prefix = lora_prefix
self.has_lora_branch = False
self.has_diff = False
self._get_base_attrs_mapping()
self._get_lora_attr_mapping()
def _get_base_attrs_mapping(self):
self.base_attrs = []
self.base_attrs.append((self.weight_name, "weight", False))
self.base_attrs.append((self.bias_name, "bias", False))
def _get_lora_attr_mapping(self):
_, _, _, self.weight_diff_name, self.bias_diff_name = build_lora_and_diff_names(self.weight_name, self.lora_prefix)
self.lora_attrs = {
"weight_diff": "weight_diff_name",
"bias_diff": "bias_diff_name",
}
self.weight_diff = torch.tensor(0.0, dtype=GET_DTYPE(), device=AI_DEVICE)
self.bias_diff = torch.tensor(0.0, dtype=GET_DTYPE(), device=AI_DEVICE)
def register_diff(self, weight_dict):
if self.weight_diff_name in weight_dict:
self.weight_diff = weight_dict[self.weight_diff_name]
logger.debug(f"Register Diff to {self.weight_name}")
if self.bias_diff_name in weight_dict:
self.bias_diff = weight_dict[self.bias_diff_name]
logger.debug(f"Register Diff to {self.bias_name}")
def set_config(self, config=None):
if config is not None:
self.config = config
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
@CONV3D_WEIGHT_REGISTER("Default")
class Conv3dWeight(Conv3dWeightTemplate):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1, lora_prefix="diffusion_model.blocks"):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups, lora_prefix)
def load(self, weight_dict):
device_tensors, pin_tensors = create_default_tensors(self.base_attrs, weight_dict)
self.weight = device_tensors.get("weight")
self.bias = device_tensors.get("bias")
self.pin_weight = pin_tensors.get("weight")
self.pin_bias = pin_tensors.get("bias")
def apply(self, input_tensor):
output_tensor = torch.nn.functional.conv3d(
input_tensor,
weight=self.weight + self.weight_diff,
bias=self.bias + self.bias_diff,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return output_tensor
def to_cuda(self, non_blocking=False):
move_attr_to_cuda(self, self.base_attrs, self.lora_attrs, non_blocking)
def to_cpu(self, non_blocking=False):
move_attr_to_cpu(self, self.base_attrs, self.lora_attrs, non_blocking)
def state_dict(self, destination=None):
return state_dict(self, self.base_attrs, self.lora_attrs, destination)
def load_state_dict(self, destination, block_index, adapter_block_index=None):
return load_state_dict(self, self.base_attrs, self.lora_attrs, destination, block_index, adapter_block_index)
from .embedding_weight import *
import re
from abc import ABCMeta
from pathlib import Path
import torch
import torch.nn.functional as F
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class EmbeddingWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.weight_name = weight_name
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.infer_dtype = GET_DTYPE()
self.config = {}
def load(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffer(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffer()
else:
self._load_default_tensors(weight_dict)
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_tensor = weight_dict[self.weight_name]
self.pin_weight = self._create_cpu_pin_weight(weight_tensor)
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.weight_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[self.weight_name]
return tensor
def _create_cpu_pin_weight(self, tensor):
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_cuda_buffer(self, weight_dict):
weight_tensor = self._get_weight_tensor(weight_dict, use_infer_dtype=self.lazy_load)
self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE)
def _load_cpu_pin_buffer(self):
weight_tensor = self._get_weight_tensor(use_infer_dtype=True)
self.pin_weight = self._create_cpu_pin_weight(weight_tensor)
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
@EMBEDDING_WEIGHT_REGISTER("Default")
class EmbeddingWeight(EmbeddingWeightTemplate):
def __init__(self, weight_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file)
def apply(self, input_indices):
output = F.embedding(input=input_indices, weight=self.weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
return output
from .mm_weight import *
import re
from abc import ABCMeta, abstractmethod
import torch
import torch.distributed as dist
from loguru import logger
from safetensors import safe_open
from lightx2v.common.ops.mm.triton_kernels import (
fp8_gemm_bias_triton,
fp8_gemm_triton,
fp8_quantize_triton,
int8_gemm_bias_triton,
int8_gemm_triton,
int8_quantize_triton,
)
from lightx2v.common.ops.utils import *
from lightx2v.utils.envs import *
from lightx2v.utils.ggml_tensor import GGMLTensor
from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tensor
from lightx2v.utils.global_paras import CALIB
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
try:
from lightx2v_kernel.gemm import (
cutlass_scaled_mxfp4_mm,
cutlass_scaled_mxfp6_mxfp8_mm,
cutlass_scaled_mxfp8_mm,
cutlass_scaled_nvfp4_mm,
scaled_mxfp4_quant,
scaled_mxfp6_quant,
scaled_mxfp8_quant,
scaled_nvfp4_quant,
)
except ImportError:
scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm = None, None
scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm = None, None
scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm = None, None
scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
try:
import sgl_kernel
except ImportError:
sgl_kernel = None
try:
from q8_kernels.functional.linear import q8_linear
except ImportError:
q8_linear = None
try:
from q8_kernels.functional.linear import fp8_linear
except ImportError:
fp8_linear = None
try:
import deep_gemm
except ImportError:
deep_gemm = None
try:
from torchao.quantization.utils import (
quant_int8_per_token_matmul as torchao_int8_gemm,
)
from torchao.quantization.utils import (
quantize_activation_per_token_absmax as torchao_int8_quant,
)
except ImportError:
try:
from torchao.quantization.utils import (
_quant_int8_per_token_matmul as torchao_int8_gemm,
)
from torchao.quantization.utils import (
_quantize_activation_per_token_absmax as torchao_int8_quant,
)
except ImportError:
torchao_int8_gemm, torchao_int8_quant = None, None
try:
import gguf
except ImportError:
gguf = None
try:
import marlin_cuda_quant
except ImportError:
marlin_cuda_quant = None
import torch.distributed as dist
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
self.weight_name = weight_name
self.bias_name = bias_name
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.config = {}
self.lora_prefix = lora_prefix
self.lora_path = lora_path
self.has_lora_branch = False
self.has_diff = False
self._get_base_attrs_mapping()
self._get_lora_attr_mapping()
def _get_base_attrs_mapping(self):
self.base_attrs = [
(self.weight_name, "weight", True),
]
if self.bias_name is not None:
self.base_attrs.append((self.bias_name, "bias", False))
def _get_lora_attr_mapping(self):
self.lora_down_name, self.lora_up_name, self.lora_alpha_name, self.weight_diff_name, self.bias_diff_name = build_lora_and_diff_names(self.weight_name, self.lora_prefix)
self.lora_attrs = {
"lora_alpha": "lora_alpha_name",
"lora_down": "lora_down_name",
"lora_up": "lora_up_name",
"weight_diff": "weight_diff_name",
"bias_diff": "bias_diff_name",
}
def _get_actual_weight(self):
if not hasattr(self, "weight_diff"):
return self.weight
return self.weight + self.weight_diff
def _get_actual_bias(self, bias=None):
if bias is not None:
if not hasattr(self, "bias_diff"):
return bias
return bias + self.bias_diff
else:
if not hasattr(self, "bias") or self.bias is None:
return None
if not hasattr(self, "bias_diff"):
return self.bias
return self.bias + self.bias_diff
def apply_lora(self, input_tensor):
h = torch.mm(input_tensor, self.lora_down.t())
out = torch.mm(h, self.lora_up.t())
return self.lora_strength * self.lora_scale * out
def set_config(self, config={}):
self.config = config
def register_diff(self, weight_dict):
if not self.lazy_load or self.create_cuda_buffer or self.create_cpu_buffer:
if self.weight_diff_name in weight_dict:
self.has_diff = True
self.weight_diff = weight_dict[self.weight_diff_name].t()
logger.debug(f"Register Diff to {self.weight_name}")
if self.bias_diff_name in weight_dict:
self.has_diff = True
self.bias_diff = weight_dict[self.bias_diff_name]
logger.debug(f"Register Diff to {self.bias_name}")
def register_lora(self, weight_dict, lora_strength=1):
if not self.lazy_load or self.create_cuda_buffer or self.create_cpu_buffer:
if self.lora_down_name in weight_dict:
self.has_lora_branch = True
self.lora_down = weight_dict[self.lora_down_name]
self.lora_up = weight_dict[self.lora_up_name]
self.lora_strength = lora_strength
if self.lora_alpha_name in weight_dict:
self.lora_alpha = weight_dict[self.lora_alpha_name]
self.lora_scale = self.lora_alpha / self.lora_down.shape[0]
else:
self.lora_scale = torch.tensor(1.0, device=AI_DEVICE)
logger.debug(f"Register LoRA to {self.weight_name} with lora_scale={self.lora_scale}")
def update_lora(self, weight_dict, lora_strength=1):
if not self.lazy_load or self.create_cuda_buffer or self.create_cpu_buffer:
if self.lora_down_name in weight_dict:
self.has_lora_branch = True
self.lora_down.copy_(weight_dict[self.lora_down_name])
self.lora_up.copy_(weight_dict[self.lora_up_name])
self.lora_strength = lora_strength
if self.lora_alpha_name in weight_dict:
self.lora_alpha.copy_(weight_dict[self.lora_alpha_name])
self.lora_scale.copy_(self.lora_alpha / self.lora_down.shape[0])
else:
self.lora_scale = torch.tensor(1.0, device=AI_DEVICE)
logger.debug(f"Update LoRA to {self.weight_name}")
def remove_lora(self):
if hasattr(self, "lora_down"):
del self.lora_down
if hasattr(self, "lora_up"):
del self.lora_up
if hasattr(self, "lora_alpha"):
del self.lora_alpha
if hasattr(self, "lora_scale"):
del self.lora_scale
self.has_lora_branch = False
logger.debug(f"Remove LoRA from {self.weight_name}")
def state_dict(self, destination=None):
return state_dict(self, self.base_attrs, self.lora_attrs, destination)
def load_state_dict(self, destination, block_index, adapter_block_index=None):
return load_state_dict(self, self.base_attrs, self.lora_attrs, destination, block_index, adapter_block_index)
def load_lora_state_dict_from_disk(self, block_index):
self.lora_alpha_name = resolve_block_name(self.lora_alpha_name, block_index)
self.lora_down_name = resolve_block_name(self.lora_down_name, block_index)
self.lora_up_name = resolve_block_name(self.lora_up_name, block_index)
self.weight_diff_name = resolve_block_name(self.weight_diff_name, block_index)
self.bias_diff_name = resolve_block_name(self.bias_diff_name, block_index)
with safe_open(self.lora_path, framework="pt", device="cpu") as lora_load_file:
for lora_attr, lora_attr_name in self.lora_attrs.items():
if getattr(self, lora_attr_name) in lora_load_file.keys():
setattr(self, lora_attr, getattr(self, lora_attr).copy_(lora_load_file.get_tensor(getattr(self, lora_attr_name)), non_blocking=True))
def to_cuda(self, non_blocking=False):
move_attr_to_cuda(self, self.base_attrs, self.lora_attrs, non_blocking)
def to_cpu(self, non_blocking=False):
move_attr_to_cpu(self, self.base_attrs, self.lora_attrs, non_blocking)
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self):
pass
@MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate):
def __init__(
self,
weight_name,
bias_name=None,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
def load(self, weight_dict):
if not self.create_cuda_buffer and not self.create_cpu_buffer and not self.lazy_load:
device_tensors, pin_tensors = create_default_tensors(self.base_attrs, weight_dict)
self.weight = device_tensors.get("weight")
self.bias = device_tensors.get("bias")
self.pin_weight = pin_tensors.get("weight")
self.pin_bias = pin_tensors.get("bias")
elif self.create_cuda_buffer:
result = create_cuda_buffers(self.base_attrs, weight_dict, self.lazy_load, self.lazy_load_file)
self.weight_cuda_buffer = result.get("weight")
self.bias_cuda_buffer = result.get("bias")
elif self.create_cpu_buffer:
result = create_cpu_buffers(self.base_attrs, self.lazy_load_file)
self.pin_weight = result.get("weight")
self.pin_bias = result.get("bias")
self.weight = None
self.bias = None
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
if not self.has_lora_branch:
if hasattr(self, "bias") and self.bias is not None:
return torch.addmm(self._get_actual_bias(), input_tensor, self._get_actual_weight(), out=output_tensor)
return torch.mm(input_tensor, self._get_actual_weight(), out=output_tensor)
else:
if hasattr(self, "bias") and self.bias is not None:
return torch.addmm(self._get_actual_bias(), input_tensor, self._get_actual_weight(), out=output_tensor) + self.apply_lora(input_tensor)
return torch.mm(input_tensor, self._get_actual_weight(), out=output_tensor) + self.apply_lora(input_tensor)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.has_lora_branch or self.has_diff:
self.load_lora_state_dict_from_disk(block_index)
self.weight_name = resolve_block_name(self.weight_name, block_index, adapter_block_index, self.is_post_adapter)
if self.bias_name is not None:
self.bias_name = resolve_block_name(self.bias_name, block_index, adapter_block_index, self.is_post_adapter)
lazy_load_file_path = get_lazy_load_file_path(self.lazy_load_file, self.weight_name)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
class MMWeightQuantTemplate(MMWeightTemplate):
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
self.load_func = None
self.weight_need_transpose = True
self.act_quant_func = None
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
self.bias_force_fp32 = False
self.scale_force_fp32 = False
self._update_base_attrs()
def _update_base_attrs(self):
self.base_attrs = [(self.weight_name, "weight", False), (self.weight_scale_name, "weight_scale", False)]
if self.bias_name is not None:
self.base_attrs.append((self.bias_name, "bias", False))
# =========================
# weight load functions
# =========================
def load(self, weight_dict):
self.load_quantized(weight_dict)
self.post_process()
def post_process(self):
if self.weight_need_transpose:
if hasattr(self, "weight") and self.weight is not None:
self.weight = self.weight.t()
self.weight = self.weight.contiguous()
if hasattr(self, "pin_weight") and self.pin_weight is not None:
self.pin_weight = self.pin_weight.t()
if hasattr(self, "weight_cuda_buffer") and self.weight_cuda_buffer is not None:
self.weight_cuda_buffer = self.weight_cuda_buffer.t()
if hasattr(self, "bias") and self.bias is not None:
if self.bias_force_fp32:
self.bias = self.bias.to(torch.float32)
else:
self.bias = self.bias.to(self.infer_dtype)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
if self.bias_force_fp32:
self.pin_bias = self.pin_bias.to(torch.float32)
else:
self.pin_bias = self.pin_bias.to(self.infer_dtype)
if self.bias_force_fp32 and hasattr(self, "bias_diff"):
self.bias_diff = self.bias_diff.to(torch.float32)
if self.scale_force_fp32:
if hasattr(self, "weight_scale") and self.weight_scale is not None:
self.weight_scale = self.weight_scale.to(torch.float32)
if hasattr(self, "pin_weight_scale") and self.pin_weight_scale is not None:
self.pin_weight_scale = self.pin_weight_scale.to(torch.float32)
def load_quantized(self, weight_dict):
if not self.create_cuda_buffer and not self.create_cpu_buffer and not self.lazy_load:
device_tensors, pin_tensors = create_default_tensors(self.base_attrs, weight_dict)
self.weight = device_tensors.get("weight")
self.weight_scale = device_tensors.get("weight_scale")
self.bias = device_tensors.get("bias")
self.pin_weight = pin_tensors.get("weight")
self.pin_weight_scale = pin_tensors.get("weight_scale")
self.pin_bias = pin_tensors.get("bias")
elif self.create_cuda_buffer:
result = create_cuda_buffers(self.base_attrs, weight_dict, self.lazy_load, self.lazy_load_file, scale_force_fp32=self.scale_force_fp32, bias_force_fp32=self.bias_force_fp32)
self.weight_cuda_buffer = result.get("weight")
self.weight_scale_cuda_buffer = result.get("weight_scale")
self.bias_cuda_buffer = result.get("bias")
elif self.create_cpu_buffer:
result = create_cpu_buffers(self.base_attrs, self.lazy_load_file, scale_force_fp32=self.scale_force_fp32, bias_force_fp32=self.bias_force_fp32)
self.pin_weight = result.get("weight")
self.pin_weight_scale = result.get("weight_scale")
self.pin_bias = result.get("bias")
self.weight = None
self.weight_scale = None
self.bias = None
def load_fp8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32)
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.load_quantized(weight_dict)
def load_int8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32)
w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.load_quantized(weight_dict)
def load_mxfp4(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
self.load_quantized(weight_dict)
def load_mxfp6(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
self.load_quantized(weight_dict)
def load_mxfp8(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
self.load_quantized(weight_dict)
def load_nvfp4(self, weight_dict):
assert not self.config.get("weight_auto_quant", False)
self.load_quantized(weight_dict)
def load_fp8_perblock128_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name]
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
else:
self.load_quantized(weight_dict)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.has_lora_branch or self.has_diff:
self.load_lora_state_dict_from_disk(block_index)
self.weight_name = resolve_block_name(self.weight_name, block_index, adapter_block_index, self.is_post_adapter)
self.weight_scale_name = resolve_block_name(self.weight_scale_name, block_index, adapter_block_index, self.is_post_adapter)
if self.bias_name is not None:
self.bias_name = resolve_block_name(self.bias_name, block_index, adapter_block_index, self.is_post_adapter)
lazy_load_file_path = get_lazy_load_file_path(self.lazy_load_file, self.weight_name)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
if self.weight_need_transpose:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
else:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
weight_scale_tensor = lazy_load_file.get_tensor(self.weight_scale_name)
self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
del weight_scale_tensor
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
def per_block_cast_to_fp8(self, x):
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
dtype=x.dtype,
device=x.device,
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
# =========================
# act quant kernels
# =========================
def act_quant_int8_perchannel_sym_torchao(self, x):
input_tensor_quant, input_tensor_scale = torchao_int8_quant(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_torchao(self, x):
abs_max = x.abs().max(dim=-1, keepdim=True)[0]
abs_max = torch.clamp(abs_max, min=1e-8)
scale = abs_max / 448.0
quantized = torch.clamp(x / scale, -448, 448).to(torch.float8_e4m3fn)
return quantized, scale.float()
def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_sgl(self, x):
m, k = x.shape
input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
return input_tensor_quant, input_tensor_scale
def act_quant_int8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale
def act_quant_nvfp4(self, x):
input_tensor_quant, input_tensor_scale = scaled_nvfp4_quant(x, self.input_global_scale)
return input_tensor_quant, input_tensor_scale
def act_quant_mxfp4(self, x):
input_tensor_quant, input_tensor_scale = scaled_mxfp4_quant(x)
return input_tensor_quant, input_tensor_scale
def act_quant_mxfp8(self, x):
input_tensor_quant, input_tensor_scale = scaled_mxfp8_quant(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
m, k = x.shape
input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
sgl_kernel.sgl_per_token_group_quant_fp8(
x,
input_tensor_quant,
input_tensor_scale,
group_size=128,
eps=1e-10,
fp8_min=-448.0,
fp8_max=448.0,
)
return input_tensor_quant, input_tensor_scale
@MM_WEIGHT_REGISTER("fp8-vllm")
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: vllm
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_fp8_perchannel_sym
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
self.weight_need_transpose = True
self.scale_force_fp32 = True
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self._get_actual_bias(),
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("int8-vllm")
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: vllm
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_int8_perchannel_sym
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
self.weight_need_transpose = False
self.scale_force_fp32 = True
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
'''
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self._get_actual_bias(),
)
'''
if ops is not None and hasattr(ops, 'blaslt_scaled_mm'):
out_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16
input_tensor_quant = input_tensor_quant.contiguous()
output_tensor = ops.blaslt_scaled_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
out_dtype,
self.bias if self.bias is not None else None
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("mxfp4")
class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp4-A-mxfp4-dynamic
Quant MM:
Weight: mxfp4
Act: mxfp4
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_mxfp4
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp4
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp4_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
alpha=self.alpha,
bias=self._get_actual_bias(),
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("mxfp6-mxfp8")
class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp6-A-nvfp8-dynamic
Quant MM:
Weight: mxfp6
Act: mxfp8
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_mxfp6
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp8
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
alpha=self.alpha,
bias=self._get_actual_bias(),
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("mxfp8")
class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp8-A-nvfp8-dynamic
Quant MM:
Weight: mxfp8
Act: mxfp8
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_mxfp8
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp8
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp8_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
alpha=self.alpha,
bias=self._get_actual_bias(),
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("nvfp4")
class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
"""
Name: W-nvfp4-A-nvfp4-dynamic
Quant MM:
Weight: nvfp4
Act: nvfp4
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix=lora_prefix,
lora_path=lora_path,
)
self.load_func = self.load_nvfp4
self.input_absmax_name = self.weight_name.replace(".weight", ".input_absmax")
self.weight_global_scale_name = self.weight_name + "_global_scale"
self.input_global_scale_name = self.weight_name.replace(".weight", ".input_global_scale")
self.alpha_name = self.weight_name.replace(".weight", ".alpha")
self.act_quant_func = self.act_quant_nvfp4
self.weight_need_transpose = False
def load_quantized(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffers(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffers()
else:
self._load_default_tensors(weight_dict)
def _load_cuda_buffers(self, weight_dict):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(
self.lazy_load_file,
f"block_{self.weight_name.split('.')[1]}.safetensors",
)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
(
self.weight_cuda_buffer,
self.weight_scale_cuda_buffer,
self.input_global_scale_cuda_buffer,
self.alpha_cuda_buffer,
) = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
else:
source = weight_dict
(
self.weight_cuda_buffer,
self.weight_scale_cuda_buffer,
self.input_global_scale_cuda_buffer,
self.alpha_cuda_buffer,
) = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
def _get_cuda_tensor_pair(self, source, is_lazy):
if is_lazy:
if self.input_absmax_name in source.keys():
input_absmax = source.get_tensor(self.input_absmax_name)
input_global_scale = (2688.0 / input_absmax).to(torch.float32).to(AI_DEVICE)
weight_global_scale = source.get_tensor(self.weight_global_scale_name).to(AI_DEVICE)
alpha = 1.0 / (input_global_scale * weight_global_scale)
else:
input_global_scale = source.get_tensor(self.input_global_scale_name).to(torch.float32).to(AI_DEVICE)
alpha = source.get_tensor(self.alpha_name).to(torch.float32).to(AI_DEVICE)
weight = source.get_tensor(self.weight_name).to(AI_DEVICE)
scale = source.get_tensor(self.weight_scale_name).to(AI_DEVICE)
else:
if self.input_absmax_name in source:
input_absmax = source[self.input_absmax_name]
input_global_scale = (2688.0 / input_absmax).to(torch.float32).to(AI_DEVICE)
weight_global_scale = source[self.weight_global_scale_name].to(AI_DEVICE)
alpha = 1.0 / (input_global_scale * weight_global_scale)
else:
input_global_scale = source[self.input_global_scale_name].to(torch.float32).to(AI_DEVICE)
alpha = source[self.alpha_name].to(torch.float32).to(AI_DEVICE)
weight = source[self.weight_name].to(AI_DEVICE)
scale = source[self.weight_scale_name].to(AI_DEVICE)
return weight, scale, input_global_scale, alpha
def _get_cuda_bias_tensor(self, source, is_lazy):
if self.bias_name is None:
return None
if is_lazy:
bias = source.get_tensor(self.bias_name)
dtype = self.infer_dtype
else:
bias = source[self.bias_name]
dtype = bias.dtype
if self.bias_force_fp32:
bias = bias.to(torch.float32)
else:
bias = bias.to(dtype)
return bias.to(AI_DEVICE)
def _load_cpu_pin_buffers(self):
(
self.pin_weight,
self.pin_weight_scale,
self.pin_input_global_scale,
self.pin_alpha,
) = self._get_cpu_pin_tensor_pair(self.lazy_load_file, is_lazy=True)
self.pin_bias = self._get_cpu_pin_bias_tensor(self.lazy_load_file, is_lazy=True)
self.bias = None
def _get_cpu_pin_tensor_pair(self, source, is_lazy):
if is_lazy:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(
self.lazy_load_file,
f"block_{self.weight_name.split('.')[1]}.safetensors",
)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
weight_tensor = source.get_tensor(self.weight_name)
scale_tensor = source.get_tensor(self.weight_scale_name)
if self.input_absmax_name in source.keys():
input_absmax = source.get_tensor(self.input_absmax_name)
input_global_scale = (2688.0 / input_absmax).to(torch.float32)
weight_global_scale = source.get_tensor(self.weight_global_scale_name)
alpha = 1.0 / (input_global_scale * weight_global_scale)
else:
input_global_scale = source.get_tensor(self.input_global_scale_name).to(torch.float32)
alpha = source.get_tensor(self.alpha_name).to(torch.float32)
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor)
pin_input_global_scale = self._create_pin_tensor(input_global_scale)
pin_alpha = self._create_pin_tensor(alpha)
else:
weight_tensor = source[self.weight_name]
scale_tensor = source[self.weight_scale_name]
if self.input_absmax_name in source:
input_absmax = source[self.input_absmax_name]
input_global_scale = (2688.0 / input_absmax).to(torch.float32)
weight_global_scale = source[self.weight_global_scale_name]
alpha = 1.0 / (input_global_scale * weight_global_scale)
else:
input_global_scale = source[self.input_global_scale_name].to(torch.float32)
alpha = source[self.alpha_name].to(torch.float32)
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor)
pin_input_global_scale = self._create_pin_tensor(input_global_scale)
pin_alpha = self._create_pin_tensor(alpha)
return pin_weight, pin_scale, pin_input_global_scale, pin_alpha
def _get_cpu_pin_bias_tensor(self, source, is_lazy):
if self.bias_name is None:
return None
if is_lazy:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(
self.lazy_load_file,
f"block_{self.weight_name.split('.')[1]}.safetensors",
)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
bias_tensor = source.get_tensor(self.bias_name)
if not self.bias_force_fp32:
bias_tensor = bias_tensor.to(self.infer_dtype)
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
else:
bias_tensor = source[self.bias_name]
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
def _create_pin_tensor(self, tensor, dtype=None):
dtype = dtype or tensor.dtype
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
(
self.weight,
self.weight_scale,
self.input_global_scale,
self.alpha,
self.pin_weight,
self.pin_weight_scale,
self.pin_input_global_scale,
self.pin_alpha,
) = self._get_device_tensor_pair(weight_dict)
self._load_default_bias(weight_dict)
else:
self.bias = None
self.pin_bias = None
def _get_device_tensor_pair(self, source):
device = source[self.weight_name].device
if device.type == "cpu":
pin_weight, pin_scale, pin_input_global_scale, pin_alpha = self._get_cpu_pin_tensor_pair(source, is_lazy=False)
return (
None,
None,
None,
None,
pin_weight,
pin_scale,
pin_input_global_scale,
pin_alpha,
)
else:
if self.input_absmax_name in source:
input_absmax = source[self.input_absmax_name]
input_global_scale = (2688.0 / input_absmax).to(torch.float32)
weight_global_scale = source[self.weight_global_scale_name]
alpha = 1.0 / (input_global_scale * weight_global_scale)
else:
input_global_scale = source[self.input_global_scale_name].to(torch.float32).to(AI_DEVICE)
alpha = source[self.alpha_name].to(torch.float32).to(AI_DEVICE)
return (
source[self.weight_name],
source[self.weight_scale_name],
input_global_scale,
alpha,
None,
None,
None,
None,
)
def _load_default_bias(self, source):
if self.bias_name is None:
self.bias = None
self.pin_bias = None
self.bias_cuda_buffer = None
return
if self.create_cuda_buffer:
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, is_lazy=False)
self.bias = None
self.pin_bias = None
else:
bias_tensor = source[self.bias_name].float() if self.bias_force_fp32 else source[self.bias_name]
device = bias_tensor.device
if device.type == "cpu":
self.pin_bias = self._get_cpu_pin_bias_tensor(source, is_lazy=False)
self.bias = None
else:
self.bias = bias_tensor
self.pin_bias = None
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_nvfp4_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
alpha=self.alpha,
bias=self.bias,
)
return output_tensor
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_weight_scale"):
self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
self.input_global_scale = self.pin_input_global_scale.to(AI_DEVICE, non_blocking=non_blocking)
self.alpha = self.pin_alpha.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if hasattr(self, "weight_scale_name"):
self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
self.input_global_scale = self.pin_input_global_scale.copy_(self.input_global_scale, non_blocking=non_blocking).cpu()
self.alpha = self.pin_alpha.copy_(self.alpha, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
self.input_global_scale = self.input_global_scale.to("cpu", non_blocking=non_blocking)
self.alpha = self.alpha.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
destination[self.weight_scale_name] = self.pin_weight_scale if hasattr(self, "pin_weight_scale") else self.weight_scale
destination[self.input_global_scale_name] = self.pin_input_global_scale if hasattr(self, "pin_input_global_scale") else self.input_global_scale
destination[self.alpha_name] = self.pin_alpha if hasattr(self, "pin_alpha") else self.alpha
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
weight_name = resolve_block_name(self.weight_name, block_index, adapter_block_index, self.is_post_adapter)
weight_scale_name = resolve_block_name(self.weight_scale_name, block_index, adapter_block_index, self.is_post_adapter)
input_global_scale_name = resolve_block_name(self.input_global_scale_name, block_index, adapter_block_index, self.is_post_adapter)
alpha_name = resolve_block_name(self.alpha_name, block_index, adapter_block_index, self.is_post_adapter)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
self.weight_scale = self.weight_scale_cuda_buffer.copy_(destination[weight_scale_name], non_blocking=True)
self.input_global_scale = self.input_global_scale_cuda_buffer.copy_(destination[input_global_scale_name], non_blocking=True)
self.alpha = self.alpha_cuda_buffer.copy_(destination[alpha_name], non_blocking=True)
if self.bias_name is not None:
bias_name = resolve_block_name(self.bias_name, block_index, adapter_block_index, self.is_post_adapter)
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else:
self.bias = None
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
self.weight_name = resolve_block_name(self.weight_name, block_index, adapter_block_index, self.is_post_adapter)
self.weight_scale_name = resolve_block_name(self.weight_scale_name, block_index, adapter_block_index, self.is_post_adapter)
self.input_global_scale_name = resolve_block_name(self.input_global_scale_name, block_index, adapter_block_index, self.is_post_adapter)
self.alpha_name = resolve_block_name(self.alpha_name, block_index, adapter_block_index, self.is_post_adapter)
if self.bias_name is not None:
self.bias_name = resolve_block_name(self.bias_name, block_index, adapter_block_index, self.is_post_adapter)
lazy_load_file_path = get_lazy_load_file_path(self.lazy_load_file, self.weight_name)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
if self.weight_need_transpose:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
else:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
weight_scale_tensor = lazy_load_file.get_tensor(self.weight_scale_name)
self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
del weight_scale_tensor
@MM_WEIGHT_REGISTER("Calib")
class MMCalibNvfp4(MMWeight):
"""
Name: calib
Calib:
absmax: torch.max(torch.abs(input_tensor))
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.running_absmax = None
self.count = 0
self.decay = 0.9
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype, device = input_tensor.dtype, input_tensor.device
current_absmax = torch.max(torch.abs(input_tensor)).to("cpu")
if self.count % 2 == 0:
if self.running_absmax is None:
self.running_absmax = current_absmax
else:
self.running_absmax = self.decay * self.running_absmax + (1 - self.decay) * current_absmax
CALIB["absmax"][self.weight_name] = self.running_absmax
self.count = self.count + 1
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
if hasattr(self, "bias") and self.bias is not None:
return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
return torch.mm(input_tensor, self.weight, out=output_tensor)
@MM_WEIGHT_REGISTER("fp8-q8f")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Q8F
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = False
self.bias_force_fp32 = True
self.scale_force_fp32 = True
if ops is not None:
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
else:
self.act_quant_func = fp8_quantize_triton
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = fp8_linear(
input_tensor_quant,
self.weight,
self._get_actual_bias(),
input_tensor_scale.float(),
self.weight_scale,
out_dtype=self.infer_dtype,
)
if self.has_lora_branch:
return output_tensor.squeeze(0) + self.apply_lora(input_tensor) if len(output_tensor.shape) == 3 else output_tensor + self.apply_lora(input_tensor)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("int8-q8f")
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Q8F
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.bias_force_fp32 = True
self.scale_force_fp32 = True
if ops is not None:
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
else:
self.act_quant_func = int8_quantize_triton
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = q8_linear(
input_tensor_quant,
self.weight,
self._get_actual_bias(),
input_tensor_scale.float(),
self.weight_scale,
fuse_gelu=False,
out_dtype=self.infer_dtype,
)
if self.has_lora_branch:
return output_tensor.squeeze(0) + self.apply_lora(input_tensor) if len(output_tensor.shape) == 3 else output_tensor + +self.apply_lora(input_tensor)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("fp8-triton")
class MMWeightWfp8channelAfp8channeldynamicTriton(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-triton
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: triton
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_fp8_perchannel_sym
self.act_quant_func = fp8_quantize_triton
self.weight_need_transpose = False
self.bias_force_fp32 = True
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
if self.bias is not None:
output_tensor = fp8_gemm_bias_triton(
input_tensor_quant,
self.weight,
self._get_actual_bias(),
input_tensor_scale,
self.weight_scale,
output_dtype=self.infer_dtype,
)
else:
output_tensor = fp8_gemm_triton(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
output_dtype=self.infer_dtype,
)
if self.has_lora_branch:
return output_tensor.squeeze(0) + self.apply_lora(input_tensor) if len(output_tensor.shape) == 3 else output_tensor + +self.apply_lora(input_tensor)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("int8-triton")
class MMWeightWint8channelAint8channeldynamicTriton(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-triton
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: triton
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_int8_perchannel_sym
self.act_quant_func = int8_quantize_triton
self.weight_need_transpose = False
self.bias_force_fp32 = True
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
if self.bias is not None:
output_tensor = int8_gemm_bias_triton(
input_tensor_quant,
self.weight,
self._get_actual_bias(),
input_tensor_scale,
self.weight_scale,
output_dtype=self.infer_dtype,
)
else:
output_tensor = int8_gemm_triton(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
output_dtype=self.infer_dtype,
)
if self.has_lora_branch:
return output_tensor.squeeze(0) + self.apply_lora(input_tensor) if len(output_tensor.shape) == 3 else output_tensor + +self.apply_lora(input_tensor)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 pertoken-pergroup group=128 dynamic sym
Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_fp8_perblock128_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
deep_gemm.gemm_fp8_fp8_bf16_nt(
(input_tensor_quant, input_tensor_scale),
(self.weight, self.weight_scale),
output_tensor,
)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self._get_actual_bias())
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("fp8-sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Sgl-kernel
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.weight_need_transpose = True
self.scale_force_fp32 = True
self.load_func = self.load_fp8_perchannel_sym
self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.infer_dtype,
self._get_actual_bias(),
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("int8-sgl")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_int8_perchannel_sym
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
self.weight_need_transpose = True
self.scale_force_fp32 = True
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.int8_scaled_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.infer_dtype,
self._get_actual_bias(),
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("fp8-torchao")
class MMWeightWfp8channelAfp8channeldynamicTorchao(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Torchao
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Torchao
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_fp8_perchannel_sym
self.act_quant_func = self.act_quant_fp8_perchannel_sym_torchao
self.weight_need_transpose = True
self.scale_force_fp32 = True
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_fp8_perchannel_sym_torchao(input_tensor)
output_tensor = torch._scaled_mm(
input_tensor_quant,
self.weight,
scale_a=input_tensor_scale.float(),
scale_b=self.weight_scale.t(),
bias=self._get_actual_bias(),
out_dtype=self.infer_dtype,
use_fast_accum=True,
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("int8-torchao")
class MMWeightWint8channelAint8channeldynamicTorchao(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao
def apply(self, input_tensor):
input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = torchao_int8_gemm(
input_tensor_quant,
input_tensor_scale,
self.weight,
self.weight_scale.t().float(),
output_dtype=self.infer_dtype,
)
if self.bias is not None:
output_tensor.add_(self._get_actual_bias())
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
class MMWeightGGUFTemplate(MMWeightTemplate):
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
def load(self, weight_dict):
if not self.lazy_load:
assert not self.create_cuda_buffer, "GGUF Unsupported offload block"
self.weight = weight_dict[self.weight_name]
weight_shape = self.weight.shape
weight_dtype = self.weight.dtype
if isinstance(self.weight, GGMLTensor):
self.pin_weight = GGMLTensor.empty_pinned(
weight_shape,
orig_shape=self.weight.orig_shape,
dtype=weight_dtype,
gguf_type=self.weight.gguf_type,
)
self.pin_weight.copy_from(self.weight)
else:
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
if isinstance(self.bias, GGMLTensor):
self.pin_bias = GGMLTensor.empty_pinned(
self.bias.shape,
orig_shape=self.bias.orig_shape,
dtype=self.bias.dtype,
gguf_type=self.bias.gguf_type,
)
self.pin_bias.copy_from(self.bias)
else:
self.pin_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
bias_name = re.sub(
r"\.\d+",
lambda m: f".{adapter_block_index}",
self.bias_name,
count=1,
)
else:
bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else:
self.bias = None
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
return destination
def get_weight(self, tensor, dtype):
if tensor is None:
return
weight = gguf_dequantize_tensor(tensor, dtype)
if isinstance(weight, GGMLTensor):
weight = torch.Tensor(weight)
return weight
def cast_bias_weight(self, input_tensor=None, dtype=None, device=None, bias_dtype=None):
if input_tensor is not None:
if dtype is None:
dtype = getattr(input_tensor, "dtype", torch.float32)
bias = None
if self.bias is not None:
bias = self.get_weight(self.bias, dtype)
weight = self.get_weight(self.weight, dtype)
return weight, bias
def apply(self, input_tensor):
weight, bias = self.cast_bias_weight(input_tensor)
if self.has_lora_branch:
return torch.nn.functional.linear(input_tensor, weight, self._get_actual_bias(bias)) + self.apply_lora(input_tensor)
return torch.nn.functional.linear(input_tensor, weight, self._get_actual_bias(bias))
@MM_WEIGHT_REGISTER("gguf-BF16")
class MMWeightGGUFBF16(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.BF16
@MM_WEIGHT_REGISTER("gguf-Q8_0")
class MMWeightGGUFQ80(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q8_0
@MM_WEIGHT_REGISTER("gguf-Q6_K")
class MMWeightGGUFQ6K(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_K_S")
class MMWeightGGUFQ5KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_K_M")
class MMWeightGGUFQ5KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_1")
class MMWeightGGUFQ51(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_1
@MM_WEIGHT_REGISTER("gguf-Q5_0")
class MMWeightGGUFQ50(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_0
@MM_WEIGHT_REGISTER("gguf-Q4_K_M")
class MMWeightGGUFQ4KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_0
@MM_WEIGHT_REGISTER("gguf-Q4_K_S")
class MMWeightGGUFQ4KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_K
@MM_WEIGHT_REGISTER("gguf-Q4_1")
class MMWeightGGUFQ41(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_1
@MM_WEIGHT_REGISTER("gguf-Q4_0")
class MMWeightGGUFQ40(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_0
@MM_WEIGHT_REGISTER("gguf-Q3_K_M")
class MMWeightGGUFQ3KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q3_K
@MM_WEIGHT_REGISTER("gguf-Q3_K_S")
class MMWeightGGUFQ3KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q2_K
@MM_WEIGHT_REGISTER("int4-g128-marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
"""
Name: "W-int4-group128-sym-Marlin
Quant int4 x FP16:
Weight: int4 pergroup sym
Kernel: Marlin
"""
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_quantized
def load(self, weight_dict):
assert not self.lazy_load
self.load_func(weight_dict)
self.workspace = weight_dict[f"{self.weight_name}_workspace"]
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def apply(self, input_tensor):
output_tensor = torch.empty(
input_tensor.shape[:-1] + (self.weight_scale.shape[1],),
dtype=input_tensor.dtype,
device=input_tensor.device,
)
marlin_cuda_quant.mul(
input_tensor,
self.weight,
output_tensor,
self.weight_scale.half(),
self.workspace,
-1,
-1,
-1,
-1,
)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self._get_actual_bias())
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("fp8-pertensor")
class MMWeightWfp8tensorAfp8tensordynamic(MMWeightQuantTemplate):
def __init__(
self,
weight_name,
bias_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.load_func = self.load_fp8_pertensor_sym
self.act_quant_func = self.act_quant_fp8_pertensor_sym
self.weight_need_transpose = True
self.scale_force_fp32 = True
def _update_base_attrs(self):
super()._update_base_attrs()
self.input_scale_name = self.weight_name.removesuffix(".weight") + ".input_scale"
self.base_attrs.append((self.input_scale_name, "input_scale", False))
def load_quantized(self, weight_dict):
super().load_quantized(weight_dict)
if not self.create_cuda_buffer and not self.create_cpu_buffer and not self.lazy_load:
device_tensors, pin_tensors = create_default_tensors(self.base_attrs, weight_dict)
self.input_scale = device_tensors.get("input_scale")
self.pin_input_scale = pin_tensors.get("input_scale")
elif self.create_cuda_buffer:
result = create_cuda_buffers(self.base_attrs, weight_dict, self.lazy_load, self.lazy_load_file, scale_force_fp32=self.scale_force_fp32, bias_force_fp32=self.bias_force_fp32)
self.input_scale_cuda_buffer = result.get("input_scale")
elif self.create_cpu_buffer:
result = create_cpu_buffers(self.base_attrs, self.lazy_load_file, scale_force_fp32=self.scale_force_fp32, bias_force_fp32=self.bias_force_fp32)
self.pin_input_scale = result.get("input_scale")
self.input_scale = None
def post_process(self):
super().post_process()
if self.scale_force_fp32:
if hasattr(self, "input_scale") and self.input_scale is not None:
self.input_scale = self.input_scale.to(torch.float32)
if hasattr(self, "pin_input_scale") and self.pin_input_scale is not None:
self.pin_input_scale = self.pin_input_scale.to(torch.float32)
def load_fp8_pertensor_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False):
raise NotImplementedError
else:
self.load_quantized(weight_dict)
def act_quant_fp8_pertensor_sym(self, x):
quantized = torch.clamp(x / self.input_scale, -448, 448).to(torch.float8_e4m3fn)
return quantized
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant = self.act_quant_func(input_tensor)
output_tensor = torch._scaled_mm(
input_tensor_quant,
self.weight,
scale_a=self.input_scale,
scale_b=self.weight_scale.reshape(1),
bias=self._get_actual_bias(),
out_dtype=dtype,
use_fast_accum=True,
)
if self.has_lora_branch:
return output_tensor + self.apply_lora(input_tensor)
return output_tensor
@MM_WEIGHT_REGISTER("TensorParallel")
class MMWeightTP(MMWeightTemplate):
"""
Tensor Parallel wrapper for any MMWeight type.
This is a generic wrapper that can wrap any MMWeight implementation (Default, fp8, int8, etc.)
and add tensor parallelism support by:
1. Handling weight splitting in load() method
2. Adding all-reduce for row-wise split in apply() method
Supports column-wise and row-wise weight splitting:
- Column split: weight [in_dim, out_dim] -> [in_dim, out_dim/tp_size] per rank
- Row split: weight [in_dim, out_dim] -> [in_dim/tp_size, out_dim] per rank
"""
def __init__(
self,
weight_name,
bias_name,
mm_type="Default",
tp_group=None,
tp_rank=0,
tp_size=1,
split_dim="col",
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
lora_prefix,
lora_path,
)
self.tp_group = tp_group
self.tp_rank = tp_rank
self.tp_size = tp_size
self.split_dim = split_dim # "col" for column split, "row" for row split
assert split_dim in ["col", "row"], f"split_dim must be 'col' or 'row', got {split_dim}"
self._mm = MM_WEIGHT_REGISTER.get(mm_type, MMWeight)(
weight_name=weight_name,
bias_name=bias_name,
create_cuda_buffer=create_cuda_buffer,
create_cpu_buffer=create_cpu_buffer,
lazy_load=lazy_load,
lazy_load_file=lazy_load_file,
is_post_adapter=is_post_adapter,
lora_prefix=lora_prefix,
lora_path=lora_path,
)
self._row_split_bias = None
def load(self, weight_dict):
"""Load weights using internal MMWeight's load method.
Note: Weights in weight_dict are already split by _load_weights_from_rank0.
The format is [out_dim/tp_size, in_dim] for column split or [out_dim, in_dim/tp_size] for row split.
MMWeight.load will handle the transposition via create_default_tensors.
For row split, bias is not split and should be added after all-reduce.
We temporarily remove bias from _mm to prevent it from being added before all-reduce.
"""
self._mm.load(weight_dict)
if self.split_dim == "row" and self.bias_name is not None and self.bias_name in weight_dict:
self._row_split_bias = self._mm.bias.clone()
self._mm.bias = None
def apply(self, input_tensor):
"""Apply matrix multiplication with tensor parallel support."""
# Use internal MMWeight's apply method (handles fp8, int8, etc.)
# For row split, _mm.bias is None, so bias won't be added here
output = self._mm.apply(input_tensor)
# For row split, need all-reduce to combine results from all ranks
if self.split_dim == "row" and self.tp_size > 1 and self.tp_group is not None:
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.tp_group)
# Add bias after all-reduce (bias is not split for row split)
if self._row_split_bias is not None:
output = output + self._row_split_bias
return output
import torch
from triton import Config, autotune, cdiv, jit, next_power_of_2
from triton import language as tl
_ordered_datatypes = [torch.int8, torch.float16, torch.bfloat16, torch.float32]
@jit
def gelu(x):
return x * tl.sigmoid(x * 1.702)
@jit
def int8_quantize_kernel(X, OUT, SCALES, HDIM, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
x_ptr = X + row_idx * HDIM
out_ptr = OUT + row_idx * HDIM
h_offset = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + h_offset, mask=h_offset < HDIM).to(tl.float32)
x_scale = 127.0 / tl.max(tl.abs(x))
x_scaled = x * x_scale
x_scaled += (0.5 * tl.where(x_scaled >= 0, 1, -1)).to(tl.int8)
tl.store(out_ptr + h_offset, x_scaled, mask=h_offset < HDIM)
tl.store(SCALES + row_idx, 1 / x_scale)
def int8_quantize_triton(x):
x_shape_orig = x.shape
x = x.view(-1, x_shape_orig[-1])
out = torch.empty(x_shape_orig, dtype=torch.int8, device=x.device)
scales = torch.empty(x.shape[0], dtype=torch.float32, device=x.device)
BLOCK_SIZE = next_power_of_2(x_shape_orig[-1])
grid = (x.shape[0],)
int8_quantize_kernel[grid](x, out, scales, x_shape_orig[-1], BLOCK_SIZE, num_warps=8)
return out.view(x_shape_orig), scales.view(x_shape_orig[:-1])
@jit
def fp8_quantize_kernel(X, OUT, SCALES, HDIM, BLOCK_SIZE: tl.constexpr, FP8_MAX_VAL: tl.constexpr):
row_idx = tl.program_id(0)
x_ptr = X + row_idx * HDIM
out_ptr = OUT + row_idx * HDIM
h_offset = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + h_offset, mask=h_offset < HDIM).to(tl.float32)
absmax = tl.max(tl.abs(x))
eps = 1e-8
absmax = tl.maximum(absmax, eps)
x_scale = absmax / FP8_MAX_VAL
x_scaled = x / x_scale
x_scaled = tl.clamp(x_scaled, -FP8_MAX_VAL, FP8_MAX_VAL)
tl.store(out_ptr + h_offset, x_scaled, mask=h_offset < HDIM)
tl.store(SCALES + row_idx, x_scale)
def fp8_quantize_triton(x):
x_shape_orig = x.shape
x = x.view(-1, x_shape_orig[-1])
out_scaled = torch.empty(x_shape_orig, dtype=torch.float32, device=x.device)
scales = torch.empty(x.shape[0], dtype=torch.bfloat16, device=x.device)
BLOCK_SIZE = next_power_of_2(x_shape_orig[-1])
grid = (x.shape[0],)
FP8_MAX = 448.0
fp8_quantize_kernel[grid](x, out_scaled, scales, x_shape_orig[-1], BLOCK_SIZE, FP8_MAX_VAL=FP8_MAX, num_warps=8)
quantized = out_scaled.to(torch.float8_e4m3fn)
return quantized.view(x_shape_orig), scales.view(x_shape_orig[:-1])
def upcast_if_fp8(a):
if "fp8" in str(a):
return torch.float16
return a
def get_higher_dtype(a, b):
a = upcast_if_fp8(a)
b = upcast_if_fp8(b)
if a is b:
return a
assert a in _ordered_datatypes
assert b in _ordered_datatypes
for d in _ordered_datatypes:
if a is d:
return b
if b is d:
return a
@autotune(
configs=[
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=8),
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=3, num_warps=8),
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=8),
],
key=["M", "N", "K"],
)
@jit
def int8_gemm_bias_kernel(
A,
B,
BIAS,
A_SCALES,
B_SCALES,
C,
M,
N,
K, #
stride_am,
stride_ak, #
stride_bk,
stride_bn, #
stride_cm,
stride_cn, #
acc_dtype: tl.constexpr, #
fuse_gelu: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, #
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
AB_DTYPE: tl.constexpr, #
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
if AB_DTYPE is not None:
a = a.to(AB_DTYPE)
b = b.to(AB_DTYPE)
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, input_precision=None)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(tl.float32)
a_scales_ptr = A_SCALES + pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
b_scales_ptr = B_SCALES + pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
a_scales = tl.load(a_scales_ptr) # [BM]
b_scales = tl.load(b_scales_ptr) # [BN]
# [BM, BN] * [BM, 1] * [1, BN]
bias_ptr = BIAS + pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
bias = tl.load(bias_ptr)
if fuse_gelu:
acc = gelu(((acc * a_scales[:, None]) * b_scales[None, :]) + bias[None, :])
else:
acc = ((acc * a_scales[:, None]) * b_scales[None, :]) + bias[None, :]
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
# @torch.compiler.disable()
def int8_gemm_bias_triton(a, b, bias, a_scales, b_scales, fuse_gelu=False, output_dtype=None):
device = a.device
# handle non-contiguous inputs if necessary
a_orig_shape = a.shape
a = a.view(-1, a.shape[-1])
b = b.t()
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], f"incompatible dimensions {a.shape} and {b.shape}"
M, K = a.shape
_, N = b.shape
out_shape = a_orig_shape[:-1] + (N,)
# common type between a and b
ab_dtype = get_higher_dtype(a.dtype, b.dtype)
# allocates output
if output_dtype is None:
output_dtype = ab_dtype
c = torch.empty((M, N), device=device, dtype=output_dtype)
# Allowed types for acc_type given the types of a and b.
supported_acc_dtypes = {
torch.float16: (torch.float32, torch.float16),
torch.bfloat16: (torch.float32, torch.bfloat16),
torch.float32: (torch.float32,),
torch.int8: (torch.int32,),
}
acc_dtype = supported_acc_dtypes[ab_dtype][0]
def to_tl_type(ty):
return getattr(tl, str(ty).split(".")[-1])
acc_dtype = to_tl_type(acc_dtype)
ab_dtype = to_tl_type(ab_dtype)
output_dtype = to_tl_type(output_dtype)
# Tensor cores support input with mixed float8 types.
if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [
tl.float8e4nv,
tl.float8e5,
]:
ab_dtype = None
# launch kernel
grid = lambda META: ( # noqa E731
cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]),
META["SPLIT_K"],
) # noqa E731
int8_gemm_bias_kernel[grid](
a,
b,
bias,
a_scales,
b_scales,
c,
M,
N,
K, #
a.stride(0),
a.stride(1), #
b.stride(0),
b.stride(1), #
c.stride(0),
c.stride(1), #
acc_dtype=acc_dtype, #
fuse_gelu=fuse_gelu,
GROUP_M=8,
EVEN_K=True,
AB_DTYPE=ab_dtype,
)
return c.view(*out_shape)
@autotune(
configs=[
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=8),
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=3, num_warps=8),
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=8),
],
key=["M", "N", "K"],
)
@jit
def int8_gemm_kernel(
A,
B,
A_SCALES,
B_SCALES,
C,
M,
N,
K, #
stride_am,
stride_ak, #
stride_bk,
stride_bn, #
stride_cm,
stride_cn, #
acc_dtype: tl.constexpr, #
fuse_gelu: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr, #
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
AB_DTYPE: tl.constexpr, #
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
_0 = tl.zeros((1, 1), dtype=C.dtype.element_ty)
a = tl.load(A, mask=rk[None, :] < k_remaining, other=_0)
b = tl.load(B, mask=rk[:, None] < k_remaining, other=_0)
if AB_DTYPE is not None:
a = a.to(AB_DTYPE)
b = b.to(AB_DTYPE)
acc = tl.dot(a, b, acc, out_dtype=acc_dtype, input_precision=None)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
acc = acc.to(tl.float32)
a_scales_ptr = A_SCALES + pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
b_scales_ptr = B_SCALES + pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
a_scales = tl.load(a_scales_ptr) # [BM]
b_scales = tl.load(b_scales_ptr) # [BN]
# [BM, BN] * [BM, 1] * [1, BN]
if fuse_gelu:
acc = gelu((acc * a_scales[:, None]) * b_scales[None, :])
else:
acc = (acc * a_scales[:, None]) * b_scales[None, :]
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
# @torch.compiler.disable()
def int8_gemm_triton(a, b, a_scales, b_scales, fuse_gelu=False, output_dtype=None):
device = a.device
# handle non-contiguous inputs if necessary
# USE ONLY IN linear layer. NOT GENERAL MATRIX MULTIPLY
a_orig_shape = a.shape
a = a.view(-1, a.shape[-1])
b = b.t()
if a.stride(0) > 1 and a.stride(1) > 1:
a = a.contiguous()
if b.stride(0) > 1 and b.stride(1) > 1:
b = b.contiguous()
# checks constraints
assert a.shape[1] == b.shape[0], f"incompatible dimensions {a.shape} and {b.shape}"
M, K = a.shape
_, N = b.shape
out_shape = a_orig_shape[:-1] + (N,)
# common type between a and b
ab_dtype = get_higher_dtype(a.dtype, b.dtype)
# allocates output
if output_dtype is None:
output_dtype = ab_dtype
c = torch.empty((M, N), device=device, dtype=output_dtype)
# Allowed types for acc_type given the types of a and b.
supported_acc_dtypes = {
torch.float16: (torch.float32, torch.float16),
torch.bfloat16: (torch.float32, torch.bfloat16),
torch.float32: (torch.float32,),
torch.int8: (torch.int32,),
}
acc_dtype = supported_acc_dtypes[ab_dtype][0]
def to_tl_type(ty):
return getattr(tl, str(ty).split(".")[-1])
acc_dtype = to_tl_type(acc_dtype)
ab_dtype = to_tl_type(ab_dtype)
output_dtype = to_tl_type(output_dtype)
# Tensor cores support input with mixed float8 types.
if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [
tl.float8e4nv,
tl.float8e5,
]:
ab_dtype = None
# launch kernel
grid = lambda META: ( # noqa E731
cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]),
META["SPLIT_K"],
) # noqa E731
int8_gemm_kernel[grid](
a,
b,
a_scales,
b_scales,
c,
M,
N,
K, #
a.stride(0),
a.stride(1), #
b.stride(0),
b.stride(1), #
c.stride(0),
c.stride(1), #
acc_dtype=acc_dtype, #
fuse_gelu=fuse_gelu,
EVEN_K=True,
GROUP_M=8,
AB_DTYPE=ab_dtype,
)
return c.view(*out_shape)
@autotune(
configs=[
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=8),
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=3, num_warps=8),
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=8),
],
key=["M", "N", "K"],
)
@jit
def fp8_gemm_bias_kernel(
A,
B,
BIAS,
A_SCALES,
B_SCALES,
C,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
fuse_gelu: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
):
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // group_size
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
A_ptr = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B_ptr = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A_ptr)
b = tl.load(B_ptr)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A_ptr, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B_ptr, mask=rk[:, None] < k_remaining, other=0.0)
acc = tl.dot(a, b, acc, out_dtype=tl.float32, input_precision=None)
A_ptr += BLOCK_K * SPLIT_K * stride_ak
B_ptr += BLOCK_K * SPLIT_K * stride_bk
a_scales_ptr = A_SCALES + pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
b_scales_ptr = B_SCALES + pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
a_scales = tl.load(a_scales_ptr).to(tl.float32) # [BM]
b_scales = tl.load(b_scales_ptr).to(tl.float32) # [BN]
bias_ptr = BIAS + pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
bias = tl.load(bias_ptr).to(tl.float32) # [BN]
out = (acc * a_scales[:, None]) * b_scales[None, :] + bias[None, :]
if fuse_gelu:
out = gelu(out)
out = out.to(C.dtype.element_ty)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C_ptr = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
if SPLIT_K == 1:
tl.store(C_ptr, out, mask=mask)
else:
tl.atomic_add(C_ptr, out, mask=mask)
def fp8_gemm_bias_triton(a, b, bias, a_scales, b_scales, fuse_gelu=False, output_dtype=None):
assert a.is_cuda and b.is_cuda, "This kernel is for CUDA"
assert a.dtype in (getattr(torch, "float8_e4m3fn", None), getattr(torch, "float8_e4m3fnuz", None)), f"a.dtype={a.dtype} is not FP8 E4M3"
assert b.dtype in (getattr(torch, "float8_e4m3fn", None), getattr(torch, "float8_e4m3fnuz", None)), f"b.dtype={b.dtype} is not FP8 E4M3"
a_orig_shape = a.shape
a2 = a.view(-1, a.shape[-1])
b2 = b.t()
if a2.stride(0) > 1 and a2.stride(1) > 1:
a2 = a2.contiguous()
if b2.stride(0) > 1 and b2.stride(1) > 1:
b2 = b2.contiguous()
M, K = a2.shape
_, N = b2.shape
out_shape = a_orig_shape[:-1] + (N,)
if output_dtype is None:
output_dtype = torch.float16
c = torch.empty((M, N), device=a.device, dtype=output_dtype)
grid = lambda META: (cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]), META["SPLIT_K"]) # noqa E731
even_k = K % 128 == 0
fp8_gemm_bias_kernel[grid](
a2,
b2,
bias,
a_scales,
b_scales,
c,
M,
N,
K,
a2.stride(0),
a2.stride(1),
b2.stride(0),
b2.stride(1),
c.stride(0),
c.stride(1),
fuse_gelu=fuse_gelu,
GROUP_M=8,
EVEN_K=even_k,
)
return c.view(*out_shape)
@autotune(
configs=[
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=3, num_warps=8),
Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, num_stages=4, num_warps=8),
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=3, num_warps=8),
Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, num_stages=4, num_warps=8),
],
key=["M", "N", "K"],
)
@jit
def fp8_gemm_kernel(
A,
B,
A_SCALES,
B_SCALES,
C,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
fuse_gelu: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
):
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = tl.cdiv(M, BLOCK_M)
grid_n = tl.cdiv(N, BLOCK_N)
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // group_size
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
A_ptr = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B_ptr = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
if EVEN_K:
a = tl.load(A_ptr)
b = tl.load(B_ptr)
else:
k_remaining = K - k * (BLOCK_K * SPLIT_K)
a = tl.load(A_ptr, mask=rk[None, :] < k_remaining, other=0.0)
b = tl.load(B_ptr, mask=rk[:, None] < k_remaining, other=0.0)
acc = tl.dot(a, b, acc, out_dtype=tl.float32, input_precision=None)
A_ptr += BLOCK_K * SPLIT_K * stride_ak
B_ptr += BLOCK_K * SPLIT_K * stride_bk
a_scales_ptr = A_SCALES + pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
b_scales_ptr = B_SCALES + pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
a_scales = tl.load(a_scales_ptr).to(tl.float32) # [BM]
b_scales = tl.load(b_scales_ptr).to(tl.float32) # [BN]
out = (acc * a_scales[:, None]) * b_scales[None, :]
if fuse_gelu:
out = gelu(out)
out = out.to(C.dtype.element_ty)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C_ptr = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
if SPLIT_K == 1:
tl.store(C_ptr, out, mask=mask)
else:
tl.atomic_add(C_ptr, out, mask=mask)
def fp8_gemm_triton(a, b, a_scales, b_scales, fuse_gelu=False, output_dtype=None):
assert a.is_cuda and b.is_cuda
e4m3_ok = []
if hasattr(torch, "float8_e4m3fn"):
e4m3_ok.append(torch.float8_e4m3fn)
if hasattr(torch, "float8_e4m3fnuz"):
e4m3_ok.append(torch.float8_e4m3fnuz)
e4m3_ok = tuple(e4m3_ok)
assert a.dtype in e4m3_ok, f"a.dtype={a.dtype} is not FP8 E4M3"
assert b.dtype in e4m3_ok, f"b.dtype={b.dtype} is not FP8 E4M3"
a_orig_shape = a.shape
a2 = a.view(-1, a.shape[-1])
b2 = b.t()
if a2.stride(0) > 1 and a2.stride(1) > 1:
a2 = a2.contiguous()
if b2.stride(0) > 1 and b2.stride(1) > 1:
b2 = b2.contiguous()
M, K = a2.shape
_, N = b2.shape
out_shape = a_orig_shape[:-1] + (N,)
if output_dtype is None:
output_dtype = torch.float16
c = torch.empty((M, N), device=a.device, dtype=output_dtype)
grid = lambda META: ( # noqa E731
cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]),
META["SPLIT_K"],
) # noqa E731
even_k = K % 128 == 0
fp8_gemm_kernel[grid](
a2,
b2,
a_scales,
b_scales,
c,
M,
N,
K,
a2.stride(0),
a2.stride(1),
b2.stride(0),
b2.stride(1),
c.stride(0),
c.stride(1),
fuse_gelu=fuse_gelu,
GROUP_M=8,
EVEN_K=even_k,
)
return c.view(*out_shape)
from .layer_norm_weight import *
from .rms_norm_weight import *
from abc import ABCMeta, abstractmethod
import torch
from loguru import logger
from safetensors import safe_open
from lightx2v.common.ops.utils import *
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
from .triton_ops import norm_infer
class LNWeightTemplate(metaclass=ABCMeta):
def __init__(
self,
weight_name=None,
bias_name=None,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
self.weight_name = weight_name
self.bias_name = bias_name
self.eps = eps
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.config = {}
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.lora_prefix = lora_prefix
self.lora_path = lora_path
self.has_lora_branch = False
self.has_diff = False
self._get_base_attrs_mapping()
self._get_lora_attr_mapping()
def _get_base_attrs_mapping(self):
self.base_attrs = []
if self.weight_name is not None:
self.base_attrs.append((self.weight_name, "weight", False))
else:
self.weight = None
if self.bias_name is not None:
self.base_attrs.append((self.bias_name, "bias", False))
else:
self.bias = None
def _get_lora_attr_mapping(self):
if self.weight_name is not None:
_, _, _, self.weight_diff_name, self.bias_diff_name = build_lora_and_diff_names(self.weight_name, self.lora_prefix)
self.lora_attrs = {
"weight_diff": "weight_diff_name",
"bias_diff": "bias_diff_name",
}
else:
self.weight_diff_name = None
self.bias_diff_name = None
self.lora_attrs = {}
def _get_actual_weight(self):
if self.weight is None:
return None
if not hasattr(self, "weight_diff"):
return self.weight
return self.weight + self.weight_diff
def _get_actual_bias(self):
if self.bias is None:
return None
if not hasattr(self, "bias_diff"):
return self.bias
return self.bias + self.bias_diff
def load(self, weight_dict):
if not self.create_cuda_buffer and not self.create_cpu_buffer and not self.lazy_load:
device_tensors, pin_tensors = create_default_tensors(self.base_attrs, weight_dict)
self.weight = device_tensors.get("weight")
self.bias = device_tensors.get("bias")
self.pin_weight = pin_tensors.get("weight")
self.pin_bias = pin_tensors.get("bias")
elif self.create_cuda_buffer:
result = create_cuda_buffers(
self.base_attrs,
weight_dict,
self.lazy_load,
self.lazy_load_file,
use_infer_dtype=True,
)
self.weight_cuda_buffer = result.get("weight")
self.bias_cuda_buffer = result.get("bias")
elif self.create_cpu_buffer:
result = create_cpu_buffers(self.base_attrs, self.lazy_load_file, use_infer_dtype=True)
self.pin_weight = result.get("weight")
self.pin_bias = result.get("bias")
self.weight = None
self.bias = None
def register_diff(self, weight_dict):
if not self.lazy_load or self.create_cuda_buffer or self.create_cpu_buffer:
if self.weight_diff_name is not None and self.weight_diff_name in weight_dict:
self.weight_diff = weight_dict[self.weight_diff_name]
self.has_diff = True
logger.debug(f"Register Diff to {self.weight_name}")
if self.bias_diff_name is not None and self.bias_diff_name in weight_dict:
self.bias_diff = weight_dict[self.bias_diff_name]
self.has_diff = True
logger.debug(f"Register Diff to {self.bias_name}")
def set_config(self, config=None):
if config is not None:
self.config = config
def state_dict(self, destination=None):
return state_dict(self, self.base_attrs, self.lora_attrs, destination)
def load_state_dict(self, destination, block_index, adapter_block_index=None):
return load_state_dict(
self,
self.base_attrs,
self.lora_attrs,
destination,
block_index,
adapter_block_index,
)
def load_lora_state_dict_from_disk(self, block_index):
self.weight_diff_name = resolve_block_name(self.weight_diff_name, block_index)
self.bias_diff_name = resolve_block_name(self.bias_diff_name, block_index)
with safe_open(self.lora_path, framework="pt", device="cpu") as lora_load_file:
for lora_attr, lora_attr_name in self.lora_attrs.items():
if getattr(self, lora_attr_name) in lora_load_file.keys():
setattr(
self,
lora_attr,
getattr(self, lora_attr).copy_(
lora_load_file.get_tensor(getattr(self, lora_attr_name)),
non_blocking=True,
),
)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.weight_name is not None:
if self.has_lora_branch or self.has_diff:
self.load_lora_state_dict_from_disk(block_index)
self.weight_name = resolve_block_name(self.weight_name, block_index, adapter_block_index, self.is_post_adapter)
lazy_load_file_path = get_lazy_load_file_path(self.lazy_load_file, self.weight_name)
if self.bias_name is not None:
self.bias_name = resolve_block_name(
self.bias_name,
block_index,
adapter_block_index,
self.is_post_adapter,
)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
self.pin_bias = self.pin_bias.copy_(bias_tensor)
else:
self.pin_bias = None
del weight_tensor
else:
self.weight = None
self.bias = None
def to_cuda(self, non_blocking=False):
move_attr_to_cuda(self, self.base_attrs, self.lora_attrs, non_blocking)
def to_cpu(self, non_blocking=False):
move_attr_to_cpu(self, self.base_attrs, self.lora_attrs, non_blocking)
@abstractmethod
def apply(self, input_tensor):
pass
@LN_WEIGHT_REGISTER("torch")
class LNWeight(LNWeightTemplate):
def __init__(
self,
weight_name=None,
bias_name=None,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
def apply(self, input_tensor):
if self.sensitive_layer_dtype != self.infer_dtype:
output_tensor = torch.nn.functional.layer_norm(
input_tensor.float(),
(input_tensor.shape[-1],),
(self._get_actual_weight()),
(self._get_actual_bias()),
self.eps,
).to(self.infer_dtype)
else:
output_tensor = torch.nn.functional.layer_norm(
input_tensor,
(input_tensor.shape[-1],),
(self._get_actual_weight()),
(self._get_actual_bias()),
self.eps,
)
return output_tensor
@LN_WEIGHT_REGISTER("Triton")
class LNWeight(LNWeightTemplate):
def __init__(
self,
weight_name=None,
bias_name=None,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
bias_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
def apply(self, input_tensor):
output_tensor = norm_infer(
input_tensor,
(self._get_actual_weight()),
self._get_actual_bias(),
self.eps,
)
return output_tensor
from abc import ABCMeta, abstractmethod
import torch
import torch.distributed as dist
from loguru import logger
from safetensors import safe_open
from lightx2v.common.ops.norm.triton_ops import rms_norm_kernel
from lightx2v.common.ops.utils import *
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
try:
import sgl_kernel
except ImportError:
sgl_kernel = None
class RMSWeightTemplate(metaclass=ABCMeta):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
self.weight_name = weight_name
self.eps = eps
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.config = {}
self.lora_prefix = lora_prefix
self.lora_path = lora_path
self.has_lora_branch = False
self.has_diff = False
self._get_base_attrs_mapping()
self._get_lora_attr_mapping()
def _get_base_attrs_mapping(self):
self.base_attrs = []
self.base_attrs.append((self.weight_name, "weight", False))
def _get_lora_attr_mapping(self):
_, _, _, self.weight_diff_name, _ = build_lora_and_diff_names(self.weight_name, self.lora_prefix)
self.lora_attrs = {
"weight_diff": "weight_diff_name",
}
self.weight_diff = torch.tensor(0.0, dtype=GET_DTYPE(), device=AI_DEVICE)
def _get_actual_weight(self):
if not hasattr(self, "weight_diff"):
return self.weight
return self.weight + self.weight_diff
def register_diff(self, weight_dict):
if not self.lazy_load or self.create_cuda_buffer or self.create_cpu_buffer:
if self.weight_diff_name in weight_dict:
self.weight_diff = weight_dict[self.weight_diff_name]
logger.debug(f"Register Diff to {self.weight_name}")
def load(self, weight_dict):
if not self.create_cuda_buffer and not self.create_cpu_buffer and not self.lazy_load:
device_tensors, pin_tensors = create_default_tensors(self.base_attrs, weight_dict)
self.weight = device_tensors.get("weight")
self.pin_weight = pin_tensors.get("weight")
elif self.create_cuda_buffer:
result = create_cuda_buffers(
self.base_attrs,
weight_dict,
self.lazy_load,
self.lazy_load_file,
use_infer_dtype=True,
)
self.weight_cuda_buffer = result.get("weight")
elif self.create_cpu_buffer:
result = create_cpu_buffers(self.base_attrs, self.lazy_load_file, use_infer_dtype=True)
self.pin_weight = result.get("weight")
self.weight = None
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cuda(self, non_blocking=False):
move_attr_to_cuda(self, self.base_attrs, self.lora_attrs, non_blocking)
def to_cpu(self, non_blocking=False):
move_attr_to_cpu(self, self.base_attrs, self.lora_attrs, non_blocking)
def state_dict(self, destination=None):
return state_dict(self, self.base_attrs, self.lora_attrs, destination)
def load_state_dict(self, destination, block_index, adapter_block_index=None):
return load_state_dict(
self,
self.base_attrs,
self.lora_attrs,
destination,
block_index,
adapter_block_index,
)
def load_lora_state_dict_from_disk(self, block_index):
self.weight_diff_name = resolve_block_name(self.weight_diff_name, block_index)
with safe_open(self.lora_path, framework="pt", device="cpu") as lora_load_file:
for lora_attr, lora_attr_name in self.lora_attrs.items():
if getattr(self, lora_attr_name) in lora_load_file.keys():
setattr(
self,
lora_attr,
getattr(self, lora_attr).copy_(
lora_load_file.get_tensor(getattr(self, lora_attr_name)),
non_blocking=True,
),
)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.has_lora_branch or self.has_diff:
self.load_lora_state_dict_from_disk(block_index)
self.weight_name = resolve_block_name(self.weight_name, block_index, adapter_block_index, self.is_post_adapter)
lazy_load_file_path = get_lazy_load_file_path(self.lazy_load_file, self.weight_name)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
@abstractmethod
def apply(self, input_tensor):
pass
@RMS_WEIGHT_REGISTER("torch")
class RMSWeight(RMSWeightTemplate):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, input_tensor):
if GET_SENSITIVE_DTYPE() != GET_DTYPE():
input_tensor = self._norm(input_tensor).type_as(input_tensor) * (self._get_actual_weight())
else:
input_tensor = self._norm(input_tensor.float()).type_as(input_tensor) * (self._get_actual_weight())
return input_tensor
@RMS_WEIGHT_REGISTER("TensorParallel")
class RMSWeightTP(RMSWeightTemplate):
"""
RMSNorm weight module with tensor parallelism support.
The weight is split along the hidden dimension to match the split QKV outputs.
"""
def __init__(
self,
weight_name,
tp_group=None,
tp_rank=0,
tp_size=1,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
self.tp_group = tp_group
self.tp_rank = tp_rank
self.tp_size = tp_size
def apply(self, input_tensor):
local_sum = input_tensor.pow(2).sum(-1, keepdim=True)
# All-reduce to get global sum
if self.tp_size > 1 and self.tp_group is not None:
dist.all_reduce(local_sum, op=dist.ReduceOp.SUM, group=self.tp_group)
# Compute global mean: global_sum / hidden_dim
hidden_dim = input_tensor.shape[-1] * self.tp_size
global_mean = local_sum / hidden_dim
# Apply normalization with global mean
if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = input_tensor * torch.rsqrt(global_mean.float() + self.eps).to(self.infer_dtype)
input_tensor = (input_tensor * self._get_actual_weight()).to(self.infer_dtype)
else:
input_tensor = input_tensor * torch.rsqrt(global_mean + self.eps)
input_tensor = input_tensor * self._get_actual_weight()
return input_tensor
@RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
def apply(self, input_tensor):
if sgl_kernel is not None and self.sensitive_layer_dtype == self.infer_dtype:
input_tensor = input_tensor.contiguous()
orig_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, orig_shape[-1])
input_tensor = sgl_kernel.rmsnorm(input_tensor, (self._get_actual_weight()), self.eps).view(orig_shape)
else:
# sgl_kernel is not available or dtype!=torch.bfloat16/float16, fallback to default implementation
if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).to(self.infer_dtype)
input_tensor = (input_tensor * (self._get_actual_weight())).to(self.infer_dtype)
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * (self._get_actual_weight())
return input_tensor
@RMS_WEIGHT_REGISTER("fp32_variance")
class RMSWeightFP32(RMSWeight):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
def apply(self, input_tensor, moe_gen=False):
input_dtype = input_tensor.dtype
if moe_gen:
hidden_states = input_tensor.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
return self.weight * hidden_states.to(input_dtype)
else:
variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = input_tensor * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)
return hidden_states
@RMS_WEIGHT_REGISTER("self_forcing")
class RMSWeightSF(RMSWeight):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, x):
return self._norm(x.float()).type_as(x) * (self._get_actual_weight())
@RMS_WEIGHT_REGISTER("one-pass")
class RMSWeightOnePass(RMSWeight):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
lora_prefix="diffusion_model.blocks",
lora_path="",
):
super().__init__(
weight_name,
create_cuda_buffer,
create_cpu_buffer,
lazy_load,
lazy_load_file,
is_post_adapter,
eps,
lora_prefix,
lora_path,
)
def apply(self, input_tensor):
return rms_norm_kernel(input_tensor, (self._get_actual_weight()), self.eps)
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang
# TODO: for temporary usage, expecting a refactor
from typing import Optional
import torch
import triton # type: ignore
import triton.language as tl # type: ignore
from torch import Tensor
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 256}, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_warps=8),
],
key=["inner_dim"],
)
@triton.jit
def _fused_scale_shift_4d_kernel(
output_ptr,
normalized_ptr,
scale_ptr,
shift_ptr,
rows,
inner_dim,
seq_len,
num_frames,
frame_seqlen,
BLOCK_N: tl.constexpr,
):
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)
mask = col_offsets < inner_dim
# Pointers for normalized and output
row_base = pid_row * inner_dim
norm_ptrs = normalized_ptr + row_base + col_offsets
out_ptrs = output_ptr + row_base + col_offsets
# Pointers for scale and shift for 4D
b_idx = pid_row // seq_len
t_idx = pid_row % seq_len
frame_idx_in_batch = t_idx // frame_seqlen
scale_row_idx = b_idx * num_frames + frame_idx_in_batch
scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets
shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets
normalized = tl.load(norm_ptrs, mask=mask, other=0.0)
scale = tl.load(scale_ptrs, mask=mask, other=0.0)
shift = tl.load(shift_ptrs, mask=mask, other=0.0)
one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype)
output = normalized * (one + scale) + shift
tl.store(out_ptrs, output, mask=mask)
@triton.jit
def fuse_scale_shift_kernel_blc_opt(
x_ptr,
shift_ptr,
scale_ptr,
y_ptr,
B,
L,
C,
stride_x_b,
stride_x_l,
stride_x_c,
stride_s_b,
stride_s_l,
stride_s_c,
stride_sc_b,
stride_sc_l,
stride_sc_c,
SCALE_IS_SCALAR: tl.constexpr,
SHIFT_IS_SCALAR: tl.constexpr,
BLOCK_L: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_l = tl.program_id(0)
pid_c = tl.program_id(1)
pid_b = tl.program_id(2)
l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
mask_l = l_offsets < L
mask_c = c_offsets < C
mask = mask_l[:, None] & mask_c[None, :]
x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c
x = tl.load(x_ptr + x_off, mask=mask, other=0)
if SHIFT_IS_SCALAR:
shift_val = tl.load(shift_ptr)
shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype)
else:
s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c
shift = tl.load(shift_ptr + s_off, mask=mask, other=0)
if SCALE_IS_SCALAR:
scale_val = tl.load(scale_ptr)
scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype)
else:
sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c
scale = tl.load(scale_ptr + sc_off, mask=mask, other=0)
y = x * (1 + scale) + shift
tl.store(y_ptr + x_off, y, mask=mask)
def fuse_scale_shift_kernel(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
block_l: int = 128,
block_c: int = 128,
):
# assert x.is_cuda and scale.is_cuda
assert x.is_contiguous()
if x.dim() == 2:
x = x.unsqueeze(0)
B, L, C = x.shape
output = torch.empty_like(x)
if scale.dim() == 4:
# scale/shift: [B, F, 1, C]
rows = B * L
x_2d = x.view(rows, C)
output_2d = output.view(rows, C)
grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa
num_frames = scale.shape[1]
assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift"
frame_seqlen = L // num_frames
# Compact [B, F, C] without the singleton dim into [B*F, C]
scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous()
shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous()
_fused_scale_shift_4d_kernel[grid](
output_2d,
x_2d,
scale_reshaped,
shift_reshaped,
rows,
C,
L,
num_frames,
frame_seqlen,
)
else:
# 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L
# 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C])
# Also support scalar (0D or 1-element)
if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1):
scale_blc = scale.reshape(1)
elif scale.dim() == 2:
scale_blc = scale[:, None, :]
elif scale.dim() == 3:
scale_blc = scale
else:
raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D")
if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1):
shift_blc = shift.reshape(1)
elif shift.dim() == 2:
shift_blc = shift[:, None, :]
elif shift.dim() == 3:
shift_blc = shift
else:
# broadcast later via expand if possible
shift_blc = shift
need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1
need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1
if not need_scale_scalar:
scale_exp = scale_blc.expand(B, L, C)
s_sb, s_sl, s_sc = scale_exp.stride()
else:
s_sb = s_sl = s_sc = 0
if not need_shift_scalar:
shift_exp = shift_blc.expand(B, L, C)
sh_sb, sh_sl, sh_sc = shift_exp.stride()
else:
sh_sb = sh_sl = sh_sc = 0
# If both scalars and both zero, copy fast-path
if need_scale_scalar and need_shift_scalar:
if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0):
output.copy_(x)
return output
grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B)
fuse_scale_shift_kernel_blc_opt[grid](
x,
shift_blc if need_shift_scalar else shift_exp,
scale_blc if need_scale_scalar else scale_exp,
output,
B,
L,
C,
x.stride(0),
x.stride(1),
x.stride(2),
sh_sb,
sh_sl,
sh_sc,
s_sb,
s_sl,
s_sc,
SCALE_IS_SCALAR=need_scale_scalar,
SHIFT_IS_SCALAR=need_shift_scalar,
BLOCK_L=block_l,
BLOCK_C=block_c,
num_warps=4,
num_stages=2,
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
],
key=["head_size", "interleaved"],
)
@triton.jit
def _rotary_embedding_kernel(
output_ptr,
x_ptr,
cos_ptr,
sin_ptr,
num_heads,
head_size,
num_tokens,
stride_x_row,
stride_cos_row,
stride_sin_row,
interleaved: tl.constexpr,
BLOCK_HS_HALF: tl.constexpr,
):
row_idx = tl.program_id(0)
token_idx = (row_idx // num_heads) % num_tokens
x_row_ptr = x_ptr + row_idx * stride_x_row
cos_row_ptr = cos_ptr + token_idx * stride_cos_row
sin_row_ptr = sin_ptr + token_idx * stride_sin_row
output_row_ptr = output_ptr + row_idx * stride_x_row
# half size for x1 and x2
head_size_half = head_size // 2
for block_start in range(0, head_size_half, BLOCK_HS_HALF):
offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF)
mask = offsets_half < head_size_half
cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0)
sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0)
offsets_x1 = 2 * offsets_half
offsets_x2 = 2 * offsets_half + 1
x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0)
x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0)
x1_fp32 = x1_vals.to(tl.float32)
x2_fp32 = x2_vals.to(tl.float32)
cos_fp32 = cos_vals.to(tl.float32)
sin_fp32 = sin_vals.to(tl.float32)
o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)
o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32)
tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask)
tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask)
def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
output = torch.empty_like(x)
if x.dim() > 3:
bsz, num_tokens, num_heads, head_size = x.shape
else:
num_tokens, num_heads, head_size = x.shape
bsz = 1
assert head_size % 2 == 0, "head_size must be divisible by 2"
x_reshaped = x.view(-1, head_size)
output_reshaped = output.view(-1, head_size)
# num_tokens per head, 1 token per block
grid = (bsz * num_tokens * num_heads,)
if interleaved and cos.shape[-1] == head_size:
cos = cos[..., ::2].contiguous()
sin = sin[..., ::2].contiguous()
else:
cos = cos.contiguous()
sin = sin.contiguous()
_rotary_embedding_kernel[grid](
output_reshaped,
x_reshaped,
cos,
sin,
num_heads,
head_size,
num_tokens,
x_reshaped.stride(0),
cos.stride(0),
sin.stride(0),
interleaved,
)
return output
# RMSNorm-fp32
def maybe_contiguous_lastdim(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def maybe_contiguous(x):
return x.contiguous() if x is not None else None
def triton_autotune_configs():
if not torch.cuda.is_available():
return []
# Return configs with a valid warp count for the current device
configs = []
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
max_threads_per_block = 1024
# Default to warp size 32 if not defined by device
warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block]
# return [triton.Config({}, num_warps=8)]
# Copied from flash-attn
@triton.autotune(
configs=triton_autotune_configs(),
key=[
"N",
"HAS_RESIDUAL",
"STORE_RESIDUAL_OUT",
"IS_RMS_NORM",
"HAS_BIAS",
"HAS_WEIGHT",
"HAS_X1",
"HAS_W1",
"HAS_B1",
],
)
# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
X1,
W1,
B1,
Y1,
RESIDUAL_OUT, # pointer to the residual
ROWSCALE,
SEEDS, # Dropout seeds for each row
DROPOUT_MASK,
DROPOUT_MASK1,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
stride_x1_row,
stride_y1_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
dropout_p, # Dropout probability
zero_centered_weight, # If true, add 1.0 to the weight
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
HAS_X1: tl.constexpr,
HAS_W1: tl.constexpr,
HAS_B1: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
if HAS_X1:
X1 += row * stride_x1_row
if HAS_W1:
Y1 += row * stride_y1_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
x *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
if HAS_X1:
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
x1 *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
x += x1
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w += 1.0
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
y = x_hat * w + b if HAS_BIAS else x_hat * w
else:
y = x_hat + b if HAS_BIAS else x_hat
# Write output
tl.store(Y + cols, y, mask=mask)
if HAS_W1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w1 += 1.0
if HAS_B1:
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
tl.store(Y1 + cols, y1, mask=mask)
def _layer_norm_fwd(
x: Tensor,
weight: Tensor,
bias: Tensor,
eps: float,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
residual_dtype: Optional[torch.dtype] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
out: Optional[Tensor] = None,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
# Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
# and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
# so that _layer_norm_fwd_impl doesn't have to return them.
if out is None:
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
if residual is not None:
residual_dtype = residual.dtype
if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None):
residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype)
else:
residual_out = None
y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
x,
weight,
bias,
eps,
out,
residual=residual,
x1=x1,
weight1=weight1,
bias1=bias1,
dropout_p=dropout_p,
rowscale=rowscale,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
residual_out=residual_out,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if residual_out is None:
residual_out = x
return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
# since we're returning a tuple of tensors
def _layer_norm_fwd_impl(
x: Tensor,
weight: Optional[Tensor],
bias: Tensor,
eps: float,
out: Tensor,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if x1 is not None:
assert x1.shape == x.shape
assert rowscale is None
assert x1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
assert out.shape == x.shape
assert out.stride(-1) == 1
if residual_out is not None:
assert residual_out.shape == x.shape
assert residual_out.stride(-1) == 1
if weight1 is not None:
y1 = torch.empty_like(out)
assert y1.stride(-1) == 1
else:
y1 = None
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
if dropout_p > 0.0:
seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64)
else:
seeds = None
if return_dropout_mask and dropout_p > 0.0:
dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
if x1 is not None:
dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
else:
dropout_mask1 = None
else:
dropout_mask, dropout_mask1 = None, None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
with torch.cuda.device(x.device.index):
torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
x,
out,
weight if weight is not None else x, # unused when HAS_WEIGHT == False
bias,
residual,
x1,
weight1,
bias1,
y1,
residual_out,
rowscale,
seeds,
dropout_mask,
dropout_mask1,
mean,
rstd,
x.stride(0),
out.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0,
y1.stride(0) if y1 is not None else 0,
M,
N,
eps,
dropout_p,
# Passing bool make torch inductor very unhappy since it then tries to compare to int_max
int(zero_centered_weight),
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
dropout_p > 0.0,
dropout_mask is not None,
rowscale is not None,
HAS_X1=x1 is not None,
HAS_W1=weight1 is not None,
HAS_B1=bias1 is not None,
)
return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
class LayerNormFn:
@staticmethod
def forward(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
if residual is not None:
assert residual.shape == x_shape_og
residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
if x1 is not None:
assert x1.shape == x_shape_og
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
# weight can be None when elementwise_affine=False for LayerNorm
if weight is not None:
weight = weight.contiguous()
bias = maybe_contiguous(bias)
weight1 = maybe_contiguous(weight1)
bias1 = maybe_contiguous(bias1)
if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None)
if out is not None:
out = out.reshape(-1, out.shape[-1])
if residual_out is not None:
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
x,
weight,
bias,
eps,
residual,
x1,
weight1,
bias1,
dropout_p=dropout_p,
rowscale=rowscale,
out_dtype=out_dtype,
residual_dtype=residual_dtype,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
out=out,
residual_out=residual_out,
)
y = y.reshape(x_shape_og)
return y
def layer_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
is_rms_norm,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
@triton.jit
def _norm_infer_kernel(
X,
Y,
W,
B,
stride_x_row,
stride_y_row,
M,
N,
eps,
IS_RMS_NORM: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
BLOCK_N: tl.constexpr,
):
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_WEIGHT:
W += 0
if HAS_BIAS:
B += 0
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)
y = x_hat * w
else:
y = x_hat
if HAS_BIAS:
b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)
y += b
tl.store(Y + cols, y, mask=cols < N)
def norm_infer(
x: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
is_rms_norm: bool = False,
out: Optional[Tensor] = None,
):
M, N = x.shape
x = x.contiguous()
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.shape == (N,)
assert bias.stride(-1) == 1
if out is None:
out = torch.empty_like(x)
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_N // 256, 1), 8)
_norm_infer_kernel[(M,)](
x,
out,
weight if weight is not None else x, # dummy when HAS_WEIGHT=False
bias if bias is not None else x, # dummy when HAS_BIAS=False
x.stride(0),
out.stride(0),
M,
N,
eps,
IS_RMS_NORM=is_rms_norm,
HAS_WEIGHT=weight is not None,
HAS_BIAS=bias is not None,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
)
return out
def rms_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
True,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
@triton.jit
def _rms_norm_tiled_onepass(
y_ptr,
x_ptr,
w_ptr,
SEQ: tl.constexpr,
DIM: tl.constexpr,
EPS: tl.constexpr,
BLOCK_SIZE_SEQ: tl.constexpr,
BLOCK_SIZE_DIM: tl.constexpr,
):
seq_blk_id = tl.program_id(0)
seq_id = seq_blk_id * BLOCK_SIZE_SEQ
seq_offset = seq_id + tl.arange(0, BLOCK_SIZE_SEQ)[:, None]
s_mask = seq_offset < SEQ
d_offset = tl.arange(0, BLOCK_SIZE_DIM)[None, :]
d_mask = d_offset < DIM
y_blk = y_ptr + seq_offset * DIM + d_offset
x_blk = x_ptr + seq_offset * DIM + d_offset
mask = s_mask & d_mask
x = tl.load(x_blk, mask=mask, other=0.0).to(tl.float32)
mean_square = tl.sum(x * x, axis=1, keep_dims=True) / DIM
rstd = tl.math.rsqrt(mean_square + EPS)
w = tl.load(w_ptr + d_offset, mask=d_mask)
tl.store(y_blk, x * rstd * w, mask=mask)
def rms_norm_kernel(x: torch.Tensor, w: torch.Tensor, eps: float = 1e-6):
shape = x.shape
x = x.contiguous()
y = torch.empty_like(x)
x_view = x.reshape(-1, shape[-1])
y_view = y.reshape(-1, shape[-1])
S, D = x_view.shape
BLOCK_SIZE_SEQ = min(16, triton.next_power_of_2(max(1, S // 512)))
grid = (triton.cdiv(S, BLOCK_SIZE_SEQ),)
with torch.cuda.device(x.device):
torch.library.wrap_triton(_rms_norm_tiled_onepass)[grid](
y_view,
x_view,
w,
S,
D,
eps,
BLOCK_SIZE_DIM=triton.next_power_of_2(D),
BLOCK_SIZE_SEQ=BLOCK_SIZE_SEQ,
)
return y
from .tensor import DefaultTensor
import os
import re
from pathlib import Path
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import TENSOR_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
@TENSOR_REGISTER("Default")
class DefaultTensor:
def __init__(self, tensor_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.tensor_name = tensor_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffer(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffer()
else:
self._load_default_tensors(weight_dict)
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
device = weight_dict[self.tensor_name].device
if device.type == "cpu":
tensor = weight_dict[self.tensor_name]
self.pin_tensor = self._create_cpu_pin_tensor(tensor)
del weight_dict[self.tensor_name]
else:
self.tensor = weight_dict[self.tensor_name]
def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.tensor_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[self.tensor_name]
return tensor
def _create_cpu_pin_tensor(self, tensor):
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_cuda_buffer(self, weight_dict):
tensor = self._get_tensor(weight_dict, use_infer_dtype=self.lazy_load)
self.tensor_cuda_buffer = tensor.to(AI_DEVICE)
def _load_cpu_pin_buffer(self):
tensor = self._get_tensor(use_infer_dtype=True)
self.pin_tensor = self._create_cpu_pin_tensor(tensor)
def to_cuda(self, non_blocking=False):
self.tensor = self.pin_tensor.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_tensor"):
self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
else:
self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
if tensor_name not in destination:
self.tensor = None
return
self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
self.pin_tensor = self.pin_tensor.copy_(tensor)
del tensor
import os
import re
from pathlib import Path
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v_platform.base.global_var import AI_DEVICE
def resolve_block_name(name, block_index, adapter_block_index=None, is_post_adapter=False):
"""Resolve the name according to the block index, replacing the block index in the name with the specified block_index.
Args:
name: Original tensor name, e.g. "blocks.0.weight"
block_index: Target block index
adapter_block_index: Target adapter block index (optional)
is_post_adapter: Whether to perform post-adapter block index replacement (optional)
Returns:
Resolved name, e.g. "blocks.1.weight" (when block_index=1)
Example:
>>> self._resolve_block_name("blocks.0.weight", 1)
"blocks.1.weight"
"""
if is_post_adapter:
return re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", name, count=1)
else:
return re.sub(r"\.\d+", lambda m: f".{block_index}", name, count=1)
def get_source_tensor(source_name, weight_dict, lazy_load, lazy_load_file, use_infer_dtype, scale_force_fp32, bias_force_fp32):
"""Get the source tensor from either weight dictionary or lazy loading safetensors file.
Args:
source_name: Name of the target tensor to get
weight_dict: Preloaded weight dictionary
lazy_load: Whether to enable lazy loading mode
lazy_load_file: File or directory path for lazy loading
use_infer_dtype: Whether to convert tensor to inference dtype
scale_force_fp32: Whether to force weight_scale tensors to float32
bias_force_fp32: Whether to force bias tensors to float32
Returns:
The target tensor retrieved from the source with appropriate dtype conversion applied
"""
if lazy_load:
if Path(lazy_load_file).is_file():
lazy_load_file_path = lazy_load_file
else:
lazy_load_file_path = os.path.join(
lazy_load_file,
f"block_{source_name.split('.')[1]}.safetensors",
)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
if use_infer_dtype:
return lazy_load_file.get_tensor(source_name).to(GET_DTYPE())
elif scale_force_fp32 and "weight_scale" in source_name:
return lazy_load_file.get_tensor(source_name).to(torch.float32)
elif bias_force_fp32 and "bias" in source_name:
return lazy_load_file.get_tensor(source_name).to(torch.float32)
return lazy_load_file.get_tensor(source_name)
else:
if use_infer_dtype:
return weight_dict[source_name].to(GET_DTYPE())
elif scale_force_fp32 and "weight_scale" in source_name:
return weight_dict[source_name].to(torch.float32)
elif bias_force_fp32 and "bias" in source_name:
return weight_dict[source_name].to(torch.float32)
return weight_dict[source_name]
def create_pin_tensor(tensor, transpose=False, dtype=None):
"""Create a tensor with pinned memory for faster data transfer to GPU.
Args:
tensor: Source tensor to be converted to pinned memory
transpose: Whether to transpose the tensor after creating pinned memory (optional)
dtype: Target data type of the pinned tensor (optional, defaults to source tensor's dtype)
Returns:
Pinned memory tensor (on CPU) with optional transposition applied
"""
dtype = dtype or tensor.dtype
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
pin_tensor = pin_tensor.copy_(tensor)
if transpose:
pin_tensor = pin_tensor.t()
del tensor
return pin_tensor
def get_lazy_load_file_path(lazy_load_file, weight_name_for_block=None):
"""Get the full file path for lazy loading, handling both file and directory inputs.
Args:
lazy_load_file: Base file or directory path for lazy loading
weight_name_for_block: Tensor weight name to generate block-specific file path (optional)
Returns:
Resolved full file path for lazy loading
"""
if weight_name_for_block is None:
return lazy_load_file
if Path(lazy_load_file).is_file():
return lazy_load_file
else:
return os.path.join(
lazy_load_file,
f"block_{weight_name_for_block.split('.')[1]}.safetensors",
)
def create_cuda_buffers(base_attrs, weight_dict, lazy_load, lazy_load_file, use_infer_dtype=None, scale_force_fp32=False, bias_force_fp32=False):
"""Create tensor buffers and move them to CUDA device (specified by AI_DEVICE).
Args:
base_attrs: [(name, attr_name, transpose), ...] List of tensor loading specifications,
where transpose indicates whether transposition is required
weight_dict: Preloaded weight dictionary
lazy_load: Whether to use lazy loading mode
lazy_load_file: File or directory path for lazy loading
use_infer_dtype: Whether to convert tensors to inference dtype (optional)
scale_force_fp32: Whether to force weight_scale tensors to float32 (optional)
bias_force_fp32: Whether to force bias tensors to float32 (optional)
Returns:
dict: {attr_name: tensor, ...} Dictionary of tensors located on CUDA device
"""
result = {}
for name, attr_name, transpose in base_attrs:
tensor = get_source_tensor(name, weight_dict, lazy_load, lazy_load_file, use_infer_dtype, scale_force_fp32, bias_force_fp32)
if transpose:
tensor = tensor.t()
result[attr_name] = tensor.to(AI_DEVICE)
return result
def create_cpu_buffers(base_attrs, lazy_load_file, use_infer_dtype=False, scale_force_fp32=False, bias_force_fp32=False):
"""Create pinned memory tensor buffers on CPU for lazy loading scenario.
Args:
base_attrs: [(name, attr_name, transpose), ...] Configuration list,
where transpose indicates whether transposition is required
lazy_load_file: File or directory path for lazy loading
use_infer_dtype: Whether to convert tensors to inference dtype (optional)
scale_force_fp32: Whether to force weight_scale tensors to float32 (optional)
bias_force_fp32: Whether to force bias tensors to float32 (optional)
Returns:
dict: {attr_name: tensor, ...} Dictionary of pinned memory tensors on CPU
"""
result = {}
# Use get_source_tensor to load the tensor (weight_dict is not required when lazy_load=True)
for name, attr_name, transpose in base_attrs:
tensor = get_source_tensor(name, {}, lazy_load=True, lazy_load_file=lazy_load_file, use_infer_dtype=use_infer_dtype, scale_force_fp32=scale_force_fp32, bias_force_fp32=bias_force_fp32)
result[attr_name] = create_pin_tensor(tensor, transpose=transpose)
return result
def create_default_tensors(base_attrs, weight_dict):
"""Create default tensors (device tensors and pinned memory tensors) based on the source weight device.
Args:
base_attrs: [(name, attr_name, transpose), ...] Configuration list,
where transpose indicates whether transposition is required
weight_dict: Preloaded weight dictionary
Returns:
tuple: (device_tensors_dict, pin_tensors_dict)
device_tensors_dict: {attr_name: tensor, ...} Tensors located on the original weight device
pin_tensors_dict: {attr_name: tensor, ...} Tensors with pinned memory on CPU
"""
device_tensors = {}
pin_tensors = {}
if not base_attrs:
return device_tensors, pin_tensors
first_tensor_name = base_attrs[0][0]
device = weight_dict[first_tensor_name].device
if device.type == "cpu":
for name, attr_name, transpose in base_attrs:
if name in weight_dict:
tensor = weight_dict[name]
pin_tensors[attr_name] = create_pin_tensor(tensor, transpose=transpose)
del weight_dict[name]
else:
for name, attr_name, transpose in base_attrs:
if name in weight_dict:
tensor = weight_dict[name]
if transpose:
tensor = tensor.t()
device_tensors[attr_name] = tensor
return device_tensors, pin_tensors
def move_tensor_to_device(obj, attr_name, target_device, non_blocking=False, use_copy=False):
"""Move the specified tensor attribute of an object to the target device,
with support for pinned memory tensors for faster transfer.
Args:
obj: Target object containing the tensor attribute
attr_name: Name of the tensor attribute to be moved
target_device: Target device to move the tensor to
non_blocking: Whether to perform non-blocking data transfer (optional)
use_copy: Whether to copy the tensor content before moving (optional)
"""
pin_attr_name = f"pin_{attr_name}"
if hasattr(obj, pin_attr_name) and getattr(obj, pin_attr_name) is not None:
pin_tensor = getattr(obj, pin_attr_name)
if hasattr(obj, attr_name) and getattr(obj, attr_name) is not None and use_copy:
setattr(obj, attr_name, pin_tensor.copy_(getattr(obj, attr_name), non_blocking=non_blocking).to(target_device))
else:
setattr(obj, attr_name, pin_tensor.to(target_device, non_blocking=non_blocking))
elif hasattr(obj, attr_name) and getattr(obj, attr_name) is not None:
setattr(obj, attr_name, getattr(obj, attr_name).to(target_device, non_blocking=non_blocking))
def build_lora_and_diff_names(weight_name, lora_prefix):
"""Build the full names of LoRA (down/up/alpha) and weight difference tensors.
Args:
weight_name: Original weight tensor name
lora_prefix: Prefix string for LoRA tensor names
Returns:
tuple: (lora_down_name, lora_up_name, lora_alpha_name, weight_diff_name, bias_diff_name)
Full names of various LoRA and difference tensors
"""
base_name = weight_name[:-7]
parts = base_name.split(".")
relative_path = ".".join(parts[1:])
lora_base = f"{lora_prefix}.{relative_path}"
lora_down_name = f"{lora_base}.lora_down.weight"
lora_up_name = f"{lora_base}.lora_up.weight"
lora_alpha_name = f"{lora_base}.alpha"
weight_diff_name = f"{lora_base}.diff"
bias_diff_name = f"{lora_base}.diff_b"
return lora_down_name, lora_up_name, lora_alpha_name, weight_diff_name, bias_diff_name
def move_attr_to_cuda(cls, base_attrs, lora_attrs, non_blocking=False):
"""Move base attributes and LoRA attributes to CUDA device.
Args:
cls: Target class instance containing tensor attributes
base_attrs: [(name, attr_name, transpose), ...] List of base attribute specifications
lora_attrs: Dictionary mapping LoRA attribute names to their name attributes
non_blocking: Whether to perform non-blocking data transfer (optional)
"""
# Base
for _, base_attr_name, _ in base_attrs:
move_tensor_to_device(cls, base_attr_name, AI_DEVICE, non_blocking)
# Lora
for lora_attr, _ in lora_attrs.items():
if hasattr(cls, lora_attr) and getattr(cls, lora_attr) is not None:
setattr(cls, lora_attr, getattr(cls, lora_attr).to(AI_DEVICE, non_blocking=non_blocking))
def move_attr_to_cpu(cls, base_attrs, lora_attrs, non_blocking=False):
"""Move base attributes and LoRA attributes to CPU device.
Args:
cls: Target class instance containing tensor attributes
base_attrs: [(name, attr_name, transpose), ...] List of base attribute specifications
lora_attrs: Dictionary mapping LoRA attribute names to their name attributes
non_blocking: Whether to perform non-blocking data transfer (optional)
"""
# Base
for _, base_attr_name, _ in base_attrs:
move_tensor_to_device(cls, base_attr_name, "cpu", non_blocking, use_copy=True)
# Lora
for lora_attr, _ in lora_attrs.items():
if hasattr(cls, lora_attr) and getattr(cls, lora_attr) is not None:
setattr(cls, lora_attr, getattr(cls, lora_attr).to("cpu", non_blocking=non_blocking))
def state_dict(cls, base_attrs, lora_attrs, destination=None):
"""Generate state dictionary containing base attributes and LoRA attributes.
Args:
cls: Target class instance containing tensor attributes
base_attrs: [(name, attr_name, transpose), ...] List of base attribute specifications
lora_attrs: Dictionary mapping LoRA attribute names to their name attributes
destination: Optional destination dictionary to store state dict (if None, creates new dict)
Returns:
dict: State dictionary containing all base and LoRA attributes with their corresponding names
"""
if destination is None:
destination = {}
# Base
for _, base_attr, _ in base_attrs:
pin_base_attr = getattr(cls, f"pin_{base_attr}", None)
device_attr = getattr(cls, base_attr, None)
name_attr = f"{base_attr}_name" if hasattr(cls, f"{base_attr}_name") else None
if name_attr:
name = getattr(cls, name_attr)
destination[name] = pin_base_attr if pin_base_attr is not None else device_attr
# Lora
for lora_attr, name_attr in lora_attrs.items():
if hasattr(cls, lora_attr):
destination[getattr(cls, name_attr)] = getattr(cls, lora_attr)
return destination
def load_state_dict(cls, base_attrs, lora_attrs, destination, block_index, adapter_block_index=None):
"""Load state dictionary into class instance, resolving block indices for base and LoRA attributes.
Args:
cls: Target class instance to load state dict into
base_attrs: [(name, attr_name, transpose), ...] List of base attribute specifications
lora_attrs: Dictionary mapping LoRA attribute names to their name attributes
destination: Source state dictionary to load from
block_index: Block index to resolve tensor names
adapter_block_index: Adapter block index for post-adapter scenarios (optional)
"""
# Base
for name, attr_name, _ in base_attrs:
actual_name = resolve_block_name(name, block_index, adapter_block_index, cls.is_post_adapter)
cuda_buffer_attr = f"{attr_name}_cuda_buffer"
if actual_name in destination:
if hasattr(cls, cuda_buffer_attr):
setattr(cls, attr_name, getattr(cls, cuda_buffer_attr).copy_(destination[actual_name], non_blocking=True))
else:
setattr(cls, attr_name, None)
# Lora
for lora_attr, lora_attr_name in lora_attrs.items():
name = resolve_block_name(getattr(cls, lora_attr_name), block_index)
if name in destination:
setattr(cls, lora_attr, getattr(cls, lora_attr).copy_(destination[name], non_blocking=True).to(AI_DEVICE))
import math
from abc import ABC, abstractmethod
class BaseTransformerInfer(ABC):
@abstractmethod
def infer(self):
pass
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.scheduler.transformer_infer = self
class BaseTaylorCachingTransformerInfer(BaseTransformerInfer):
@abstractmethod
def infer_calculating(self):
pass
@abstractmethod
def infer_using_cache(self):
pass
@abstractmethod
def get_taylor_step_diff(self):
pass
# 1. when fully calcualted, stored in cache
def derivative_approximation(self, block_cache, module_name, out):
if module_name not in block_cache:
block_cache[module_name] = {0: out}
else:
step_diff = self.get_taylor_step_diff()
previous_out = block_cache[module_name][0]
block_cache[module_name][0] = out
block_cache[module_name][1] = (out - previous_out) / step_diff
def taylor_formula(self, tensor_dict):
x = self.get_taylor_step_diff()
output = 0
for i in range(len(tensor_dict)):
output += (1 / math.factorial(i)) * tensor_dict[i] * (x**i)
return output
"""
WorldPlay AR Dataset and Data Utilities.
"""
from lightx2v.data.worldplay_ar_dataset import (
WorldPlayARDataset,
collate_fn,
)
__all__ = [
"WorldPlayARDataset",
"collate_fn",
]
"""
WorldPlay AR Dataset for autoregressive video generation training.
This dataset supports:
- Camera pose (w2c, intrinsic) loading
- Action label extraction
- Image conditioning for I2V
- Chunk-based training with memory window
- I2V masking for conditional generation
"""
import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from torch.utils.data import Dataset
try:
from PIL import Image
except ImportError:
Image = None
try:
import decord
decord.bridge.set_bridge("torch")
except ImportError:
decord = None
class WorldPlayARDataset(Dataset):
"""
Dataset for WorldPlay AR model training.
Supports loading video data with camera poses and action labels
for autoregressive video generation training.
Args:
data_root: Root directory containing video data
meta_file: Path to metadata JSON file
video_length: Number of frames per video sample
resolution: Target resolution (height, width)
chunk_latent_num: Number of latent frames per chunk
memory_window_size: Size of memory window for AR training
select_window_out_flag: Whether to use memory window selection
task: Task type ('t2v' or 'i2v')
transform: Optional transform to apply to frames
action_trans_thresh: Translation threshold for action quantization
action_rot_thresh: Rotation threshold for action quantization
num_action_classes: Number of discrete action classes (default 81 = 3^4)
"""
def __init__(
self,
data_root: str,
meta_file: str,
video_length: int = 125,
resolution: Tuple[int, int] = (480, 832),
chunk_latent_num: int = 4,
memory_window_size: int = 8,
select_window_out_flag: bool = True,
task: str = "i2v",
transform: Optional[Any] = None,
action_trans_thresh: float = 0.1,
action_rot_thresh: float = 0.05,
num_action_classes: int = 81,
):
super().__init__()
self.data_root = data_root
self.video_length = video_length
self.resolution = resolution
self.chunk_latent_num = chunk_latent_num
self.memory_window_size = memory_window_size
self.select_window_out_flag = select_window_out_flag
self.task = task
self.transform = transform
# Action quantization parameters
self.action_trans_thresh = action_trans_thresh
self.action_rot_thresh = action_rot_thresh
self.num_action_classes = num_action_classes
# Load metadata
self.samples = self._load_metadata(meta_file)
logger.info(f"Loaded {len(self.samples)} samples from {meta_file}")
def _load_metadata(self, meta_file: str) -> List[Dict]:
"""Load metadata from JSON file."""
with open(meta_file, "r") as f:
data = json.load(f)
samples = []
for item in data:
sample = {
"video_path": os.path.join(self.data_root, item["video_path"]),
"caption": item.get("caption", ""),
}
# Camera pose data
if "w2c" in item:
sample["w2c"] = item["w2c"]
if "intrinsic" in item:
sample["intrinsic"] = item["intrinsic"]
if "action" in item:
sample["action"] = item["action"]
# Image conditioning for I2V
if "image_cond" in item:
sample["image_cond"] = os.path.join(self.data_root, item["image_cond"])
samples.append(sample)
return samples
def __len__(self) -> int:
return len(self.samples)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
sample = self.samples[idx]
# Load video frames
video = self._load_video(sample["video_path"])
# Load camera poses
w2c = self._load_camera_pose(sample.get("w2c"))
intrinsic = self._load_intrinsic(sample.get("intrinsic"))
# Load or compute action labels
action = self._load_action(sample.get("action"), w2c)
# Load image condition for I2V
image_cond = None
if self.task == "i2v" and "image_cond" in sample:
image_cond = self._load_image(sample["image_cond"])
# Prepare I2V mask
i2v_mask = self._prepare_i2v_mask(video.shape[0])
# Select memory window for training
if self.select_window_out_flag:
(video, w2c, intrinsic, action, i2v_mask) = self._select_memory_window(video, w2c, intrinsic, action, i2v_mask)
output = {
"video": video,
"caption": sample["caption"],
"w2c": w2c,
"intrinsic": intrinsic,
"action": action,
"i2v_mask": i2v_mask,
}
if image_cond is not None:
output["image_cond"] = image_cond
return output
def _load_video(self, video_path: str) -> torch.Tensor:
"""Load video frames from file."""
if decord is None:
raise ImportError("decord is required for video loading")
vr = decord.VideoReader(video_path)
total_frames = len(vr)
# Sample frames
if total_frames >= self.video_length:
start_idx = random.randint(0, total_frames - self.video_length)
frame_indices = list(range(start_idx, start_idx + self.video_length))
else:
frame_indices = list(range(total_frames))
# Pad with last frame
frame_indices += [total_frames - 1] * (self.video_length - total_frames)
frames = vr.get_batch(frame_indices) # [T, H, W, C]
frames = frames.permute(0, 3, 1, 2).float() / 255.0 # [T, C, H, W]
# Resize to target resolution
frames = F.interpolate(frames, size=self.resolution, mode="bilinear", align_corners=False)
if self.transform is not None:
frames = self.transform(frames)
return frames
def _load_camera_pose(self, w2c_data: Optional[Any]) -> torch.Tensor:
"""Load world-to-camera transformation matrices."""
if w2c_data is None:
# Return identity matrices
return torch.eye(4).unsqueeze(0).repeat(self.video_length, 1, 1)
if isinstance(w2c_data, str):
# Load from file
w2c = np.load(w2c_data)
elif isinstance(w2c_data, list):
w2c = np.array(w2c_data)
else:
w2c = w2c_data
w2c = torch.from_numpy(w2c).float()
# Ensure correct shape [T, 4, 4]
if w2c.dim() == 2:
w2c = w2c.unsqueeze(0).repeat(self.video_length, 1, 1)
elif w2c.shape[0] < self.video_length:
# Pad with last pose
pad_size = self.video_length - w2c.shape[0]
w2c = torch.cat([w2c, w2c[-1:].repeat(pad_size, 1, 1)], dim=0)
elif w2c.shape[0] > self.video_length:
w2c = w2c[: self.video_length]
return w2c
def _load_intrinsic(self, intrinsic_data: Optional[Any]) -> torch.Tensor:
"""Load camera intrinsic matrices."""
if intrinsic_data is None:
# Return default intrinsics
K = torch.tensor([[500.0, 0.0, self.resolution[1] / 2], [0.0, 500.0, self.resolution[0] / 2], [0.0, 0.0, 1.0]])
return K.unsqueeze(0).repeat(self.video_length, 1, 1)
if isinstance(intrinsic_data, str):
intrinsic = np.load(intrinsic_data)
elif isinstance(intrinsic_data, list):
intrinsic = np.array(intrinsic_data)
else:
intrinsic = intrinsic_data
intrinsic = torch.from_numpy(intrinsic).float()
# Ensure correct shape [T, 3, 3]
if intrinsic.dim() == 2:
intrinsic = intrinsic.unsqueeze(0).repeat(self.video_length, 1, 1)
elif intrinsic.shape[0] < self.video_length:
pad_size = self.video_length - intrinsic.shape[0]
intrinsic = torch.cat([intrinsic, intrinsic[-1:].repeat(pad_size, 1, 1)], dim=0)
elif intrinsic.shape[0] > self.video_length:
intrinsic = intrinsic[: self.video_length]
return intrinsic
def _load_action(self, action_data: Optional[Any], w2c: torch.Tensor) -> torch.Tensor:
"""Load or compute action labels from camera poses."""
if action_data is not None:
if isinstance(action_data, str):
action = np.load(action_data)
elif isinstance(action_data, list):
action = np.array(action_data)
else:
action = action_data
action = torch.from_numpy(action).long()
else:
# Compute action from camera pose differences
action = self._compute_action_from_pose(w2c)
# Ensure correct shape [T]
if action.shape[0] < self.video_length:
pad_size = self.video_length - action.shape[0]
action = torch.cat([action, action[-1:].repeat(pad_size)], dim=0)
elif action.shape[0] > self.video_length:
action = action[: self.video_length]
return action
def _compute_action_from_pose(self, w2c: torch.Tensor) -> torch.Tensor:
"""
Compute discrete action labels from camera pose differences.
Action space: 81 classes (3^4 for forward/backward, left/right,
up/down, rotation)
"""
T = w2c.shape[0]
actions = torch.zeros(T, dtype=torch.long)
for t in range(1, T):
# Compute relative transformation
rel_pose = torch.inverse(w2c[t - 1]) @ w2c[t]
# Extract translation and rotation
translation = rel_pose[:3, 3]
rotation = rel_pose[:3, :3]
# Quantize to discrete action
action_idx = self._quantize_action(translation, rotation)
actions[t] = action_idx
return actions
def _quantize_action(self, translation: torch.Tensor, rotation: torch.Tensor) -> int:
"""Quantize continuous motion to discrete action index."""
# Use configurable thresholds
trans_thresh = self.action_trans_thresh
rot_thresh = self.action_rot_thresh
# Forward/backward (z-axis)
if translation[2] > trans_thresh:
fb = 2 # forward
elif translation[2] < -trans_thresh:
fb = 0 # backward
else:
fb = 1 # stationary
# Left/right (x-axis)
if translation[0] > trans_thresh:
lr = 2 # right
elif translation[0] < -trans_thresh:
lr = 0 # left
else:
lr = 1 # stationary
# Up/down (y-axis)
if translation[1] > trans_thresh:
ud = 2 # up
elif translation[1] < -trans_thresh:
ud = 0 # down
else:
ud = 1 # stationary
# Rotation (simplified to yaw)
yaw = torch.atan2(rotation[0, 2], rotation[2, 2])
if yaw > rot_thresh:
rot = 2 # rotate right
elif yaw < -rot_thresh:
rot = 0 # rotate left
else:
rot = 1 # no rotation
# Combine into single index (base-3 encoding)
action_idx = fb * 27 + lr * 9 + ud * 3 + rot
return action_idx
def _load_image(self, image_path: str) -> torch.Tensor:
"""Load conditioning image for I2V."""
if Image is None:
raise ImportError("PIL is required for image loading")
img = Image.open(image_path).convert("RGB")
img = img.resize((self.resolution[1], self.resolution[0]))
img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0
return img
def _prepare_i2v_mask(self, num_frames: int) -> torch.Tensor:
"""
Prepare I2V mask for conditional generation.
For I2V task, the first frame is conditioned (mask=0),
and remaining frames are generated (mask=1).
"""
mask = torch.ones(num_frames)
if self.task == "i2v":
mask[0] = 0 # First frame is conditioned
return mask
def _select_memory_window(
self,
video: torch.Tensor,
w2c: torch.Tensor,
intrinsic: torch.Tensor,
action: torch.Tensor,
i2v_mask: torch.Tensor,
) -> Tuple[torch.Tensor, ...]:
"""
Select a random memory window for training.
This simulates the AR generation process where we only
attend to a window of previous frames.
"""
T = video.shape[0]
window_size = self.memory_window_size * self.chunk_latent_num
if T <= window_size:
return video, w2c, intrinsic, action, i2v_mask
# Random start position
start_idx = random.randint(0, T - window_size)
end_idx = start_idx + window_size
return (
video[start_idx:end_idx],
w2c[start_idx:end_idx],
intrinsic[start_idx:end_idx],
action[start_idx:end_idx],
i2v_mask[start_idx:end_idx],
)
def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
"""Collate function for DataLoader."""
output = {}
# Stack tensors
for key in ["video", "w2c", "intrinsic", "action", "i2v_mask"]:
if key in batch[0]:
output[key] = torch.stack([item[key] for item in batch])
# Handle optional image_cond
if "image_cond" in batch[0]:
output["image_cond"] = torch.stack([item["image_cond"] for item in batch])
# Collect captions
output["caption"] = [item["caption"] for item in batch]
return output
import asyncio
import json
import os
import sys
from alibabacloud_dypnsapi20170525 import models as dypnsapi_models
from alibabacloud_dypnsapi20170525.client import Client
from alibabacloud_tea_openapi import models as openapi_models
from alibabacloud_tea_util import models as util_models
from loguru import logger
class AlibabaCloudClient:
def __init__(self):
config = openapi_models.Config(
access_key_id=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_ID"),
access_key_secret=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_SECRET"),
https_proxy=os.getenv("auth_https_proxy", None),
)
self.client = Client(config)
self.runtime = util_models.RuntimeOptions()
def check_ok(self, res, prefix):
logger.info(f"{prefix}: {res}")
if not isinstance(res, dict) or "statusCode" not in res or res["statusCode"] != 200:
logger.warning(f"{prefix}: error response: {res}")
return False
if "body" not in res or "Code" not in res["body"] or "Success" not in res["body"]:
logger.warning(f"{prefix}: error body: {res}")
return False
if res["body"]["Code"] != "OK" or res["body"]["Success"] is not True:
logger.warning(f"{prefix}: sms error: {res}")
return False
return True
async def send_sms(self, phone_number):
try:
req = dypnsapi_models.SendSmsVerifyCodeRequest(
phone_number=phone_number,
sign_name="速通互联验证服务",
template_code="100001",
template_param=json.dumps({"code": "##code##", "min": "5"}),
valid_time=300,
)
res = await self.client.send_sms_verify_code_with_options_async(req, self.runtime)
ok = self.check_ok(res.to_map(), "AlibabaCloudClient send sms")
logger.info(f"AlibabaCloudClient send sms for {phone_number}: {ok}")
return ok
except Exception as e:
logger.warning(f"AlibabaCloudClient send sms for {phone_number}: {e}")
return False
async def check_sms(self, phone_number, verify_code):
try:
req = dypnsapi_models.CheckSmsVerifyCodeRequest(
phone_number=phone_number,
verify_code=verify_code,
)
res = await self.client.check_sms_verify_code_with_options_async(req, self.runtime)
ok = self.check_ok(res.to_map(), "AlibabaCloudClient check sms")
logger.info(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {ok}")
return ok
except Exception as e:
logger.warning(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {e}")
return False
async def test(args):
assert len(args) in [1, 2], "Usage: python aliyun_sms.py <phone_number> [verify_code]"
phone_number = args[0]
client = AlibabaCloudClient()
if len(args) == 1:
await client.send_sms(phone_number)
else:
await client.check_sms(phone_number, args[1])
if __name__ == "__main__":
asyncio.run(test(sys.argv[1:]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment