"server/text_generation_server/models/flash_dbrx.py" did not exist on "e71471bec95823ef69daaeb03c4657b9b5211a02"
Unverified Commit 1b144016 authored by STwangyingrui's avatar STwangyingrui Committed by GitHub
Browse files
parent 9b13cab2
{
"infer_steps": 50,
"transformer_model_name": "480p_i2v",
"fps": 24,
"target_video_length": 121,
"vae_stride": [4, 16, 16],
"sample_shift": 5.0,
"sample_guide_scale": 6.0,
"enable_cfg": false,
"attn_type": "sage_attn3",
"vae_cpu_offload": false,
"byt5_cpu_offload": false,
"qwen25vl_cpu_offload": true,
"siglip_cpu_offload": false,
"dit_quantized_ckpt": "/path/to/quant_model.safetensors",
"dit_quantized": true,
"dit_quant_scheme": "int8-q8f",
"parallel": {
"seq_p_size": 8,
"seq_p_fp8_comm": true,
"seq_p_attn_type": "ulysses"
}
}
{
"infer_steps": 2,
"target_fps": 16,
"video_duration": 5,
"audio_sr": 16000,
"target_video_length": 81,
"resize_mode": "adaptive",
"self_attn_1_type": "sage_attn3",
"cross_attn_1_type": "sage_attn3",
"cross_attn_2_type": "sage_attn3",
"sample_guide_scale": 1,
"sample_shift": 5,
"enable_cfg": false,
"use_31_block": false,
"cpu_offload": false,
"offload_granularity": "block",
"offload_ratio": 1,
"t5_cpu_offload": false,
"t5_quantized": true,
"t5_quant_scheme": "int8-q8f",
"clip_cpu_offload": false,
"clip_quantized": false,
"audio_encoder_cpu_offload": false,
"audio_adapter_cpu_offload": false,
"adapter_quantized": true,
"adapter_quant_scheme": "int8-q8f",
"vae_cpu_offload": false,
"use_tiling_vae": false,
"dit_quantized": true,
"dit_quant_scheme": "int8-q8f",
"resize_mode": "fixed_shape",
"fixed_shape": [
832,
480
],
"parallel": {
"seq_p_size": 8,
"seq_p_fp8_comm": true,
"seq_p_attn_type": "ulysses"
}
}
...@@ -12,8 +12,8 @@ class WeightAsyncStreamManager(object): ...@@ -12,8 +12,8 @@ class WeightAsyncStreamManager(object):
def __init__(self, offload_granularity): def __init__(self, offload_granularity):
self.offload_granularity = offload_granularity self.offload_granularity = offload_granularity
self.init_stream = torch.cuda.Stream(priority=0) self.init_stream = torch.cuda.Stream(priority=0)
self.cuda_load_stream = torch.cuda.Stream(priority=0) self.cuda_load_stream = torch.cuda.Stream(priority=1)
self.compute_stream = torch.cuda.Stream(priority=-1) self.compute_stream = torch.cuda.Stream(priority=1)
def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None): def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
if self.offload_granularity == "block": if self.offload_granularity == "block":
......
...@@ -41,7 +41,7 @@ class RingAttnWeight(AttnWeightTemplate): ...@@ -41,7 +41,7 @@ class RingAttnWeight(AttnWeightTemplate):
def __init__(self): def __init__(self):
self.config = {} self.config = {}
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None): def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
""" """
执行 Ring 注意力机制,结合图像和文本的查询、键和值。 执行 Ring 注意力机制,结合图像和文本的查询、键和值。
...@@ -56,6 +56,8 @@ class RingAttnWeight(AttnWeightTemplate): ...@@ -56,6 +56,8 @@ class RingAttnWeight(AttnWeightTemplate):
返回: 返回:
torch.Tensor: 计算得到的注意力结果 torch.Tensor: 计算得到的注意力结果
""" """
assert not use_fp8_comm, "RingAttn can't support fp8 comm now."
# 获取当前进程的排名和全局进程数 # 获取当前进程的排名和全局进程数
cur_rank = dist.get_rank(seq_p_group) cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group) world_size = dist.get_world_size(seq_p_group)
......
This diff is collapsed.
...@@ -3,6 +3,7 @@ import argparse ...@@ -3,6 +3,7 @@ import argparse
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from loguru import logger from loguru import logger
from torch.distributed import ProcessGroupNCCL
from lightx2v.common.ops import * from lightx2v.common.ops import *
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
...@@ -102,7 +103,9 @@ def main(): ...@@ -102,7 +103,9 @@ def main():
if config["parallel"]: if config["parallel"]:
run_device = config.get("run_device", "cuda") run_device = config.get("run_device", "cuda")
if "cuda" in run_device: if "cuda" in run_device:
dist.init_process_group(backend="nccl") pg_options = ProcessGroupNCCL.Options()
pg_options.is_high_priority_stream = True
dist.init_process_group(backend="nccl", pg_options=pg_options)
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
elif "mlu" in run_device: elif "mlu" in run_device:
dist.init_process_group(backend="cncl") dist.init_process_group(backend="cncl")
......
...@@ -103,8 +103,10 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer): ...@@ -103,8 +103,10 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
self.device = torch.device(self.config.get("run_device", "cuda")) self.device = torch.device(self.config.get("run_device", "cuda"))
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
else: else:
self.seq_p_group = None self.seq_p_group = None
elf.seq_p_fp8_comm = False
self.infer_func = self.infer_without_offload self.infer_func = self.infer_without_offload
if self.config.get("modulate_type", "triton") == "triton": if self.config.get("modulate_type", "triton") == "triton":
self.modulate_func = fuse_scale_shift_kernel self.modulate_func = fuse_scale_shift_kernel
...@@ -231,6 +233,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer): ...@@ -231,6 +233,7 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv=cu_seqlens_qkv, cu_seqlens_qkv=cu_seqlens_qkv,
attention_module=weights.self_attention, attention_module=weights.self_attention,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
use_fp8_comm=self.seq_p_fp8_comm,
) )
else: else:
attn_out = weights.self_attention.apply( attn_out = weights.self_attention.apply(
......
...@@ -37,8 +37,10 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -37,8 +37,10 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.config["seq_parallel"]: if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
else: else:
self.seq_p_group = None self.seq_p_group = None
elf.seq_p_fp8_comm = False
self.infer_func = self.infer_without_offload self.infer_func = self.infer_without_offload
self.cos_sin = None self.cos_sin = None
...@@ -173,6 +175,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -173,6 +175,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv=cu_seqlens_qkv, cu_seqlens_qkv=cu_seqlens_qkv,
attention_module=phase.self_attn_1, attention_module=phase.self_attn_1,
seq_p_group=self.seq_p_group, seq_p_group=self.seq_p_group,
use_fp8_comm=self.seq_p_fp8_comm,
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
else: else:
......
...@@ -170,6 +170,22 @@ class FloatQuantizer(BaseQuantizer): ...@@ -170,6 +170,22 @@ class FloatQuantizer(BaseQuantizer):
return tensor return tensor
# 导入 VLLM 的量化函数
try:
from vllm import _custom_ops as ops
except ImportError:
ops = None
def quant_fp8_vllm(input_tensor):
input_tensor_fp8, input_tensor_scale = ops.scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
return input_tensor_fp8, input_tensor_scale
def dequant_fp8_vllm(input_tensor_fp8, input_tensor_scale, dtype):
return input_tensor_fp8.to(dtype) * input_tensor_scale.to(dtype)
if __name__ == "__main__": if __name__ == "__main__":
weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda() weight = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
quantizer = IntegerQuantizer(4, False, "per_group", group_size=128) quantizer = IntegerQuantizer(4, False, "per_group", group_size=128)
...@@ -194,3 +210,10 @@ if __name__ == "__main__": ...@@ -194,3 +210,10 @@ if __name__ == "__main__":
logger.info(f"realq_weight = {realq_weight}, {realq_weight.shape}") logger.info(f"realq_weight = {realq_weight}, {realq_weight.shape}")
logger.info(f"scales = {scales}, {scales.shape}") logger.info(f"scales = {scales}, {scales.shape}")
logger.info(f"zeros = {zeros}") logger.info(f"zeros = {zeros}")
input_tensor = torch.randn(4096, 4096, dtype=torch.bfloat16).cuda()
input_tensor_fp8, input_tensor_scale = quant_fp8_vllm(input_tensor)
dequant_tensor = dequant_fp8_vllm(input_tensor_fp8, input_tensor_scale, input_tensor.dtype)
logger.info(input_tensor)
logger.info(dequant_tensor)
logger.info(f"cosine vllm fp8 quant/dequant = {torch.cosine_similarity(input_tensor.view(1, -1).to(torch.float64), dequant_tensor.view(1, -1).to(torch.float64))}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment