Unverified Commit 1b144016 authored by STwangyingrui's avatar STwangyingrui Committed by GitHub
Browse files
parent 9b13cab2
{
"infer_steps": 50,
"transformer_model_name": "480p_i2v",
"fps": 24,
"target_video_length": 121,
"vae_stride": [4, 16, 16],
"sample_shift": 5.0,
"sample_guide_scale": 6.0,
"enable_cfg": false,
"attn_type": "sage_attn3",
"vae_cpu_offload": false,
"byt5_cpu_offload": false,
"qwen25vl_cpu_offload": true,
"siglip_cpu_offload": false,
"dit_quantized_ckpt": "/path/to/quant_model.safetensors",
"dit_quantized": true,
"dit_quant_scheme": "int8-q8f",
"parallel": {
"seq_p_size": 8,
"seq_p_fp8_comm": true,
"seq_p_attn_type": "ulysses"
}
}
{
"infer_steps": 2,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn3",
"cross_attn_1_type": "sage_attn3",
"cross_attn_2_type": "sage_attn3",
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"use_31_block": false,
"cpu_offload": false,
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": false,
"t5_quantized": true,
"t5_quant_scheme": "int8-q8f",
"clip_cpu_offload": false,
"clip_quantized": false,
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"adapter_quantized": true,
"adapter_quant_scheme": "int8-q8f",
"vae_cpu_offload": false,
"use_tiling_vae": false,
"dit_quantized": true,
"dit_quant_scheme": "int8-q8f",
"resize_mode": "fixed_shape",
"fixed_shape": [
832,
480
],
"parallel": {
"seq_p_size": 8,
"seq_p_fp8_comm": true,
"seq_p_attn_type": "ulysses"
}
}
...@@ -12,8 +12,8 @@ class WeightAsyncStreamManager(object): ...@@ -12,8 +12,8 @@ class WeightAsyncStreamManager(object):
def __init__(self, offload_granularity): def __init__(self, offload_granularity):
self.offload_granularity = offload_granularity self.offload_granularity = offload_granularity
self.init_stream = torch.cuda.Stream(priority=0) self.init_stream = torch.cuda.Stream(priority=0)
self.cuda_load_stream = torch.cuda.Stream(priority=0) self.cuda_load_stream = torch.cuda.Stream(priority=1)
self.compute_stream = torch.cuda.Stream(priority=-1) self.compute_stream = torch.cuda.Stream(priority=1)
def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None): def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
if self.offload_granularity == "block": if self.offload_granularity == "block":
......
...@@ -41,7 +41,7 @@ class RingAttnWeight(AttnWeightTemplate): ...@@ -41,7 +41,7 @@ class RingAttnWeight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None): def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
""" """
执行 Ring 注意力机制,结合图像和文本的查询、键和值。 执行 Ring 注意力机制,结合图像和文本的查询、键和值。
...@@ -56,6 +56,8 @@ class RingAttnWeight(AttnWeightTemplate): ...@@ -56,6 +56,8 @@ class RingAttnWeight(AttnWeightTemplate):
返回: 返回:
torch.Tensor: 计算得到的注意力结果 torch.Tensor: 计算得到的注意力结果
""" """
assert not use_fp8_comm, "RingAttn can't support fp8 comm now."
# 获取当前进程的排名和全局进程数 # 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank(seq_p_group) cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group) world_size = dist.get_world_size(seq_p_group)
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from lightx2v.utils.quant_utils import dequant_fp8_vllm, quant_fp8_vllm
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
from .template import AttnWeightTemplate from .template import AttnWeightTemplate
...@@ -12,7 +13,7 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -12,7 +13,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None): def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
""" """
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
...@@ -55,6 +56,22 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -55,6 +56,22 @@ class UlyssesAttnWeight(AttnWeightTemplate):
txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous() txt_q, txt_k, txt_v = q[img_qkv_len:, :, :].contiguous(), k[img_qkv_len:, :, :].contiguous(), v[img_qkv_len:, :, :].contiguous()
# 将图像的查询、键和值转换为头的格式 # 将图像的查询、键和值转换为头的格式
if use_fp8_comm:
original_dtype = img_q.dtype
original_shape = img_q.shape
img_q_fp8, q_scale = quant_fp8_vllm(img_q.reshape(-1, original_shape[-1]))
img_k_fp8, k_scale = quant_fp8_vllm(img_k.reshape(-1, original_shape[-1]))
img_v_fp8, v_scale = quant_fp8_vllm(img_v.reshape(-1, original_shape[-1]))
img_q_fp8 = all2all_seq2head(img_q_fp8.reshape(original_shape), group=seq_p_group)
img_k_fp8 = all2all_seq2head(img_k_fp8.reshape(original_shape), group=seq_p_group)
img_v_fp8 = all2all_seq2head(img_v_fp8.reshape(original_shape), group=seq_p_group)
q_scale = all2all_seq2head(q_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
k_scale = all2all_seq2head(k_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
v_scale = all2all_seq2head(v_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
img_q = dequant_fp8_vllm(img_q_fp8, q_scale, original_dtype)
img_k = dequant_fp8_vllm(img_k_fp8, k_scale, original_dtype)
img_v = dequant_fp8_vllm(img_v_fp8, v_scale, original_dtype)
else:
img_q = all2all_seq2head(img_q, group=seq_p_group) img_q = all2all_seq2head(img_q, group=seq_p_group)
img_k = all2all_seq2head(img_k, group=seq_p_group) img_k = all2all_seq2head(img_k, group=seq_p_group)
img_v = all2all_seq2head(img_v, group=seq_p_group) img_v = all2all_seq2head(img_v, group=seq_p_group)
...@@ -91,7 +108,7 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -91,7 +108,7 @@ class UlyssesAttnWeight(AttnWeightTemplate):
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)] gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group) img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果 txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
...@@ -101,9 +118,20 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -101,9 +118,20 @@ class UlyssesAttnWeight(AttnWeightTemplate):
return attn # 返回最终的注意力结果 return attn # 返回最终的注意力结果
@torch.compiler.disable @torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group): def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果 img_attn = img_attn.reshape(world_size * shard_seqlen, shard_heads, hidden_dims) # 重塑图像注意力结果
img_attn = all2all_head2seq(img_attn, group=seq_p_group) # 将头的格式转换回序列格式
# 将头的格式转换回序列格式
if use_fp8_comm:
original_dtype = img_attn.dtype
original_shape = img_attn.shape
img_attn_fp8, attn_scale = quant_fp8_vllm(img_attn.reshape(-1, original_shape[-1]))
img_attn_fp8 = all2all_head2seq(img_attn_fp8.reshape(original_shape), group=seq_p_group)
attn_scale = all2all_head2seq(attn_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
img_attn = dequant_fp8_vllm(img_attn_fp8, attn_scale, original_dtype)
else:
img_attn = all2all_head2seq(img_attn, group=seq_p_group)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状 img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
self.device_synchronize() # 确保CUDA操作完成 self.device_synchronize() # 确保CUDA操作完成
return img_attn return img_attn
...@@ -112,7 +140,8 @@ class UlyssesAttnWeight(AttnWeightTemplate): ...@@ -112,7 +140,8 @@ class UlyssesAttnWeight(AttnWeightTemplate):
self, self,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() # no need to sync between comm and comp
# torch.cuda.synchronize()
self.config["run_device"] = "cuda" self.config["run_device"] = "cuda"
elif hasattr(torch, "mlu") and torch.mlu.is_available(): elif hasattr(torch, "mlu") and torch.mlu.is_available():
torch.mlu.synchronize() torch.mlu.synchronize()
...@@ -127,7 +156,7 @@ class Ulysses4090AttnWeight(AttnWeightTemplate): ...@@ -127,7 +156,7 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None): def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
""" """
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
...@@ -180,6 +209,107 @@ class Ulysses4090AttnWeight(AttnWeightTemplate): ...@@ -180,6 +209,107 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
k_shards = [img_k[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)] k_shards = [img_k[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
v_shards = [img_v[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)] v_shards = [img_v[:, i * shard_heads : (i + 1) * shard_heads, :].contiguous() for i in range(world_size)]
if use_fp8_comm:
original_dtype = img_q.dtype
original_shape = img_q.shape
# 量化所有分片
q_shards_fp8 = []
q_scales = []
k_shards_fp8 = []
k_scales = []
v_shards_fp8 = []
v_scales = []
for i in range(world_size):
q_fp8, q_scale = quant_fp8_vllm(q_shards[i].reshape(-1, original_shape[-1]))
q_shards_fp8.append(q_fp8)
q_scales.append(q_scale)
k_fp8, k_scale = quant_fp8_vllm(k_shards[i].reshape(-1, original_shape[-1]))
k_shards_fp8.append(k_fp8)
k_scales.append(k_scale)
v_fp8, v_scale = quant_fp8_vllm(v_shards[i].reshape(-1, original_shape[-1]))
v_shards_fp8.append(v_fp8)
v_scales.append(v_scale)
# 准备接收缓冲区 (FP8 + scale)
gathered_q_fp8 = [None] * world_size
gathered_q_scales = [None] * world_size
gathered_k_fp8 = [None] * world_size
gathered_k_scales = [None] * world_size
gathered_v_fp8 = [None] * world_size
gathered_v_scales = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_q_fp8[target_rank] = torch.empty_like(q_shards_fp8[target_rank])
gathered_q_scales[target_rank] = torch.empty_like(q_scales[target_rank])
gathered_k_fp8[target_rank] = torch.empty_like(k_shards_fp8[target_rank])
gathered_k_scales[target_rank] = torch.empty_like(k_scales[target_rank])
gathered_v_fp8[target_rank] = torch.empty_like(v_shards_fp8[target_rank])
gathered_v_scales[target_rank] = torch.empty_like(v_scales[target_rank])
else:
gathered_q_fp8[cur_rank] = q_shards_fp8[cur_rank]
gathered_q_scales[cur_rank] = q_scales[cur_rank]
gathered_k_fp8[cur_rank] = k_shards_fp8[cur_rank]
gathered_k_scales[cur_rank] = k_scales[cur_rank]
gathered_v_fp8[cur_rank] = v_shards_fp8[cur_rank]
gathered_v_scales[cur_rank] = v_scales[cur_rank]
# 通信 FP8 数据和 scales
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
if cur_rank < target_rank:
sendq_fp8_req = dist.isend(q_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_fp8_req = dist.isend(k_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_fp8_req = dist.isend(v_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_scale_req = dist.isend(q_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_scale_req = dist.isend(k_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_scale_req = dist.isend(v_scales[target_rank], dst=target_global_rank, group=seq_p_group)
recvq_fp8_req = dist.irecv(gathered_q_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvk_fp8_req = dist.irecv(gathered_k_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvv_fp8_req = dist.irecv(gathered_v_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvq_scale_req = dist.irecv(gathered_q_scales[target_rank], src=target_global_rank, group=seq_p_group)
recvk_scale_req = dist.irecv(gathered_k_scales[target_rank], src=target_global_rank, group=seq_p_group)
recvv_scale_req = dist.irecv(gathered_v_scales[target_rank], src=target_global_rank, group=seq_p_group)
else:
recvq_fp8_req = dist.irecv(gathered_q_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvk_fp8_req = dist.irecv(gathered_k_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvv_fp8_req = dist.irecv(gathered_v_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recvq_scale_req = dist.irecv(gathered_q_scales[target_rank], src=target_global_rank, group=seq_p_group)
recvk_scale_req = dist.irecv(gathered_k_scales[target_rank], src=target_global_rank, group=seq_p_group)
recvv_scale_req = dist.irecv(gathered_v_scales[target_rank], src=target_global_rank, group=seq_p_group)
sendq_fp8_req = dist.isend(q_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_fp8_req = dist.isend(k_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_fp8_req = dist.isend(v_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_scale_req = dist.isend(q_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_scale_req = dist.isend(k_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendv_scale_req = dist.isend(v_scales[target_rank], dst=target_global_rank, group=seq_p_group)
sendq_fp8_req.wait()
sendk_fp8_req.wait()
sendv_fp8_req.wait()
sendq_scale_req.wait()
sendk_scale_req.wait()
sendv_scale_req.wait()
recvq_fp8_req.wait()
recvk_fp8_req.wait()
recvv_fp8_req.wait()
recvq_scale_req.wait()
recvk_scale_req.wait()
recvv_scale_req.wait()
# 反量化
gathered_q_shards = []
gathered_k_shards = []
gathered_v_shards = []
for i in range(world_size):
q_shards_new = dequant_fp8_vllm(gathered_q_fp8[i], gathered_q_scales[i], original_dtype).reshape(-1, shard_heads, hidden_dims)
k_shards_new = dequant_fp8_vllm(gathered_k_fp8[i], gathered_k_scales[i], original_dtype).reshape(-1, shard_heads, hidden_dims)
v_shards_new = dequant_fp8_vllm(gathered_v_fp8[i], gathered_v_scales[i], original_dtype).reshape(-1, shard_heads, hidden_dims)
gathered_q_shards.append(q_shards_new)
gathered_k_shards.append(k_shards_new)
gathered_v_shards.append(v_shards_new)
else:
# 准备接收缓冲区 # 准备接收缓冲区
gathered_q_shards = [None] * world_size gathered_q_shards = [None] * world_size
gathered_k_shards = [None] * world_size gathered_k_shards = [None] * world_size
...@@ -194,11 +324,10 @@ class Ulysses4090AttnWeight(AttnWeightTemplate): ...@@ -194,11 +324,10 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
gathered_k_shards[cur_rank] = k_shards[cur_rank] gathered_k_shards[cur_rank] = k_shards[cur_rank]
gathered_v_shards[cur_rank] = v_shards[cur_rank] gathered_v_shards[cur_rank] = v_shards[cur_rank]
# 异步发起通信后同步 # 通信
for target_rank in range(world_size): for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank: if target_rank != cur_rank:
# 避免死锁: 按 rank 顺序决定发送/接收顺序
if cur_rank < target_rank: if cur_rank < target_rank:
sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group) sendq_req = dist.isend(q_shards[target_rank], dst=target_global_rank, group=seq_p_group)
sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group) sendk_req = dist.isend(k_shards[target_rank], dst=target_global_rank, group=seq_p_group)
...@@ -258,7 +387,7 @@ class Ulysses4090AttnWeight(AttnWeightTemplate): ...@@ -258,7 +387,7 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)] gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group) img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果 txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
...@@ -268,7 +397,7 @@ class Ulysses4090AttnWeight(AttnWeightTemplate): ...@@ -268,7 +397,7 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
return attn # 返回最终的注意力结果 return attn # 返回最终的注意力结果
@torch.compiler.disable @torch.compiler.disable
def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group): def _reshape_img_attn(self, img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm):
cur_rank = dist.get_rank(seq_p_group) cur_rank = dist.get_rank(seq_p_group)
global_world_size = dist.get_world_size() global_world_size = dist.get_world_size()
global_rank = dist.get_global_rank(seq_p_group, cur_rank) global_rank = dist.get_global_rank(seq_p_group, cur_rank)
...@@ -278,6 +407,55 @@ class Ulysses4090AttnWeight(AttnWeightTemplate): ...@@ -278,6 +407,55 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
# 按序列维度切分成 N 份 # 按序列维度切分成 N 份
attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)] attn_shards = [img_attn[i * shard_seqlen : (i + 1) * shard_seqlen, :, :].contiguous() for i in range(world_size)]
if use_fp8_comm:
original_dtype = img_attn.dtype
original_shape = img_attn.shape
# 量化所有分片
attn_shards_fp8 = []
attn_scales = []
for i in range(world_size):
attn_fp8, attn_scale = quant_fp8_vllm(attn_shards[i].reshape(-1, original_shape[-1]))
attn_shards_fp8.append(attn_fp8)
attn_scales.append(attn_scale)
# 准备接收缓冲区 (FP8 + scale)
gathered_attn_fp8 = [None] * world_size
gathered_attn_scales = [None] * world_size
for target_rank in range(world_size):
if target_rank != cur_rank:
gathered_attn_fp8[target_rank] = torch.empty_like(attn_shards_fp8[target_rank])
gathered_attn_scales[target_rank] = torch.empty_like(attn_scales[target_rank])
else:
gathered_attn_fp8[cur_rank] = attn_shards_fp8[cur_rank]
gathered_attn_scales[cur_rank] = attn_scales[cur_rank]
# 通信 FP8 数据和 scales
for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank:
if cur_rank < target_rank:
send_attn_fp8_req = dist.isend(attn_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
send_attn_scale_req = dist.isend(attn_scales[target_rank], dst=target_global_rank, group=seq_p_group)
recv_attn_fp8_req = dist.irecv(gathered_attn_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recv_attn_scale_req = dist.irecv(gathered_attn_scales[target_rank], src=target_global_rank, group=seq_p_group)
else:
recv_attn_fp8_req = dist.irecv(gathered_attn_fp8[target_rank], src=target_global_rank, group=seq_p_group)
recv_attn_scale_req = dist.irecv(gathered_attn_scales[target_rank], src=target_global_rank, group=seq_p_group)
send_attn_fp8_req = dist.isend(attn_shards_fp8[target_rank], dst=target_global_rank, group=seq_p_group)
send_attn_scale_req = dist.isend(attn_scales[target_rank], dst=target_global_rank, group=seq_p_group)
send_attn_fp8_req.wait()
send_attn_scale_req.wait()
recv_attn_fp8_req.wait()
recv_attn_scale_req.wait()
# 反量化
gathered_attn_shards = []
for i in range(world_size):
attn_shards_new = dequant_fp8_vllm(gathered_attn_fp8[i], gathered_attn_scales[i], original_dtype).reshape(-1, shard_heads, hidden_dims)
gathered_attn_shards.append(attn_shards_new)
else:
# 准备接收缓冲区 # 准备接收缓冲区
gathered_attn_shards = [None] * world_size gathered_attn_shards = [None] * world_size
for target_rank in range(world_size): for target_rank in range(world_size):
...@@ -286,11 +464,10 @@ class Ulysses4090AttnWeight(AttnWeightTemplate): ...@@ -286,11 +464,10 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
else: else:
gathered_attn_shards[cur_rank] = attn_shards[cur_rank] gathered_attn_shards[cur_rank] = attn_shards[cur_rank]
# 异步发起通信后同步 # 通信
for target_rank in range(world_size): for target_rank in range(world_size):
target_global_rank = cfg_p_group_index * world_size + target_rank target_global_rank = cfg_p_group_index * world_size + target_rank
if target_rank != cur_rank: if target_rank != cur_rank:
# 避免死锁: 按 rank 顺序决定发送/接收顺序
if cur_rank < target_rank: if cur_rank < target_rank:
send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group) send_req = dist.isend(attn_shards[target_rank], dst=target_global_rank, group=seq_p_group)
recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group) recv_req = dist.irecv(gathered_attn_shards[target_rank], src=target_global_rank, group=seq_p_group)
...@@ -304,5 +481,4 @@ class Ulysses4090AttnWeight(AttnWeightTemplate): ...@@ -304,5 +481,4 @@ class Ulysses4090AttnWeight(AttnWeightTemplate):
img_attn = torch.cat(gathered_attn_shards, dim=1) img_attn = torch.cat(gathered_attn_shards, dim=1)
img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状 img_attn = img_attn.reshape(shard_seqlen, -1) # 重塑为 [shard_seqlen, -1] 形状
torch.cuda.synchronize() # 确保CUDA操作完成
return img_attn return img_attn
...@@ -3,6 +3,7 @@ import argparse ...@@ -3,6 +3,7 @@ import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from torch.distributed import ProcessGroupNCCL
from lightx2v.common.ops import * from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
...@@ -102,7 +103,9 @@ def main(): ...@@ -102,7 +103,9 @@ def main():
if config["parallel"]: if config["parallel"]:
run_device = config.get("run_device", "cuda") run_device = config.get("run_device", "cuda")
if "cuda" in run_device: if "cuda" in run_device:
dist.init_process_group(backend="nccl") pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = True
dist.init_process_group(backend="nccl", pg_options=pg_options)
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
elif "mlu" in run_device: elif "mlu" in run_device:
dist.init_process_group(backend="cncl") dist.init_process_group(backend="cncl")
......
...@@ -103,8 +103,10 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer): ...@@ -103,8 +103,10 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self.device = torch.device(self.config.get("run_device", "cuda")) self.device = torch.device(self.config.get("run_device", "cuda"))
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
else: else:
self.seq_p_group = None self.seq_p_group = None
elf.seq_p_fp8_comm = False
self.infer_func = self.infer_without_offload self.infer_func = self.infer_without_offload
if self.config.get("modulate_type", "triton") == "triton": if self.config.get("modulate_type", "triton") == "triton":
self.modulate_func = fuse_scale_shift_kernel self.modulate_func = fuse_scale_shift_kernel
...@@ -231,6 +233,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer): ...@@ -231,6 +233,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv=cu_seqlens_qkv, cu_seqlens_qkv=cu_seqlens_qkv,
attention_module=weights.self_attention, attention_module=weights.self_attention,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
use_fp8_comm=self.seq_p_fp8_comm,
) )
else: else:
attn_out = weights.self_attention.apply( attn_out = weights.self_attention.apply(
......
...@@ -37,8 +37,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -37,8 +37,10 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
else: else:
self.seq_p_group = None self.seq_p_group = None
elf.seq_p_fp8_comm = False
self.infer_func = self.infer_without_offload self.infer_func = self.infer_without_offload
self.cos_sin = None self.cos_sin = None
...@@ -173,6 +175,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -173,6 +175,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv=cu_seqlens_qkv, cu_seqlens_qkv=cu_seqlens_qkv,
attention_module=phase.self_attn_1, attention_module=phase.self_attn_1,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
use_fp8_comm=self.seq_p_fp8_comm,
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
else: else:
......
...@@ -170,6 +170,22 @@ class FloatQuantizer(BaseQuantizer): ...@@ -170,6 +170,22 @@ class FloatQuantizer(BaseQuantizer):
return tensor return tensor
# 导入 VLLM 的量化函数
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
def quant_fp8_vllm(input_tensor):
input_tensor_fp8, input_tensor_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_fp8, input_tensor_scale
def dequant_fp8_vllm(input_tensor_fp8, input_tensor_scale, dtype):
return input_tensor_fp8.to(dtype) * input_tensor_scale.to(dtype)
if __name__ == "__main__": if __name__ == "__main__":
weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda() weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
quantizer = IntegerQuantizer(4, False, "per_group", group_size=128) quantizer = IntegerQuantizer(4, False, "per_group", group_size=128)
...@@ -194,3 +210,10 @@ if __name__ == "__main__": ...@@ -194,3 +210,10 @@ if __name__ == "__main__":
logger.info(f"realq_weight = {realq_weight}, {realq_weight.shape}") logger.info(f"realq_weight = {realq_weight}, {realq_weight.shape}")
logger.info(f"scales = {scales}, {scales.shape}") logger.info(f"scales = {scales}, {scales.shape}")
logger.info(f"zeros = {zeros}") logger.info(f"zeros = {zeros}")
input_tensor = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
input_tensor_fp8, input_tensor_scale = quant_fp8_vllm(input_tensor)
dequant_tensor = dequant_fp8_vllm(input_tensor_fp8, input_tensor_scale, input_tensor.dtype)
logger.info(input_tensor)
logger.info(dequant_tensor)
logger.info(f"cosine vllm fp8 quant/dequant = {torch.cosine_similarity(input_tensor.view(1, -1).to(torch.float64), dequant_tensor.view(1, -1).to(torch.float64))}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment