Commit a1ebc651 authored by xuwx1's avatar xuwx1
Browse files

updata lightx2v

parent 5a4db490
Pipeline #3149 canceled with stages
import torch
from loguru import logger
try:
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
magi_ffa_func = None
try:
import flashinfer
except ImportError:
flashinfer = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
def generate_nbhd_mask(a, block_num, attnmap_frame_num, coefficient=[1.0, 0.5, 0.056], min_width=1.0, device="cpu"):
"""
a : block num per frame
block_num : block num per col/row
attnmap_frame_num : total frame num
"""
i_indices = torch.arange(block_num, device=device).unsqueeze(1) # [block_num, 1]
j_indices = torch.arange(block_num, device=device).unsqueeze(0) # [1, block_num]
assert len(coefficient) <= attnmap_frame_num, f"coefficient length {len(coefficient)} should <= attnmap_frame_num {attnmap_frame_num}"
width_list = [max(min_width, coefficient[i] * a) for i in range(len(coefficient))] + [min_width] * (attnmap_frame_num - len(coefficient))
logger.info(f"nbhd_attn width_list: {width_list}, len={len(width_list)}")
# attention sink frame: j <= a
mask_sink = j_indices <= a
mask_sparse = torch.zeros((block_num, block_num), dtype=torch.bool, device=device)
for interval in range(0, attnmap_frame_num):
n = i_indices // a
mask_sparse_base_1 = (j_indices >= (n + interval) * a) & (j_indices <= (n + interval + 1) * a)
n = j_indices // a
mask_sparse_base_2 = (i_indices >= (n + interval) * a) & (i_indices <= (n + interval + 1) * a)
width = width_list[interval]
mask_1 = mask_sparse_base_1 & (i_indices - j_indices + (interval * a + width) >= 0) & (i_indices - j_indices + (interval * a - width) <= 0)
mask_2 = mask_sparse_base_2 & (i_indices - j_indices - (interval * a - width) >= 0) & (i_indices - j_indices - (interval * a + width) <= 0)
mask_sparse = mask_sparse | mask_1 | mask_2
mask = mask_sink | mask_sparse
return mask
def generate_qk_ranges(mask, block_size, seqlen):
indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2]
i_indices = indices[:, 0] # [N]
j_indices = indices[:, 1] # [N]
q_start = i_indices * block_size # [N]
q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N]
k_start = j_indices * block_size # [N]
k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N]
q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2]
k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2]
return q_ranges, k_ranges
@ATTN_WEIGHT_REGISTER("nbhd_attn")
class NbhdAttnWeight(AttnWeightTemplate):
block_size = 128
seqlen = None
attnmap_frame_num = None
q_ranges = None
k_ranges = None
attn_type_map = None
coefficient = [1.0, 0.5, 0.056]
min_width = 1.0
def __init__(self):
self.config = {}
@classmethod
@torch.compiler.disable
def prepare_mask(cls, seqlen):
if seqlen == cls.seqlen:
return
block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda")
k_ranges = k_ranges.to(torch.int32).to("cuda")
cls.seqlen = seqlen
cls.q_ranges = q_ranges
cls.k_ranges = k_ranges
cls.attn_type_map = attn_type_map
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self.prepare_mask(seqlen=q.shape[0])
out = magi_ffa_func(
q,
k,
v,
q_ranges=self.q_ranges,
k_ranges=self.k_ranges,
attn_type_map=self.attn_type_map,
auto_range_merge=True,
)[0]
return out.reshape(out.shape[0], -1)
@ATTN_WEIGHT_REGISTER("nbhd_attn_flashinfer")
class NbhdAttnWeightFlashInfer(AttnWeightTemplate):
block_size = 128
seqlen = None
attnmap_frame_num = None
coefficient = [1.0, 0.5, 0.056]
min_width = 1.0
sparse_wrapper = None
def __init__(self):
self.config = {}
@classmethod
@torch.compiler.disable
def prepare_mask(cls, seqlen, head_num, head_dim):
if seqlen == cls.seqlen:
return
block_num = (seqlen + cls.block_size - 1) // cls.block_size
block_num_per_frame = seqlen / cls.attnmap_frame_num / cls.block_size
mask = generate_nbhd_mask(block_num_per_frame, block_num, cls.attnmap_frame_num, coefficient=cls.coefficient, min_width=cls.min_width, device="cpu")
mask = mask.unsqueeze(0).repeat(head_num, 1, 1)
block_rowcol_size = torch.ones(block_num, dtype=torch.int32) * cls.block_size
block_rowcol_size[-1] = seqlen - cls.block_size * (block_num - 1)
block_rowcol_size = block_rowcol_size.unsqueeze(0).repeat(head_num, 1)
float_workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.uint8, device="cuda:0")
cls.sparse_wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="fa2")
cls.sparse_wrapper.plan(
block_mask_map=mask,
block_row_sz=block_rowcol_size,
block_col_sz=block_rowcol_size,
num_qo_heads=head_num,
num_kv_heads=head_num,
head_dim=head_dim,
q_data_type=torch.bfloat16,
)
cls.seqlen = seqlen
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self.prepare_mask(seqlen=q.shape[0], head_num=q.shape[1], head_dim=q.shape[2])
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
out = self.sparse_wrapper.run(q, k, v)
out = out.transpose(0, 1)
return out.reshape(out.shape[0], -1)
import torch
from loguru import logger
try:
from magi_attention.functional import flex_flash_attn_func as magi_ffa_func
except ImportError:
magi_ffa_func = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
def shrinkMaskStrict(mask, block_size=128):
seqlen = mask.shape[0]
block_num = seqlen // block_size
mask = mask[: block_num * block_size, : block_num * block_size].view(block_num, block_size, block_num, block_size)
col_densities = mask.sum(dim=1) / block_size
# we want the minimum non-zero column density in the block
non_zero_densities = col_densities > 0
high_density_cols = col_densities > 1 / 3
frac_high_density_cols = high_density_cols.sum(dim=-1) / (non_zero_densities.sum(dim=-1) + 1e-9)
block_mask = frac_high_density_cols > 0.6
block_mask[0:0] = True
block_mask[-1:-1] = True
return block_mask
def get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=1, block_size=128, model_type=None):
assert sparse_type in ["radial"]
dist = abs(i - j)
if model_type == "wan":
if dist < 1:
return token_per_frame
if dist == 1:
return token_per_frame // 2
elif model_type == "hunyuan":
if dist <= 1:
return token_per_frame
else:
raise ValueError(f"Unknown model type: {model_type}")
group = dist.bit_length()
decay_length = 2 ** token_per_frame.bit_length() / 2**group * decay_factor
threshold = block_size
if decay_length >= threshold:
return decay_length
else:
return threshold
def get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device):
assert sparse_type in ["radial"]
dist = abs(i - j)
group = dist.bit_length()
threshold = 128 # hardcoded threshold for now, which is equal to block-size
decay_length = 2 ** token_per_frame.bit_length() / 2**group
if decay_length >= threshold:
return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
split_factor = int(threshold / decay_length)
modular = dist % split_factor
if modular == 0:
return torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
else:
return torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
def gen_log_mask_shrinked(device, s, video_token_num, num_frame, block_size=128, sparse_type="log", decay_factor=0.5, model_type=None):
"""
A more memory friendly version, we generate the attention mask of each frame pair at a time,
shrinks it, and stores it into the final result
"""
final_log_mask = torch.zeros(((s + block_size - 1) // block_size, (s + block_size - 1) // block_size), device=device, dtype=torch.bool)
token_per_frame = video_token_num // num_frame
video_text_border = video_token_num // block_size
col_indices = torch.arange(0, token_per_frame, device=device).view(1, -1)
row_indices = torch.arange(0, token_per_frame, device=device).view(-1, 1)
final_log_mask[video_text_border:] = True
final_log_mask[:, video_text_border:] = True
for i in range(num_frame):
for j in range(num_frame):
local_mask = torch.zeros((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
if j == 0 and model_type == "wan": # this is attention sink
local_mask = torch.ones((token_per_frame, token_per_frame), device=device, dtype=torch.bool)
else:
window_width = get_window_width(i, j, token_per_frame, sparse_type, num_frame, decay_factor=decay_factor, block_size=block_size, model_type=model_type)
local_mask = torch.abs(col_indices - row_indices) <= window_width
split_mask = get_diagonal_split_mask(i, j, token_per_frame, sparse_type, device)
local_mask = torch.logical_and(local_mask, split_mask)
remainder_row = (i * token_per_frame) % block_size
remainder_col = (j * token_per_frame) % block_size
# get the padded size
all_length_row = remainder_row + ((token_per_frame - 1) // block_size + 1) * block_size
all_length_col = remainder_col + ((token_per_frame - 1) // block_size + 1) * block_size
padded_local_mask = torch.zeros((all_length_row, all_length_col), device=device, dtype=torch.bool)
padded_local_mask[remainder_row : remainder_row + token_per_frame, remainder_col : remainder_col + token_per_frame] = local_mask
# shrink the mask
block_mask = shrinkMaskStrict(padded_local_mask, block_size=block_size)
# set the block mask to the final log mask
block_row_start = (i * token_per_frame) // block_size
block_col_start = (j * token_per_frame) // block_size
block_row_end = block_row_start + block_mask.shape[0]
block_col_end = block_col_start + block_mask.shape[1]
final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end] = torch.logical_or(final_log_mask[block_row_start:block_row_end, block_col_start:block_col_end], block_mask)
return final_log_mask
def generate_qk_ranges(mask, block_size, seqlen):
indices = torch.nonzero(mask, as_tuple=False) # shape: [N, 2]
i_indices = indices[:, 0] # [N]
j_indices = indices[:, 1] # [N]
q_start = i_indices * block_size # [N]
q_end = torch.clamp((i_indices + 1) * block_size, max=seqlen) # [N]
k_start = j_indices * block_size # [N]
k_end = torch.clamp((j_indices + 1) * block_size, max=seqlen) # [N]
q_ranges = torch.stack([q_start, q_end], dim=1) # [N, 2]
k_ranges = torch.stack([k_start, k_end], dim=1) # [N, 2]
return q_ranges, k_ranges
@ATTN_WEIGHT_REGISTER("radial_attn")
class RadialAttnWeight(AttnWeightTemplate):
block_size = 128
seqlen = None
attnmap_frame_num = None
q_ranges = None
k_ranges = None
attn_type_map = None
def __init__(self):
self.config = {}
@classmethod
def prepare_mask(cls, seqlen):
if seqlen == cls.seqlen:
return
mask = gen_log_mask_shrinked(
device="cuda", s=seqlen, video_token_num=seqlen, num_frame=cls.attnmap_frame_num, block_size=cls.block_size, sparse_type="radial", decay_factor=0.2, model_type="wan"
)
q_ranges, k_ranges = generate_qk_ranges(mask, cls.block_size, seqlen)
attn_type_map = torch.zeros(len(q_ranges), dtype=torch.int32, device="cuda")
q_ranges = q_ranges.to(torch.int32).to("cuda")
k_ranges = k_ranges.to(torch.int32).to("cuda")
cls.seqlen = seqlen
cls.q_ranges = q_ranges
cls.k_ranges = k_ranges
cls.attn_type_map = attn_type_map
logger.info(f"NbhdAttnWeight Update: seqlen={seqlen}")
sparsity = 1 - mask.sum().item() / mask.numel()
logger.info(f"Attention sparsity: {sparsity}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
"""
q: [seqlen, head_num, head_dim]
k: [seqlen, head_num, head_dim]
v: [seqlen, head_num, head_dim]
"""
self.prepare_mask(seqlen=q.shape[0])
out = magi_ffa_func(
q,
k,
v,
q_ranges=self.q_ranges,
k_ranges=self.k_ranges,
attn_type_map=self.attn_type_map,
auto_range_merge=True,
)[0]
return out.reshape(out.shape[0], -1)
import torch
import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
from .utils.ring_comm import RingComm
try:
import flash_attn
from flash_attn.flash_attn_interface import flash_attn_varlen_func
except ImportError:
logger.info("flash_attn_varlen_func not found, please install flash_attn2 first")
flash_attn_varlen_func = None
@torch.jit.script
def _update_out_and_lse(
out,
lse,
block_out,
block_lse,
):
block_out = block_out.to(torch.float32)
block_lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
# For additional context and discussion, please refer to:
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
out = out - F.sigmoid(block_lse - lse) * (out - block_out)
lse = lse - F.logsigmoid(lse - block_lse)
return out, lse
@ATTN_WEIGHT_REGISTER("ring")
class RingAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
"""
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
assert not use_fp8_comm, "RingAttn can't support fp8 comm now."
# 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group)
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = 0
# if RING_COMM is None:
# init_ring_comm()
RING_COMM = RingComm(seq_p_group)
# if len(cu_seqlens_qkv) == 3:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
# elif len(cu_seqlens_qkv) == 2:
# txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
# txt_mask_len = None
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
img_q, img_k, img_v = q[:, :img_qkv_len, :, :].contiguous(), k[:, :img_qkv_len, :, :].contiguous(), v[:, :img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = (
q[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
k[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
v[:, img_qkv_len : img_qkv_len + txt_qkv_len, :, :].contiguous(),
)
out, lse, next_k, next_v = None, None, None, None
if len(cu_seqlens_qkv) == 3:
q = torch.cat((img_q, txt_q), dim=1)
k = img_k
v = img_v
for step in range(world_size):
if step + 1 != world_size:
next_k = RING_COMM.send_recv(k)
next_v = RING_COMM.send_recv(v)
RING_COMM.commit()
if step + 1 == world_size:
k = torch.cat((k, txt_k), dim=1)
v = torch.cat((v, txt_v), dim=1)
block_out, block_lse = self.ring_attn_sub(q, k, v)
out, lse = self.update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != world_size:
RING_COMM.wait()
k = next_k
v = next_v
attn1 = out.to(GET_DTYPE()).squeeze(0).reshape(img_qkv_len + txt_qkv_len, -1)
if txt_mask_len > 0:
attn2, *_ = flash_attn.flash_attn_interface._flash_attn_forward(
q[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
k[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
v[:, -(txt_mask_len - txt_qkv_len) :, :, :].contiguous(),
dropout_p=0.0,
softmax_scale=q.shape[-1] ** (-0.5),
causal=False,
window_size_left=-1,
window_size_right=-1,
softcap=0.0,
alibi_slopes=None,
return_softmax=False,
)
attn2 = attn2.to(GET_DTYPE()).squeeze(0).reshape((txt_mask_len - txt_qkv_len), -1)
attn1 = torch.cat([attn1, attn2], dim=0)
return attn1
def ring_attn_sub(self, q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1), softcap=0.0, alibi_slopes=None, return_softmax=False):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
block_out, block_lse, _, _ = flash_attn.flash_attn_interface._flash_attn_forward(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax,
)
return block_out, block_lse
def update_out_and_lse(
self,
out,
lse,
block_out,
block_lse,
slice_=None,
):
if out is None:
if slice_ is not None:
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
out = block_out.to(torch.float32)
lse = block_lse.transpose(-2, -1).unsqueeze(dim=-1)
elif slice_ is not None:
slice_out, slice_lse = out[slice_], lse[slice_]
slice_out, slice_lse = _update_out_and_lse(slice_out, slice_lse, block_out, block_lse)
out[slice_], lse[slice_] = slice_out, slice_lse
else:
out, lse = _update_out_and_lse(out, lse, block_out, block_lse)
return out, lse
import torch
from loguru import logger
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
if torch.cuda.is_available() and torch.cuda.get_device_capability(0) in [(8, 9), (12, 0)]:
try:
from sageattention import sageattn_qk_int8_pv_fp16_triton as sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
else:
try:
from sageattention import sageattn
except ImportError:
logger.info("sageattn not found, please install sageattention first")
sageattn = None
try:
from sageattn3 import sageattn3_blackwell
except ImportError:
logger.info("sageattn3 not found, please install sageattention first")
sageattn3_blackwell = None
@ATTN_WEIGHT_REGISTER("sage_attn2")
class SageAttn2Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
x = sageattn(
q,
k,
v,
tensor_layout="NHD",
).view(bs * max_seqlen_q, -1)
return x
@ATTN_WEIGHT_REGISTER("sage_attn3")
class SageAttn3Weight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
if len(q.shape) == 3:
bs = 1
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
elif len(q.shape) == 4:
bs = q.shape[0]
x = sageattn3_blackwell(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(1, 2).reshape(bs * max_seqlen_q, -1)
return x
import os
import torch
try:
import spas_sage_attn
except ImportError:
spas_sage_attn = None
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("spas_sage_attn")
class SageAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
@classmethod
def apply(self, q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None, max_seqlen_kv=None, model_cls=None, tensor_layout="HND"):
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
attn_out = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout)
_, H, N, D = attn_out.shape
attn_out = attn_out.permute(2, 1, 3, 0).contiguous().view(N, H * D)
return attn_out
if __name__ == "__main__":
import matplotlib.pyplot as plt
# 1. 构造输入
q = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
k = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
v = torch.randn(32760, 12, 128, dtype=torch.bfloat16).cuda()
# 2. 直接用PyTorch计算注意力
q_ = q.float()
k_ = k.float()
v_ = v.float()
attn_weights = torch.matmul(q_, k_.transpose(-2, -1)) / (128**0.5)
attn_weights = torch.softmax(attn_weights, dim=-1)
output_pt = torch.matmul(attn_weights, v_)
# 3. 用spas_sage2_attn_meansim_cuda计算注意力
q = q.unsqueeze(0) # shape: (1, 32760, 12, 128)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
q = q.transpose(1, 2) # shape: (1, 12, 32760, 128)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
output_cuda = spas_sage_attn.core.spas_sage2_attn_meansim_cuda(q, k, v, tensor_layout="HND")
output_cuda = output_cuda.float()
# 4. 取左上角[3000, 3000],只取第一个head
output_pt_crop = output_pt[0, :3000, :3000].cpu().detach().numpy()
output_cuda_crop = output_cuda[0, 0, :3000, :3000].cpu().detach().numpy()
# 5. 保存图片
save_dir = os.path.expanduser("~/Log/10-22/")
os.makedirs(save_dir, exist_ok=True)
plt.imshow(output_pt_crop, aspect="auto")
plt.title("PyTorch Attention (left-top 3000x3000)")
plt.savefig(os.path.join(save_dir, "attn.png"))
plt.close()
plt.imshow(output_cuda_crop, aspect="auto")
plt.title("spas_sage2_attn_meansim_cuda (left-top 3000x3000)")
plt.savefig(os.path.join(save_dir, "spas_attn.png"))
plt.close()
from typing import Optional
# Please reinstall flashinfer by referring to https://github.com/svg-project/Sparse-VideoGen
try:
import flashinfer
except ImportError:
flashinfer = None
import torch
import triton
import triton.language as tl
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .svg2_attn_utils import (
batch_kmeans_Euclid,
identify_dynamic_map,
)
from .template import AttnWeightTemplate
@triton.jit
def _permute_kernel(
X_ptr,
IDX_ptr,
Y_ptr,
S: tl.constexpr,
D: tl.constexpr,
BLOCK_S: tl.constexpr,
):
"""Each program permutes BLOCK_S tokens *all* hidden features (D). No inner python loop."""
pid_bh = tl.program_id(0)
tile_s = tl.program_id(1)
# Offsets along sequence
s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
token_mask = s_offsets < S
# Gather source indices for these tokens
idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
src_row_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)
# Broadcast to create 2-D pointer matrix (BLOCK_S, D)
d_offsets = tl.arange(0, D)
src_ptrs = X_ptr + (pid_bh * S + src_row_idx[:, None]) * D + d_offsets[None, :]
dst_ptrs = Y_ptr + (pid_bh * S + s_offsets[:, None]) * D + d_offsets[None, :]
full_mask = token_mask[:, None]
values = tl.load(src_ptrs, mask=full_mask, other=0.0)
tl.store(dst_ptrs, values, mask=full_mask)
def permute_tensor_by_labels_triton(
tensor: torch.Tensor,
labels: Optional[torch.Tensor],
dim: int,
*,
sorted_indices: Optional[torch.Tensor] = None,
):
"""
Permute `tensor` along `dim` according to ascending order of `labels`.
This is a Triton-accelerated replacement for the original implementation.
It currently supports 4-D tensors of shape [B, H, S, D] and `dim == 2`.
If these conditions are not met or the tensors reside on CPU, we fall back
to the reference PyTorch implementation.
"""
# Assertions – we only support the optimized CUDA path.
assert dim == 2, "permute_tensor_by_labels currently only supports dim==2 (sequence dimension)"
assert tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
assert tensor.is_cuda, "permute_tensor_by_labels requires CUDA tensors"
B, H, S, D = tensor.shape
BH = B * H
# Determine sorted indices
if sorted_indices is not None:
sorted_indices = sorted_indices.to(torch.int32).contiguous()
else:
assert labels is not None, "Either `labels` or `sorted_indices` must be provided."
labels = labels.to(tensor.device)
sorted_indices = torch.argsort(labels, dim=-1).to(torch.int32).contiguous()
# Flatten tensor and allocate output
inp_flat = tensor.reshape(BH, S, D).contiguous()
out_flat = torch.empty_like(inp_flat)
# Triton kernel tile size
BLOCK_S = 64 # number of tokens per program, tunable
n_s_tiles = triton.cdiv(S, BLOCK_S)
grid = (BH, n_s_tiles)
_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)
permuted_tensor = out_flat.reshape(B, H, S, D)
return permuted_tensor, sorted_indices
@triton.jit
def _inverse_permute_kernel(
X_ptr,
IDX_ptr,
Y_ptr,
S: tl.constexpr,
D: tl.constexpr,
BLOCK_S: tl.constexpr,
):
"""Inverse permutation: scatter BLOCK_S tokens back in one shot."""
pid_bh = tl.program_id(0)
tile_s = tl.program_id(1)
s_offsets = tile_s * BLOCK_S + tl.arange(0, BLOCK_S)
token_mask = s_offsets < S
idx_ptrs = IDX_ptr + pid_bh * S + s_offsets
src_pos_idx = s_offsets.to(tl.int32)
dst_pos_idx = tl.load(idx_ptrs, mask=token_mask, other=0).to(tl.int32)
d_offsets = tl.arange(0, D)
src_ptrs = X_ptr + (pid_bh * S + src_pos_idx[:, None]) * D + d_offsets[None, :]
dst_ptrs = Y_ptr + (pid_bh * S + dst_pos_idx[:, None]) * D + d_offsets[None, :]
full_mask = token_mask[:, None]
values = tl.load(src_ptrs, mask=full_mask, other=0.0)
tl.store(dst_ptrs, values, mask=full_mask)
def apply_inverse_permutation_triton(
permuted_tensor: torch.Tensor,
sorted_indices: torch.Tensor,
dim: int,
):
"""
Triton implementation of inverse permutation. Inverse the permutation applied by `permute_tensor_by_labels`.
Args:
permuted_tensor: (B, H, S, D).
sorted_indices: (B, H, S).
dim: Dimension along which to apply inverse permutation. Typically 2, meaning the sequence lengthdimension.
Returns:
Tensor of shape (B, H, S, D).
"""
assert dim == 2, "apply_inverse_permutation currently only supports dim==2"
assert permuted_tensor.dim() == 4, "Expected tensor shape [B,H,S,D]"
assert permuted_tensor.is_cuda, "apply_inverse_permutation requires CUDA tensors"
B, H, S, D = permuted_tensor.shape
BH = B * H
# Ensure index dtype
sorted_indices = sorted_indices.to(torch.int32).contiguous()
# Flatten inputs
inp_flat = permuted_tensor.reshape(BH, S, D).contiguous()
out_flat = torch.empty_like(inp_flat)
BLOCK_S = 64
n_s_tiles = triton.cdiv(S, BLOCK_S)
grid = (BH, n_s_tiles)
_inverse_permute_kernel[grid](inp_flat, sorted_indices, out_flat, S, D, BLOCK_S, num_warps=4)
original_tensor = out_flat.reshape(B, H, S, D)
return original_tensor
@ATTN_WEIGHT_REGISTER("svg2_attn")
class Svg2AttnWeight(AttnWeightTemplate):
centroids_init = False
num_q_centroids = 300
num_k_centroids = 1000
kmeans_iter_init = 50
top_p_kmeans = 0.9
min_kc_ratio = 0.10
kmeans_iter_step = 2
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
q = q.unsqueeze(0).transpose(1, 2)
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)
bs, num_heads, seq_len, dim = q.size()
q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, q_sorted_indices = self.semantic_aware_permutation(q, k, v)
output_permuted = self.dynamic_block_sparse_fwd_flashinfer(q_perm, k_perm, v_perm, dyn_map, qc_sz_s, kc_sz_s, is_cpu=False)
attn_output = apply_inverse_permutation_triton(output_permuted, q_sorted_indices, dim=2)
return attn_output.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
def dynamic_block_sparse_fwd_flashinfer(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
block_mask_map: torch.Tensor,
block_row_sz: torch.Tensor,
block_col_sz: torch.Tensor,
is_cpu: bool = True,
):
"""
Launcher for the Flashinfer dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
block_mask_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num]. Currently must on CPU.
block_row_sz (torch.Tensor): Query block sizes, shape [B, H, qc_num]. Currently must on CPU.
block_col_sz (torch.Tensor): Key block sizes, shape [B, H, kc_num]. Currently must on CPU.
is_cpu (bool): Whether to run on CPU. Flashinfer default is to run on CPU. We switch to GPU for faster planning. Default is True.
"""
# Input shape check
B, H, S, D = q.shape
qc_num = block_row_sz.shape[-1]
kc_num = block_col_sz.shape[-1]
assert block_mask_map.shape == (B, H, qc_num, kc_num)
assert all(t.device == torch.device("cpu") for t in [block_mask_map, block_row_sz, block_col_sz]) if is_cpu else True
# Check if block_col_sz and block_row_sz are the same for each head
assert torch.all(block_col_sz.sum(dim=2) == block_col_sz.sum(dim=2)[0, 0])
assert torch.all(block_row_sz.sum(dim=2) == block_row_sz.sum(dim=2)[0, 0])
# Prepare flashinfer wrapper
float_workspace_buffer = torch.empty(128 * 1024 * 1024, device=q.device)
vector_sparse_indices_buffer = torch.empty(1024 * 1024 * 1024, device=q.device)
wrapper = flashinfer.sparse.VariableBlockSparseAttentionWrapper(float_workspace_buffer, backend="auto")
wrapper.reset_workspace_buffer(
float_workspace_buffer=wrapper._float_workspace_buffer,
int_workspace_buffer=wrapper._int_workspace_buffer,
vector_sparse_indices_buffer=vector_sparse_indices_buffer, # Only reset this buffer size
vector_sparse_indptr_buffer=wrapper._vector_sparse_indptr_buffer,
)
block_mask_map = block_mask_map.reshape(B * H, qc_num, kc_num)
block_row_sz = block_row_sz.reshape(B * H, qc_num)
block_col_sz = block_col_sz.reshape(B * H, kc_num)
wrapper.plan(
block_mask_map=block_mask_map,
block_row_sz=block_row_sz,
block_col_sz=block_col_sz,
num_qo_heads=B * H,
num_kv_heads=B * H,
head_dim=D,
q_data_type=q.dtype,
kv_data_type=k.dtype,
)
# print_memory_usage("After plan")
q = q.reshape(B * H, S, D)
k = k.reshape(B * H, S, D)
v = v.reshape(B * H, S, D)
o = wrapper.run(q, k, v) # [num_qo_heads, qo_len, head_dim]
o = o.reshape(B, H, S, D)
return o
def semantic_aware_permutation(self, query, key, value):
cfg, num_heads, seq_len, dim = query.size()
# 1. Kmeans clustering
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_clustering(query, key)
# 2. Identify dynamic map
q_cluster_sizes = qcluster_sizes.view(cfg, num_heads, self.num_q_centroids)
k_cluster_sizes = kcluster_sizes.view(cfg, num_heads, self.num_k_centroids)
dynamic_map = identify_dynamic_map(
qcentroids.view(cfg, num_heads, self.num_q_centroids, dim),
kcentroids.view(cfg, num_heads, self.num_k_centroids, dim),
q_cluster_sizes,
k_cluster_sizes,
self.top_p_kmeans,
self.min_kc_ratio,
)
# 3. Permute the query, key, value
q_permuted, q_sorted_indices = permute_tensor_by_labels_triton(query, qlabels, dim=2)
k_permuted, k_sorted_indices = permute_tensor_by_labels_triton(key, klabels, dim=2)
v_permuted, v_sorted_indices = permute_tensor_by_labels_triton(value, klabels, dim=2, sorted_indices=k_sorted_indices)
return q_permuted, k_permuted, v_permuted, dynamic_map, q_cluster_sizes, k_cluster_sizes, q_sorted_indices
def kmeans_clustering(self, query, key):
if not self.centroids_init:
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_init(query, key)
self.centroids_init = True
else:
qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter = self.kmeans_step(query, key)
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
def kmeans_init(self, query, key):
cfg, num_heads, seq_len, dim = query.size()
qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(query.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_q_centroids, max_iters=self.kmeans_iter_init)
klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(key.view(cfg * num_heads, seq_len, dim), n_clusters=self.num_k_centroids, max_iters=self.kmeans_iter_init)
self.q_centroids = qcentroids
self.k_centroids = kcentroids
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
def kmeans_step(self, query, key):
cfg, num_heads, seq_len, dim = query.size()
qlabels, qcentroids, qcluster_sizes, qiter = batch_kmeans_Euclid(
query.view(cfg * num_heads, seq_len, dim),
n_clusters=self.num_q_centroids,
max_iters=self.kmeans_iter_step,
init_centroids=self.q_centroids,
)
klabels, kcentroids, kcluster_sizes, kiter = batch_kmeans_Euclid(
key.view(cfg * num_heads, seq_len, dim),
n_clusters=self.num_k_centroids,
max_iters=self.kmeans_iter_step,
init_centroids=self.k_centroids,
)
self.q_centroids = qcentroids
self.k_centroids = kcentroids
return qlabels, qcentroids, qcluster_sizes, qiter, klabels, kcentroids, kcluster_sizes, kiter
if __name__ == "__main__":
q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()
svg2_attn = Svg2AttnWeight()
print("Svg2AttnWeight initialized.")
out = svg2_attn.apply(q, k, v)
print(f"out: {out.shape}, {out.dtype}, {out.device}")
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
try:
from cuvs.cluster.kmeans import KMeansParams, fit
except ImportError:
KMeansParams = None
fit = None
# --- New functions ---
def density_calculation(dynamic_map, q_cluster_sizes, k_cluster_sizes):
"""
Calculate the density of the dynamic map. Currently only batch size = 1 and head size = 1 are supported.
Input:
dynamic_map: [cfg, num_heads, qc_num, kc_num]
q_cluster_sizes: [cfg, num_heads, qc_num]
k_cluster_sizes: [cfg, num_heads, kc_num]
"""
cfg, num_heads, qc_num, kc_num = dynamic_map.shape
# Calculate the block size of each block
clustered_block_size = q_cluster_sizes[:, :, :, None] * k_cluster_sizes[:, :, None, :]
masked_block_size = clustered_block_size * dynamic_map
# Calculate the density of each block
density = torch.sum(masked_block_size, dim=(2, 3)) / torch.sum(clustered_block_size, dim=(2, 3))
return density
# --- Functions from analyze/kmeans_rapidai.py ---
def pairwise_distance(x, y):
"""
Computes pairwise squared Euclidean distance between two sets of points.
"""
x_norm = (x**2).sum(1).view(-1, 1)
y_norm = (y**2).sum(1).view(1, -1)
dist = torch.clamp(x_norm + y_norm - 2.0 * torch.mm(x, torch.transpose(y, 0, 1)), min=0.0)
return dist
def kmeans_predict(centroids, input_tensor): # Removed unused params argument
"""
Predict the labels for the input tensor using the centroids.
"""
input_tensor = input_tensor.to(torch.float32)
dist = pairwise_distance(input_tensor, centroids)
labels = torch.argmin(dist, dim=1)
return labels
def kmeans_rapidai(tensor, k, max_iter=5, tol=1e-4, init_method="Array", centroids_init=None): # Renamed centroids to centroids_init
"""
Performs K-means clustering using cuVS.
"""
assert tensor.dtype == torch.float32, "Tensor must be float32 for cuVS KMeans"
assert tensor.ndim == 2, f"Tensor must be 2D, but got {tensor.shape}"
# assert init_method == "Array", "init_method must be 'Array' for now"
L, D = tensor.shape
# cuVS KMeans in RAPIDS >=23.10 uses 'centroids_init' for initial centroids
current_centroids = centroids_init
if current_centroids is None:
# Default init: cuVS handles KMeansPlusPlus if centroids_init is None and init_method is KMeansPlusPlus
# If you need to pass an empty tensor for cuVS to initialize:
current_centroids = torch.empty(k, D, device=tensor.device, dtype=torch.float32) # Or pass None
else:
assert current_centroids.dtype == torch.float32, "Initial centroids must be float32"
assert current_centroids.shape == (
k,
D,
), f"Initial centroids shape mismatch, got {current_centroids.shape}, expected ({k}, {D})"
# cuVS uses 'init_method="Array"' when 'centroids_init' is provided.
# import IPython; IPython.embed()
params = KMeansParams(n_clusters=k, max_iter=max_iter, tol=tol, init_method=init_method) # Changed init_method to init
# Call fit with centroids_init (can be None)
new_centroids, inertia, n_iter_ = fit(params, tensor, current_centroids) # Added handle=None
labels = kmeans_predict(new_centroids, tensor)
return labels, new_centroids, n_iter_
@triton.jit
def _centroid_update_kernel(
x_ptr, # *f16 [B, N, D]
cluster_ptr, # *i32 [B, N]
sum_ptr, # *f32 [B, K, D]
count_ptr, # *i32 [B, K]
B: tl.constexpr,
N: tl.constexpr,
D: tl.constexpr,
K: tl.constexpr,
BLOCK_D: tl.constexpr, # number of dims processed per program
):
"""Each program processes 1 point (token) across BLOCK_D dimensions with atomics."""
pid = tl.program_id(axis=0)
token_idx = pid # range: [0, B * N)
# Derive (b, n) indices
b = token_idx // N
n = token_idx % N
# Pointers to the token features and its cluster id
x_offset = (b * N + n) * D
x_ptr = x_ptr + x_offset
cluster_idx = tl.load(cluster_ptr + b * N + n) # int32
# Guard for invalid cluster ids (should not happen)
cluster_idx = tl.where(cluster_idx < K, cluster_idx, 0)
# Base pointer for this centroid in the output sum tensor
centroid_base = (b * K + cluster_idx) * D
# Process feature vector in chunks of BLOCK_D
offs = tl.arange(0, BLOCK_D)
for d_start in range(0, D, BLOCK_D):
mask = offs + d_start < D
feats = tl.load(x_ptr + d_start + offs, mask=mask, other=0.0)
feats = feats.to(tl.float32)
dest_ptr = sum_ptr + centroid_base + d_start + offs
tl.atomic_add(dest_ptr, feats, mask=mask)
# Update counts (only once per point)
tl.atomic_add(count_ptr + b * K + cluster_idx, 1)
def triton_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
"""Compute centroids using custom Triton kernel.
Args:
x_norm (Tensor): (B, N, D) normalized input vectors (float16/float32)
cluster_ids (LongTensor): (B, N) cluster assignment per point
old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x_norm)
Returns:
Tensor: (B, K, D) updated and L2-normalized centroids (dtype == x_norm.dtype)
"""
assert x_norm.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device"
B, N, D = x_norm.shape
K = old_centroids.shape[1]
assert cluster_ids.shape == (B, N)
# Allocate accumulation buffers
centroid_sums = torch.zeros((B, K, D), device=x_norm.device, dtype=torch.float32)
centroid_counts = torch.zeros((B, K), device=x_norm.device, dtype=torch.int32)
# Launch Triton kernel – one program per token
total_tokens = B * N
BLOCK_D = 128 # tuneable
grid = (total_tokens,)
_centroid_update_kernel[grid](
x_norm,
cluster_ids.to(torch.int32),
centroid_sums,
centroid_counts,
B,
N,
D,
K,
BLOCK_D=BLOCK_D,
)
# Compute means; keep old centroid if empty cluster
counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
centroids = centroid_sums / counts_f
# For clusters with zero count, revert to old centroids
zero_mask = (centroid_counts == 0).unsqueeze(-1)
centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids)
centroids = centroids.to(x_norm.dtype)
centroids = F.normalize(centroids, p=2, dim=-1)
return centroids
def torch_loop_centroid_update_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
"""Reference Python implementation (double for-loop)"""
B, N, D = x_norm.shape
K = old_centroids.shape[1]
new_centroids = torch.zeros_like(old_centroids)
for b in range(B):
for k in range(K):
mask = cluster_ids[b] == k
if mask.any():
new_centroids[b, k] = F.normalize(x_norm[b][mask].mean(dim=0, dtype=x_norm.dtype), p=2, dim=0)
else:
new_centroids[b, k] = old_centroids[b, k]
return new_centroids
def triton_centroid_update_euclid(x: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor):
"""Compute centroids for Euclidean KMeans using Triton.
Args:
x (Tensor): (B, N, D) input vectors (float16/float32)
cluster_ids (LongTensor): (B, N) cluster assignment per point
old_centroids (Tensor): (B, K, D) previous centroids (same dtype as x)
Returns:
Tensor: (B, K, D) updated centroids (dtype == x.dtype)
"""
assert x.is_cuda and cluster_ids.is_cuda, "Input tensors must be on CUDA device"
B, N, D = x.shape
K = old_centroids.shape[1]
assert cluster_ids.shape == (B, N)
# Allocate accumulation buffers
centroid_sums = torch.zeros((B, K, D), device=x.device, dtype=torch.float32)
centroid_counts = torch.zeros((B, K), device=x.device, dtype=torch.int32)
total_tokens = B * N
BLOCK_D = 128 # tuneable
grid = (total_tokens,)
_centroid_update_kernel[grid](
x,
cluster_ids.to(torch.int32),
centroid_sums,
centroid_counts,
B,
N,
D,
K,
BLOCK_D=BLOCK_D,
)
# Compute means; keep old centroid if empty cluster
counts_f = centroid_counts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
centroids = centroid_sums / counts_f
# For clusters with zero count, revert to old centroids
zero_mask = (centroid_counts == 0).unsqueeze(-1)
centroids = torch.where(zero_mask, old_centroids.to(torch.float32), centroids)
return centroids.to(x.dtype)
# ------------------------------ NEW: chunk-wise centroid update (sorted ids) ------------------------------
@triton.jit
def _centroid_update_chunk_kernel(
x_ptr, # *f16 / *f32 [B, N, D] – ORIGINAL ORDER
sorted_idx_ptr, # *i32 [B, N] – indices after sort
sorted_cluster_ptr, # *i32 [B, N] – cluster ids in sorted order
sum_ptr, # *f32 [B, K, D]
count_ptr, # *i32 [B, K]
B: tl.constexpr,
N: tl.constexpr,
D: tl.constexpr,
K: tl.constexpr,
BLOCK_N: tl.constexpr, # how many tokens (points) each program processes
):
"""Each program processes **BLOCK_N consecutive, already-sorted tokens**.
Because the tokens are sorted by cluster id, identical ids appear in
contiguous runs. We therefore accumulate a local sum/count for the
current run and perform **a single atomic update per run**, instead of
per-token.
"""
# program indices – 2-D launch grid: (chunk_id, batch_id)
pid_chunk = tl.program_id(axis=0)
pid_b = tl.program_id(axis=1)
b = pid_b
chunk_start = pid_chunk * BLOCK_N # position of the first token handled by this program
# Nothing to do – out of range
if chunk_start >= N:
return
# base pointers for this batch
idx_batch_base = sorted_idx_ptr + b * N
cid_batch_base = sorted_cluster_ptr + b * N
x_batch_base = x_ptr + b * N * D # for pointer arithmetic
# helper aranges
offs_token = tl.arange(0, BLOCK_N)
offs_dim = tl.arange(0, D)
# first token index & validity mask
token_idx = chunk_start + offs_token
valid_tok = token_idx < N
first_token_idx = chunk_start
last_token_idx = tl.minimum(chunk_start + BLOCK_N, N) - 1
# Load first cluster id to initialise the running accumulator
first_id = tl.load(cid_batch_base + first_token_idx)
last_id = tl.load(cid_batch_base + last_token_idx)
all_ids = tl.load(cid_batch_base + token_idx, mask=valid_tok, other=-1)
all_tokens_idxs = tl.load(idx_batch_base + token_idx, mask=valid_tok, other=-1) # [BLOCK_N]
load_mask = all_tokens_idxs[:, None] * D + offs_dim[None, :]
for cid in range(first_id, last_id + 1):
cluster_mask = all_ids == cid
cluster_size = tl.sum(cluster_mask.to(tl.int32))
if cluster_size != 0:
cluster_feats = tl.load(x_batch_base + load_mask, mask=cluster_mask[:, None], other=0.0) # [BLOCK_N, D]
cluster_feats = cluster_feats.to(tl.float32)
sum_feats = tl.sum(cluster_feats, axis=0)
dest_ptr = sum_ptr + (b * K + cid) * D + offs_dim
tl.atomic_add(dest_ptr, sum_feats)
tl.atomic_add(count_ptr + b * K + cid, cluster_size)
# ---------------------------------------------------------------------------------------------
def triton_centroid_update_sorted_cosine(x_norm: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor, *, BLOCK_N: int = 256):
"""Fast centroid update assuming **cluster_ids are sorted along N**.
This helper will sort the assignments (together with `x_norm`) and launch the
chunk kernel above. Compared to the naive per-token kernel it performs *one
atomic add per run of identical ids* instead of per token, providing large
speed-ups when clusters are reasonably sized.
"""
assert x_norm.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA"
B, N, D = x_norm.shape
K = old_centroids.shape[1]
assert cluster_ids.shape == (B, N)
# -------- sort per-batch --------
sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids, dim=-1)
sorted_idx_int = sorted_idx.to(torch.int32)
# accumulation buffers
centroid_sums = torch.zeros((B, K, D), device=x_norm.device, dtype=torch.float32)
centroid_cnts = torch.zeros((B, K), device=x_norm.device, dtype=torch.int32)
grid = (triton.cdiv(N, BLOCK_N), B)
_centroid_update_chunk_kernel[grid](
x_norm,
sorted_idx_int,
sorted_cluster_ids.to(torch.int32),
centroid_sums,
centroid_cnts,
B,
N,
D,
K,
BLOCK_N=BLOCK_N,
)
# finalise – convert to means, handle empty clusters, renormalise
counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
centroids = centroid_sums / counts_f
empty_mask = (centroid_cnts == 0).unsqueeze(-1)
centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids)
centroids = centroids.to(x_norm.dtype)
centroids = F.normalize(centroids, p=2, dim=-1)
return centroids
def triton_centroid_update_sorted_euclid(x: torch.Tensor, cluster_ids: torch.Tensor, old_centroids: torch.Tensor, *, BLOCK_N: int = 256):
"""Fast centroid update for *Euclidean* KMeans assuming cluster IDs are pre-sorted.
Parameters
----------
x : Tensor [B, N, D]
Input feature vectors (no normalization assumed).
cluster_ids : LongTensor [B, N]
Cluster assignment for each point.
old_centroids : Tensor [B, K, D]
Previous centroids (used to fill empty clusters).
BLOCK_N : int, optional
Tokens per Triton program (affects occupancy/perf).
"""
assert x.is_cuda and cluster_ids.is_cuda, "Inputs must be on CUDA device"
B, N, D = x.shape
K = old_centroids.shape[1]
# Batch-wise sort of cluster assignments
sorted_cluster_ids, sorted_idx = torch.sort(cluster_ids, dim=-1)
sorted_idx_int = sorted_idx.to(torch.int32)
centroid_sums = torch.zeros((B, K, D), device=x.device, dtype=torch.float32)
centroid_cnts = torch.zeros((B, K), device=x.device, dtype=torch.int32)
grid = (triton.cdiv(N, BLOCK_N), B)
_centroid_update_chunk_kernel[grid](
x, # original features
sorted_idx_int, # gather indices
sorted_cluster_ids.to(torch.int32),
centroid_sums,
centroid_cnts,
B,
N,
D,
K,
BLOCK_N=BLOCK_N,
)
# Convert sums to means; replace empty clusters with old centroids
counts_f = centroid_cnts.to(torch.float32).unsqueeze(-1).clamp(min=1.0)
centroids = centroid_sums / counts_f
empty_mask = (centroid_cnts == 0).unsqueeze(-1)
centroids = torch.where(empty_mask, old_centroids.to(torch.float32), centroids)
return centroids.to(x.dtype), centroid_cnts
# ===============================================================
# Triton kernel: compute nearest-centroid IDs (Euclidean distance)
# Inputs:
# x : (B, N, D) float16 / float32
# centroids : (B, K, D) same dtype as x
# x_sq : (B, N) float32 – pre-computed ||x||^2 per point
# Output:
# cluster_ids : (B, N) int32 – nearest centroid index per point
# ===============================================================
def _ceil_div(a: int, b: int) -> int:
return (a + b - 1) // b
# -----------------------------------------------------------------------------
# Auto-tuning setup – explore various tile sizes / warp counts
# -----------------------------------------------------------------------------
_TUNE_CONFIGS = [triton.Config({"BLOCK_N": BN, "BLOCK_K": BK}, num_stages=4, num_warps=wp) for BN in [32, 64, 128] for BK in [32, 64, 128] for wp in [4, 8]]
def _cfg_keep(conf):
"""Basic heuristic to prune unbalanced configs."""
BN = conf.kwargs["BLOCK_N"]
BK = conf.kwargs["BLOCK_K"]
# Avoid tiny tiles on many warps
if BN * BK < 32 * 32 and conf.num_warps > 4:
return False
return True
_TUNE_CONFIGS = list(filter(_cfg_keep, _TUNE_CONFIGS))
@triton.autotune(_TUNE_CONFIGS, key=["N", "K"])
@triton.jit
def _euclid_assign_kernel(
x_ptr, # *f16 / *f32 [B, N, D]
c_ptr, # *f16 / *f32 [B, K, D]
x_sq_ptr, # *f32 [B, N]
out_ptr, # *i32 [B, N]
B: tl.constexpr,
N: tl.constexpr,
K: tl.constexpr,
D: tl.constexpr,
stride_x_b: tl.constexpr,
stride_x_n: tl.constexpr,
stride_x_d: tl.constexpr,
stride_c_b: tl.constexpr,
stride_c_k: tl.constexpr,
stride_c_d: tl.constexpr,
stride_xsq_b: tl.constexpr,
stride_xsq_n: tl.constexpr,
stride_out_b: tl.constexpr,
stride_out_n: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
"""Each program handles a tile of BLOCK_N points for a given batch element.
The kernel iterates over the centroid dimension K in chunks of BLOCK_K and
maintains the running minimum distance as well as the corresponding index
for every point in the tile.
"""
pid_n = tl.program_id(0) # tile index along N dimension
pid_b = tl.program_id(1) # batch index
n_start = pid_n * BLOCK_N
n_offsets = n_start + tl.arange(0, BLOCK_N)
n_mask = n_offsets < N
# ------------------------------------------------------------------
# Load x tile (BLOCK_N, D)
# ------------------------------------------------------------------
offs_d = tl.arange(0, D)
# Compute pointer for x block: base + b*stride_x_b + n*stride_x_n + d*stride_x_d
x_ptrs = x_ptr + pid_b * stride_x_b + n_offsets[:, None] * stride_x_n + offs_d[None, :] * stride_x_d
x_tile = tl.load(x_ptrs, mask=n_mask[:, None], other=0.0)
x_tile = x_tile # compute in f32
# Pre-load x_sq for the tile (BLOCK_N,)
xsq_ptrs = x_sq_ptr + pid_b * stride_xsq_b + n_offsets * stride_xsq_n
x_sq_tile = tl.load(xsq_ptrs, mask=n_mask, other=0.0).to(tl.float32)
# Init best distance / index
best_dist = tl.full((BLOCK_N,), 3.4e38, tl.float32) # large number
best_idx = tl.zeros((BLOCK_N,), tl.int32)
# ------------------------------------------------------------------
# Iterate over centroids in chunks of BLOCK_K
# ------------------------------------------------------------------
for k_start in range(0, K, BLOCK_K):
k_offsets = k_start + tl.arange(0, BLOCK_K)
k_mask = k_offsets < K
# Load centroid tile (D, BLOCK_K)
c_ptrs = c_ptr + pid_b * stride_c_b + k_offsets[None, :] * stride_c_k + offs_d[:, None] * stride_c_d
c_tile = tl.load(c_ptrs, mask=k_mask[None, :], other=0.0)
c_tile = c_tile
# Compute centroid squared norms (BLOCK_K,)
cent_sq = tl.sum(c_tile * c_tile, axis=0).to(tl.float32)
# Compute cross term (BLOCK_N, BLOCK_K) = x_tile @ c_tile
cross = tl.dot(x_tile, c_tile).to(tl.float32) # float32
# Squared Euclidean distance
dist = x_sq_tile[:, None] + cent_sq[None, :] - 2.0 * cross
dist = tl.maximum(dist, 0.0)
# Mask out invalid centroid columns before reduction
dist = tl.where(k_mask[None, :], dist, 3.4e38)
curr_min = tl.min(dist, axis=1)
curr_idx = tl.argmin(dist, axis=1)
update = curr_min < best_dist
best_dist = tl.where(update, curr_min, best_dist)
best_idx = tl.where(update, k_start + curr_idx, best_idx)
# ------------------------------------------------------------------
# Write results
# ------------------------------------------------------------------
out_ptrs = out_ptr + pid_b * stride_out_b + n_offsets * stride_out_n
tl.store(out_ptrs, best_idx, mask=n_mask)
# ---------------------------------------------------------------
# Python wrapper
# ---------------------------------------------------------------
def euclid_assign_triton(
x: torch.Tensor,
centroids: torch.Tensor,
x_sq: torch.Tensor,
out: torch.Tensor = None,
*,
BLOCK_N: int = 128,
BLOCK_K: int = 128,
) -> torch.Tensor:
"""Return nearest-centroid indices using Triton kernel.
Args:
x : (B, N, D) float16 / float32 (on CUDA)
centroids : (B, K, D) same dtype/device as x
x_sq : (B, N) float32 – ||x||^2 per point (on CUDA)
Returns:
cluster_ids (B, N) int32 (callers can cast to int64 if desired)
"""
assert x.is_cuda and centroids.is_cuda and x_sq.is_cuda, "All tensors must be on CUDA"
# assert x.dtype in (torch.float16, torch.float32), "x must be fp16/fp32"
assert centroids.dtype == x.dtype, "centroids dtype mismatch"
B, N, D = x.shape
K = centroids.shape[1]
assert centroids.shape == (B, K, D), "centroids shape mismatch"
assert x_sq.shape == (B, N), "x_sq shape mismatch"
# x = x.contiguous()
# centroids = centroids.contiguous()
# x_sq = x_sq.contiguous()
if out is None:
out = torch.empty((B, N), device=x.device, dtype=torch.int64)
# Strides (in elements)
stride_x_b, stride_x_n, stride_x_d = x.stride()
stride_c_b, stride_c_k, stride_c_d = centroids.stride()
stride_xsq_b, stride_xsq_n = x_sq.stride()
stride_out_b, stride_out_n = out.stride()
grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]), B) # noqa
_euclid_assign_kernel[grid](
x,
centroids,
x_sq,
out,
B,
N,
K,
D,
stride_x_b,
stride_x_n,
stride_x_d,
stride_c_b,
stride_c_k,
stride_c_d,
stride_xsq_b,
stride_xsq_n,
stride_out_b,
stride_out_n,
)
return out
# 1. Euclidean
def _euclid_iter(x, x_sq, centroids):
# cent_sq = (centroids ** 2).sum(dim=-1)
# cross = torch.einsum('bnd,bkd->bnk', x, centroids)
# dist_sq = (x_sq[:,:,None] + cent_sq[:,None,:] - 2.0 * cross).clamp_min_(0.0)
# cluster_ids = dist_sq.argmin(dim=-1)
cluster_ids = euclid_assign_triton(x, centroids, x_sq)
centroids_new, cluster_sizes = triton_centroid_update_sorted_euclid(x, cluster_ids, centroids)
# centroids_new = triton_centroid_update_euclid(x, cluster_ids, centroids)
# centroids_new = centroids_new.clone() # avoid CUDA graphs aliasing
shift = (centroids_new - centroids).norm(dim=-1).max()
return centroids_new, shift, cluster_ids, cluster_sizes
# 2. Cosine
def _cosine_iter(x_norm, centroids):
cos_sim = torch.einsum("bnd,bkd->bnk", x_norm, centroids)
cluster_ids = cos_sim.argmax(dim=-1)
centroids_new = triton_centroid_update_cosine(x_norm, cluster_ids, centroids)
# centroids_new = centroids_new.clone()
shift = (centroids_new - centroids).norm(dim=-1).max()
return centroids_new, shift, cluster_ids
# 3. Dot-product
def _dot_iter(x, centroids):
sim = torch.einsum("bnd,bkd->bnk", x, centroids)
cluster_ids = sim.argmax(dim=-1)
centroids_new = triton_centroid_update_cosine(x, cluster_ids, centroids)
# centroids_new = centroids_new.clone()
shift = (centroids_new - centroids).norm(dim=-1).max()
return centroids_new, shift, cluster_ids
COMPILE_FLAG = False
# Try to compile; if PyTorch < 2.0 or compile is not available, fallback to original function
try:
if COMPILE_FLAG:
_euclid_iter_compiled = torch.compile(_euclid_iter, dynamic=True, mode="reduce-overhead")
_cosine_iter_compiled = torch.compile(_cosine_iter, dynamic=True, mode="reduce-overhead")
_dot_iter_compiled = torch.compile(_dot_iter, dynamic=True, mode="reduce-overhead")
else:
_euclid_iter_compiled = _euclid_iter
_cosine_iter_compiled = _cosine_iter
_dot_iter_compiled = _dot_iter
except Exception: # pragma: no cover
_euclid_iter_compiled = _euclid_iter
_cosine_iter_compiled = _cosine_iter
_dot_iter_compiled = _dot_iter
def batch_kmeans_Euclid(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
"""
Batched KMeans clustering in PyTorch using Euclidean distance.
Args:
x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
n_clusters: Number of clusters.
max_iters: Max number of iterations.
tol: Relative tolerance for center movement.
verbose: Print loss for each iter.
Returns:
cluster_ids: (B, N) LongTensor, cluster assignment for each point.
centroids: (B, n_clusters, D) final cluster centers.
cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
n_iters: actual number of iterations executed (int)
"""
B, N, D = x.shape
# Pre-compute squared L2 norm of all points (constant during iterations)
x_sq = (x**2).sum(dim=-1) # (B, N)
if init_centroids is None:
# Randomly select initial centers from x
indices = torch.randint(0, N, (B, n_clusters), device=x.device)
centroids = torch.gather(x, dim=1, index=indices[..., None].expand(-1, -1, D)) # (B, n_clusters, D)
else:
# centroids = init_centroids.clone()
centroids = init_centroids
centroids = centroids.view(B, n_clusters, D)
for it in range(max_iters):
# ---- compiled single iteration ----
centroids_new, center_shift, cluster_ids, cluster_sizes = _euclid_iter_compiled(x, x_sq, centroids)
# 4. Check for convergence
if verbose:
print(f"Iter {it}, center shift: {center_shift.item():.6f}")
if center_shift < tol:
break
# centroids = centroids_new.clone()
centroids = centroids_new
# # --- compute cluster sizes ---
# ones = torch.ones_like(cluster_ids, dtype=torch.int64)
# cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
# cluster_sizes.scatter_add_(1, cluster_ids, ones)
return cluster_ids, centroids, cluster_sizes, it + 1
# return cluster_ids.clone(), centroids.clone(), cluster_sizes.clone(), it + 1
# batch_kmeans_Euclid = torch.compile(batch_kmeans_Euclid, dynamic=True, mode="reduce-overhead")
def batch_kmeans_Cosine(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
"""
Batched KMeans clustering in PyTorch using Cosine similarity.
Args:
x: Tensor of shape (B, N, D), batch_size B, N points per batch, D dims.
n_clusters: Number of clusters.
max_iters: Max number of iterations.
tol: Relative tolerance for center movement.
verbose: Print loss for each iter.
Returns:
cluster_ids: (B, N) LongTensor, cluster assignment for each point.
centroids: (B, n_clusters, D) final cluster centers.
cluster_sizes: (B, n_clusters) LongTensor, number of points per cluster.
n_iters: actual number of iterations executed (int)
"""
B, N, D = x.shape
# Normalize input vectors for cosine similarity
x_norm = F.normalize(x, p=2, dim=-1) # (B, N, D)
if init_centroids is None:
# Randomly select initial centers from x_norm
indices = torch.randint(0, N, (B, n_clusters), device=x.device)
centroids = torch.gather(x_norm, dim=1, index=indices[..., None].expand(-1, -1, D)) # (B, n_clusters, D)
else:
centroids = init_centroids
centroids = centroids.view(B, n_clusters, D)
centroids = F.normalize(centroids, p=2, dim=-1) # Ensure centroids are normalized
for it in range(max_iters):
# ---- compiled single iteration ----
centroids_new, center_shift, cluster_ids = _cosine_iter_compiled(x_norm, centroids)
# 4. Check for convergence
if verbose:
print(f"Iter {it}, center shift: {center_shift.item():.6f}")
if center_shift < tol:
break
centroids = centroids_new.clone()
# --- compute cluster sizes ---
ones = torch.ones_like(cluster_ids, dtype=torch.int64)
cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
cluster_sizes.scatter_add_(1, cluster_ids, ones)
return cluster_ids, centroids, cluster_sizes, it + 1
def batch_kmeans_Dot(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
"""
Batched KMeans clustering in PyTorch using raw dot-product as similarity.
"""
B, N, D = x.shape
if init_centroids is None:
# Randomly initialize centroids
indices = torch.randint(0, N, (B, n_clusters), device=x.device)
centroids = torch.gather(x, dim=1, index=indices[..., None].expand(-1, -1, D))
else:
centroids = init_centroids
centroids = centroids.view(B, n_clusters, D)
for it in range(max_iters):
# ---- compiled single iteration ----
centroids_new, center_shift, cluster_ids = _dot_iter_compiled(x, centroids)
# 4. Check for convergence
if verbose:
print(f"Iter {it} (dot), center shift: {center_shift.item():.6f}")
if center_shift < tol:
break
centroids = centroids_new.clone()
# --- compute cluster sizes ---
ones = torch.ones_like(cluster_ids, dtype=torch.int64)
cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
cluster_sizes.scatter_add_(1, cluster_ids, ones)
return cluster_ids, centroids, cluster_sizes, it + 1
# --- Functions from analyze/kmeans_block_sparse_attention.py (helpers) ---
def permute_tensor_by_labels(tensor, labels, dim):
labels = labels.to(tensor.device)
sorted_indices = torch.argsort(labels, dim=-1)
gather_indices = sorted_indices
for i in range(dim + 1, tensor.dim()):
gather_indices = gather_indices.unsqueeze(-1)
expand_shape = list(tensor.shape)
gather_indices = gather_indices.expand(expand_shape)
permuted_tensor = torch.gather(tensor, dim, gather_indices)
return permuted_tensor, sorted_indices
def apply_inverse_permutation(permuted_tensor, sorted_indices, dim):
inverse_indices = torch.argsort(sorted_indices, dim=-1)
gather_indices = inverse_indices
for i in range(dim + 1, permuted_tensor.dim()):
gather_indices = gather_indices.unsqueeze(-1)
gather_indices = gather_indices.expand(permuted_tensor.shape)
original_tensor = torch.gather(permuted_tensor, dim, gather_indices)
return original_tensor
def weighted_softmax(scores, weights):
input_dtype = scores.dtype
scores = scores.float()
weights = weights.float()
max_score = torch.max(scores, dim=-1, keepdim=True)[0]
exp_scores = torch.exp(scores - max_score)
weighted_exp = weights * exp_scores
softmax_out = weighted_exp / torch.sum(weighted_exp, dim=-1, keepdim=True).clamp(min=1e-12)
return softmax_out.to(input_dtype)
def identify_dynamic_map(
query_centroids,
key_centroids,
q_cluster_sizes,
k_cluster_sizes,
p,
min_kc_ratio=0,
):
B, H, qc_num, D = query_centroids.shape
kc_num = key_centroids.shape[2]
device = query_centroids.device
attn_scores = torch.matmul(query_centroids, key_centroids.transpose(-2, -1)) / (D**0.5)
k_weights = k_cluster_sizes.unsqueeze(-2).float()
weighted_attn_probs = weighted_softmax(attn_scores, k_weights)
sorted_probs, sorted_indices = torch.sort(weighted_attn_probs, dim=-1, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
remove_indices = cumsum_probs > p
remove_indices[..., 1:] = remove_indices[..., :-1].clone()
remove_indices[..., 0] = False
if min_kc_ratio > 0:
preserve_length = int(min_kc_ratio * kc_num)
remove_indices[..., :preserve_length] = False
sorted_clusters_to_keep = ~remove_indices
dynamic_map = torch.zeros(B, H, qc_num, kc_num, dtype=torch.bool, device=device)
dynamic_map.scatter_(-1, sorted_indices, sorted_clusters_to_keep)
return dynamic_map
# --- Functions from analyze/dynamic_block_sparse_attention.py ---
def dynamic_block_sparse_fwd_torch(q, k, v, dynamic_map, qc_size, kc_size):
"""
Computes dynamic block sparse attention using pure PyTorch.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
Returns:
torch.Tensor: Output tensor, shape [B, H, S, D].
"""
B, H, S, D = q.shape
qc_num = qc_size.shape[-1]
kc_num = kc_size.shape[-1]
device = q.device
dtype = q.dtype
# Ensure sequence lengths match sum of block sizes
assert S == torch.sum(qc_size[0, 0, :]), "Sum of qc_size must equal S"
assert S == torch.sum(kc_size[0, 0, :]), "Sum of kc_size must equal S"
# Precompute cumulative sizes for block indexing
# Add a 0 at the beginning for easier slicing
qc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(qc_size[..., :1]), qc_size], dim=-1), dim=-1)
kc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(kc_size[..., :1]), kc_size], dim=-1), dim=-1)
out = torch.zeros_like(q)
scale = D**-0.5
# Naive implementation: Iterate through batch, head, and blocks
for b in range(B):
for h in range(H):
# Precompute start/end indices for this batch/head
q_starts = qc_cum_size[b, h, :-1]
q_ends = qc_cum_size[b, h, 1:]
k_starts = kc_cum_size[b, h, :-1]
k_ends = kc_cum_size[b, h, 1:]
# Iterate through query blocks
for i in range(qc_num):
q_start, q_end = q_starts[i], q_ends[i]
q_block = q[b, h, q_start:q_end, :] # Shape: [qc_i, D]
if q_block.shape[0] == 0:
continue # Skip empty blocks
m_i = torch.full((q_block.shape[0], 1), -float("inf"), device=device, dtype=dtype)
l_i = torch.zeros((q_block.shape[0], 1), device=device, dtype=dtype)
acc_o_i = torch.zeros_like(q_block) # Shape: [qc_i, D]
# Iterate through key/value blocks for the current query block
for j in range(kc_num):
# Check if this block needs computation
if dynamic_map[b, h, i, j]:
k_start, k_end = k_starts[j], k_ends[j]
k_block = k[b, h, k_start:k_end, :] # Shape: [kc_j, D]
v_block = v[b, h, k_start:k_end, :] # Shape: [kc_j, D]
if k_block.shape[0] == 0:
continue # Skip empty blocks
# Compute attention scores for the block
# QK^T: [qc_i, D] @ [D, kc_j] -> [qc_i, kc_j]
s_ij = (q_block @ k_block.transpose(-1, -2)) * scale
# --- Online Softmax ---
# Find max score per query token in this block
m_ij = torch.max(s_ij, dim=-1, keepdim=True)[0] # Shape: [qc_i, 1]
# Update overall max score (m_i)
m_new = torch.maximum(m_i, m_ij) # Shape: [qc_i, 1]
# Calculate scaling factors for previous accumulator and current block
p_ij = torch.exp(s_ij - m_new) # Shape: [qc_i, kc_j]
exp_m_diff = torch.exp(m_i - m_new) # Shape: [qc_i, 1]
# Update softmax denominator (l_i)
l_i = (l_i * exp_m_diff) + torch.sum(p_ij, dim=-1, keepdim=True) # Shape: [qc_i, 1]
# Update output accumulator (acc_o_i)
# P_ij @ V_j: [qc_i, kc_j] @ [kc_j, D] -> [qc_i, D]
acc_o_i = (acc_o_i * exp_m_diff) + (p_ij @ v_block) # Shape: [qc_i, D]
# Update max score for next iteration
m_i = m_new
# Normalize the accumulated output
out[b, h, q_start:q_end, :] = acc_o_i / l_i.clamp(min=1e-12) # Avoid division by zero
return out
# --- Triton Implementation ---
@triton.jit
def _dynamic_block_sparse_fwd_kernel(
Q,
K,
V,
Out,
dynamic_map,
qc_cum_size,
kc_cum_size,
stride_qb,
stride_qh,
stride_qs,
stride_qd,
stride_kb,
stride_kh,
stride_ks,
stride_kd,
stride_vb,
stride_vh,
stride_vs,
stride_vd,
stride_ob,
stride_oh,
stride_os,
stride_od,
stride_dmap_b,
stride_dmap_h,
stride_dmap_qc,
stride_dmap_kc,
stride_qcs_b,
stride_qcs_h,
stride_qcs_qc,
stride_kcs_b,
stride_kcs_h,
stride_kcs_kc,
B,
H,
S,
D,
scale,
QC_NUM: tl.constexpr,
KC_NUM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""
Triton kernel for dynamic block sparse attention.
Each program computes attention for one query block within a batch/head.
Processes query block in chunks of BLOCK_M.
Iterates through key blocks, checking dynamic_map.
Processes key/value blocks in chunks of BLOCK_N.
Uses online softmax.
"""
# --- Grid Calculation ---
# Each program instance handles one query block for a specific batch and head
pid = tl.program_id(axis=0)
B * H * QC_NUM
# Calculate batch, head, and query block index
pid_q_block_global = pid # 0 to B*H*QC_NUM - 1
# pid_bh = pid // QC_NUM # Deprecated: Causes issues if QC_NUM is not constant across BH
# pid_q_block_idx = pid % QC_NUM
# Need to map pid (0.. B*H*QC_NUM-1) back to (b, h, q_block_idx)
# q_block_idx changes fastest, then h, then b
q_block_idx = pid_q_block_global % QC_NUM
pid_h_temp = pid_q_block_global // QC_NUM
h = pid_h_temp % H
b = pid_h_temp // H
# --- Load Q block info (start/end offsets) ---
qcs_offset = b * stride_qcs_b + h * stride_qcs_h
q_start_offset = tl.load(qc_cum_size + qcs_offset + q_block_idx * stride_qcs_qc)
q_end_offset = tl.load(qc_cum_size + qcs_offset + (q_block_idx + 1) * stride_qcs_qc)
q_block_size = q_end_offset - q_start_offset
# Early exit if the query block is empty
if q_block_size == 0:
return
# --- Pointers setup ---
q_ptr_base = Q + b * stride_qb + h * stride_qh + q_start_offset * stride_qs
k_ptr_base = K + b * stride_kb + h * stride_kh
v_ptr_base = V + b * stride_vb + h * stride_vh
out_ptr_base = Out + b * stride_ob + h * stride_oh + q_start_offset * stride_os
dmap_ptr = dynamic_map + b * stride_dmap_b + h * stride_dmap_h + q_block_idx * stride_dmap_qc
kcs_ptr = kc_cum_size + b * stride_kcs_b + h * stride_kcs_h
# --- Iterate over the query block rows in chunks of BLOCK_M ---
offs_qm = tl.arange(0, BLOCK_M) # Query block row offsets [0, 1, ..., BLOCK_M-1]
offs_d = tl.arange(0, BLOCK_D) # Dimension offsets [0, 1, ..., BLOCK_D-1]
for q_chunk_start in range(0, q_block_size, BLOCK_M):
q_chunk_rows = offs_qm + q_chunk_start
q_rows_mask = q_chunk_rows < q_block_size # Mask for valid rows in this Q chunk [BLOCK_M]
# --- Initialize accumulators for this Q chunk ---
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # Max score
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # Sum of exp(scores - max)
acc_o = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32) # Accumulated output
# --- Load Q chunk ---
q_ptr = q_ptr_base + q_chunk_rows[:, None] * stride_qs + offs_d[None, :]
# Mask ensures we don't read out of bounds for the query block or dimension D
mask_q = q_rows_mask[:, None] & (offs_d[None, :] < D)
q_chunk = tl.load(q_ptr, mask=mask_q, other=0.0) # Shape: [BLOCK_M, BLOCK_D]
# --- Inner loop over K blocks (columns in the block sparse map) ---
for k_block_idx in range(KC_NUM):
# --- Check dynamic_map: Is this block active? ---
is_active = tl.load(dmap_ptr + k_block_idx * stride_dmap_kc)
if is_active: # Process block only if it's active
# --- Load K block info (start/end offsets) ---
k_start_offset = tl.load(kcs_ptr + k_block_idx * stride_kcs_kc)
k_end_offset = tl.load(kcs_ptr + (k_block_idx + 1) * stride_kcs_kc)
k_block_size = k_end_offset - k_start_offset
# Skip if the key block is empty (inside the active block check)
if k_block_size > 0:
k_block_ptr_base = k_ptr_base + k_start_offset * stride_ks
v_block_ptr_base = v_ptr_base + k_start_offset * stride_vs
# --- Loop over K block chunks (size BLOCK_N) ---
offs_kn = tl.arange(0, BLOCK_N) # Key block row offsets [0, ..., BLOCK_N-1]
for k_chunk_start in range(0, k_block_size, BLOCK_N):
k_chunk_rows = offs_kn + k_chunk_start
k_rows_mask = k_chunk_rows < k_block_size # Mask for valid rows in this K/V chunk [BLOCK_N]
# --- Load K, V chunks ---
k_ptr = k_block_ptr_base + k_chunk_rows[:, None] * stride_ks + offs_d[None, :]
v_ptr = v_block_ptr_base + k_chunk_rows[:, None] * stride_vs + offs_d[None, :]
# Mask ensures we don't read out of bounds for the key block or dimension D
mask_kv = k_rows_mask[:, None] & (offs_d[None, :] < D)
k_chunk = tl.load(k_ptr, mask=mask_kv, other=0.0) # Shape: [BLOCK_N, BLOCK_D]
v_chunk = tl.load(v_ptr, mask=mask_kv, other=0.0) # Shape: [BLOCK_N, BLOCK_D]
# --- Compute Scores (Attention) ---
# QK^T: [BLOCK_M, BLOCK_D] @ [BLOCK_D, BLOCK_N] -> [BLOCK_M, BLOCK_N]
s_ij_chunk = tl.dot(q_chunk, k_chunk.T) * scale
# IMPORTANT: Mask out scores corresponding to padding in K before max/softmax
# Set scores for invalid K elements to -inf
s_ij_chunk = tl.where(k_rows_mask[None, :], s_ij_chunk, -float("inf"))
# Mask out scores for invalid Q elements as well (although q_chunk elements are 0, avoid potential issues)
s_ij_chunk = tl.where(q_rows_mask[:, None], s_ij_chunk, -float("inf"))
# --- Online Softmax Update ---
# Current max for this Q-K chunk interaction
m_ij_chunk = tl.max(s_ij_chunk, axis=1) # Shape: [BLOCK_M]
# Update overall max (across K chunks seen so far for this Q chunk)
m_new = tl.maximum(m_i, m_ij_chunk) # Shape: [BLOCK_M]
# Calculate scaled probabilities P_ij = exp(S_ij - m_new)
p_ij_chunk = tl.exp(s_ij_chunk - m_new[:, None]) # Shape: [BLOCK_M, BLOCK_N]
# Zero out probabilities for masked K elements before summing
p_ij_chunk = tl.where(k_rows_mask[None, :], p_ij_chunk, 0.0)
# Calculate scaling factor for previous accumulator state
exp_m_diff = tl.exp(m_i - m_new) # Shape: [BLOCK_M]
# Update sum accumulator (denominator L)
l_i_chunk = tl.sum(p_ij_chunk, axis=1) # Sum probabilities for this chunk, shape [BLOCK_M]
l_i = (l_i * exp_m_diff) + l_i_chunk # Shape: [BLOCK_M]
# Update output accumulator O
# P_ij @ V_j: [BLOCK_M, BLOCK_N] @ [BLOCK_N, BLOCK_D] -> [BLOCK_M, BLOCK_D]
# Ensure p_ij_chunk is the correct dtype for dot product
p_ij_chunk_casted = p_ij_chunk.to(V.dtype.element_ty)
o_chunk = tl.dot(p_ij_chunk_casted, v_chunk) # Shape: [BLOCK_M, BLOCK_D]
acc_o = (acc_o * exp_m_diff[:, None]) + o_chunk # Shape: [BLOCK_M, BLOCK_D]
# Update max for the next K chunk/block
m_i = m_new
# End of 'if is_active:' block
# --- End of loop over K blocks ---
# --- Finalize output for this Q chunk ---
# Normalize the accumulated output: O = acc_o / l_i
# Add epsilon to l_i to avoid division by zero
l_i_safe = tl.where(l_i == 0, 1.0, l_i) # Avoid 0/0 -> NaN
o_final_chunk = acc_o / (l_i_safe[:, None])
o_final_chunk = tl.where(l_i[:, None] == 0, 0.0, o_final_chunk) # Ensure output is 0 if l_i was 0
# --- Write output chunk to global memory ---
out_ptr = out_ptr_base + q_chunk_rows[:, None] * stride_os + offs_d[None, :]
# Mask ensures we don't write out of bounds for the query block or dimension D
mask_out = q_rows_mask[:, None] & (offs_d[None, :] < D)
tl.store(out_ptr, o_final_chunk.to(Out.dtype.element_ty), mask=mask_out)
# --- (Optional: Write L and M stats if needed) ---
# Example:
# l_ptr = L + b * stride_lb + h * stride_lh + (q_start_offset + q_chunk_rows) * stride_ls
# tl.store(l_ptr, l_i, mask=q_rows_mask)
# m_ptr = M + ...
# tl.store(m_ptr, m_i, mask=q_rows_mask)
# --- End of loop over Q chunks ---
def dynamic_block_sparse_fwd_triton(q, k, v, dynamic_map, qc_size, kc_size):
"""
Launcher for the Triton dynamic block sparse attention kernel.
Args:
q (torch.Tensor): Query tensor, shape [B, H, S, D].
k (torch.Tensor): Key tensor, shape [B, H, S, D].
v (torch.Tensor): Value tensor, shape [B, H, S, D].
dynamic_map (torch.Tensor): Boolean mask, shape [B, H, qc_num, kc_num].
qc_size (torch.Tensor): Query block sizes, shape [B, H, qc_num].
kc_size (torch.Tensor): Key block sizes, shape [B, H, kc_num].
Returns:
torch.Tensor: Output tensor, shape [B, H, S, D].
"""
B, H, S, D = q.shape
qc_num = qc_size.shape[-1]
kc_num = kc_size.shape[-1]
dtype = q.dtype
# Assertions and checks
assert q.is_cuda and k.is_cuda and v.is_cuda, "Inputs must be CUDA tensors"
assert dynamic_map.is_cuda and qc_size.is_cuda and kc_size.is_cuda
assert q.dtype == k.dtype == v.dtype, "Input dtypes must match"
assert dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype"
assert D in [16, 32, 64, 128], "Head dimension D must be 16, 32, 64, or 128 for efficient Triton dot"
# Ensure sequence lengths match sum of block sizes (check on one batch/head for simplicity)
assert S == torch.sum(qc_size[0, 0, :]), "Sum of qc_size must equal S"
assert S == torch.sum(kc_size[0, 0, :]), "Sum of kc_size must equal S"
# Ensure dynamic_map is boolean
assert dynamic_map.dtype == torch.bool
# Calculate scale factor (using float32 for stability)
scale = D**-0.5
# Precompute cumulative sizes (on CPU/GPU, keep on device)
qc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(qc_size[..., :1]), qc_size], dim=-1), dim=-1).int()
kc_cum_size = torch.cumsum(torch.cat([torch.zeros_like(kc_size[..., :1]), kc_size], dim=-1), dim=-1).int()
# Output tensor
out = torch.empty_like(q)
# Triton kernel config
# BLOCK_M/N can be tuned. Larger blocks may increase occupancy but need more shared memory.
# Let's start with reasonably sized blocks.
BLOCK_D = D
if S <= 512: # Smaller sequence, smaller blocks might be ok
BLOCK_M = 64
BLOCK_N = 64
elif S <= 1024:
BLOCK_M = 64
BLOCK_N = 64
else: # Larger sequence, potentially larger blocks
BLOCK_M = 128 # Or keep 64? Test
BLOCK_N = 64
# Adjust block size if sequence length is smaller
BLOCK_M = min(BLOCK_M, S)
BLOCK_N = min(BLOCK_N, S)
# Launch grid: One program per query block per batch/head
grid = (B * H * qc_num,)
# Call the kernel
_dynamic_block_sparse_fwd_kernel[grid](
q,
k,
v,
out,
dynamic_map,
qc_cum_size,
kc_cum_size,
q.stride(0),
q.stride(1),
q.stride(2),
q.stride(3),
k.stride(0),
k.stride(1),
k.stride(2),
k.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
out.stride(0),
out.stride(1),
out.stride(2),
out.stride(3),
dynamic_map.stride(0),
dynamic_map.stride(1),
dynamic_map.stride(2),
dynamic_map.stride(3),
qc_cum_size.stride(0),
qc_cum_size.stride(1),
qc_cum_size.stride(2),
kc_cum_size.stride(0),
kc_cum_size.stride(1),
kc_cum_size.stride(2),
B,
H,
S,
D,
scale,
QC_NUM=qc_num,
KC_NUM=kc_num,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_D=BLOCK_D,
# num_warps=4 # Can tune this
)
return out
# ---------------- Batch wrapper for cuVS KMeans -----------------
def batch_kmeans_rapidai(x, n_clusters, max_iters=100, tol=1e-4, init_centroids=None, verbose=False):
"""Batched K-Means using RAPIDS cuVS implementation.
Args:
x (Tensor): (B, N, D) float32 tensor on CUDA.
n_clusters (int): K.
max_iters (int): maximum iterations.
tol (float): tolerance.
init_centroids (Tensor|None): optional initial centroids (B,K,D) float32.
verbose (bool): print per-batch info.
Returns:
cluster_ids (B, N) LongTensor
centroids (B, K, D) float32
cluster_sizes (B, K) LongTensor
n_iters_list (List[int]) iterations per batch
"""
B, N, D = x.shape
if init_centroids is not None:
assert init_centroids.shape == (B, n_clusters, D)
cluster_ids_list = []
centroids_list = []
# cluster_sizes_list = []
n_iters_list = []
x_float = x.float()
if init_centroids is not None:
init_centroids_float = init_centroids.float()
for b in range(B):
xb = x_float[b]
if init_centroids is None:
centroids_init_b = None
init_method = "KMeansPlusPlus"
else:
centroids_init_b = init_centroids_float[b]
init_method = "Array"
labels_b, centroids_b, n_iter_b = kmeans_rapidai(xb, n_clusters, max_iter=max_iters, tol=tol, init_method=init_method, centroids_init=centroids_init_b)
cluster_ids_list.append(labels_b.to(torch.int64)) # (N,)
centroids_list.append(centroids_b)
# cluster_sizes_b = torch.bincount(labels_b, minlength=n_clusters).to(torch.int64)
# cluster_sizes_list.append(cluster_sizes_b)
# n_iters_list.append(n_iter_b)
# if verbose:
# print(f"Batch {b}: iters={n_iter_b}, cluster sizes min={cluster_sizes_b.min().item()} max={cluster_sizes_b.max().item()}")
cluster_ids = torch.stack(cluster_ids_list, dim=0) # (B,N)
centroids = torch.stack(centroids_list, dim=0).to(x.dtype) # (B,K,D)
# cluster_sizes = torch.stack(cluster_sizes_list, dim=0) # (B,K)
# --- compute cluster sizes ---
ones = torch.ones_like(cluster_ids, dtype=torch.int64)
cluster_sizes = torch.zeros(B, n_clusters, dtype=torch.int64, device=x.device)
cluster_sizes.scatter_add_(1, cluster_ids, ones)
return cluster_ids, centroids, cluster_sizes, n_iters_list
import math
from functools import lru_cache
from math import ceil
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from loguru import logger
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@triton.jit
def wan_hidden_states_placement_kernel(
hidden_states_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
hidden_states_out_ptr, # [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr, # [cfg, num_heads]
hidden_states_stride_b,
hidden_states_stride_h,
hidden_states_stride_s,
hidden_states_stride_d,
mask_idx_stride_b,
mask_idx_stride_h,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
context_length: tl.constexpr,
num_frame: tl.constexpr,
frame_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Copy hidden_states to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg = tl.program_id(0)
head = tl.program_id(1)
block_id = tl.program_id(2)
start_id = block_id * BLOCK_SIZE
end_id = start_id + BLOCK_SIZE
end_id = tl.where(end_id > seq_len, seq_len, end_id)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
offset_token = tl.arange(0, BLOCK_SIZE) + start_id
offset_mask = offset_token < seq_len
offset_d = tl.arange(0, head_dim)
if is_temporal:
patch_id = offset_token // num_frame
frame_id = offset_token - patch_id * num_frame
offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, frame_id * frame_size + patch_id)
offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states = hidden_states_ptr + offset_load
offset_store = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_store_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states_out = hidden_states_out_ptr + offset_store
# Maybe tune the pipeline here
hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
else:
offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:, None] * hidden_states_stride_s) + offset_d[None, :] * hidden_states_stride_d
offset_hidden_states = hidden_states_ptr + offset_load
offset_store = offset_load
offset_hidden_states_out = hidden_states_out_ptr + offset_store
# Maybe tune the pipeline here
hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:, None])
tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:, None])
def wan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size):
cfg, num_heads, seq_len, head_dim = hidden_states.shape
BLOCK_SIZE = 128
assert seq_len == context_length + num_frame * frame_size
grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
wan_hidden_states_placement_kernel[grid](
hidden_states,
hidden_states_out,
best_mask_idx,
hidden_states.stride(0),
hidden_states.stride(1),
hidden_states.stride(2),
hidden_states.stride(3),
best_mask_idx.stride(0),
best_mask_idx.stride(1),
seq_len,
head_dim,
context_length,
num_frame,
frame_size,
BLOCK_SIZE,
)
return hidden_states_out
@triton.jit
def wan_sparse_head_placement_kernel(
query_ptr,
key_ptr,
value_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
query_out_ptr,
key_out_ptr,
value_out_ptr, # [cfg, num_heads, seq_len, head_dim]
best_mask_idx_ptr, # [cfg, num_heads]
query_stride_b,
query_stride_h,
query_stride_s,
query_stride_d,
mask_idx_stride_b,
mask_idx_stride_h,
seq_len: tl.constexpr,
head_dim: tl.constexpr,
context_length: tl.constexpr,
num_frame: tl.constexpr,
frame_size: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Copy query, key, value to output
# range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
cfg = tl.program_id(0)
head = tl.program_id(1)
block_id = tl.program_id(2)
start_id = block_id * BLOCK_SIZE
end_id = start_id + BLOCK_SIZE
end_id = tl.where(end_id > seq_len, seq_len, end_id)
# Load best mask idx (0 is spatial, 1 is temporal)
is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
offset_token = tl.arange(0, BLOCK_SIZE) + start_id
offset_mask = offset_token < seq_len
offset_d = tl.arange(0, head_dim)
if is_temporal:
frame_id = offset_token // frame_size
patch_id = offset_token - frame_id * frame_size
offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, patch_id * num_frame + frame_id)
offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query = query_ptr + offset_load
offset_key = key_ptr + offset_load
offset_value = value_ptr + offset_load
offset_store = (cfg * query_stride_b + head * query_stride_h + offset_store_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query_out = query_out_ptr + offset_store
offset_key_out = key_out_ptr + offset_store
offset_value_out = value_out_ptr + offset_store
# Maybe tune the pipeline here
query = tl.load(offset_query, mask=offset_mask[:, None])
tl.store(offset_query_out, query, mask=offset_mask[:, None])
key = tl.load(offset_key, mask=offset_mask[:, None])
tl.store(offset_key_out, key, mask=offset_mask[:, None])
value = tl.load(offset_value, mask=offset_mask[:, None])
tl.store(offset_value_out, value, mask=offset_mask[:, None])
else:
offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:, None] * query_stride_s) + offset_d[None, :] * query_stride_d
offset_query = query_ptr + offset_load
offset_key = key_ptr + offset_load
offset_value = value_ptr + offset_load
offset_store = offset_load
offset_query_out = query_out_ptr + offset_store
offset_key_out = key_out_ptr + offset_store
offset_value_out = value_out_ptr + offset_store
# Maybe tune the pipeline here
query = tl.load(offset_query, mask=offset_mask[:, None])
tl.store(offset_query_out, query, mask=offset_mask[:, None])
key = tl.load(offset_key, mask=offset_mask[:, None])
tl.store(offset_key_out, key, mask=offset_mask[:, None])
value = tl.load(offset_value, mask=offset_mask[:, None])
tl.store(offset_value_out, value, mask=offset_mask[:, None])
def wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
cfg, num_heads, seq_len, head_dim = query.shape
BLOCK_SIZE = 128
assert seq_len == context_length + num_frame * frame_size
grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
wan_sparse_head_placement_kernel[grid](
query,
key,
value,
query_out,
key_out,
value_out,
best_mask_idx,
query.stride(0),
query.stride(1),
query.stride(2),
query.stride(3),
best_mask_idx.stride(0),
best_mask_idx.stride(1),
seq_len,
head_dim,
context_length,
num_frame,
frame_size,
BLOCK_SIZE,
)
def generate_temporal_head_mask_mod(context_length: int = 226, prompt_length: int = 226, num_frames: int = 13, token_per_frame: int = 1350, mul: int = 2):
def round_to_multiple(idx):
return ceil(idx / 128) * 128
def temporal_mask_mod(b, h, q_idx, kv_idx):
two_frame = round_to_multiple(mul * token_per_frame)
temporal_head_mask = torch.abs(q_idx - kv_idx) <= two_frame
# return temporal_head_mask
first_frame_mask = kv_idx < token_per_frame
video_mask = first_frame_mask | temporal_head_mask
return video_mask
return temporal_mask_mod
@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
return block_mask
def prepare_flexattention(cfg_size, num_head, head_dim, dtype, device, context_length, prompt_length, num_frame, frame_size, diag_width=1, multiplier=2):
assert diag_width == multiplier, f"{diag_width} is not equivalent to {multiplier}"
seq_len = context_length + num_frame * frame_size
mask_mod = generate_temporal_head_mask_mod(context_length, prompt_length, num_frame, frame_size, mul=multiplier)
block_mask = create_block_mask_cached(mask_mod, None, None, seq_len, seq_len, device=device, _compile=True)
return block_mask
def sparsity_to_width(sparsity, context_length, num_frame, frame_size):
seq_len = context_length + num_frame * frame_size
total_elements = seq_len**2
sparsity = (sparsity * total_elements - 2 * seq_len * context_length) / total_elements
width = seq_len * (1 - math.sqrt(1 - sparsity))
width_frame = width / frame_size
return width_frame
def get_attention_mask(mask_name, sample_mse_max_row, context_length, num_frame, frame_size):
attention_mask = torch.zeros((context_length + num_frame * frame_size, context_length + num_frame * frame_size), device="cpu")
# TODO: fix hard coded mask
if mask_name == "spatial":
pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")
pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink
block_size, block_thres = 128, frame_size * 2
num_block = math.ceil(num_frame * frame_size / block_size)
for i in range(num_block):
for j in range(num_block):
if abs(i - j) < block_thres // block_size:
pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
attention_mask = pixel_attn_mask
else:
pixel_attn_mask = torch.zeros_like(attention_mask, dtype=torch.bool, device="cpu")
pixel_attn_mask[:, :frame_size] = 1 # First Frame Sink
block_size, block_thres = 128, frame_size * 2
num_block = math.ceil(num_frame * frame_size / block_size)
for i in range(num_block):
for j in range(num_block):
if abs(i - j) < block_thres // block_size:
pixel_attn_mask[i * block_size : (i + 1) * block_size, j * block_size : (j + 1) * block_size] = 1
pixel_attn_mask = pixel_attn_mask.reshape(frame_size, num_frame, frame_size, num_frame).permute(1, 0, 3, 2).reshape(frame_size * num_frame, frame_size * num_frame)
attention_mask = pixel_attn_mask
attention_mask = attention_mask[:sample_mse_max_row].cuda()
return attention_mask
@ATTN_WEIGHT_REGISTER("svg_attn")
class SvgAttnWeight(AttnWeightTemplate):
head_num = None
head_dim = None
sample_mse_max_row = None
num_sampled_rows = None
context_length = None
attnmap_frame_num = None
seqlen = None
sparsity = None
mask_name_list = ["spatial", "temporal"]
attention_masks = None
block_mask = None
@classmethod
def prepare(cls, head_num, head_dim, sample_mse_max_row, num_sampled_rows, context_length, sparsity):
cls.head_num = head_num
cls.head_dim = head_dim
cls.sample_mse_max_row = sample_mse_max_row
cls.num_sampled_rows = num_sampled_rows
cls.context_length = context_length
cls.sparsity = sparsity
torch._dynamo.config.cache_size_limit = 192 * 3
torch._dynamo.config.accumulated_cache_size_limit = 192 * 3
logger.info(
f"SvgAttnWeight Prepare: head_num={head_num}, head_dim={head_dim}, sample_mse_max_row={sample_mse_max_row}, num_sampled_rows={num_sampled_rows}, context_length={context_length}, sparsity={sparsity}"
)
def __init__(self):
self.config = {}
self.sparse_attention = torch.compile(flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs")
@classmethod
def prepare_mask(cls, seqlen):
# Use class attributes so updates affect all instances of this class
if seqlen == cls.seqlen:
return
frame_size = seqlen // cls.attnmap_frame_num
cls.attention_masks = [get_attention_mask(mask_name, cls.sample_mse_max_row, cls.context_length, cls.attnmap_frame_num, frame_size) for mask_name in cls.mask_name_list]
multiplier = diag_width = sparsity_to_width(cls.sparsity, cls.context_length, cls.attnmap_frame_num, frame_size)
cls.block_mask = prepare_flexattention(
1, cls.head_num, cls.head_dim, torch.bfloat16, "cuda", cls.context_length, cls.context_length, cls.attnmap_frame_num, frame_size, diag_width=diag_width, multiplier=multiplier
)
cls.seqlen = seqlen
logger.info(f"SvgAttnWeight Update: seqlen={seqlen}")
def apply(
self,
q,
k,
v,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
q = q.unsqueeze(0).transpose(1, 2)
k = k.unsqueeze(0).transpose(1, 2)
v = v.unsqueeze(0).transpose(1, 2)
bs, num_heads, seq_len, dim = q.size()
self.prepare_mask(seq_len)
sampled_mses = self.sample_mse(q, k, v)
best_mask_idx = torch.argmin(sampled_mses, dim=0)
output_hidden_states = torch.zeros_like(q)
query_out, key_out, value_out = torch.zeros_like(q), torch.zeros_like(k), torch.zeros_like(v)
query_out, key_out, value_out = self.fast_sparse_head_placement(
q, k, v, query_out, key_out, value_out, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num
)
hidden_states = self.sparse_attention(query_out, key_out, value_out)
wan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, self.context_length, self.attnmap_frame_num, seq_len // self.attnmap_frame_num)
return output_hidden_states.reshape(bs, num_heads, seq_len, dim).transpose(1, 2).reshape(bs * seq_len, -1)
def fast_sparse_head_placement(self, query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
wan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
return query_out, key_out, value_out
def sample_mse(self, query, key, value):
cfg, num_heads, seq_len, dim = query.size()
num_sampled_rows = min(self.num_sampled_rows, seq_len)
sampled_rows = torch.randint(low=0, high=self.sample_mse_max_row, size=(num_sampled_rows,))
sampled_q = query[:, :, sampled_rows, :]
sampled_qk_scores = torch.matmul(sampled_q, key.transpose(-2, -1)) / (dim**0.5)
sampled_attn_weights = F.softmax(sampled_qk_scores, dim=-1)
sampled_golden_hidden_states = torch.matmul(sampled_attn_weights, value) # (1, seq_len, dim)
sampled_mses = torch.zeros(len(self.attention_masks), cfg, num_heads, device=query.device, dtype=query.dtype)
# Only have Tri-diagonal and Striped
for mask_idx, attn_mask in enumerate(self.attention_masks):
sampled_attention_mask = attn_mask[sampled_rows, :]
sampled_attention_scores = sampled_qk_scores.masked_fill(sampled_attention_mask == 0, float("-inf"))
sampled_attn_weights = F.softmax(sampled_attention_scores, dim=-1)
sampled_hidden_states = torch.matmul(sampled_attn_weights, value)
mse = torch.mean((sampled_hidden_states - sampled_golden_hidden_states) ** 2, dim=(2, 3))
sampled_mses[mask_idx] = mse
return sampled_mses
if __name__ == "__main__":
q, k, v = torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda(), torch.randn(32130, 40, 128, dtype=torch.bfloat16).cuda()
SvgAttnWeight.prepare(head_num=40, head_dim=128, sample_mse_max_row=10000, num_sampled_rows=64, context_length=0, sparsity=0.25)
svg_attn = SvgAttnWeight()
print("SvgAttnWeight initialized.")
out = svg_attn.apply(q, k, v)
print(f"out: {out.shape}, {out.dtype}, {out.device}")
from abc import ABCMeta, abstractmethod
class AttnWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name):
self.weight_name = weight_name
self.config = {}
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cpu(self, non_blocking=False):
pass
def to_cuda(self, non_blocking=False):
pass
def state_dict(self, destination=None):
if destination is None:
destination = {}
return destination
def load_state_dict(self, destination, block_index, adapter_block_inde=None):
return {}
def load_state_dict_from_disk(self, block_index, adapter_block_inde=None):
pass
import torch
import torch.nn.functional as F
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate
@ATTN_WEIGHT_REGISTER("torch_sdpa")
class TorchSDPAWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(
self,
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
cu_seqlens_q=None,
cu_seqlens_kv=None,
max_seqlen_q=None,
max_seqlen_kv=None,
model_cls=None,
):
if q.ndim == 3:
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
x = x.transpose(1, 2)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out.squeeze(0)
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
from .template import AttnWeightTemplate
from .utils.all2all import all2all_head2seq, all2all_seq2head
@ATTN_WEIGHT_REGISTER("ulysses")
class UlyssesAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
# 获取当前进程的排名和全局进程数
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = None
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
# 将图像的查询、键和值转换为头的格式
if use_fp8_comm:
original_dtype = img_q.dtype
original_shape = img_q.shape
img_q_fp8, q_scale = quant_fp8_vllm(img_q.reshape(-1, original_shape[-1]))
img_k_fp8, k_scale = quant_fp8_vllm(img_k.reshape(-1, original_shape[-1]))
img_v_fp8, v_scale = quant_fp8_vllm(img_v.reshape(-1, original_shape[-1]))
img_q_fp8 = all2all_seq2head(img_q_fp8.reshape(original_shape), group=seq_p_group)
img_k_fp8 = all2all_seq2head(img_k_fp8.reshape(original_shape), group=seq_p_group)
img_v_fp8 = all2all_seq2head(img_v_fp8.reshape(original_shape), group=seq_p_group)
q_scale = all2all_seq2head(q_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
k_scale = all2all_seq2head(k_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
v_scale = all2all_seq2head(v_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
img_q = dequant_fp8_vllm(img_q_fp8, q_scale, original_dtype)
img_k = dequant_fp8_vllm(img_k_fp8, k_scale, original_dtype)
img_v = dequant_fp8_vllm(img_v_fp8, v_scale, original_dtype)
else:
img_q = all2all_seq2head(img_q, group=seq_p_group)
img_k = all2all_seq2head(img_k, group=seq_p_group)
img_v = all2all_seq2head(img_v, group=seq_p_group)
# 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
# 合并图像和文本的查询、键和值
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
# 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=AI_DEVICE)
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
if txt_mask_len:
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls)
# 分割图像和文本的注意力结果
img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn = torch.cat([img_attn, txt_attn], dim=0)
return attn # 返回最终的注意力结果
@torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
# 将头的格式转换回序列格式
if use_fp8_comm:
original_dtype = img_attn.dtype
original_shape = img_attn.shape
img_attn_fp8, attn_scale = quant_fp8_vllm(img_attn.reshape(-1, original_shape[-1]))
img_attn_fp8 = all2all_head2seq(img_attn_fp8.reshape(original_shape), group=seq_p_group)
attn_scale = all2all_head2seq(attn_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
img_attn = dequant_fp8_vllm(img_attn_fp8, attn_scale, original_dtype)
else:
img_attn = all2all_head2seq(img_attn, group=seq_p_group)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
return img_attn
@ATTN_WEIGHT_REGISTER("ulysses-4090")
class Ulysses4090AttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
self.rounds = []
def generate_round_robin_pairs(self, seq_p_group=None):
"""
生成循环赛配对表,并确保每个配对中的第一个元素小于第二个
这样我们可以用简单的规则确定通信顺序
"""
cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group)
if world_size % 2 != 0:
raise ValueError("world_size必须是偶数,奇数情况需要特殊处理")
teams = list(range(world_size))
for _ in range(world_size - 1):
round_schedule = {}
for i in range(world_size // 2):
team1, team2 = teams[i], teams[world_size - 1 - i]
smaller, larger = min(team1, team2), max(team1, team2)
round_schedule[smaller] = (larger, True)
round_schedule[larger] = (smaller, False)
self.rounds.append(round_schedule)
# 旋转列表(固定第一个元素)
teams = [teams[0]] + [teams[-1]] + teams[1:-1]
# if cur_rank == 0:
# self.print_pairing_schedule(seq_p_group)
def print_pairing_schedule(self, seq_p_group):
"""打印通信调度表"""
world_size = dist.get_world_size(seq_p_group)
logger.info("循环赛通信调度表:")
logger.info("=" * 50)
for i, round_schedule in enumerate(self.rounds):
logger.info(f"第 {i + 1} 轮:")
for cur_rank in range(world_size):
partner, is_smaller_in_pair = round_schedule[cur_rank]
logger.info(f" 进程 {cur_rank} ←→ 进程 {partner}")
logger.info("=" * 50)
def load_balanced_all_to_all(self, shards, seq_p_group=None):
"""
负载均衡all-to-all通信实现
"""
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
# 准备接收缓冲区
gathered_shards = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_shards[target_rank] = torch.empty_like(shards[target_rank])
else:
gathered_shards[cur_rank] = shards[cur_rank]
for i, round_schedule in enumerate(self.rounds):
# 查找当前进程在本轮的配对
partner = None
is_smaller_in_pair = False
if cur_rank in round_schedule:
partner, is_smaller_in_pair = round_schedule[cur_rank]
# 如果没有找到配对,说明本轮当前进程空闲
if partner is None:
continue
# 计算全局rank
partner_global_rank = cfg_p_group_index * world_size + partner
if is_smaller_in_pair:
# 当前进程是配对中的较小者,先发送后接收
send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group)
send_req.wait()
recv_req.wait()
else:
# 当前进程是配对中的较大者,先接收后发送
recv_req = dist.irecv(gathered_shards[partner], src=partner_global_rank, group=seq_p_group)
send_req = dist.isend(shards[partner], dst=partner_global_rank, group=seq_p_group)
recv_req.wait()
send_req.wait()
return gathered_shards
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
"""
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
"""
if len(self.rounds) == 0:
self.generate_round_robin_pairs(seq_p_group)
if len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
# 获取当前进程的排名和全局进程数
world_size = dist.get_world_size(seq_p_group)
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
# 获取序列长度和文本相关的长度
seq_len = q.shape[0]
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len # 文本掩码长度
elif len(cu_seqlens_qkv) == 2:
txt_qkv_len = cu_seqlens_qkv[1] - img_qkv_len # 文本查询、键和值的长度
txt_mask_len = None
# 获取查询张量的头数和隐藏维度
_, heads, hidden_dims = q.shape
shard_heads = heads // world_size # 每个进程处理的头数
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q, img_k, img_v = q[:img_qkv_len, :, :].contiguous(), k[:img_qkv_len, :, :].contiguous(), v[:img_qkv_len, :, :].contiguous()
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
# 计算每个进程应该持有的头数分片
num_heads = img_q.shape[1]
shard_heads = num_heads // world_size
# 将 image QKV 拼接后,按头维度切分成 N 份,每份大小为 D/N
img_qkv = torch.stack([img_q, img_k, img_v], dim=0)
qkv_shards = [img_qkv[:, :, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
qkv_dtype = img_qkv.dtype
if use_fp8_comm:
qkv_fp8_byte_tensors = []
qkv_fp8_bytes = 0
qkv_fp8_dtype = None
qkv_scale_dtype = None
for i in range(world_size):
qkv_fp8, qkv_scale = quant_fp8_vllm(qkv_shards[i].reshape(-1, hidden_dims))
if i == 0:
qkv_fp8_bytes = qkv_fp8.numel() * qkv_fp8.element_size()
qkv_fp8_dtype = qkv_fp8.dtype
qkv_scale_dtype = qkv_scale.dtype
qkv_fp8_byte_tensors.append(torch.cat([qkv_fp8.contiguous().reshape(-1).view(torch.uint8), qkv_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0))
gathered_qkv_fp8_byte_tensors = self.load_balanced_all_to_all(qkv_fp8_byte_tensors, seq_p_group)
gathered_q_shards = []
gathered_k_shards = []
gathered_v_shards = []
for i in range(world_size):
qkv_fp8_byte_tensor = gathered_qkv_fp8_byte_tensors[i]
qkv_fp8 = qkv_fp8_byte_tensor[:qkv_fp8_bytes].view(qkv_fp8_dtype).reshape(3, -1, hidden_dims)
qkv_scale = qkv_fp8_byte_tensor[qkv_fp8_bytes:].view(qkv_scale_dtype).reshape(3, -1, 1)
q_shards_new = dequant_fp8_vllm(qkv_fp8[0], qkv_scale[0], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
k_shards_new = dequant_fp8_vllm(qkv_fp8[1], qkv_scale[1], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
v_shards_new = dequant_fp8_vllm(qkv_fp8[2], qkv_scale[2], qkv_dtype).reshape(-1, shard_heads, hidden_dims)
gathered_q_shards.append(q_shards_new)
gathered_k_shards.append(k_shards_new)
gathered_v_shards.append(v_shards_new)
else:
gathered_qkv_byte_tensors = self.load_balanced_all_to_all(qkv_shards, seq_p_group)
gathered_q_shards = []
gathered_k_shards = []
gathered_v_shards = []
for i in range(world_size):
qkv_tensor = gathered_qkv_byte_tensors[i].view(qkv_dtype).reshape(3, -1, shard_heads, hidden_dims)
gathered_q_shards.append(qkv_tensor[0])
gathered_k_shards.append(qkv_tensor[1])
gathered_v_shards.append(qkv_tensor[2])
# 拼接所有分片 (在序列维度上)
# 每个 gathered_*_shards[i] 的形状是 (seq_len/N, num_heads/N, head_dim)
# 拼接后形状是 (seq_len, num_heads/N, head_dim)
img_q = torch.cat(gathered_q_shards, dim=0)
img_k = torch.cat(gathered_k_shards, dim=0)
img_v = torch.cat(gathered_v_shards, dim=0)
# 处理文本的查询、键和值,选择当前进程的头
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
# 合并图像和文本的查询、键和值
q = torch.cat((img_q, txt_q), dim=0)
k = torch.cat((img_k, txt_k), dim=0)
v = torch.cat((img_v, txt_v), dim=0)
# 初始化累积序列长度张量
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device="cuda")
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
s1 = s # 当前样本的结束位置
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
if txt_mask_len:
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
# 调用注意力函数计算注意力结果
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls)
# 分割图像和文本的注意力结果
img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn = torch.cat([img_attn, txt_attn], dim=0)
return attn # 返回最终的注意力结果
@torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank)
cfg_p_group_index = global_rank // world_size
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
attn_dtype = img_attn.dtype
# 按序列维度切分成 N 份
attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)]
if use_fp8_comm:
attn_fp8_byte_tensors = []
attn_fp8_bytes = 0
attn_fp8_dtype = None
attn_scale_dtype = None
for i in range(world_size):
attn_fp8, attn_scale = quant_fp8_vllm(attn_shards[i].reshape(-1, hidden_dims))
if i == 0:
attn_fp8_bytes = attn_fp8.numel() * attn_fp8.element_size()
attn_fp8_dtype = attn_fp8.dtype
attn_scale_dtype = attn_scale.dtype
attn_fp8_byte_tensors.append(torch.cat([attn_fp8.contiguous().reshape(-1).view(torch.uint8), attn_scale.contiguous().reshape(-1).view(torch.uint8)], dim=0))
gathered_attn_fp8_byte_tensors = self.load_balanced_all_to_all(attn_fp8_byte_tensors, seq_p_group)
gathered_attn_shards = []
for i in range(world_size):
attn_fp8_byte_tensor = gathered_attn_fp8_byte_tensors[i]
attn_fp8 = attn_fp8_byte_tensor[:attn_fp8_bytes].view(attn_fp8_dtype).reshape(-1, hidden_dims)
attn_scale = attn_fp8_byte_tensor[attn_fp8_bytes:].view(attn_scale_dtype).reshape(-1, 1)
attn_shards_new = dequant_fp8_vllm(attn_fp8, attn_scale, attn_dtype).reshape(-1, shard_heads, hidden_dims)
gathered_attn_shards.append(attn_shards_new)
else:
gathered_attn_shards = self.load_balanced_all_to_all(attn_shards, seq_p_group)
# 拼接所有分片 (在头维度上)
img_attn = torch.cat(gathered_attn_shards, dim=1)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
return img_attn
import torch
import torch._dynamo as dynamo
import torch.distributed as dist
@dynamo.disable
def all2all_seq2head(input, group=None):
"""
将输入张量从 [seq_len/N, heads, hidden_dims] 转换为 [seq_len, heads/N, hidden_dims] 的格式。
参数:
input (torch.Tensor): 输入张量,形状为 [seq_len/N, heads, hidden_dims]
返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len, heads/N, hidden_dims]
"""
# 确保输入是一个3D张量
assert input.dim() == 3, f"input must be 3D tensor"
# 获取当前进程的世界大小
world_size = dist.get_world_size(group=group)
# 获取输入张量的形状
shard_seq_len, heads, hidden_dims = input.shape
seq_len = shard_seq_len * world_size # 计算总序列长度
shard_heads = heads // world_size # 计算每个进程处理的头数
# 重塑输入张量以便进行 all-to-all 操作
input_t = (
input.reshape(shard_seq_len, world_size, shard_heads, hidden_dims) # 重塑为 [shard_seq_len, world_size, shard_heads, hidden_dims]
.transpose(0, 1) # 转置以便进行 all-to-all 操作
.contiguous() # 确保内存连续
)
# 创建一个与输入张量相同形状的输出张量
output = torch.empty_like(input_t)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist.all_to_all_single(output, input_t, group=group)
# 重塑输出张量为 [seq_len, heads/N, hidden_dims] 形状
output = output.reshape(seq_len, shard_heads, hidden_dims).contiguous()
return output # 返回转换后的输出张量
@dynamo.disable
def all2all_head2seq(input, group=None):
"""
将输入张量从 [seq_len, heads/N, hidden_dims] 转换为 [seq_len/N, heads, hidden_dims] 的格式。
参数:
input (torch.Tensor): 输入张量,形状为 [seq_len, heads/N, hidden_dims]
返回:
torch.Tensor: 转换后的输出张量,形状为 [seq_len/N, heads, hidden_dims]
"""
# 确保输入是一个3D张量
assert input.dim() == 3, f"input must be 3D tensor"
# 获取当前进程的世界大小
world_size = dist.get_world_size(group=group)
# 获取输入张量的形状
seq_len, shard_heads, hidden_dims = input.shape
heads = shard_heads * world_size # 计算总头数
shard_seq_len = seq_len // world_size # 计算每个进程处理的序列长度
# 重塑输入张量以便进行 all-to-all 操作
input_t = (
input.reshape(world_size, shard_seq_len, shard_heads, hidden_dims) # 重塑为 [world_size, shard_seq_len, shard_heads, hidden_dims]
.transpose(1, 2) # 转置以便进行 all-to-all 操作
.contiguous() # 确保内存连续
.reshape(world_size, shard_heads, shard_seq_len, hidden_dims) # 再次重塑为 [world_size, shard_heads, shard_seq_len, hidden_dims]
)
# 创建一个与输入张量相同形状的输出张量
output = torch.empty_like(input_t)
# 执行 all-to-all 操作,将输入张量的内容分发到所有进程
dist.all_to_all_single(output, input_t, group=group)
# 重塑输出张量为 [heads, shard_seq_len, hidden_dims] 形状
output = output.reshape(heads, shard_seq_len, hidden_dims)
# 转置输出张量并重塑为 [shard_seq_len, heads, hidden_dims] 形状
output = output.transpose(0, 1).contiguous().reshape(shard_seq_len, heads, hidden_dims)
return output # 返回转换后的输出张量
from typing import Optional
import torch
import torch.distributed as dist
class RingComm:
def __init__(self, process_group: dist.ProcessGroup = None):
self._process_group = process_group
self._ops = []
self.rank = dist.get_rank(self._process_group)
self.world_size = dist.get_world_size(self._process_group)
self._reqs = None
self.send_rank = (self.rank + 1) % self.world_size
self.recv_rank = (self.rank - 1) % self.world_size
if process_group is not None:
self.send_rank = dist.get_global_rank(self._process_group, self.send_rank)
self.recv_rank = dist.get_global_rank(self._process_group, self.recv_rank)
def send_recv(self, to_send: torch.Tensor, recv_tensor: Optional[torch.Tensor] = None) -> torch.Tensor:
if recv_tensor is None:
res = torch.empty_like(to_send)
# logger.info(f"send_recv: empty_like {to_send.shape}")
else:
res = recv_tensor
send_op = dist.P2POp(dist.isend, to_send, self.send_rank, group=self._process_group)
recv_op = dist.P2POp(dist.irecv, res, self.recv_rank, group=self._process_group)
self._ops.append(send_op)
self._ops.append(recv_op)
return res
def commit(self):
if self._reqs is not None:
raise RuntimeError("commit called twice")
self._reqs = dist.batch_isend_irecv(self._ops)
def wait(self):
if self._reqs is None:
raise RuntimeError("wait called before commit")
for req in self._reqs:
req.wait()
self._reqs = None
self._ops = []
from .conv2d import *
from .conv3d import *
from abc import ABCMeta, abstractmethod
import torch
from lightx2v.utils.registry_factory import CONV2D_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class Conv2dWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, stride, padding, dilation, groups):
self.weight_name = weight_name
self.bias_name = bias_name
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.config = {}
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
@CONV2D_WEIGHT_REGISTER("Default")
class Conv2dWeight(Conv2dWeightTemplate):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].to(AI_DEVICE)
self.bias = weight_dict[self.bias_name].to(AI_DEVICE) if self.bias_name is not None else None
def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv2d(input_tensor, weight=self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups)
return input_tensor
def to_cpu(self, non_blocking=False):
self.weight = self.weight.cpu(non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.cpu(non_blocking=non_blocking)
def to_cuda(self, non_blocking=False):
self.weight = self.weight.to(AI_DEVICE, non_blocking=non_blocking)
if self.bias is not None:
self.bias = self.bias.to(AI_DEVICE, non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.weight.cpu().detach().clone()
if self.bias is not None:
destination[self.bias_name] = self.bias.cpu().detach().clone()
return destination
from abc import ABCMeta, abstractmethod
import torch
from lightx2v.utils.registry_factory import CONV3D_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class Conv3dWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
self.weight_name = weight_name
self.bias_name = bias_name
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.config = {}
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
@CONV3D_WEIGHT_REGISTER("Default")
class Conv3dWeight(Conv3dWeightTemplate):
def __init__(self, weight_name, bias_name, stride=1, padding=0, dilation=1, groups=1):
super().__init__(weight_name, bias_name, stride, padding, dilation, groups)
def load(self, weight_dict):
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
self.pin_bias = None
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
def apply(self, input_tensor):
input_tensor = torch.nn.functional.conv3d(
input_tensor,
weight=self.weight,
bias=self.bias,
stride=self.stride,
padding=self.padding,
dilation=self.dilation,
groups=self.groups,
)
return input_tensor
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight # .cpu().detach().clone().contiguous()
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias # .cpu().detach().clone()
return destination
from .embedding_weight import *
import re
from abc import ABCMeta
import torch
import torch.nn.functional as F
from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
class EmbeddingWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.weight_name = weight_name
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.config = {}
def load(self, weight_dict):
if not self.lazy_load:
if self.create_cuda_buffer:
self.weight_cuda_buffer = weight_dict[self.weight_name].to(AI_DEVICE)
else:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
@EMBEDDING_WEIGHT_REGISTER("Default")
class EmbeddingWeight(EmbeddingWeightTemplate):
def __init__(self, weight_name=None, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, lazy_load, lazy_load_file)
def apply(self, input_indices):
output = F.embedding(input=input_indices, weight=self.weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False)
return output
from .mm_weight import *
import os
import re
from abc import ABCMeta, abstractmethod
from pathlib import Path
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.ggml_tensor import GGMLTensor
from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tensor
from lightx2v.utils.global_paras import CALIB
from lightx2v.utils.quant_utils import FloatQuantizer, IntegerQuantizer
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
try:
from lightx2v_kernel.gemm import (
cutlass_scaled_mxfp4_mm,
cutlass_scaled_mxfp6_mxfp8_mm,
cutlass_scaled_mxfp8_mm,
cutlass_scaled_nvfp4_mm,
scaled_mxfp4_quant,
scaled_mxfp6_quant,
scaled_mxfp8_quant,
scaled_nvfp4_quant,
)
except ImportError:
scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm = None, None
scaled_mxfp4_quant, cutlass_scaled_mxfp4_mm = None, None
scaled_mxfp6_quant, cutlass_scaled_mxfp6_mxfp8_mm = None, None
scaled_mxfp8_quant, cutlass_scaled_mxfp8_mm = None, None
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
try:
import sgl_kernel
except ImportError:
sgl_kernel = None
try:
from q8_kernels.functional.linear import q8_linear
except ImportError:
q8_linear = None
try:
from q8_kernels.functional.linear import fp8_linear
except ImportError:
fp8_linear = None
try:
import deep_gemm
except ImportError:
deep_gemm = None
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ImportError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try:
import gguf
except ImportError:
gguf = None
try:
import marlin_cuda_quant
except ImportError:
marlin_cuda_quant = None
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.weight_name = weight_name
self.bias_name = bias_name
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.config = {}
@abstractmethod
def load(self, weight_dict):
pass
@abstractmethod
def apply(self):
pass
def set_config(self, config={}):
self.config = config
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_weight_scale"):
self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if hasattr(self, "weight_scale_name"):
self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
@MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
def load(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffers(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffers()
else:
self._load_default_tensors(weight_dict)
def _get_source_tensor(self, source_name, weight_dict=None):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{source_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
return lazy_load_file.get_tensor(source_name)
return weight_dict[source_name]
def _create_pin_tensor(self, tensor, transpose=False):
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor = pin_tensor.copy_(tensor)
if transpose:
pin_tensor = pin_tensor.t()
del tensor
return pin_tensor
def _load_cuda_buffers(self, weight_dict):
self.weight_cuda_buffer = self._get_source_tensor(self.weight_name, weight_dict).t().to(AI_DEVICE)
if self.bias_name is not None:
self.bias_cuda_buffer = self._get_source_tensor(self.bias_name, weight_dict).to(AI_DEVICE)
def _load_cpu_pin_buffers(self):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias = self._create_pin_tensor(bias_tensor)
else:
self.bias = None
self.pin_bias = None
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_tensor = weight_dict[self.weight_name]
self.pin_weight = self._create_pin_tensor(weight_tensor, transpose=True)
if self.bias_name is not None:
bias_tensor = weight_dict[self.bias_name]
self.pin_bias = self._create_pin_tensor(bias_tensor)
else:
self.bias = None
self.pin_bias = None
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name].t()
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
if self.bias is None:
return torch.mm(input_tensor, self.weight, out=output_tensor)
return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
return destination
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else:
self.bias = None
@MM_WEIGHT_REGISTER("Default-Force-FP32")
class MMWeightForceFP32(MMWeight):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
def load(self, weight_dict):
if not self.lazy_load:
super().load(weight_dict)
self.weight = self.weight.to(torch.float32)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to(torch.float32)
class MMWeightQuantTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
self.load_func = None
self.weight_need_transpose = True
self.act_quant_func = None
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.infer_dtype = GET_DTYPE()
self.bias_force_fp32 = False
# =========================
# weight load functions
# =========================
def load(self, weight_dict):
self.load_quantized(weight_dict)
if self.weight_need_transpose:
if hasattr(self, "weight") and self.weight is not None:
self.weight = self.weight.t()
if hasattr(self, "pin_weight") and self.pin_weight is not None:
self.pin_weight = self.pin_weight.t()
if hasattr(self, "weight_cuda_buffer") and self.weight_cuda_buffer is not None:
self.weight_cuda_buffer = self.weight_cuda_buffer.t()
def load_quantized(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffers(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffers()
else:
self._load_default_tensors(weight_dict)
def _load_cuda_buffers(self, weight_dict):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
else:
source = weight_dict
self.weight_cuda_buffer, self.weight_scale_cuda_buffer = self._get_cuda_tensor_pair(source, self.lazy_load)
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, self.lazy_load)
def _get_cuda_tensor_pair(self, source, is_lazy):
if is_lazy:
weight = source.get_tensor(self.weight_name).to(AI_DEVICE)
scale = source.get_tensor(self.weight_scale_name).float().to(AI_DEVICE)
else:
weight = source[self.weight_name].to(AI_DEVICE)
scale = source[self.weight_scale_name].float().to(AI_DEVICE)
return weight, scale
def _get_cuda_bias_tensor(self, source, is_lazy):
if self.bias_name is None:
return None
if is_lazy:
bias = source.get_tensor(self.bias_name)
dtype = self.infer_dtype
else:
bias = source[self.bias_name]
dtype = bias.dtype
if self.bias_force_fp32:
bias = bias.to(torch.float32)
else:
bias = bias.to(dtype)
return bias.to(AI_DEVICE)
def _load_cpu_pin_buffers(self):
self.pin_weight, self.pin_weight_scale = self._get_cpu_pin_tensor_pair(self.lazy_load_file, is_lazy=True)
self.pin_bias = self._get_cpu_pin_bias_tensor(self.lazy_load_file, is_lazy=True)
self.bias = None
def _get_cpu_pin_tensor_pair(self, source, is_lazy):
if is_lazy:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
weight_tensor = source.get_tensor(self.weight_name)
scale_tensor = source.get_tensor(self.weight_scale_name)
scale_dtype = torch.float
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
else:
weight_tensor = source[self.weight_name]
scale_tensor = source[self.weight_scale_name]
scale_dtype = torch.float
pin_weight = self._create_pin_tensor(weight_tensor)
pin_scale = self._create_pin_tensor(scale_tensor, scale_dtype)
return pin_weight, pin_scale
def _get_cpu_pin_bias_tensor(self, source, is_lazy):
if self.bias_name is None:
return None
if is_lazy:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as source:
bias_tensor = source.get_tensor(self.bias_name)
if not self.bias_force_fp32:
bias_tensor = bias_tensor.to(self.infer_dtype)
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
else:
bias_tensor = source[self.bias_name]
if self.bias_force_fp32:
bias_tensor = bias_tensor.to(torch.float32)
return self._create_pin_tensor(bias_tensor)
def _create_pin_tensor(self, tensor, dtype=None):
dtype = dtype or tensor.dtype
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
self.weight, self.weight_scale, self.pin_weight, self.pin_weight_scale = self._get_device_tensor_pair(weight_dict)
self._load_default_bias(weight_dict)
else:
self.bias = None
self.pin_bias = None
def _get_device_tensor_pair(self, source):
device = source[self.weight_name].device
if device.type == "cpu":
pin_weight, pin_scale = self._get_cpu_pin_tensor_pair(source, is_lazy=False)
return None, None, pin_weight, pin_scale
else:
return source[self.weight_name], source[self.weight_scale_name].float(), None, None
def _load_default_bias(self, source):
if self.bias_name is None:
self.bias = None
self.pin_bias = None
self.bias_cuda_buffer = None
return
if self.create_cuda_buffer:
self.bias_cuda_buffer = self._get_cuda_bias_tensor(source, is_lazy=False)
self.bias = None
self.pin_bias = None
else:
bias_tensor = source[self.bias_name].float() if self.bias_force_fp32 else source[self.bias_name]
device = bias_tensor.device
if device.type == "cpu":
self.pin_bias = self._get_cpu_pin_bias_tensor(source, is_lazy=False)
self.bias = None
else:
self.bias = bias_tensor
self.pin_bias = None
def load_fp8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32)
w_quantizer = FloatQuantizer("e4m3", True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.float8_e4m3fn)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.load_quantized(weight_dict)
def load_int8_perchannel_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name].to(torch.float32)
w_quantizer = IntegerQuantizer(8, True, "per_channel")
self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight)
self.weight = self.weight.to(torch.int8)
self.weight_scale = self.weight_scale.to(torch.float32)
else:
self.load_quantized(weight_dict)
def load_mxfp4(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp4_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_mxfp6(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp6_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_mxfp8(self, weight_dict):
if self.config.get("weight_auto_quant", False):
device = weight_dict[self.weight_name].device
self.weight = weight_dict[self.weight_name].to(AI_DEVICE).to(torch.bfloat16)
self.weight, self.weight_scale = scaled_mxfp8_quant(self.weight)
self.weight, self.weight_scale = self.weight.to(device), self.weight_scale.to(device)
else:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
def load_nvfp4(self, weight_dict):
device = weight_dict[self.weight_name].device
input_absmax = weight_dict[self.weight_name.replace(".weight", ".input_absmax")]
input_global_scale = (2688.0 / input_absmax).to(torch.float32)
weight_global_scale = weight_dict[f"{self.weight_name}_global_scale"]
alpha = 1.0 / (input_global_scale * weight_global_scale)
if device.type == "cpu":
weight_shape = weight_dict[self.weight_name].shape
weight_dtype = weight_dict[self.weight_name].dtype
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
weight_scale_shape = weight_dict[self.weight_scale_name].shape
weight_scale_dtype = weight_dict[self.weight_scale_name].dtype
self.pin_weight_scale = torch.empty(weight_scale_shape, pin_memory=True, dtype=weight_scale_dtype)
self.pin_weight_scale.copy_(weight_dict[self.weight_scale_name])
input_global_scale_shape = input_global_scale.shape
input_global_scale_dtype = input_global_scale.dtype
self.pin_input_global_scale = torch.empty(input_global_scale_shape, pin_memory=True, dtype=input_global_scale_dtype)
self.pin_input_global_scale.copy_(input_global_scale)
alpha_shape = alpha.shape
alpha_dtype = alpha.dtype
self.pin_alpha = torch.empty(alpha_shape, pin_memory=True, dtype=alpha_dtype)
self.pin_alpha.copy_(alpha)
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
self.weight_scale = weight_dict[self.weight_scale_name]
self.input_global_scale = input_global_scale
self.alpha = alpha
if self.bias_name is not None:
if self.create_cuda_buffer:
self.bias_cuda_buffer = weight_dict[self.bias_name].to(AI_DEVICE)
else:
device = weight_dict[self.bias_name].device
if device.type == "cpu":
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.pin_bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = weight_dict[self.bias_name]
else:
self.bias = None
self.pin_bias = None
def load_fp8_perblock128_sym(self, weight_dict):
if self.config.get("weight_auto_quant", False):
self.weight = weight_dict[self.weight_name]
self.weight, self.weight_scale = self.per_block_cast_to_fp8(self.weight)
else:
self.load_quantized(weight_dict)
def per_block_cast_to_fp8(self, x):
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(deep_gemm.ceil_div(m, 128) * 128, deep_gemm.ceil_div(n, 128) * 128),
dtype=x.dtype,
device=x.device,
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
# =========================
# act quant kernels
# =========================
def act_quant_int8_perchannel_sym_torchao(self, x):
input_tensor_quant, input_tensor_scale = quantize_activation_per_token_absmax(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale = ops.scaled_fp8_quant(x, None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannel_sym_sgl(self, x):
m, k = x.shape
input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
input_tensor_scale = torch.empty((m, 1), dtype=torch.float32, device="cuda", requires_grad=False)
sgl_kernel.sgl_per_token_quant_fp8(x, input_tensor_quant, input_tensor_scale)
return input_tensor_quant, input_tensor_scale
def act_quant_int8_perchannel_sym_vllm(self, x):
input_tensor_quant, input_tensor_scale, _ = ops.scaled_int8_quant(x, scale=None, azp=None, symmetric=True)
return input_tensor_quant, input_tensor_scale
def act_quant_nvfp4(self, x):
input_tensor_quant, input_tensor_scale = scaled_nvfp4_quant(x, self.input_global_scale)
return input_tensor_quant, input_tensor_scale
def act_quant_mxfp4(self, x):
input_tensor_quant, input_tensor_scale = scaled_mxfp4_quant(x)
return input_tensor_quant, input_tensor_scale
def act_quant_mxfp8(self, x):
input_tensor_quant, input_tensor_scale = scaled_mxfp8_quant(x)
return input_tensor_quant, input_tensor_scale
def act_quant_fp8_perchannelgroup128_sym_deepgemm(self, x):
assert x.dim() == 2 and x.size(1) % 128 == 0
m, n = x.shape
x_view = x.view(m, -1, 128)
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, n), (x_amax / 448.0).view(m, -1)
def act_quant_fp8_perchannelgroup128_sym_sgl(self, x):
m, k = x.shape
input_tensor_quant = torch.empty((m, k), dtype=torch.float8_e4m3fn, device="cuda", requires_grad=False)
input_tensor_scale = torch.empty((m, k // 128), dtype=torch.float32, device="cuda", requires_grad=False)
sgl_kernel.sgl_per_token_group_quant_fp8(
x,
input_tensor_quant,
input_tensor_scale,
group_size=128,
eps=1e-10,
fp8_min=-448.0,
fp8_max=448.0,
)
return input_tensor_quant, input_tensor_scale
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
destination[self.weight_scale_name] = self.pin_weight_scale if hasattr(self, "pin_weight_scale") else self.weight_scale
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
self.weight_scale = self.weight_scale_cuda_buffer.copy_(destination[weight_scale_name], non_blocking=True)
if self.bias_name is not None:
bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else:
self.bias = None
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_scale_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
self.weight_scale_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_scale_name, count=1)
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
if self.weight_need_transpose:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).t()
else:
weight_tensor = lazy_load_file.get_tensor(self.weight_name)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
weight_scale_tensor = lazy_load_file.get_tensor(self.weight_scale_name)
self.pin_weight_scale = self.pin_weight_scale.copy_(weight_scale_tensor)
del weight_scale_tensor
if self.bias_name is not None:
bias_tensor = lazy_load_file.get_tensor(self.bias_name)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
@MM_WEIGHT_REGISTER("fp8-vllm")
class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: vllm
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.bias if self.bias is not None else None,
)
return output_tensor
@MM_WEIGHT_REGISTER("int8-vllm")
class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: vllm
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
torch.ops._C.cutlass_scaled_mm(
output_tensor,
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.bias if self.bias is not None else None,
)
return output_tensor
@MM_WEIGHT_REGISTER("mxfp4")
class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp4-A-mxfp4-dynamic
Quant MM:
Weight: mxfp4
Act: mxfp4
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_mxfp4
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp4
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("mxfp6-mxfp8")
class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp6-A-nvfp8-dynamic
Quant MM:
Weight: mxfp6
Act: mxfp8
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_mxfp6
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp8
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp6_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("mxfp8")
class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
"""
Name: W-mxfp8-A-nvfp8-dynamic
Quant MM:
Weight: mxfp8
Act: mxfp8
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_mxfp8
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_mxfp8
self.set_alpha()
def set_alpha(self):
self.alpha = torch.tensor(1.0, dtype=torch.float32)
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
self.alpha = self.alpha.to(self.weight.device)
output_tensor = cutlass_scaled_mxfp8_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("nvfp4")
class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
"""
Name: W-nvfp4-A-nvfp4-dynamic
Quant MM:
Weight: nvfp4
Act: nvfp4
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_nvfp4
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_nvfp4
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_weight_scale"):
self.weight_scale = self.pin_weight_scale.to(AI_DEVICE, non_blocking=non_blocking)
self.input_global_scale = self.pin_input_global_scale.to(AI_DEVICE, non_blocking=non_blocking)
self.alpha = self.pin_alpha.to(AI_DEVICE, non_blocking=non_blocking)
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if hasattr(self, "weight_scale_name"):
self.weight_scale = self.pin_weight_scale.copy_(self.weight_scale, non_blocking=non_blocking).cpu()
self.input_global_scale = self.pin_input_global_scale.copy_(self.input_global_scale, non_blocking=non_blocking).cpu()
self.alpha = self.pin_alpha.copy_(self.alpha, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "weight_scale"):
self.weight_scale = self.weight_scale.to("cpu", non_blocking=non_blocking)
self.input_global_scale = self.input_global_scale.to("cpu", non_blocking=non_blocking)
self.alpha = self.alpha.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
@MM_WEIGHT_REGISTER("Calib")
class MMCalibNvfp4(MMWeight):
"""
Name: calib
Calib:
absmax: torch.max(torch.abs(input_tensor))
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.running_absmax = None
self.count = 0
self.decay = 0.9
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype, device = input_tensor.dtype, input_tensor.device
current_absmax = torch.max(torch.abs(input_tensor)).to("cpu")
if self.count % 2 == 0:
if self.running_absmax is None:
self.running_absmax = current_absmax
else:
self.running_absmax = self.decay * self.running_absmax + (1 - self.decay) * current_absmax
CALIB["absmax"][self.weight_name] = self.running_absmax
self.count = self.count + 1
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
if self.bias is None:
return torch.mm(input_tensor, self.weight, out=output_tensor)
return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor)
@MM_WEIGHT_REGISTER("fp8-q8f")
class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Q8F
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_fp8_perchannel_sym_vllm
self.bias_force_fp32 = True
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = fp8_linear(
input_tensor_quant,
self.weight,
self.bias.float() if self.bias is not None else None,
input_tensor_scale,
self.weight_scale,
out_dtype=self.infer_dtype,
)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("int8-q8f")
class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Q8F
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = q8_linear(
input_tensor_quant,
self.weight,
self.bias.float() if self.bias is not None else None,
input_tensor_scale,
self.weight_scale,
fuse_gelu=False,
out_dtype=self.infer_dtype,
)
return output_tensor.squeeze(0) if len(output_tensor.shape) == 3 else output_tensor
@MM_WEIGHT_REGISTER("fp8-b128-deepgemm")
class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 pertoken-pergroup group=128 dynamic sym
Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_fp8_perblock128_sym
self.weight_need_transpose = False
self.act_quant_func = self.act_quant_fp8_perchannelgroup128_sym_sgl
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[0])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
deep_gemm.gemm_fp8_fp8_bf16_nt(
(input_tensor_quant, input_tensor_scale),
(self.weight, self.weight_scale),
output_tensor,
)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
@MM_WEIGHT_REGISTER("fp8-sgl")
class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Sgl-kernel
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_fp8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_fp8_perchannel_sym_sgl
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.fp8_scaled_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.infer_dtype,
self.bias if self.bias is not None else None,
)
return output_tensor
@MM_WEIGHT_REGISTER("int8-sgl")
class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_vllm
def apply(self, input_tensor):
shape = (input_tensor.shape[0], self.weight.shape[1])
dtype = input_tensor.dtype
device = input_tensor.device
output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False)
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = sgl_kernel.int8_scaled_mm(
input_tensor_quant,
self.weight,
input_tensor_scale,
self.weight_scale,
self.infer_dtype,
self.bias if self.bias is not None else None,
)
return output_tensor
@MM_WEIGHT_REGISTER("int8-torchao")
class MMWeightWint8channelAint8channeldynamicTorchao(MMWeightQuantTemplate):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: Torchao
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_int8_perchannel_sym
self.weight_need_transpose = True
self.act_quant_func = self.act_quant_int8_perchannel_sym_torchao
def apply(self, input_tensor):
input_tensor = input_tensor
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = quant_int8_per_token_matmul(input_tensor_quant, input_tensor_scale, self.weight, self.weight_scale.t().float(), output_dtype=self.infer_dtype)
if self.bias is not None:
output_tensor = output_tensor + self.bias
return output_tensor
class MMWeightGGUFTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
def load(self, weight_dict):
if not self.lazy_load:
assert not self.create_cuda_buffer, "GGUF Unsupported offload block"
self.weight = weight_dict[self.weight_name]
weight_shape = self.weight.shape
weight_dtype = self.weight.dtype
if isinstance(self.weight, GGMLTensor):
self.pin_weight = GGMLTensor.empty_pinned(weight_shape, orig_shape=self.weight.orig_shape, dtype=weight_dtype, gguf_type=self.weight.gguf_type)
self.pin_weight.copy_from(self.weight)
else:
self.pin_weight = torch.empty(weight_shape, pin_memory=True, dtype=weight_dtype)
self.pin_weight.copy_(weight_dict[self.weight_name])
if self.bias_name is not None:
self.bias = weight_dict[self.bias_name]
if isinstance(self.bias, GGMLTensor):
self.pin_bias = GGMLTensor.empty_pinned(self.bias.shape, orig_shape=self.bias.orig_shape, dtype=self.bias.dtype, gguf_type=self.bias.gguf_type)
self.pin_bias.copy_from(self.bias)
else:
self.pin_bias = torch.empty(self.bias.shape, pin_memory=True, dtype=self.bias.dtype)
self.pin_bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
if self.bias_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else:
self.bias = None
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
return destination
def get_weight(self, tensor, dtype):
if tensor is None:
return
weight = gguf_dequantize_tensor(tensor, dtype)
if isinstance(weight, GGMLTensor):
weight = torch.Tensor(weight)
return weight
def cast_bias_weight(self, input_tensor=None, dtype=None, device=None, bias_dtype=None):
if input_tensor is not None:
if dtype is None:
dtype = getattr(input_tensor, "dtype", torch.float32)
bias = None
if self.bias is not None:
bias = self.get_weight(self.bias, dtype)
weight = self.get_weight(self.weight, dtype)
return weight, bias
def apply(self, input_tensor):
weight, bias = self.cast_bias_weight(input_tensor)
return torch.nn.functional.linear(input_tensor, weight, bias)
@MM_WEIGHT_REGISTER("gguf-BF16")
class MMWeightGGUFBF16(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.BF16
@MM_WEIGHT_REGISTER("gguf-Q8_0")
class MMWeightGGUFQ80(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q8_0
@MM_WEIGHT_REGISTER("gguf-Q6_K")
class MMWeightGGUFQ6K(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_K_S")
class MMWeightGGUFQ5KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_K_M")
class MMWeightGGUFQ5KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q6_K
@MM_WEIGHT_REGISTER("gguf-Q5_1")
class MMWeightGGUFQ51(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_1
@MM_WEIGHT_REGISTER("gguf-Q5_0")
class MMWeightGGUFQ50(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_0
@MM_WEIGHT_REGISTER("gguf-Q4_K_M")
class MMWeightGGUFQ4KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q5_0
@MM_WEIGHT_REGISTER("gguf-Q4_K_S")
class MMWeightGGUFQ4KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_K
@MM_WEIGHT_REGISTER("gguf-Q4_1")
class MMWeightGGUFQ41(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_1
@MM_WEIGHT_REGISTER("gguf-Q4_0")
class MMWeightGGUFQ40(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q4_0
@MM_WEIGHT_REGISTER("gguf-Q3_K_M")
class MMWeightGGUFQ3KM(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q3_K
@MM_WEIGHT_REGISTER("gguf-Q3_K_S")
class MMWeightGGUFQ3KS(MMWeightGGUFTemplate):
qtype = gguf.GGMLQuantizationType.Q2_K
@MM_WEIGHT_REGISTER("int4-g128-marlin")
class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
"""
Name: "W-int4-group128-sym-Marlin
Quant int4 x FP16:
Weight: int4 pergroup sym
Kernel: Marlin
"""
def __init__(self, weight_name, bias_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter)
self.load_func = self.load_quantized
def load(self, weight_dict):
assert not self.lazy_load
self.load_func(weight_dict)
self.workspace = weight_dict[f"{self.weight_name}_workspace"]
if self.bias_name is not None:
bias_shape = weight_dict[self.bias_name].shape
bias_dtype = weight_dict[self.bias_name].dtype
self.bias = torch.empty(bias_shape, pin_memory=True, dtype=bias_dtype)
self.bias.copy_(weight_dict[self.bias_name])
else:
self.bias = None
def apply(self, input_tensor):
output_tensor = torch.empty(input_tensor.shape[:-1] + (self.weight_scale.shape[1],), dtype=input_tensor.dtype, device=input_tensor.device)
marlin_cuda_quant.mul(input_tensor, self.weight, output_tensor, self.weight_scale.half(), self.workspace, -1, -1, -1, -1)
if hasattr(self, "bias") and self.bias is not None:
output_tensor.add_(self.bias)
return output_tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment