Commit a1ebc651 authored by xuwx1's avatar xuwx1
Browse files

updata lightx2v

parent 5a4db490
Pipeline #3149 canceled with stages
from .layer_norm_weight import *
from .rms_norm_weight import *
import os
import re
from abc import ABCMeta, abstractmethod
from pathlib import Path
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
from .triton_ops import norm_infer
class LNWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
self.weight_name = weight_name
self.bias_name = bias_name
self.eps = eps
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.config = {}
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffers(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffers()
else:
self._load_default_tensors(weight_dict)
def _load_default_tensors(self, weight_dict):
if not self.lazy_load and self.weight_name is not None:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_tensor = weight_dict[self.weight_name]
self.pin_weight = self._create_cpu_pin_tensor(weight_tensor)
bias_tensor = weight_dict[self.bias_name] if self.bias_name is not None else None
self.pin_bias = self._create_cpu_pin_tensor(bias_tensor) if bias_tensor is not None else None
self.bias = None
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
self.bias = weight_dict[self.bias_name] if self.bias_name is not None else None
else:
self.weight = None
self.bias = None
def _get_tensor(self, name, weight_dict=None, use_infer_dtype=False):
if name is None:
return None
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[name]
return tensor
def _create_cpu_pin_tensor(self, tensor):
if tensor is None:
return None
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_cuda_buffers(self, weight_dict):
weight_tensor = self._get_tensor(self.weight_name, weight_dict, use_infer_dtype=self.lazy_load)
if weight_tensor is not None:
self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE)
bias_tensor = self._get_tensor(self.bias_name, weight_dict, use_infer_dtype=self.lazy_load)
if bias_tensor is not None:
self.bias_cuda_buffer = bias_tensor.to(AI_DEVICE)
def _load_cpu_pin_buffers(self):
weight_tensor = self._get_tensor(self.weight_name, use_infer_dtype=True)
if weight_tensor is not None:
self.pin_weight = self._create_cpu_pin_tensor(weight_tensor)
else:
self.weight = None
bias_tensor = self._get_tensor(self.bias_name, use_infer_dtype=True)
if bias_tensor is not None:
self.pin_bias = self._create_cpu_pin_tensor(bias_tensor)
else:
self.bias = None
self.pin_bias = None
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cuda(self, non_blocking=False):
if hasattr(self, "pin_weight") and self.pin_weight is not None:
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
else:
self.weight = None
if hasattr(self, "pin_bias") and self.pin_bias is not None:
self.bias = self.pin_bias.to(AI_DEVICE, non_blocking=non_blocking)
else:
self.bias = None
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight") and self.pin_weight is not None:
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
if self.bias is not None:
self.bias = self.pin_bias.copy_(self.bias, non_blocking=non_blocking).cpu()
elif hasattr(self, "weight") and self.weight is not None:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
if hasattr(self, "bias") and self.bias is not None:
self.bias = self.bias.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
if self.weight_name is not None:
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
if self.bias_name is not None:
destination[self.bias_name] = self.pin_bias if hasattr(self, "pin_bias") else self.bias
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.weight_name is not None:
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
else:
self.weight = None
if self.bias_name is not None:
bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
self.bias = self.bias_cuda_buffer.copy_(destination[bias_name], non_blocking=True)
else:
self.bias = None
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.weight_name is not None:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
if self.bias_name is not None:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
if self.is_post_adapter:
assert adapter_block_index is not None
self.bias_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.bias_name, count=1)
else:
self.bias_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.bias_name, count=1)
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
bias_tensor = lazy_load_file.get_tensor(self.bias_name).to(self.infer_dtype)
self.pin_bias.copy_(bias_tensor)
del bias_tensor
@LN_WEIGHT_REGISTER("Default")
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor):
if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = torch.nn.functional.layer_norm(
input_tensor.float(),
(input_tensor.shape[-1],),
self.weight,
self.bias,
self.eps,
).to(self.infer_dtype)
else:
input_tensor = torch.nn.functional.layer_norm(input_tensor, (input_tensor.shape[-1],), self.weight, self.bias, self.eps)
return input_tensor
@LN_WEIGHT_REGISTER("Triton")
class LNWeight(LNWeightTemplate):
def __init__(self, weight_name=None, bias_name=None, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, bias_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor):
input_tensor = norm_infer(input_tensor, self.weight, self.bias, self.eps)
return input_tensor
import os
import re
from abc import ABCMeta, abstractmethod
from pathlib import Path
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RMS_WEIGHT_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
try:
import sgl_kernel
except ImportError:
sgl_kernel = None
class RMSWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
self.weight_name = weight_name
self.eps = eps
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
self.config = {}
def load(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffer(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffer()
else:
self._load_default_tensors(weight_dict)
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
device = weight_dict[self.weight_name].device
if device.type == "cpu":
weight_tensor = weight_dict[self.weight_name]
self.pin_weight = self._create_cpu_pin_weight(weight_tensor)
del weight_dict[self.weight_name]
else:
self.weight = weight_dict[self.weight_name]
def _get_weight_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.weight_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.weight_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[self.weight_name]
return tensor
def _create_cpu_pin_weight(self, tensor):
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_cuda_buffer(self, weight_dict):
weight_tensor = self._get_weight_tensor(weight_dict, use_infer_dtype=self.lazy_load)
self.weight_cuda_buffer = weight_tensor.to(AI_DEVICE)
def _load_cpu_pin_buffer(self):
weight_tensor = self._get_weight_tensor(use_infer_dtype=True)
self.pin_weight = self._create_cpu_pin_weight(weight_tensor)
@abstractmethod
def apply(self, input_tensor):
pass
def set_config(self, config=None):
if config is not None:
self.config = config
def to_cuda(self, non_blocking=False):
self.weight = self.pin_weight.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_weight"):
self.weight = self.pin_weight.copy_(self.weight, non_blocking=non_blocking).cpu()
else:
self.weight = self.weight.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.weight_name] = self.pin_weight if hasattr(self, "pin_weight") else self.weight
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if weight_name not in destination:
self.weight = None
return
self.weight = self.weight_cuda_buffer.copy_(destination[weight_name], non_blocking=True)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.weight_name, count=1)
else:
self.weight_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.weight_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
weight_tensor = lazy_load_file.get_tensor(self.weight_name).to(self.infer_dtype)
self.pin_weight = self.pin_weight.copy_(weight_tensor)
del weight_tensor
@RMS_WEIGHT_REGISTER("Default")
class RMSWeight(RMSWeightTemplate):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, input_tensor):
if GET_SENSITIVE_DTYPE() != GET_DTYPE():
input_tensor = self._norm(input_tensor).type_as(input_tensor) * self.weight
else:
input_tensor = self._norm(input_tensor.float()).type_as(input_tensor) * self.weight
return input_tensor
@RMS_WEIGHT_REGISTER("sgl-kernel")
class RMSWeightSgl(RMSWeight):
def __init__(
self,
weight_name,
create_cuda_buffer=False,
create_cpu_buffer=False,
lazy_load=False,
lazy_load_file=None,
is_post_adapter=False,
eps=1e-6,
):
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor):
if sgl_kernel is not None and self.sensitive_layer_dtype == self.infer_dtype:
input_tensor = input_tensor.contiguous()
orig_shape = input_tensor.shape
input_tensor = input_tensor.view(-1, orig_shape[-1])
input_tensor = sgl_kernel.rmsnorm(input_tensor, self.weight, self.eps).view(orig_shape)
else:
# sgl_kernel is not available or dtype!=torch.bfloat16/float16, fallback to default implementation
if self.sensitive_layer_dtype != self.infer_dtype:
input_tensor = input_tensor * torch.rsqrt(input_tensor.float().pow(2).mean(-1, keepdim=True) + self.eps).to(self.infer_dtype)
input_tensor = (input_tensor * self.weight).to(self.infer_dtype)
else:
input_tensor = input_tensor * torch.rsqrt(input_tensor.pow(2).mean(-1, keepdim=True) + self.eps)
input_tensor = input_tensor * self.weight
return input_tensor
@RMS_WEIGHT_REGISTER("fp32_variance")
class RMSWeightFP32(RMSWeight):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def apply(self, input_tensor):
input_dtype = input_tensor.dtype
variance = input_tensor.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = input_tensor * torch.rsqrt(variance + self.eps)
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
if self.weight is not None:
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)
return hidden_states
@RMS_WEIGHT_REGISTER("self_forcing")
class RMSWeightSF(RMSWeight):
def __init__(self, weight_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False, eps=1e-6):
super().__init__(weight_name, create_cuda_buffer, create_cpu_buffer, lazy_load, lazy_load_file, is_post_adapter, eps)
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def apply(self, x):
return self._norm(x.float()).type_as(x) * self.weight
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang
# TODO: for temporary usage, expecting a refactor
from typing import Optional
import torch
import triton # type: ignore
import triton.language as tl # type: ignore
from torch import Tensor
@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 64}, num_warps=2),
triton.Config({"BLOCK_N": 128}, num_warps=4),
triton.Config({"BLOCK_N": 256}, num_warps=4),
triton.Config({"BLOCK_N": 512}, num_warps=4),
triton.Config({"BLOCK_N": 1024}, num_warps=8),
],
key=["inner_dim"],
)
@triton.jit
def _fused_scale_shift_4d_kernel(
output_ptr,
normalized_ptr,
scale_ptr,
shift_ptr,
rows,
inner_dim,
seq_len,
num_frames,
frame_seqlen,
BLOCK_N: tl.constexpr,
):
pid_row = tl.program_id(0)
pid_col = tl.program_id(1)
col_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)
mask = col_offsets < inner_dim
# Pointers for normalized and output
row_base = pid_row * inner_dim
norm_ptrs = normalized_ptr + row_base + col_offsets
out_ptrs = output_ptr + row_base + col_offsets
# Pointers for scale and shift for 4D
b_idx = pid_row // seq_len
t_idx = pid_row % seq_len
frame_idx_in_batch = t_idx // frame_seqlen
scale_row_idx = b_idx * num_frames + frame_idx_in_batch
scale_ptrs = scale_ptr + scale_row_idx * inner_dim + col_offsets
shift_ptrs = shift_ptr + scale_row_idx * inner_dim + col_offsets
normalized = tl.load(norm_ptrs, mask=mask, other=0.0)
scale = tl.load(scale_ptrs, mask=mask, other=0.0)
shift = tl.load(shift_ptrs, mask=mask, other=0.0)
one = tl.full([BLOCK_N], 1.0, dtype=scale.dtype)
output = normalized * (one + scale) + shift
tl.store(out_ptrs, output, mask=mask)
@triton.jit
def fuse_scale_shift_kernel_blc_opt(
x_ptr,
shift_ptr,
scale_ptr,
y_ptr,
B,
L,
C,
stride_x_b,
stride_x_l,
stride_x_c,
stride_s_b,
stride_s_l,
stride_s_c,
stride_sc_b,
stride_sc_l,
stride_sc_c,
SCALE_IS_SCALAR: tl.constexpr,
SHIFT_IS_SCALAR: tl.constexpr,
BLOCK_L: tl.constexpr,
BLOCK_C: tl.constexpr,
):
pid_l = tl.program_id(0)
pid_c = tl.program_id(1)
pid_b = tl.program_id(2)
l_offsets = pid_l * BLOCK_L + tl.arange(0, BLOCK_L)
c_offsets = pid_c * BLOCK_C + tl.arange(0, BLOCK_C)
mask_l = l_offsets < L
mask_c = c_offsets < C
mask = mask_l[:, None] & mask_c[None, :]
x_off = pid_b * stride_x_b + l_offsets[:, None] * stride_x_l + c_offsets[None, :] * stride_x_c
x = tl.load(x_ptr + x_off, mask=mask, other=0)
if SHIFT_IS_SCALAR:
shift_val = tl.load(shift_ptr)
shift = tl.full((BLOCK_L, BLOCK_C), shift_val, dtype=shift_val.dtype)
else:
s_off = pid_b * stride_s_b + l_offsets[:, None] * stride_s_l + c_offsets[None, :] * stride_s_c
shift = tl.load(shift_ptr + s_off, mask=mask, other=0)
if SCALE_IS_SCALAR:
scale_val = tl.load(scale_ptr)
scale = tl.full((BLOCK_L, BLOCK_C), scale_val, dtype=scale_val.dtype)
else:
sc_off = pid_b * stride_sc_b + l_offsets[:, None] * stride_sc_l + c_offsets[None, :] * stride_sc_c
scale = tl.load(scale_ptr + sc_off, mask=mask, other=0)
y = x * (1 + scale) + shift
tl.store(y_ptr + x_off, y, mask=mask)
def fuse_scale_shift_kernel(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
block_l: int = 128,
block_c: int = 128,
):
assert x.is_cuda and scale.is_cuda
assert x.is_contiguous()
B, L, C = x.shape
output = torch.empty_like(x)
if scale.dim() == 4:
# scale/shift: [B, F, 1, C]
rows = B * L
x_2d = x.view(rows, C)
output_2d = output.view(rows, C)
grid = lambda META: (rows, triton.cdiv(C, META["BLOCK_N"])) # noqa
num_frames = scale.shape[1]
assert L % num_frames == 0, "seq_len must be divisible by num_frames for 4D scale/shift"
frame_seqlen = L // num_frames
# Compact [B, F, C] without the singleton dim into [B*F, C]
scale_reshaped = scale.squeeze(2).reshape(-1, C).contiguous()
shift_reshaped = shift.squeeze(2).reshape(-1, C).contiguous()
_fused_scale_shift_4d_kernel[grid](
output_2d,
x_2d,
scale_reshaped,
shift_reshaped,
rows,
C,
L,
num_frames,
frame_seqlen,
)
else:
# 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L
# 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C])
# Also support scalar (0D or 1-element)
if scale.dim() == 0 or (scale.dim() == 1 and scale.numel() == 1):
scale_blc = scale.reshape(1)
elif scale.dim() == 2:
scale_blc = scale[:, None, :]
elif scale.dim() == 3:
scale_blc = scale
else:
raise ValueError("scale must be 0D/1D(1)/2D/3D or 4D")
if shift.dim() == 0 or (shift.dim() == 1 and shift.numel() == 1):
shift_blc = shift.reshape(1)
elif shift.dim() == 2:
shift_blc = shift[:, None, :]
elif shift.dim() == 3:
shift_blc = shift
else:
# broadcast later via expand if possible
shift_blc = shift
need_scale_scalar = scale_blc.dim() == 1 and scale_blc.numel() == 1
need_shift_scalar = shift_blc.dim() == 1 and shift_blc.numel() == 1
if not need_scale_scalar:
scale_exp = scale_blc.expand(B, L, C)
s_sb, s_sl, s_sc = scale_exp.stride()
else:
s_sb = s_sl = s_sc = 0
if not need_shift_scalar:
shift_exp = shift_blc.expand(B, L, C)
sh_sb, sh_sl, sh_sc = shift_exp.stride()
else:
sh_sb = sh_sl = sh_sc = 0
# If both scalars and both zero, copy fast-path
if need_scale_scalar and need_shift_scalar:
if (scale_blc.abs().max() == 0) and (shift_blc.abs().max() == 0):
output.copy_(x)
return output
grid = (triton.cdiv(L, block_l), triton.cdiv(C, block_c), B)
fuse_scale_shift_kernel_blc_opt[grid](
x,
shift_blc if need_shift_scalar else shift_exp,
scale_blc if need_scale_scalar else scale_exp,
output,
B,
L,
C,
x.stride(0),
x.stride(1),
x.stride(2),
sh_sb,
sh_sl,
sh_sc,
s_sb,
s_sl,
s_sc,
SCALE_IS_SCALAR=need_scale_scalar,
SHIFT_IS_SCALAR=need_shift_scalar,
BLOCK_L=block_l,
BLOCK_C=block_c,
num_warps=4,
num_stages=2,
)
return output
@triton.autotune(
configs=[
triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
],
key=["head_size", "interleaved"],
)
@triton.jit
def _rotary_embedding_kernel(
output_ptr,
x_ptr,
cos_ptr,
sin_ptr,
num_heads,
head_size,
num_tokens,
stride_x_row,
stride_cos_row,
stride_sin_row,
interleaved: tl.constexpr,
BLOCK_HS_HALF: tl.constexpr,
):
row_idx = tl.program_id(0)
token_idx = (row_idx // num_heads) % num_tokens
x_row_ptr = x_ptr + row_idx * stride_x_row
cos_row_ptr = cos_ptr + token_idx * stride_cos_row
sin_row_ptr = sin_ptr + token_idx * stride_sin_row
output_row_ptr = output_ptr + row_idx * stride_x_row
# half size for x1 and x2
head_size_half = head_size // 2
for block_start in range(0, head_size_half, BLOCK_HS_HALF):
offsets_half = block_start + tl.arange(0, BLOCK_HS_HALF)
mask = offsets_half < head_size_half
cos_vals = tl.load(cos_row_ptr + offsets_half, mask=mask, other=0.0)
sin_vals = tl.load(sin_row_ptr + offsets_half, mask=mask, other=0.0)
offsets_x1 = 2 * offsets_half
offsets_x2 = 2 * offsets_half + 1
x1_vals = tl.load(x_row_ptr + offsets_x1, mask=mask, other=0.0)
x2_vals = tl.load(x_row_ptr + offsets_x2, mask=mask, other=0.0)
x1_fp32 = x1_vals.to(tl.float32)
x2_fp32 = x2_vals.to(tl.float32)
cos_fp32 = cos_vals.to(tl.float32)
sin_fp32 = sin_vals.to(tl.float32)
o1_vals = tl.fma(-x2_fp32, sin_fp32, x1_fp32 * cos_fp32)
o2_vals = tl.fma(x1_fp32, sin_fp32, x2_fp32 * cos_fp32)
tl.store(output_row_ptr + offsets_x1, o1_vals.to(x1_vals.dtype), mask=mask)
tl.store(output_row_ptr + offsets_x2, o2_vals.to(x2_vals.dtype), mask=mask)
def apply_rotary_embedding(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
output = torch.empty_like(x)
if x.dim() > 3:
bsz, num_tokens, num_heads, head_size = x.shape
else:
num_tokens, num_heads, head_size = x.shape
bsz = 1
assert head_size % 2 == 0, "head_size must be divisible by 2"
x_reshaped = x.view(-1, head_size)
output_reshaped = output.view(-1, head_size)
# num_tokens per head, 1 token per block
grid = (bsz * num_tokens * num_heads,)
if interleaved and cos.shape[-1] == head_size:
cos = cos[..., ::2].contiguous()
sin = sin[..., ::2].contiguous()
else:
cos = cos.contiguous()
sin = sin.contiguous()
_rotary_embedding_kernel[grid](
output_reshaped,
x_reshaped,
cos,
sin,
num_heads,
head_size,
num_tokens,
x_reshaped.stride(0),
cos.stride(0),
sin.stride(0),
interleaved,
)
return output
# RMSNorm-fp32
def maybe_contiguous_lastdim(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
def maybe_contiguous(x):
return x.contiguous() if x is not None else None
def triton_autotune_configs():
if not torch.cuda.is_available():
return []
# Return configs with a valid warp count for the current device
configs = []
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
max_threads_per_block = 1024
# Default to warp size 32 if not defined by device
warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32] if warp_count * warp_size <= max_threads_per_block]
# return [triton.Config({}, num_warps=8)]
# Copied from flash-attn
@triton.autotune(
configs=triton_autotune_configs(),
key=[
"N",
"HAS_RESIDUAL",
"STORE_RESIDUAL_OUT",
"IS_RMS_NORM",
"HAS_BIAS",
"HAS_WEIGHT",
"HAS_X1",
"HAS_W1",
"HAS_B1",
],
)
# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
@triton.jit
def _layer_norm_fwd_1pass_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
RESIDUAL, # pointer to the residual
X1,
W1,
B1,
Y1,
RESIDUAL_OUT, # pointer to the residual
ROWSCALE,
SEEDS, # Dropout seeds for each row
DROPOUT_MASK,
DROPOUT_MASK1,
Mean, # pointer to the mean
Rstd, # pointer to the 1/std
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_res_row,
stride_res_out_row,
stride_x1_row,
stride_y1_row,
M, # number of rows in X
N, # number of columns in X
eps, # epsilon to avoid division by zero
dropout_p, # Dropout probability
zero_centered_weight, # If true, add 1.0 to the weight
IS_RMS_NORM: tl.constexpr,
BLOCK_N: tl.constexpr,
HAS_RESIDUAL: tl.constexpr,
STORE_RESIDUAL_OUT: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_DROPOUT: tl.constexpr,
STORE_DROPOUT_MASK: tl.constexpr,
HAS_ROWSCALE: tl.constexpr,
HAS_X1: tl.constexpr,
HAS_W1: tl.constexpr,
HAS_B1: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
if HAS_X1:
X1 += row * stride_x1_row
if HAS_W1:
Y1 += row * stride_y1_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + row).to(tl.float32)
x *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
if HAS_X1:
x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)
if HAS_ROWSCALE:
rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)
x1 *= rowscale
if HAS_DROPOUT:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask = tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
if STORE_DROPOUT_MASK:
tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
x += x1
if HAS_RESIDUAL:
residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
x += residual
if STORE_RESIDUAL_OUT:
tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
tl.store(Mean + row, mean)
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
tl.store(Rstd + row, rstd)
# Normalize and apply linear transformation
mask = cols < N
if HAS_WEIGHT:
w = tl.load(W + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w += 1.0
if HAS_BIAS:
b = tl.load(B + cols, mask=mask).to(tl.float32)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
y = x_hat * w + b if HAS_BIAS else x_hat * w
else:
y = x_hat + b if HAS_BIAS else x_hat
# Write output
tl.store(Y + cols, y, mask=mask)
if HAS_W1:
w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
if zero_centered_weight:
w1 += 1.0
if HAS_B1:
b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
tl.store(Y1 + cols, y1, mask=mask)
def _layer_norm_fwd(
x: Tensor,
weight: Tensor,
bias: Tensor,
eps: float,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
residual_dtype: Optional[torch.dtype] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
out: Optional[Tensor] = None,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
# Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
# and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
# so that _layer_norm_fwd_impl doesn't have to return them.
if out is None:
out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
if residual is not None:
residual_dtype = residual.dtype
if residual_out is None and (residual is not None or (residual_dtype is not None and residual_dtype != x.dtype) or dropout_p > 0.0 or rowscale is not None or x1 is not None):
residual_out = torch.empty_like(x, dtype=residual_dtype if residual_dtype is not None else x.dtype)
else:
residual_out = None
y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
x,
weight,
bias,
eps,
out,
residual=residual,
x1=x1,
weight1=weight1,
bias1=bias1,
dropout_p=dropout_p,
rowscale=rowscale,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
residual_out=residual_out,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if residual_out is None:
residual_out = x
return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
# since we're returning a tuple of tensors
def _layer_norm_fwd_impl(
x: Tensor,
weight: Optional[Tensor],
bias: Tensor,
eps: float,
out: Tensor,
residual: Optional[Tensor] = None,
x1: Optional[Tensor] = None,
weight1: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
dropout_p: float = 0.0,
rowscale: Optional[Tensor] = None,
zero_centered_weight: bool = False,
is_rms_norm: bool = False,
return_dropout_mask: bool = False,
residual_out: Optional[Tensor] = None,
) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
M, N = x.shape
assert x.stride(-1) == 1
if residual is not None:
assert residual.stride(-1) == 1
assert residual.shape == (M, N)
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if x1 is not None:
assert x1.shape == x.shape
assert rowscale is None
assert x1.stride(-1) == 1
if weight1 is not None:
assert weight1.shape == (N,)
assert weight1.stride(-1) == 1
if bias1 is not None:
assert bias1.shape == (N,)
assert bias1.stride(-1) == 1
if rowscale is not None:
assert rowscale.is_contiguous()
assert rowscale.shape == (M,)
assert out.shape == x.shape
assert out.stride(-1) == 1
if residual_out is not None:
assert residual_out.shape == x.shape
assert residual_out.stride(-1) == 1
if weight1 is not None:
y1 = torch.empty_like(out)
assert y1.stride(-1) == 1
else:
y1 = None
mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
if dropout_p > 0.0:
seeds = torch.randint(2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64)
else:
seeds = None
if return_dropout_mask and dropout_p > 0.0:
dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
if x1 is not None:
dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
else:
dropout_mask1 = None
else:
dropout_mask, dropout_mask1 = None, None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
with torch.cuda.device(x.device.index):
torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
x,
out,
weight if weight is not None else x, # unused when HAS_WEIGHT == False
bias,
residual,
x1,
weight1,
bias1,
y1,
residual_out,
rowscale,
seeds,
dropout_mask,
dropout_mask1,
mean,
rstd,
x.stride(0),
out.stride(0),
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,
x1.stride(0) if x1 is not None else 0,
y1.stride(0) if y1 is not None else 0,
M,
N,
eps,
dropout_p,
# Passing bool make torch inductor very unhappy since it then tries to compare to int_max
int(zero_centered_weight),
is_rms_norm,
BLOCK_N,
residual is not None,
residual_out is not None,
weight is not None,
bias is not None,
dropout_p > 0.0,
dropout_mask is not None,
rowscale is not None,
HAS_X1=x1 is not None,
HAS_W1=weight1 is not None,
HAS_B1=bias1 is not None,
)
return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
class LayerNormFn:
@staticmethod
def forward(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
if residual is not None:
assert residual.shape == x_shape_og
residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
if x1 is not None:
assert x1.shape == x_shape_og
assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
# weight can be None when elementwise_affine=False for LayerNorm
if weight is not None:
weight = weight.contiguous()
bias = maybe_contiguous(bias)
weight1 = maybe_contiguous(weight1)
bias1 = maybe_contiguous(bias1)
if rowscale is not None:
rowscale = rowscale.reshape(-1).contiguous()
residual_dtype = residual.dtype if residual is not None else (torch.float32 if residual_in_fp32 else None)
if out is not None:
out = out.reshape(-1, out.shape[-1])
if residual_out is not None:
residual_out = residual_out.reshape(-1, residual_out.shape[-1])
y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
x,
weight,
bias,
eps,
residual,
x1,
weight1,
bias1,
dropout_p=dropout_p,
rowscale=rowscale,
out_dtype=out_dtype,
residual_dtype=residual_dtype,
zero_centered_weight=zero_centered_weight,
is_rms_norm=is_rms_norm,
return_dropout_mask=return_dropout_mask,
out=out,
residual_out=residual_out,
)
y = y.reshape(x_shape_og)
return y
def layer_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
is_rms_norm=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
is_rms_norm,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
@triton.jit
def _norm_infer_kernel(
X,
Y,
W,
B,
stride_x_row,
stride_y_row,
M,
N,
eps,
IS_RMS_NORM: tl.constexpr,
HAS_WEIGHT: tl.constexpr,
HAS_BIAS: tl.constexpr,
BLOCK_N: tl.constexpr,
):
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_WEIGHT:
W += 0
if HAS_BIAS:
B += 0
cols = tl.arange(0, BLOCK_N)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
if not IS_RMS_NORM:
mean = tl.sum(x, axis=0) / N
xbar = tl.where(cols < N, x - mean, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
else:
xbar = tl.where(cols < N, x, 0.0)
var = tl.sum(xbar * xbar, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
if HAS_WEIGHT:
w = tl.load(W + cols, mask=cols < N, other=1.0).to(tl.float32)
y = x_hat * w
else:
y = x_hat
if HAS_BIAS:
b = tl.load(B + cols, mask=cols < N, other=0.0).to(tl.float32)
y += b
tl.store(Y + cols, y, mask=cols < N)
def norm_infer(
x: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
is_rms_norm: bool = False,
out: Optional[Tensor] = None,
):
M, N = x.shape
assert x.stride(-1) == 1
if weight is not None:
assert weight.shape == (N,)
assert weight.stride(-1) == 1
if bias is not None:
assert bias.shape == (N,)
assert bias.stride(-1) == 1
if out is None:
out = torch.empty_like(x)
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_N // 256, 1), 8)
_norm_infer_kernel[(M,)](
x,
out,
weight if weight is not None else x, # dummy when HAS_WEIGHT=False
bias if bias is not None else x, # dummy when HAS_BIAS=False
x.stride(0),
out.stride(0),
M,
N,
eps,
IS_RMS_NORM=is_rms_norm,
HAS_WEIGHT=weight is not None,
HAS_BIAS=bias is not None,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
)
return out
def rms_norm_fn(
x,
weight,
bias,
residual=None,
x1=None,
weight1=None,
bias1=None,
eps=1e-6,
dropout_p=0.0,
rowscale=None,
prenorm=False,
residual_in_fp32=False,
zero_centered_weight=False,
return_dropout_mask=False,
out_dtype=None,
out=None,
residual_out=None,
):
return LayerNormFn.forward(
x,
weight,
bias,
residual,
x1,
weight1,
bias1,
eps,
dropout_p,
rowscale,
prenorm,
residual_in_fp32,
zero_centered_weight,
True,
return_dropout_mask,
out_dtype,
out,
residual_out,
)
from .tensor import DefaultTensor
import os
import re
from pathlib import Path
import torch
from safetensors import safe_open
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import TENSOR_REGISTER
from lightx2v_platform.base.global_var import AI_DEVICE
@TENSOR_REGISTER("Default")
class DefaultTensor:
def __init__(self, tensor_name, create_cuda_buffer=False, create_cpu_buffer=False, lazy_load=False, lazy_load_file=None, is_post_adapter=False):
self.tensor_name = tensor_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.is_post_adapter = is_post_adapter
self.create_cuda_buffer = create_cuda_buffer
self.create_cpu_buffer = create_cpu_buffer
self.infer_dtype = GET_DTYPE()
self.sensitive_layer_dtype = GET_SENSITIVE_DTYPE()
def load(self, weight_dict):
if self.create_cuda_buffer:
self._load_cuda_buffer(weight_dict)
elif self.create_cpu_buffer:
self._load_cpu_pin_buffer()
else:
self._load_default_tensors(weight_dict)
def _load_default_tensors(self, weight_dict):
if not self.lazy_load:
device = weight_dict[self.tensor_name].device
if device.type == "cpu":
tensor = weight_dict[self.tensor_name]
self.pin_tensor = self._create_cpu_pin_tensor(tensor)
del weight_dict[self.tensor_name]
else:
self.tensor = weight_dict[self.tensor_name]
def _get_tensor(self, weight_dict=None, use_infer_dtype=False):
if self.lazy_load:
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{self.tensor_name.split('.')[1]}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.tensor_name)
if use_infer_dtype:
tensor = tensor.to(self.infer_dtype)
else:
tensor = weight_dict[self.tensor_name]
return tensor
def _create_cpu_pin_tensor(self, tensor):
pin_tensor = torch.empty(tensor.shape, pin_memory=True, dtype=tensor.dtype)
pin_tensor.copy_(tensor)
del tensor
return pin_tensor
def _load_cuda_buffer(self, weight_dict):
tensor = self._get_tensor(weight_dict, use_infer_dtype=self.lazy_load)
self.tensor_cuda_buffer = tensor.to(AI_DEVICE)
def _load_cpu_pin_buffer(self):
tensor = self._get_tensor(use_infer_dtype=True)
self.pin_tensor = self._create_cpu_pin_tensor(tensor)
def to_cuda(self, non_blocking=False):
self.tensor = self.pin_tensor.to(AI_DEVICE, non_blocking=non_blocking)
def to_cpu(self, non_blocking=False):
if hasattr(self, "pin_tensor"):
self.tensor = self.pin_tensor.copy_(self.tensor, non_blocking=non_blocking).cpu()
else:
self.tensor = self.tensor.to("cpu", non_blocking=non_blocking)
def state_dict(self, destination=None):
if destination is None:
destination = {}
destination[self.tensor_name] = self.pin_tensor if hasattr(self, "pin_tensor") else self.tensor
return destination
def load_state_dict(self, destination, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
if tensor_name not in destination:
self.tensor = None
return
self.tensor = self.tensor_cuda_buffer.copy_(destination[tensor_name], non_blocking=True)
def load_state_dict_from_disk(self, block_index, adapter_block_index=None):
if self.is_post_adapter:
assert adapter_block_index is not None
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{adapter_block_index}", self.tensor_name, count=1)
else:
self.tensor_name = re.sub(r"\.\d+", lambda m: f".{block_index}", self.tensor_name, count=1)
if Path(self.lazy_load_file).is_file():
lazy_load_file_path = self.lazy_load_file
else:
lazy_load_file_path = os.path.join(self.lazy_load_file, f"block_{block_index}.safetensors")
with safe_open(lazy_load_file_path, framework="pt", device="cpu") as lazy_load_file:
tensor = lazy_load_file.get_tensor(self.tensor_name).to(self.infer_dtype)
self.pin_tensor = self.pin_tensor.copy_(tensor)
del tensor
import math
from abc import ABC, abstractmethod
class BaseTransformerInfer(ABC):
@abstractmethod
def infer(self):
pass
def set_scheduler(self, scheduler):
self.scheduler = scheduler
self.scheduler.transformer_infer = self
class BaseTaylorCachingTransformerInfer(BaseTransformerInfer):
@abstractmethod
def infer_calculating(self):
pass
@abstractmethod
def infer_using_cache(self):
pass
@abstractmethod
def get_taylor_step_diff(self):
pass
# 1. when fully calcualted, stored in cache
def derivative_approximation(self, block_cache, module_name, out):
if module_name not in block_cache:
block_cache[module_name] = {0: out}
else:
step_diff = self.get_taylor_step_diff()
previous_out = block_cache[module_name][0]
block_cache[module_name][0] = out
block_cache[module_name][1] = (out - previous_out) / step_diff
def taylor_formula(self, tensor_dict):
x = self.get_taylor_step_diff()
output = 0
for i in range(len(tensor_dict)):
output += (1 / math.factorial(i)) * tensor_dict[i] * (x**i)
return output
import asyncio
import json
import os
import sys
from alibabacloud_dypnsapi20170525 import models as dypnsapi_models
from alibabacloud_dypnsapi20170525.client import Client
from alibabacloud_tea_openapi import models as openapi_models
from alibabacloud_tea_util import models as util_models
from loguru import logger
class AlibabaCloudClient:
def __init__(self):
config = openapi_models.Config(
access_key_id=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_ID"),
access_key_secret=os.getenv("ALIBABA_CLOUD_ACCESS_KEY_SECRET"),
https_proxy=os.getenv("auth_https_proxy", None),
)
self.client = Client(config)
self.runtime = util_models.RuntimeOptions()
def check_ok(self, res, prefix):
logger.info(f"{prefix}: {res}")
if not isinstance(res, dict) or "statusCode" not in res or res["statusCode"] != 200:
logger.warning(f"{prefix}: error response: {res}")
return False
if "body" not in res or "Code" not in res["body"] or "Success" not in res["body"]:
logger.warning(f"{prefix}: error body: {res}")
return False
if res["body"]["Code"] != "OK" or res["body"]["Success"] is not True:
logger.warning(f"{prefix}: sms error: {res}")
return False
return True
async def send_sms(self, phone_number):
try:
req = dypnsapi_models.SendSmsVerifyCodeRequest(
phone_number=phone_number,
sign_name="速通互联验证服务",
template_code="100001",
template_param=json.dumps({"code": "##code##", "min": "5"}),
valid_time=300,
)
res = await self.client.send_sms_verify_code_with_options_async(req, self.runtime)
ok = self.check_ok(res.to_map(), "AlibabaCloudClient send sms")
logger.info(f"AlibabaCloudClient send sms for {phone_number}: {ok}")
return ok
except Exception as e:
logger.warning(f"AlibabaCloudClient send sms for {phone_number}: {e}")
return False
async def check_sms(self, phone_number, verify_code):
try:
req = dypnsapi_models.CheckSmsVerifyCodeRequest(
phone_number=phone_number,
verify_code=verify_code,
)
res = await self.client.check_sms_verify_code_with_options_async(req, self.runtime)
ok = self.check_ok(res.to_map(), "AlibabaCloudClient check sms")
logger.info(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {ok}")
return ok
except Exception as e:
logger.warning(f"AlibabaCloudClient check sms for {phone_number} with {verify_code}: {e}")
return False
async def test(args):
assert len(args) in [1, 2], "Usage: python aliyun_sms.py <phone_number> [verify_code]"
phone_number = args[0]
client = AlibabaCloudClient()
if len(args) == 1:
await client.send_sms(phone_number)
else:
await client.check_sms(phone_number, args[1])
if __name__ == "__main__":
asyncio.run(test(sys.argv[1:]))
# -*- coding: utf-8 -*-
"""
Audio Source Separation Module
Separates different voice tracks in audio, supports multi-person audio separation
"""
import base64
import io
import os
import tempfile
import traceback
from collections import defaultdict
from typing import Dict, Optional, Union
import torch
import torchaudio
from loguru import logger
# Import pyannote.audio for speaker diarization
from pyannote.audio import Audio, Pipeline
_origin_torch_load = torch.load
def our_torch_load(checkpoint_file, *args, **kwargs):
kwargs["weights_only"] = False
return _origin_torch_load(checkpoint_file, *args, **kwargs)
class AudioSeparator:
"""
Audio separator for separating different voice tracks in audio using pyannote.audio
Supports multi-person conversation separation, maintains duration (other speakers' tracks are empty)
"""
def __init__(
self,
model_path: str = None,
device: str = None,
sample_rate: int = 16000,
):
"""
Initialize audio separator
Args:
model_path: Model path (if using custom model), default uses pyannote/speaker-diarization-community-1
device: Device ('cpu', 'cuda', etc.), None for auto selection
sample_rate: Target sample rate, default 16000
"""
self.sample_rate = sample_rate
self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu")
self._init_pyannote(model_path)
def _init_pyannote(self, model_path: str = None):
"""Initialize pyannote.audio pipeline"""
try:
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN")
model_name = model_path or "pyannote/speaker-diarization-community-1"
try:
torch.load = our_torch_load
# Try loading with token if available
if huggingface_token:
self.pipeline = Pipeline.from_pretrained(model_name, token=huggingface_token)
else:
# Try without token (may work for public models)
self.pipeline = Pipeline.from_pretrained(model_name)
except Exception as e:
if "gated" in str(e).lower() or "token" in str(e).lower():
raise RuntimeError(f"Model requires authentication. Set HUGGINGFACE_TOKEN or HF_TOKEN environment variable: {e}")
raise RuntimeError(f"Failed to load pyannote model: {e}")
finally:
torch.load = _origin_torch_load
# Move pipeline to specified device
if self.device:
self.pipeline.to(torch.device(self.device))
# Initialize Audio helper for waveform loading
self.pyannote_audio = Audio()
logger.info("Initialized pyannote.audio speaker diarization pipeline")
except Exception as e:
logger.error(f"Failed to initialize pyannote: {e}")
raise RuntimeError(f"Failed to initialize pyannote.audio pipeline: {e}")
def separate_speakers(
self,
audio_path: Union[str, bytes],
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""
Separate different speakers in audio
Args:
audio_path: Audio file path or bytes data
num_speakers: Specified number of speakers, None for auto detection
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Dict containing:
- speakers: List of speaker audio segments, each containing:
- speaker_id: Speaker ID (0, 1, 2, ...)
- audio: torch.Tensor audio data [channels, samples]
- segments: List of (start_time, end_time) tuples
- sample_rate: Sample rate
"""
try:
# Load audio
if isinstance(audio_path, bytes):
# 尝试从字节数据推断音频格式
# 检查是否是 WAV 格式(RIFF 头)
is_wav = audio_path[:4] == b"RIFF" and audio_path[8:12] == b"WAVE"
# 检查是否是 MP3 格式(ID3 或 MPEG 头)
is_mp3 = audio_path[:3] == b"ID3" or audio_path[:2] == b"\xff\xfb" or audio_path[:2] == b"\xff\xf3"
# 根据格式选择后缀
if is_wav:
suffix = ".wav"
elif is_mp3:
suffix = ".mp3"
else:
# 默认尝试 WAV,如果失败会抛出错误
suffix = ".wav"
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp_file:
tmp_file.write(audio_path)
tmp_audio_path = tmp_file.name
try:
result = self._separate_speakers_internal(tmp_audio_path, num_speakers, min_speakers, max_speakers)
finally:
# 确保临时文件被删除
try:
os.unlink(tmp_audio_path)
except Exception as e:
logger.warning(f"Failed to delete temp file {tmp_audio_path}: {e}")
return result
else:
return self._separate_speakers_internal(audio_path, num_speakers, min_speakers, max_speakers)
except Exception as e:
logger.error(f"Speaker separation failed: {traceback.format_exc()}")
raise RuntimeError(f"Audio separation error: {e}")
def _separate_speakers_internal(
self,
audio_path: str,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""Internal method: execute speaker separation"""
# Load audio
waveform, original_sr = torchaudio.load(audio_path)
if original_sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(original_sr, self.sample_rate)
waveform = resampler(waveform)
# Convert to mono if stereo
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Ensure waveform is float32 and normalized (pyannote expects this format)
if waveform.dtype != torch.float32:
waveform = waveform.float()
# Ensure waveform is in range [-1, 1] (normalize if needed)
if waveform.abs().max() > 1.0:
waveform = waveform / waveform.abs().max()
if self.pipeline is None:
raise RuntimeError("Pyannote pipeline not initialized")
return self._separate_with_pyannote(audio_path, waveform, num_speakers, min_speakers, max_speakers)
def _separate_with_pyannote(
self,
audio_path: str,
waveform: torch.Tensor,
num_speakers: Optional[int],
min_speakers: int,
max_speakers: int,
) -> Dict:
"""Use pyannote.audio for speaker diarization"""
try:
# Use waveform dict to avoid AudioDecoder dependency issues
# Pipeline can accept either file path or waveform dict
# Using waveform dict is more reliable when torchcodec is not properly installed
audio_input = {
"waveform": waveform,
"sample_rate": self.sample_rate,
}
# Run speaker diarization
output = self.pipeline(
audio_input,
min_speakers=min_speakers if num_speakers is None else num_speakers,
max_speakers=max_speakers if num_speakers is None else num_speakers,
)
# Extract audio segments for each speaker
speakers_dict = defaultdict(list)
for turn, speaker in output.speaker_diarization:
print(f"Speaker: {speaker}, Start time: {turn.start}, End time: {turn.end}")
start_time = turn.start
end_time = turn.end
start_sample = int(start_time * self.sample_rate)
end_sample = int(end_time * self.sample_rate)
# Extract audio segment for this time period
segment_audio = waveform[:, start_sample:end_sample]
speakers_dict[speaker].append((start_time, end_time, segment_audio))
# Generate complete audio for each speaker (other speakers' segments are empty)
speakers = []
audio_duration = waveform.shape[1] / self.sample_rate
num_samples = waveform.shape[1]
for speaker_id, segments in speakers_dict.items():
# Create zero-filled audio
speaker_audio = torch.zeros_like(waveform)
# Fill in this speaker's segments
for start_time, end_time, segment_audio in segments:
start_sample = int(start_time * self.sample_rate)
end_sample = int(end_time * self.sample_rate)
# Ensure no out-of-bounds
end_sample = min(end_sample, num_samples)
segment_len = end_sample - start_sample
if segment_len > 0 and segment_audio.shape[1] > 0:
actual_len = min(segment_len, segment_audio.shape[1])
speaker_audio[:, start_sample : start_sample + actual_len] = segment_audio[:, :actual_len]
speakers.append(
{
"speaker_id": speaker_id,
"audio": speaker_audio,
"segments": [(s[0], s[1]) for s in segments],
"sample_rate": self.sample_rate,
}
)
logger.info(f"Separated audio into {len(speakers)} speakers using pyannote")
return {"speakers": speakers, "method": "pyannote"}
except Exception as e:
logger.error(f"Pyannote separation failed: {e}")
raise RuntimeError(f"Audio separation failed: {e}")
def save_speaker_audio(self, speaker_audio: torch.Tensor, output_path: str, sample_rate: int = None):
"""
Save speaker audio to file
Args:
speaker_audio: Audio tensor [channels, samples]
output_path: Output path
sample_rate: Sample rate, if None uses self.sample_rate
"""
sr = sample_rate if sample_rate else self.sample_rate
torchaudio.save(output_path, speaker_audio, sr)
logger.info(f"Saved speaker audio to {output_path}")
def speaker_audio_to_base64(self, speaker_audio: torch.Tensor, sample_rate: int = None, format: str = "wav") -> str:
"""
Convert speaker audio tensor to base64 encoded string without saving to file
Args:
speaker_audio: Audio tensor [channels, samples]
sample_rate: Sample rate, if None uses self.sample_rate
format: Audio format (default: "wav")
Returns:
Base64 encoded audio string
"""
sr = sample_rate if sample_rate else self.sample_rate
# Use BytesIO to save audio to memory
buffer = io.BytesIO()
torchaudio.save(buffer, speaker_audio, sr, format=format)
# Get the audio bytes
audio_bytes = buffer.getvalue()
# Encode to base64
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
logger.debug(f"Converted speaker audio to base64, size: {len(audio_bytes)} bytes")
return audio_base64
def separate_and_save(
self,
audio_path: Union[str, bytes],
output_dir: str,
num_speakers: Optional[int] = None,
min_speakers: int = 1,
max_speakers: int = 5,
) -> Dict:
"""
Separate audio and save to files
Args:
audio_path: Input audio path or bytes data
output_dir: Output directory
num_speakers: Specified number of speakers
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Separation result dictionary, containing output file paths
"""
os.makedirs(output_dir, exist_ok=True)
result = self.separate_speakers(audio_path, num_speakers, min_speakers, max_speakers)
output_paths = []
for speaker in result["speakers"]:
speaker_id = speaker["speaker_id"]
output_path = os.path.join(output_dir, f"{speaker_id}.wav")
self.save_speaker_audio(speaker["audio"], output_path, speaker["sample_rate"])
output_paths.append(output_path)
speaker["output_path"] = output_path
result["output_paths"] = output_paths
return result
def separate_audio_tracks(
audio_path: str,
output_dir: str = None,
num_speakers: int = None,
model_path: str = None,
) -> Dict:
"""
Convenience function: separate different audio tracks
Args:
audio_path: Audio file path
output_dir: Output directory, if None does not save files
num_speakers: Number of speakers
model_path: Model path (optional)
Returns:
Separation result dictionary
"""
separator = AudioSeparator(model_path=model_path)
if output_dir:
return separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
else:
return separator.separate_speakers(audio_path, num_speakers=num_speakers)
if __name__ == "__main__":
# Test code
import sys
if len(sys.argv) < 2:
print("Usage: python audio_separator.py <audio_path> [output_dir] [num_speakers]")
sys.exit(1)
audio_path = sys.argv[1]
output_dir = sys.argv[2] if len(sys.argv) > 2 else "./separated_audio"
num_speakers = int(sys.argv[3]) if len(sys.argv) > 3 else None
separator = AudioSeparator()
result = separator.separate_and_save(audio_path, output_dir, num_speakers=num_speakers)
print(f"Separated audio into {len(result['speakers'])} speakers:")
for speaker in result["speakers"]:
print(f" Speaker {speaker['speaker_id']}: {len(speaker['segments'])} segments")
if "output_path" in speaker:
print(f" Saved to: {speaker['output_path']}")
# -*- coding: utf-8 -*-
"""
Face Detection Module using YOLO
Supports detecting faces in images, including human faces, animal faces, anime faces, sketches, etc.
"""
import io
import traceback
from typing import Dict, List, Union
import numpy as np
from PIL import Image, ImageDraw
from loguru import logger
from ultralytics import YOLO
class FaceDetector:
"""
Face detection using YOLO models
Supports detecting: human faces, animal faces, anime faces, sketch faces, etc.
"""
def __init__(self, model_path: str = None, conf_threshold: float = 0.25, device: str = None):
"""
Initialize face detector
Args:
model_path: YOLO model path, if None uses default pretrained model
conf_threshold: Confidence threshold, default 0.25
device: Device ('cpu', 'cuda', '0', '1', etc.), None for auto selection
"""
self.conf_threshold = conf_threshold
self.device = device
if model_path is None:
# Use YOLO11 pretrained model, can detect COCO dataset classes (including person)
# Or use dedicated face detection model
logger.info("Loading default YOLO11n model for face detection")
try:
self.model = YOLO("yolo11n.pt") # Lightweight model
except Exception as e:
logger.warning(f"Failed to load default model, trying yolov8n: {e}")
self.model = YOLO("yolov8n.pt")
else:
logger.info(f"Loading YOLO model from {model_path}")
self.model = YOLO(model_path)
# Person class ID in COCO dataset is 0
# YOLO can detect person, for more precise face detection, recommend using dedicated face detection models
# Such as YOLOv8-face or RetinaFace, can be specified via model_path parameter
# First use YOLO to detect person region, then can further detect faces within
self.target_classes = {
"person": 0, # Face (by detecting person class)
# Can be extended to detect animal faces (cat, dog, etc.) and other classes
}
def detect_faces(
self,
image: Union[str, Image.Image, bytes, np.ndarray],
return_image: bool = False,
) -> Dict:
"""
Detect faces in image
Args:
image: Input image, can be path, PIL Image, bytes or numpy array
return_image: Whether to return annotated image with detection boxes
return_boxes: Whether to return detection box information
Returns:
Dict containing:
- faces: List of face detection results, each containing:
- bbox: [x1, y1, x2, y2] bounding box coordinates (absolute pixel coordinates)
- confidence: Confidence score (0.0-1.0)
- class_id: Class ID
- class_name: Class name
- image (optional): PIL Image with detection boxes drawn (if return_image=True)
"""
try:
# Load image
if isinstance(image, str):
img = Image.open(image).convert("RGB")
elif isinstance(image, bytes):
img = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, np.ndarray):
img = Image.fromarray(image).convert("RGB")
elif isinstance(image, Image.Image):
img = image.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
# Use YOLO for detection
# Note: YOLO by default detects person, we focus on person detection
# For more precise face detection, can train or use dedicated face detection models
results = self.model.predict(
source=img,
conf=self.conf_threshold,
device=self.device,
verbose=False,
)
faces = []
annotated_img = img.copy() if return_image else None
if len(results) > 0:
result = results[0]
boxes = result.boxes
if boxes is not None and len(boxes) > 0:
for i in range(len(boxes)):
# Get bounding box coordinates (xyxy format)
bbox = boxes.xyxy[i].cpu().numpy().tolist()
confidence = float(boxes.conf[i].cpu().numpy())
class_id = int(boxes.cls[i].cpu().numpy())
# Get class name
class_name = result.names.get(class_id, "unknown")
# Process target classes (person, etc.)
# For person, the entire body box contains face region
# For more precise face detection, can:
# 1. Use dedicated face detection models (RetinaFace, YOLOv8-face)
# 2. Further use face detection model within current person box
# 3. Use specifically trained multi-class detection models (faces, animal faces, anime faces, etc.)
if class_id in self.target_classes.values():
face_info = {
"bbox": bbox, # [x1, y1, x2, y2] - absolute pixel coordinates
"confidence": confidence,
"class_id": class_id,
"class_name": class_name,
}
faces.append(face_info)
# Draw annotations on image if needed
if return_image and annotated_img is not None:
draw = ImageDraw.Draw(annotated_img)
x1, y1, x2, y2 = bbox
# Draw bounding box
draw.rectangle(
[x1, y1, x2, y2],
outline="red",
width=2,
)
# Draw label
label = f"{class_name} {confidence:.2f}"
draw.text((x1, y1 - 15), label, fill="red")
result_dict = {"faces": faces}
if return_image and annotated_img is not None:
result_dict["image"] = annotated_img
logger.info(f"Detected {len(faces)} faces in image")
return result_dict
except Exception as e:
logger.error(f"Face detection failed: {traceback.format_exc()}")
raise RuntimeError(f"Face detection error: {e}")
def detect_faces_from_bytes(self, image_bytes: bytes, **kwargs) -> Dict:
"""
Detect faces from byte data
Args:
image_bytes: Image byte data
**kwargs: Additional parameters passed to detect_faces
Returns:
Detection result dictionary
"""
return self.detect_faces(image_bytes, **kwargs)
def extract_face_regions(self, image: Union[str, Image.Image, bytes], expand_ratio: float = 0.1) -> List[Image.Image]:
"""
Extract detected face regions
Args:
image: Input image
expand_ratio: Bounding box expansion ratio to include more context
Returns:
List of extracted face region images
"""
result = self.detect_faces(image)
faces = result["faces"]
# Load original image
if isinstance(image, str):
img = Image.open(image).convert("RGB")
elif isinstance(image, bytes):
img = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, Image.Image):
img = image.convert("RGB")
else:
raise ValueError(f"Unsupported image type: {type(image)}")
face_regions = []
img_width, img_height = img.size
for face in faces:
x1, y1, x2, y2 = face["bbox"]
# Expand bounding box
width = x2 - x1
height = y2 - y1
expand_x = width * expand_ratio
expand_y = height * expand_ratio
x1 = max(0, int(x1 - expand_x))
y1 = max(0, int(y1 - expand_y))
x2 = min(img_width, int(x2 + expand_x))
y2 = min(img_height, int(y2 + expand_y))
# Crop region
face_region = img.crop((x1, y1, x2, y2))
face_regions.append(face_region)
return face_regions
def count_faces(self, image: Union[str, Image.Image, bytes]) -> int:
"""
Count number of faces in image
Args:
image: Input image
Returns:
Number of detected faces
"""
result = self.detect_faces(image, return_image=False)
return len(result["faces"])
def detect_faces_in_image(
image_path: str,
model_path: str = None,
conf_threshold: float = 0.25,
return_image: bool = False,
) -> Dict:
"""
Convenience function: detect faces in image
Args:
image_path: Image path
model_path: YOLO model path
conf_threshold: Confidence threshold
return_image: Whether to return annotated image
Returns:
Detection result dictionary containing:
- faces: List of face detection results with bbox coordinates [x1, y1, x2, y2]
- image (optional): Annotated image with detection boxes
"""
detector = FaceDetector(model_path=model_path, conf_threshold=conf_threshold)
return detector.detect_faces(image_path, return_image=return_image)
if __name__ == "__main__":
# Test code
import sys
if len(sys.argv) < 2:
print("Usage: python face_detector.py <image_path>")
sys.exit(1)
image_path = sys.argv[1]
detector = FaceDetector()
result = detector.detect_faces(image_path, return_image=True)
print(f"Detected {len(result['faces'])} faces:")
for i, face in enumerate(result["faces"]):
print(f" Face {i + 1}: {face}")
output_path = "detected_faces.png"
result["image"].save(output_path)
print(f"Annotated image saved to: {output_path}")
import json
import sys
from loguru import logger
class Pipeline:
def __init__(self, pipeline_json_file):
self.pipeline_json_file = pipeline_json_file
x = json.load(open(pipeline_json_file))
self.data = x["data"]
self.meta = x["meta"]
self.inputs = {}
self.outputs = {}
self.temps = {}
self.model_lists = []
self.types = {}
self.queues = set()
self.model_name_inner_to_outer = self.meta.get("model_name_inner_to_outer", {})
self.model_name_outer_to_inner = self.meta.get("model_name_outer_to_inner", {})
self.tidy_pipeline()
def init_dict(self, base, task, model_cls):
if task not in base:
base[task] = {}
if model_cls not in base[task]:
base[task][model_cls] = {}
# tidy each task item eg, ['t2v', 'wan2.1', 'multi_stage']
def tidy_task(self, task, model_cls, stage, v3):
out2worker = {}
out2num = {}
cur_inps = set()
cur_temps = set()
cur_types = {}
for worker_name, worker_item in v3.items():
prevs = []
for inp in worker_item["inputs"]:
cur_types[inp] = self.get_type(inp)
if inp in out2worker:
prevs.append(out2worker[inp])
out2num[inp] -= 1
if out2num[inp] <= 0:
cur_temps.add(inp)
else:
cur_inps.add(inp)
worker_item["previous"] = prevs
for out in worker_item["outputs"]:
cur_types[out] = self.get_type(out)
out2worker[out] = worker_name
if out not in out2num:
out2num[out] = 0
out2num[out] += 1
if "queue" not in worker_item:
worker_item["queue"] = "-".join([task, model_cls, stage, worker_name])
self.queues.add(worker_item["queue"])
cur_outs = [out for out, num in out2num.items() if num > 0]
self.inputs[task][model_cls][stage] = list(cur_inps)
self.outputs[task][model_cls][stage] = cur_outs
self.temps[task][model_cls][stage] = list(cur_temps)
self.types[task][model_cls][stage] = cur_types
# tidy previous dependence workers and queue name
def tidy_pipeline(self):
for task, v1 in self.data.items():
for model_cls, v2 in v1.items():
for stage, v3 in v2.items():
self.init_dict(self.inputs, task, model_cls)
self.init_dict(self.outputs, task, model_cls)
self.init_dict(self.temps, task, model_cls)
self.init_dict(self.types, task, model_cls)
self.tidy_task(task, model_cls, stage, v3)
self.model_lists.append({"task": task, "model_cls": model_cls, "stage": stage})
logger.info(f"pipelines: {json.dumps(self.data, indent=4)}")
logger.info(f"inputs: {self.inputs}")
logger.info(f"outputs: {self.outputs}")
logger.info(f"temps: {self.temps}")
logger.info(f"types: {self.types}")
logger.info(f"model_lists: {self.model_lists}")
logger.info(f"queues: {self.queues}")
def get_item_by_keys(self, keys):
item = self.data
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in {self.pipeline_json_file}!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage', 'text_encoder']
def get_worker(self, keys):
return self.get_item_by_keys(keys)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_workers(self, keys):
return self.get_item_by_keys(keys)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_inputs(self, keys):
item = self.inputs
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in inputs!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_outputs(self, keys):
item = self.outputs
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in outputs!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_temps(self, keys):
item = self.temps
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in temps!")
item = item[k]
return item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def get_types(self, keys):
item = self.types
for k in keys:
if k not in item:
raise Exception(f"{keys} are not in types!")
item = item[k]
return item
def check_item_by_keys(self, keys):
item = self.data
for k in keys:
if k not in item:
return False
item = item[k]
return True
def get_model_lists(self):
return self.model_lists
def get_type(self, name):
return self.meta["special_types"].get(name, "OBJECT")
def get_monitor_config(self):
return self.meta["monitor"]
def get_queues(self):
return self.queues
def inner_model_name(self, name):
return self.model_name_outer_to_inner.get(name, name)
def outer_model_name(self, name):
return self.model_name_inner_to_outer.get(name, name)
if __name__ == "__main__":
pipeline = Pipeline(sys.argv[1])
print(pipeline.get_workers(["t2v", "wan2.1", "multi_stage"]))
print(pipeline.get_worker(["i2v", "wan2.1", "multi_stage", "dit"]))
# -*- coding: utf-8 -*-
import asyncio
import io
import json
import os
import struct
import uuid
from dataclasses import dataclass
from enum import IntEnum
from typing import Callable, List, Optional
import websockets
from loguru import logger
from pydub import AudioSegment
# Protocol definitions (from podcasts_protocols)
class MsgType(IntEnum):
"""Message type enumeration"""
Invalid = 0
FullClientRequest = 0b1
AudioOnlyClient = 0b10
FullServerResponse = 0b1001
AudioOnlyServer = 0b1011
FrontEndResultServer = 0b1100
Error = 0b1111
ServerACK = AudioOnlyServer
class MsgTypeFlagBits(IntEnum):
"""Message type flag bits"""
NoSeq = 0
PositiveSeq = 0b1
LastNoSeq = 0b10
NegativeSeq = 0b11
WithEvent = 0b100
class VersionBits(IntEnum):
"""Version bits"""
Version1 = 1
class HeaderSizeBits(IntEnum):
"""Header size bits"""
HeaderSize4 = 1
HeaderSize8 = 2
HeaderSize12 = 3
HeaderSize16 = 4
class SerializationBits(IntEnum):
"""Serialization method bits"""
Raw = 0
JSON = 0b1
Thrift = 0b11
Custom = 0b1111
class CompressionBits(IntEnum):
"""Compression method bits"""
None_ = 0
Gzip = 0b1
Custom = 0b1111
class EventType(IntEnum):
"""Event type enumeration"""
None_ = 0
StartConnection = 1
StartTask = 1
FinishConnection = 2
FinishTask = 2
ConnectionStarted = 50
TaskStarted = 50
ConnectionFailed = 51
TaskFailed = 51
ConnectionFinished = 52
TaskFinished = 52
StartSession = 100
CancelSession = 101
FinishSession = 102
SessionStarted = 150
SessionCanceled = 151
SessionFinished = 152
SessionFailed = 153
UsageResponse = 154
ChargeData = 154
TaskRequest = 200
UpdateConfig = 201
AudioMuted = 250
SayHello = 300
TTSSentenceStart = 350
TTSSentenceEnd = 351
TTSResponse = 352
TTSEnded = 359
PodcastRoundStart = 360
PodcastRoundResponse = 361
PodcastRoundEnd = 362
PodcastEnd = 363
@dataclass
class Message:
"""Message object"""
version: VersionBits = VersionBits.Version1
header_size: HeaderSizeBits = HeaderSizeBits.HeaderSize4
type: MsgType = MsgType.Invalid
flag: MsgTypeFlagBits = MsgTypeFlagBits.NoSeq
serialization: SerializationBits = SerializationBits.JSON
compression: CompressionBits = CompressionBits.None_
event: EventType = EventType.None_
session_id: str = ""
connect_id: str = ""
sequence: int = 0
error_code: int = 0
payload: bytes = b""
@classmethod
def from_bytes(cls, data: bytes) -> "Message":
"""Create message object from bytes"""
if len(data) < 3:
raise ValueError(f"Data too short: expected at least 3 bytes, got {len(data)}")
type_and_flag = data[1]
msg_type = MsgType(type_and_flag >> 4)
flag = MsgTypeFlagBits(type_and_flag & 0b00001111)
msg = cls(type=msg_type, flag=flag)
msg.unmarshal(data)
return msg
def marshal(self) -> bytes:
"""Serialize message to bytes"""
buffer = io.BytesIO()
header = [
(self.version << 4) | self.header_size,
(self.type << 4) | self.flag,
(self.serialization << 4) | self.compression,
]
header_size = 4 * self.header_size
if padding := header_size - len(header):
header.extend([0] * padding)
buffer.write(bytes(header))
writers = self._get_writers()
for writer in writers:
writer(buffer)
return buffer.getvalue()
def unmarshal(self, data: bytes) -> None:
"""Deserialize message from bytes"""
buffer = io.BytesIO(data)
version_and_header_size = buffer.read(1)[0]
self.version = VersionBits(version_and_header_size >> 4)
self.header_size = HeaderSizeBits(version_and_header_size & 0b00001111)
buffer.read(1)
serialization_compression = buffer.read(1)[0]
self.serialization = SerializationBits(serialization_compression >> 4)
self.compression = CompressionBits(serialization_compression & 0b00001111)
header_size = 4 * self.header_size
read_size = 3
if padding_size := header_size - read_size:
buffer.read(padding_size)
readers = self._get_readers()
for reader in readers:
reader(buffer)
remaining = buffer.read()
if remaining:
raise ValueError(f"Unexpected data after message: {remaining}")
def _get_writers(self) -> List[Callable[[io.BytesIO], None]]:
"""Get list of writer functions"""
writers = []
if self.flag == MsgTypeFlagBits.WithEvent:
writers.extend([self._write_event, self._write_session_id])
if self.type in [MsgType.FullClientRequest, MsgType.FullServerResponse, MsgType.FrontEndResultServer, MsgType.AudioOnlyClient, MsgType.AudioOnlyServer]:
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
writers.append(self._write_sequence)
elif self.type == MsgType.Error:
writers.append(self._write_error_code)
else:
raise ValueError(f"Unsupported message type: {self.type}")
writers.append(self._write_payload)
return writers
def _get_readers(self) -> List[Callable[[io.BytesIO], None]]:
"""Get list of reader functions"""
readers = []
if self.type in [MsgType.FullClientRequest, MsgType.FullServerResponse, MsgType.FrontEndResultServer, MsgType.AudioOnlyClient, MsgType.AudioOnlyServer]:
if self.flag in [MsgTypeFlagBits.PositiveSeq, MsgTypeFlagBits.NegativeSeq]:
readers.append(self._read_sequence)
elif self.type == MsgType.Error:
readers.append(self._read_error_code)
if self.flag == MsgTypeFlagBits.WithEvent:
readers.extend([self._read_event, self._read_session_id, self._read_connect_id])
readers.append(self._read_payload)
return readers
def _write_event(self, buffer: io.BytesIO) -> None:
buffer.write(struct.pack(">i", self.event))
def _write_session_id(self, buffer: io.BytesIO) -> None:
if self.event in [EventType.StartConnection, EventType.FinishConnection, EventType.ConnectionStarted, EventType.ConnectionFailed]:
return
session_id_bytes = self.session_id.encode("utf-8")
size = len(session_id_bytes)
if size > 0xFFFFFFFF:
raise ValueError(f"Session ID size ({size}) exceeds max(uint32)")
buffer.write(struct.pack(">I", size))
if size > 0:
buffer.write(session_id_bytes)
def _write_sequence(self, buffer: io.BytesIO) -> None:
buffer.write(struct.pack(">i", self.sequence))
def _write_error_code(self, buffer: io.BytesIO) -> None:
buffer.write(struct.pack(">I", self.error_code))
def _write_payload(self, buffer: io.BytesIO) -> None:
size = len(self.payload)
if size > 0xFFFFFFFF:
raise ValueError(f"Payload size ({size}) exceeds max(uint32)")
buffer.write(struct.pack(">I", size))
buffer.write(self.payload)
def _read_event(self, buffer: io.BytesIO) -> None:
event_bytes = buffer.read(4)
if event_bytes:
self.event = EventType(struct.unpack(">i", event_bytes)[0])
def _read_session_id(self, buffer: io.BytesIO) -> None:
if self.event in [EventType.StartConnection, EventType.FinishConnection, EventType.ConnectionStarted, EventType.ConnectionFailed, EventType.ConnectionFinished]:
return
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
session_id_bytes = buffer.read(size)
if len(session_id_bytes) == size:
self.session_id = session_id_bytes.decode("utf-8")
def _read_connect_id(self, buffer: io.BytesIO) -> None:
if self.event in [EventType.ConnectionStarted, EventType.ConnectionFailed, EventType.ConnectionFinished]:
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
self.connect_id = buffer.read(size).decode("utf-8")
def _read_sequence(self, buffer: io.BytesIO) -> None:
sequence_bytes = buffer.read(4)
if sequence_bytes:
self.sequence = struct.unpack(">i", sequence_bytes)[0]
def _read_error_code(self, buffer: io.BytesIO) -> None:
error_code_bytes = buffer.read(4)
if error_code_bytes:
self.error_code = struct.unpack(">I", error_code_bytes)[0]
def _read_payload(self, buffer: io.BytesIO) -> None:
size_bytes = buffer.read(4)
if size_bytes:
size = struct.unpack(">I", size_bytes)[0]
if size > 0:
self.payload = buffer.read(size)
async def receive_message(websocket: websockets.WebSocketClientProtocol) -> Message:
"""Receive message from websocket"""
try:
data = await websocket.recv()
if isinstance(data, str):
raise ValueError(f"Unexpected text message: {data}")
elif isinstance(data, bytes):
msg = Message.from_bytes(data)
# logger.debug(f"Received: {msg}")
return msg
else:
raise ValueError(f"Unexpected message type: {type(data)}")
except Exception as e:
logger.error(f"Failed to receive message: {e}")
raise
async def wait_for_event(websocket: websockets.WebSocketClientProtocol, msg_type: MsgType, event_type: EventType) -> Message:
"""Wait for specific event"""
while True:
msg = await receive_message(websocket)
if msg.type != msg_type or msg.event != event_type:
raise ValueError(f"Unexpected message: {msg}")
if msg.type == msg_type and msg.event == event_type:
return msg
async def start_connection(websocket: websockets.WebSocketClientProtocol) -> None:
"""Start connection"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.StartConnection
msg.payload = b"{}"
logger.debug(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def finish_connection(websocket: websockets.WebSocketClientProtocol) -> None:
"""Finish connection"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.FinishConnection
msg.payload = b"{}"
logger.debug(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def start_session(websocket: websockets.WebSocketClientProtocol, payload: bytes, session_id: str) -> None:
"""Start session"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.StartSession
msg.session_id = session_id
msg.payload = payload
logger.debug(f"Sending: {msg}")
await websocket.send(msg.marshal())
async def finish_session(websocket: websockets.WebSocketClientProtocol, session_id: str) -> None:
"""Finish session"""
msg = Message(type=MsgType.FullClientRequest, flag=MsgTypeFlagBits.WithEvent)
msg.event = EventType.FinishSession
msg.session_id = session_id
msg.payload = b"{}"
logger.debug(f"Sending: {msg}")
await websocket.send(msg.marshal())
class PodcastRoundPostProcessor:
def __init__(self, session_id, data_manager):
self.session_id = session_id
self.data_manager = data_manager
self.temp_merged_audio_name = "merged_audio.mp3"
self.output_merged_audio_name = f"{session_id}-merged_audio.mp3"
self.subtitle_timestamps = [] # 记录字幕时间戳
self.current_audio_duration = 0.0 # 当前音频时长
self.merged_audio = None # 用于存储合并的音频对象
self.merged_audio_bytes = None
async def init(self):
if self.data_manager:
await self.data_manager.create_podcast_temp_session_dir(self.session_id)
async def postprocess_round(self, current_round, voice, audio, podcast_texts):
text = ""
if podcast_texts:
text = podcast_texts[-1].get("text", "")
logger.debug(f"Processing round: {current_round}, voice: {voice}, text: {text}, audio: {len(audio)} bytes")
new_segment = AudioSegment.from_mp3(io.BytesIO(bytes(audio)))
round_duration = len(new_segment) / 1000.0
if self.merged_audio is None:
self.merged_audio = new_segment
else:
self.merged_audio = self.merged_audio + new_segment
# 保存合并后的音频到临时文件(用于前端实时访问)
merged_io = io.BytesIO()
self.merged_audio.export(merged_io, format="mp3")
self.merged_audio_bytes = merged_io.getvalue()
if self.data_manager:
await self.data_manager.save_podcast_temp_session_file(self.session_id, self.temp_merged_audio_name, self.merged_audio_bytes)
merged_file_size = len(self.merged_audio_bytes)
# 记录字幕时间戳
self.subtitle_timestamps.append(
{
"start": self.current_audio_duration,
"end": self.current_audio_duration + round_duration,
"text": text,
"speaker": voice,
}
)
self.current_audio_duration += round_duration
logger.debug(f"Merged audio updated: {merged_file_size} bytes, duration: {self.current_audio_duration:.2f}s")
return {
"url": f"/api/v1/podcast/audio?session_id={self.session_id}&filename={self.temp_merged_audio_name}",
"size": merged_file_size,
"duration": self.current_audio_duration,
"round": current_round,
"text": text,
"speaker": voice,
}
async def postprocess_final(self):
if self.data_manager:
await self.data_manager.save_podcast_output_file(self.output_merged_audio_name, self.merged_audio_bytes)
return {
"subtitles": self.subtitle_timestamps,
"audio_name": self.output_merged_audio_name,
}
async def cleanup(self):
if self.data_manager:
await self.data_manager.clear_podcast_temp_session_dir(self.session_id)
self.data_manager = None
class VolcEnginePodcastClient:
"""
VolcEngine Podcast客户端
支持多种播客类型:
- action=0: 文本转播客
- action=3: NLP文本转播客
- action=4: 提示词生成播客
"""
def __init__(self):
self.endpoint = "wss://openspeech.bytedance.com/api/v3/sami/podcasttts"
self.appid = os.getenv("VOLCENGINE_PODCAST_APPID")
self.access_token = os.getenv("VOLCENGINE_PODCAST_ACCESS_TOKEN")
self.app_key = "aGjiRDfUWi"
self.proxy = os.getenv("HTTPS_PROXY", None)
if self.proxy:
logger.info(f"volcengine podcast use proxy: {self.proxy}")
async def podcast_request(
self,
session_id: str,
data_manager=None,
text: str = "",
input_url: str = "",
prompt_text: str = "",
nlp_texts: str = "",
action: int = 0,
resource_id: str = "volc.service_type.10050",
encoding: str = "mp3",
input_id: str = "test_podcast",
speaker_info: str = '{"random_order":false}',
use_head_music: bool = False,
use_tail_music: bool = False,
only_nlp_text: bool = False,
return_audio_url: bool = False,
skip_round_audio_save: bool = False,
on_round_complete: Optional[Callable] = None,
):
"""
执行播客请求
Args:
text: 输入文本 (action=0时使用)
input_url: Web URL或文件URL (action=0时使用)
prompt_text: 提示词文本 (action=4时必须)
nlp_texts: NLP文本 (action=3时必须)
action: 播客类型 (0/3/4)
resource_id: 音频资源ID
encoding: 音频格式 (mp3/wav)
input_id: 唯一输入标识
speaker_info: 播客说话人信息
use_head_music: 是否使用开头音乐
use_tail_music: 是否使用结尾音乐
only_nlp_text: 是否只返回播客文本
return_audio_url: 是否返回音频URL
skip_round_audio_save: 是否跳过单轮音频保存
output_dir: 输出目录
on_round_complete: 轮次完成回调函数
"""
if not self.appid or not self.access_token:
logger.error("APP ID or Access Key is required")
return None, None
headers = {
"X-Api-App-Id": self.appid,
"X-Api-App-Key": self.app_key,
"X-Api-Access-Key": self.access_token,
"X-Api-Resource-Id": resource_id,
"X-Api-Connect-Id": str(uuid.uuid4()),
}
is_podcast_round_end = True
audio_received = False
last_round_id = -1
task_id = ""
websocket = None
retry_num = 5
audio = bytearray()
voice = ""
current_round = 0
podcast_texts = []
post_processor = PodcastRoundPostProcessor(session_id, data_manager)
await post_processor.init()
try:
while retry_num > 0:
# 建立WebSocket连接
websocket = await websockets.connect(self.endpoint, additional_headers=headers)
logger.debug(f"WebSocket connected: {websocket.response.headers}")
# 构建请求参数
if input_url:
req_params = {
"input_id": input_id,
"nlp_texts": json.loads(nlp_texts) if nlp_texts else None,
"prompt_text": prompt_text,
"action": action,
"use_head_music": use_head_music,
"use_tail_music": use_tail_music,
"input_info": {
"input_url": input_url,
"return_audio_url": return_audio_url,
"only_nlp_text": only_nlp_text,
},
"speaker_info": json.loads(speaker_info) if speaker_info else None,
"audio_config": {"format": encoding, "sample_rate": 24000, "speech_rate": 0},
}
else:
req_params = {
"input_id": input_id,
"input_text": text,
"nlp_texts": json.loads(nlp_texts) if nlp_texts else None,
"prompt_text": prompt_text,
"action": action,
"use_head_music": use_head_music,
"use_tail_music": use_tail_music,
"input_info": {
"input_url": input_url,
"return_audio_url": return_audio_url,
"only_nlp_text": only_nlp_text,
},
"speaker_info": json.loads(speaker_info) if speaker_info else None,
"audio_config": {"format": encoding, "sample_rate": 24000, "speech_rate": 0},
}
logger.debug(f"Request params: {json.dumps(req_params, indent=2, ensure_ascii=False)}")
if not is_podcast_round_end:
req_params["retry_info"] = {"retry_task_id": task_id, "last_finished_round_id": last_round_id}
# Start connection
await start_connection(websocket)
await wait_for_event(websocket, MsgType.FullServerResponse, EventType.ConnectionStarted)
session_id = str(uuid.uuid4())
if not task_id:
task_id = session_id
# Start session
await start_session(websocket, json.dumps(req_params).encode(), session_id)
await wait_for_event(websocket, MsgType.FullServerResponse, EventType.SessionStarted)
# Finish session
await finish_session(websocket, session_id)
while True:
msg = await receive_message(websocket)
# 音频数据块
if msg.type == MsgType.AudioOnlyServer and msg.event == EventType.PodcastRoundResponse:
if not audio_received and audio:
audio_received = True
audio.extend(msg.payload)
# 错误信息
elif msg.type == MsgType.Error:
raise RuntimeError(f"Server error: {msg.payload.decode()}")
elif msg.type == MsgType.FullServerResponse:
# 播客 round 开始
if msg.event == EventType.PodcastRoundStart:
data = json.loads(msg.payload.decode())
if data.get("text"):
filtered_payload = {"text": data.get("text"), "speaker": data.get("speaker")}
podcast_texts.append(filtered_payload)
voice = data.get("speaker")
current_round = data.get("round_id")
if current_round == -1:
voice = "head_music"
if current_round == 9999:
voice = "tail_music"
is_podcast_round_end = False
logger.debug(f"New round started: {data}")
# 播客 round 结束
if msg.event == EventType.PodcastRoundEnd:
data = json.loads(msg.payload.decode())
logger.debug(f"Podcast round end: {data}")
if data.get("is_error"):
break
is_podcast_round_end = True
last_round_id = current_round
if audio:
round_info = await post_processor.postprocess_round(current_round, voice, audio, podcast_texts)
if on_round_complete:
await on_round_complete(round_info)
audio.clear()
# 播客结束
if msg.event == EventType.PodcastEnd:
data = json.loads(msg.payload.decode())
logger.info(f"Podcast end: {data}")
# 会话结束
if msg.event == EventType.SessionFinished:
break
if not audio_received and not only_nlp_text:
raise RuntimeError("No audio data received")
# 保持连接
await finish_connection(websocket)
await wait_for_event(websocket, MsgType.FullServerResponse, EventType.ConnectionFinished)
# 播客结束, 保存最终音频文件
if is_podcast_round_end:
podcast_info = await post_processor.postprocess_final()
return podcast_info
else:
logger.error(f"Current podcast not finished, resuming from round {last_round_id}")
retry_num -= 1
await asyncio.sleep(1)
if websocket:
await websocket.close()
finally:
await post_processor.cleanup()
if websocket:
await websocket.close()
return None
async def test(args):
"""
Podcast测试函数
Args:
args: dict, 包含所有podcast参数
"""
client = VolcEnginePodcastClient()
# 设置默认参数
params = {
"text": "",
"input_url": "https://zhuanlan.zhihu.com/p/607822576",
"prompt_text": "",
"nlp_texts": "",
"action": 0,
"resource_id": "volc.service_type.10050",
"encoding": "mp3",
"input_id": "test_podcast",
"speaker_info": '{"random_order":false}',
"use_head_music": False,
"use_tail_music": False,
"only_nlp_text": False,
"return_audio_url": True,
"skip_round_audio_save": False,
"output_dir": "output",
}
# 覆盖默认参数
if args:
params.update(args)
await client.podcast_request(**params)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--text", default="", help="Input text Use when action in [0]")
parser.add_argument("--input_url", default="", help="Web url or file url Use when action in [0]")
parser.add_argument("--prompt_text", default="", help="Input Prompt Text must not empty when action in [4]")
parser.add_argument("--nlp_texts", default="", help="Input NLP Texts must not empty when action in [3]")
parser.add_argument("--resource_id", default="volc.service_type.10050", help="Audio Resource ID")
parser.add_argument("--encoding", default="mp3", choices=["mp3", "wav"], help="Audio format")
parser.add_argument("--input_id", default="test_podcast", help="Unique input identifier")
parser.add_argument("--speaker_info", default='{"random_order":false}', help="Podcast Speaker Info")
parser.add_argument("--use_head_music", default=False, action="store_true", help="Enable head music")
parser.add_argument("--use_tail_music", default=False, action="store_true", help="Enable tail music")
parser.add_argument("--only_nlp_text", default=False, action="store_true", help="Enable only podcast text when action in [0, 4]")
parser.add_argument("--return_audio_url", default=False, action="store_true", help="Enable return audio url that can download")
parser.add_argument("--action", default=0, type=int, choices=[0, 3, 4], help="different podcast type")
parser.add_argument("--skip_round_audio_save", default=False, action="store_true", help="skip round audio save")
parser.add_argument("--output_dir", default="output", help="Output directory")
args = parser.parse_args()
kwargs = {k: v for k, v in vars(args).items() if v is not None and not (isinstance(v, bool) and not v)}
asyncio.run(test(kwargs))
import asyncio
import base64
import io
import os
import subprocess
import tempfile
import time
import traceback
from datetime import datetime
import httpx
import torchaudio
from PIL import Image
from loguru import logger
FMT = "%Y-%m-%d %H:%M:%S"
def current_time():
return datetime.now().timestamp()
def time2str(t):
d = datetime.fromtimestamp(t)
return d.strftime(FMT)
def str2time(s):
d = datetime.strptime(s, FMT)
return d.timestamp()
def try_catch(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except Exception:
logger.error(f"Error in {func.__name__}:")
traceback.print_exc()
return None
return wrapper
def class_try_catch(func):
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
def class_try_catch_async(func):
async def wrapper(self, *args, **kwargs):
try:
return await func(self, *args, **kwargs)
except Exception:
logger.error(f"Error in {self.__class__.__name__}.{func.__name__}:")
traceback.print_exc()
return None
return wrapper
def data_name(x, task_id):
if x == "input_image":
x = x + ".png"
elif x == "input_video":
x = x + ".mp4"
elif x == "output_video":
x = x + ".mp4"
return f"{task_id}-{x}"
async def fetch_resource(url, timeout):
logger.info(f"Begin to download resource from url: {url}")
t0 = time.time()
async with httpx.AsyncClient() as client:
async with client.stream("GET", url, timeout=timeout) as response:
response.raise_for_status()
ans_bytes = []
async for chunk in response.aiter_bytes(chunk_size=1024 * 1024):
ans_bytes.append(chunk)
if len(ans_bytes) > 128:
raise Exception(f"url {url} recv data is too big")
content = b"".join(ans_bytes)
logger.info(f"Download url {url} resource cost time: {time.time() - t0} seconds")
return content
# check, resize, read rotate meta info
def format_image_data(data, max_size=1280):
image = Image.open(io.BytesIO(data)).convert("RGB")
exif = image.getexif()
changed = False
w, h = image.size
assert w > 0 and h > 0, "image is empty"
logger.info(f"load image: {w}x{h}, exif: {exif}")
if w > max_size or h > max_size:
ratio = max_size / max(w, h)
w = int(w * ratio)
h = int(h * ratio)
image = image.resize((w, h))
logger.info(f"resize image to: {image.size}")
changed = True
orientation_key = 274
if orientation_key and orientation_key in exif:
orientation = exif[orientation_key]
if orientation == 2:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
elif orientation == 3:
image = image.rotate(180, expand=True)
elif orientation == 4:
image = image.transpose(Image.FLIP_TOP_BOTTOM)
elif orientation == 5:
image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(90, expand=True)
elif orientation == 6:
image = image.rotate(270, expand=True)
elif orientation == 7:
image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(270, expand=True)
elif orientation == 8:
image = image.rotate(90, expand=True)
# reset orientation to 1
if orientation != 1:
logger.info(f"reset orientation from {orientation} to 1")
exif[orientation_key] = 1
changed = True
if not changed:
return data
output = io.BytesIO()
image.save(output, format=image.format or "JPEG", exif=exif.tobytes())
return output.getvalue()
def media_to_wav(data):
with tempfile.NamedTemporaryFile() as fin:
fin.write(data)
fin.flush()
cmd = ["ffmpeg", "-i", fin.name, "-f", "wav", "-acodec", "pcm_s16le", "-ar", "44100", "-ac", "2", "pipe:1"]
p = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
assert p.returncode == 0, f"media to wav failed: {p.stderr.decode()}"
return p.stdout
def format_audio_data(data):
if len(data) < 4:
raise ValueError("Audio file too short")
data = media_to_wav(data)
waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
logger.info(f"load audio: {waveform.size()}, {sample_rate}")
assert waveform.numel() > 0, "audio is empty"
assert sample_rate > 0, "audio sample rate is not valid"
return data
async def preload_data(inp, inp_type, typ, val):
try:
if typ == "url":
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
data = await fetch_resource(val, timeout=timeout)
elif typ == "base64":
# Decode base64 in background thread to avoid blocking event loop
data = await asyncio.to_thread(base64.b64decode, val)
# For multi-person audio directory, val should be a dict with file structure
elif typ == "directory":
data = {}
for fname, b64_data in val.items():
data[fname] = await asyncio.to_thread(base64.b64decode, b64_data)
return {"type": "directory", "data": data}
elif typ == "stream":
# no bytes data need to be saved by data_manager
data = None
else:
raise ValueError(f"cannot read {inp}[{inp_type}] which type is {typ}!")
# check if valid image bytes
if inp_type == "IMAGE":
data = await asyncio.to_thread(format_image_data, data)
elif inp_type == "AUDIO":
if typ != "stream" and typ != "directory":
data = await asyncio.to_thread(format_audio_data, data)
elif inp_type == "VIDEO":
# Video data doesn't need special formatting, just validate it's not empty
if len(data) == 0:
raise ValueError("Video file is empty")
logger.info(f"load video: {len(data)} bytes")
else:
raise Exception(f"cannot parse inp_type={inp_type} data")
return data
except Exception as e:
raise ValueError(f"Failed to read {inp}, type={typ}, val={val[:100]}: {e}!")
async def load_inputs(params, raw_inputs, types):
inputs_data = {}
for inp in raw_inputs:
item = params.pop(inp)
bytes_data = await preload_data(inp, types[inp], item["type"], item["data"])
# Handle multi-person audio directory
if bytes_data is not None and isinstance(bytes_data, dict) and bytes_data.get("type") == "directory":
fs = []
for fname, fdata in bytes_data["data"].items():
inputs_data[f"{inp}/{fname}"] = fdata
fs.append(f"{inp}/{fname}")
params["extra_inputs"] = {inp: fs}
elif bytes_data is not None:
inputs_data[inp] = bytes_data
else:
params[inp] = item
return inputs_data
def check_params(params, raw_inputs, raw_outputs, types):
stream_audio = os.getenv("STREAM_AUDIO", "0") == "1"
stream_video = os.getenv("STREAM_VIDEO", "0") == "1"
for x in raw_inputs + raw_outputs:
if x in params and "type" in params[x]:
if params[x]["type"] == "stream":
if types[x] == "AUDIO":
assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1"
elif types[x] == "VIDEO":
assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1"
elif params[x]["type"] == "directory":
# Multi-person audio directory is only supported for AUDIO type
assert types[x] == "AUDIO", f"directory type is only supported for AUDIO input, got {types[x]}"
if __name__ == "__main__":
# https://github.com/recurser/exif-orientation-examples
exif_dir = "/data/nvme0/liuliang1/exif-orientation-examples"
out_dir = "/data/nvme0/liuliang1/exif-orientation-examples/outs"
os.makedirs(out_dir, exist_ok=True)
for base_name in ["Landscape", "Portrait"]:
for i in range(9):
fin_name = os.path.join(exif_dir, f"{base_name}_{i}.jpg")
fout_name = os.path.join(out_dir, f"{base_name}_{i}_formatted.jpg")
logger.info(f"format image: {fin_name} -> {fout_name}")
with open(fin_name, "rb") as f:
data = f.read()
data = format_image_data(data)
with open(fout_name, "wb") as f:
f.write(data)
import math
import os
import torch
import torch.distributed as dist
from loguru import logger
from lightx2v.models.runners.vsr.vsr_wrapper import compute_scaled_and_target_dims
from lightx2v_platform.base.global_var import AI_DEVICE
class NextControl:
def __init__(self, action: str, data: any = None):
# action: switch, data: prev_video tensor
# action: wait, data: None
# action: fetch, data: None
self.action = action
self.data = data
class VAController:
def __init__(self, model_runner):
self.reader = None
self.recorder = None
self.rank = 0
self.world_size = 1
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.target_reader_rank = int(os.getenv("READER_RANK", "0")) % self.world_size
self.target_recorder_rank = int(os.getenv("RECORDER_RANK", "0")) % self.world_size
self.init_base(model_runner.config, model_runner.input_info, model_runner.vfi_model is not None, model_runner.vsr_model is not None)
self.init_recorder()
self.init_reader(model_runner)
def init_base(self, config, input_info, has_vfi_model, has_vsr_model):
self.audio_path = input_info.audio_path
self.output_video_path = input_info.save_result_path
if isinstance(self.output_video_path, dict):
self.output_video_path = self.output_video_path["data"]
self.audio_sr = config.get("audio_sr", 16000)
self.target_fps = config.get("target_fps", 16)
self.max_num_frames = config.get("target_video_length", 81)
self.prev_frame_length = config.get("prev_frame_length", 5)
self.record_fps = config.get("target_fps", 16)
if "video_frame_interpolation" in config and has_vfi_model:
self.record_fps = config["video_frame_interpolation"]["target_fps"]
self.record_fps = config.get("record_fps", self.record_fps)
self.tgt_h = input_info.target_shape[0]
self.tgt_w = input_info.target_shape[1]
self.record_h, self.record_w = self.tgt_h, self.tgt_w
if "video_super_resolution" in config and has_vsr_model:
_, _, self.record_w, self.record_h = compute_scaled_and_target_dims(
self.record_w,
self.record_h,
scale=config["video_super_resolution"]["scale"],
multiple=128,
)
# how many frames to publish stream as a batch
self.slice_frame = config.get("slice_frame", self.prev_frame_length)
# estimate the max infer seconds, for immediate switch with local omni
slice_interval = self.slice_frame / self.record_fps
est_max_infer_secs = config.get("est_max_infer_secs", 0.6)
self.est_infer_end_idx = math.ceil(est_max_infer_secs / slice_interval)
self.min_stay_queue_num = self.est_infer_end_idx * 2 + 1
def init_recorder(self):
if not self.output_video_path or self.rank != self.target_recorder_rank:
return
logger.info(f"Rank {self.rank} init recorder with: {self.output_video_path}")
whip_shared_path = os.getenv("WHIP_SHARED_LIB", None)
if whip_shared_path and self.output_video_path.startswith("http"):
from lightx2v.deploy.common.va_recorder_x264 import X264VARecorder
self.recorder = X264VARecorder(
whip_shared_path=whip_shared_path,
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
else:
from lightx2v.deploy.common.va_recorder import VARecorder
self.recorder = VARecorder(
livestream_url=self.output_video_path,
fps=self.record_fps,
sample_rate=self.audio_sr,
slice_frame=self.slice_frame,
prev_frame=self.prev_frame_length,
)
def init_reader(self, model_runner=None):
if not isinstance(self.audio_path, dict):
return
assert self.audio_path["type"] == "stream", f"unexcept audio_path: {self.audio_path}"
segment_duration = self.max_num_frames / self.target_fps
prev_duration = self.prev_frame_length / self.target_fps
omni_work_dir = os.getenv("OMNI_WORK_DIR", None)
if omni_work_dir:
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
self.reader = OmniVAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
model_runner=model_runner,
huoshan_tts_voice_type=self.audio_path.get("huoshan_tts_voice_type", None),
)
else:
from lightx2v.deploy.common.va_reader import VAReader
self.reader = VAReader(
rank=self.rank,
world_size=self.world_size,
stream_url=self.audio_path["data"],
sample_rate=self.audio_sr,
segment_duration=segment_duration,
prev_duration=prev_duration,
target_rank=self.target_reader_rank,
)
def start(self):
self.reader.start()
if self.rank == self.target_recorder_rank:
assert self.recorder is not None, f"recorder is required for stream audio input for rank {self.rank}"
self.recorder.start(self.record_w, self.record_h)
if self.world_size > 1:
dist.barrier()
def next_control(self):
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
if isinstance(self.reader, OmniVAReader):
return self.omni_reader_next_control()
return NextControl(action="fetch")
def before_control(self):
from lightx2v.deploy.common.va_reader_omni import OmniVAReader
if isinstance(self.reader, OmniVAReader):
self.len_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.flag_tensor = torch.tensor([0], dtype=torch.int32, device=AI_DEVICE)
self.prev_tensor = torch.zeros((1, 3, self.prev_frame_length, self.tgt_h, self.tgt_w), dtype=torch.float, device=AI_DEVICE)
def omni_reader_next_control(self):
immediate_switch = self.reader.get_immediate_switch()
if immediate_switch == 1:
# truncate the stream buffer to keep the max infer time length
# and broadcast the prev video tensor to all ranks
if self.rank == self.target_recorder_rank:
logger.warning(f"runner recv immediate switch, truncate stream buffer")
video_tensor = self.recorder.truncate_stream_buffer(self.est_infer_end_idx)
if video_tensor is not None:
self.flag_tensor.fill_(1)
self.prev_tensor.copy_(video_tensor)
else:
self.flag_tensor.fill_(0)
dist.broadcast(self.flag_tensor, src=self.target_recorder_rank)
if self.flag_tensor.item() == 1:
dist.broadcast(self.prev_tensor, src=self.target_recorder_rank)
return NextControl(action="switch", data=self.prev_tensor)
else:
# get the length of stream buffer, broadcast to all ranks
if self.rank == self.target_recorder_rank:
stream_buffer_length = self.recorder.get_buffer_stream_size()
self.len_tensor.copy_(stream_buffer_length)
dist.broadcast(self.len_tensor, src=self.target_recorder_rank)
buffer_length = self.len_tensor.item()
# stream buffer is enough, skip infer
if buffer_length >= self.min_stay_queue_num:
return NextControl(action="wait")
return NextControl(action="fetch")
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
if self.recorder.realtime:
self.recorder.buffer_stream(images, audios, gen_video)
else:
self.recorder.pub_livestream(images, audios)
def clear(self):
self.len_tensor = None
self.flag_tensor = None
self.prev_tensor = None
if self.reader is not None:
self.reader.stop()
self.reader = None
if self.recorder is not None:
self.recorder.stop()
self.recorder = None
def __del__(self):
self.clear()
import os
import queue
import signal
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
import torch.distributed as dist
from loguru import logger
class VAReader:
def __init__(
self,
rank: int,
world_size: int,
stream_url: str,
segment_duration: float = 5.0,
sample_rate: int = 16000,
audio_channels: int = 1,
buffer_size: int = 1,
prev_duration: float = 0.3125,
target_rank: int = 0,
):
self.rank = rank
self.world_size = world_size
self.stream_url = stream_url
self.segment_duration = segment_duration
self.sample_rate = sample_rate
self.audio_channels = audio_channels
self.prev_duration = prev_duration
# int16 = 2 bytes
self.chunk_size = int(self.segment_duration * self.sample_rate) * 2
self.prev_size = int(self.prev_duration * self.sample_rate) * 2
self.prev_chunk = None
self.buffer_size = buffer_size
self.audio_queue = queue.Queue(maxsize=self.buffer_size)
self.audio_thread = None
self.ffmpeg_process = None
self.bytes_buffer = bytearray()
self.target_rank = target_rank % self.world_size
self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
self.audio_tensor = torch.zeros(self.chunk_size, dtype=torch.uint8, device="cuda")
logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
def start(self):
if self.rank == self.target_rank:
if self.stream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.stream_url.startswith("http"):
self.start_ffmpeg_process_whep()
else:
raise Exception(f"Unsupported stream URL: {self.stream_url}")
self.audio_thread = threading.Thread(target=self.audio_worker, daemon=True)
self.audio_thread.start()
logger.info(f"VAReader {self.rank}/{self.world_size} started successfully")
else:
logger.info(f"VAReader {self.rank}/{self.world_size} wait only")
if self.world_size > 1:
logger.info(f"VAReader {self.rank}/{self.world_size} wait barrier")
dist.barrier()
logger.info(f"VAReader {self.rank}/{self.world_size} end barrier")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process read audio from stream"""
ffmpeg_cmd = [
"ffmpeg",
"-i",
self.stream_url,
"-vn",
# "-acodec",
# "pcm_s16le",
"-ar",
str(self.sample_rate),
"-ac",
str(self.audio_channels),
"-f",
"s16le",
"-",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0)
logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg process: {e}")
raise
def start_ffmpeg_process_whep(self):
"""Start gstream process read audio from stream"""
ffmpeg_cmd = [
"gst-launch-1.0",
"-q",
"whepsrc",
f"whep-endpoint={self.stream_url}",
"video-caps=none",
"!rtpopusdepay",
"!opusdec",
"plc=false",
"!audioconvert",
"!audioresample",
f"!audio/x-raw,format=S16LE,channels={self.audio_channels},rate={self.sample_rate}",
"!fdsink",
"fd=1",
]
try:
self.ffmpeg_process = subprocess.Popen(
ffmpeg_cmd,
stdout=subprocess.PIPE,
# stderr=subprocess.PIPE,
bufsize=0,
)
logger.info(f"FFmpeg audio pull process started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg process: {e}")
raise
def audio_worker(self):
logger.info("Audio pull worker thread started")
try:
while True:
if not self.ffmpeg_process or self.ffmpeg_process.poll() is not None:
logger.warning("FFmpeg process exited, audio worker thread stopped")
break
self.fetch_audio_data()
time.sleep(0.01)
except: # noqa
logger.error(f"Audio pull worker error: {traceback.format_exc()}")
finally:
logger.warning("Audio pull worker thread stopped")
def fetch_audio_data(self):
"""Fetch audio data from ffmpeg process"""
try:
audio_bytes = self.ffmpeg_process.stdout.read(self.chunk_size)
if not audio_bytes:
return
self.bytes_buffer.extend(audio_bytes)
# logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes")
if len(self.bytes_buffer) >= self.chunk_size:
audio_data = self.bytes_buffer[: self.chunk_size]
self.bytes_buffer = self.bytes_buffer[self.chunk_size :]
# first chunk, read original 81 frames
# for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames
if self.prev_chunk is None:
logger.info(f"change chunk_size: from {self.chunk_size} to {self.chunk_size - self.prev_size}")
self.chunk_size -= self.prev_size
else:
audio_data = self.prev_chunk + audio_data
self.prev_chunk = audio_data[-self.prev_size :]
try:
self.audio_queue.put_nowait(audio_data)
except queue.Full:
logger.warning(f"Audio queue full:{self.audio_queue.qsize()}, discarded oldest chunk")
self.audio_queue.get_nowait()
self.audio_queue.put_nowait(audio_data)
logger.info(f"Put audio data: {len(audio_data)} bytes, audio_queue: {self.audio_queue.qsize()}, chunk_size:{self.chunk_size}")
except: # noqa
logger.error(f"Fetch audio data error: {traceback.format_exc()}")
def braodcast_audio_data(self, audio_data):
if self.rank == self.target_rank:
if audio_data is None:
self.flag_tensor.fill_(0)
else:
self.flag_tensor.fill_(1)
self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist.broadcast(self.flag_tensor, src=self.target_rank)
if self.flag_tensor.item() == 0:
return None
dist.broadcast(self.audio_tensor, src=self.target_rank)
if self.rank != self.target_rank:
logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data = self.audio_tensor.cpu().numpy().tobytes()
return audio_data
def bytes_to_ndarray(self, audio_data):
if audio_data is None:
return None
audio_data = np.frombuffer(audio_data, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return audio_data
def get_audio_segment(self, timeout: float = 1.0):
audio_data = None
if self.rank == self.target_rank:
try:
audio_data = self.audio_queue.get(timeout=timeout)
except: # noqa
logger.warning(f"Failed to get audio segment: {traceback.format_exc()}")
if self.world_size > 1:
audio_data = self.braodcast_audio_data(audio_data)
audio_data = self.bytes_to_ndarray(audio_data)
return audio_data
def stop(self):
# Stop ffmpeg process
if self.ffmpeg_process:
self.ffmpeg_process.send_signal(signal.SIGINT)
try:
self.ffmpeg_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.ffmpeg_process.kill()
logger.warning("FFmpeg reader process stopped")
# Wait for threads to finish
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=5)
if self.audio_thread.is_alive():
logger.error("Audio pull thread did not stop gracefully")
while self.audio_queue and self.audio_queue.qsize() > 0:
self.audio_queue.get_nowait()
self.audio_queue = None
logger.warning("Audio pull queue cleaned")
def __del__(self):
self.stop()
if __name__ == "__main__":
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
reader = VAReader(
RANK,
WORLD_SIZE,
# "rtmp://localhost/live/test_audio",
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000",
segment_duration=1.0,
sample_rate=16000,
audio_channels=1,
prev_duration=1 / 16,
)
reader.start()
fail_count = 0
max_fail_count = 2
try:
while True:
audio_data = reader.get_audio_segment(timeout=2)
if audio_data is not None:
# logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count = 0
else:
fail_count += 1
if fail_count > max_fail_count:
logger.warning("Failed to get audio chunk, stop reader")
reader.stop()
break
time.sleep(0.95)
finally:
reader.stop()
import datetime
import json
import os
import random
import subprocess
import threading
import time
import traceback
from collections import deque
from copy import deepcopy
import jsonschema
import numpy as np
import torch
import torch.distributed as dist
import zmq
from loguru import logger
try:
from bson import BSON
except ImportError:
BSON = None
logger.warning("BSON is not installed")
from scipy.signal import resample
class AudioInfo:
def __init__(self, info: dict):
self.sample_count = info["sample_count"]
self.sample_rate = info["sample_rate"]
self.channel_count = info["channel_count"]
self.sample_fmt = info["sample_fmt"]
self.pts = info["pts"]
def is_spec_equal(self, other: "AudioInfo") -> bool:
return self.sample_fmt == other.sample_fmt and self.sample_rate == other.sample_rate and self.channel_count == other.channel_count
def duration(self) -> datetime.timedelta:
return datetime.timedelta(seconds=self.sample_count / self.sample_rate)
def __str__(self):
return "AudioInfo(sample_count={}, sample_rate={}, channel_count={}, sample_fmt={}, pts={})".format(self.sample_count, self.sample_rate, self.channel_count, self.sample_fmt, self.pts)
class ByteBuffer:
def __init__(self):
self.buffer = deque()
self.current_size = 0
# is the audio belonging to current turn finished
self.audio_finished = False
def add(self, byte_data: bytes):
self.buffer.append(byte_data)
self.current_size += len(byte_data)
def get(self, size=1024):
data = bytearray()
while size > 0 and len(self.buffer) > 0:
chunk = self.buffer.popleft()
if len(chunk) <= size:
# 如果当前数据小于size,则将当前数据全部添加到data中
data.extend(chunk)
self.current_size -= len(chunk)
size -= len(chunk)
else:
# 如果当前数据大于size,则将当前数据的一部分添加到data中,剩余部分留在缓冲区
data.extend(chunk[:size])
self.buffer.appendleft(chunk[size:]) # 剩余部分留在缓冲区
self.current_size -= size
size = 0
return bytes(data)
def mark_finished(self):
self.audio_finished = True
def has_more_voice(self):
return not self.audio_finished
def __len__(self):
return self.current_size
class ChatAdapter:
def __init__(
self,
omni_work_dir: str,
whep_url: str,
session_id: str,
account: str,
config_files: list[str],
config_schema_path: str,
seg_duration: float,
model_runner,
huoshan_tts_voice_type,
):
assert os.path.exists(omni_work_dir), f"OMNI work directory {omni_work_dir} does not exist"
self.omni_work_dir = omni_work_dir
self.context = zmq.Context()
self.w2f_socket = self.context.socket(zmq.PULL)
self.w2f_url = ChatAdapter.select_and_bind(self.w2f_socket)
self.f2w_socket = self.context.socket(zmq.PUSH)
self.f2w_url = ChatAdapter.select_and_bind(self.f2w_socket)
self.recv_thread = None
self.audio_buffer = ByteBuffer()
self.audio_info = None
self.chat_server_cmd = [
os.path.join(self.omni_work_dir, "bin", "seko-chatter"),
"--session-id",
session_id,
"--account",
account,
"--whep-server-url",
whep_url,
"--w2f-endpoint",
self.w2f_url,
"--f2w-endpoint",
self.f2w_url,
"--config-files",
*config_files,
]
override_config = {}
if huoshan_tts_voice_type is not None:
logger.info(f"Use Huoshan TTS voice type: {huoshan_tts_voice_type}")
override_config["TTS"] = {
"default_voice_info": {
"voice_type": huoshan_tts_voice_type,
"provider": "huoshan_stream_tts",
}
}
with open(config_schema_path, "r") as f:
schema = json.load(f)
jsonschema.validate(instance=override_config, schema=schema)
if override_config is not None:
self.chat_server_cmd.extend(["--override-config", json.dumps(override_config)])
self.chatter_proc = None
self.seg_duration = seg_duration
self.reset_prev = False
self.status = "blank"
self.immediate_switch = 0
self.model_runner = model_runner
def launch_chat_server(self):
env = {
"RUST_LOG": "info,duplex_server=debug,backend_5o=debug",
"LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", "") + ":" + os.path.join(self.omni_work_dir, "lib/"),
"PATH": os.environ["PATH"] + ":" + os.path.join(self.omni_work_dir, "bin/"),
}
self.chatter_proc = subprocess.Popen(self.chat_server_cmd, env=env, cwd=self.omni_work_dir)
@staticmethod
def select_and_bind(socket: zmq.Socket) -> str:
# randomly select a port between 1024 and 6553
retry_count = 20
err = None
while retry_count > 0:
try:
port = random.randint(1024, 65535)
# port = 5555
url = f"tcp://localhost:{port}"
socket.bind(url)
return url
except zmq.error.ZMQError as e:
retry_count -= 1
err = e
raise err
# immediate switch to status, discard prev_bytes, set immediate_switch to 1
def immediate_switch_to(self, status):
logger.warning(f"VA reader immediate switch to {status}")
self.reset_prev = True
self.status = status
self.immediate_switch = 1
if self.model_runner is not None:
self.model_runner.pause_signal = True
logger.warning(f"Model runner pause signal set to True")
def recv_loop(self):
while True:
try:
message = self.w2f_socket.recv()
except Exception:
logger.error(f"Error receiving message: {traceback.format_exc()}")
break
try:
message = BSON.decode(message)
msg_type = message["type"]
logger.debug("Received message type: {}".format(msg_type))
if msg_type == "AgentAudio":
audio = message["audio"]
if audio["type"] != "Pcm":
logger.error("Unsupported audio type: {}".format(audio["type"]))
continue
pcm_data = audio["data"]
audio_info = AudioInfo(audio["info"])
logger.debug("Received audio with duration: {}".format(audio_info.duration()))
if self.audio_info is None:
self.audio_info = audio_info
else:
# check if the audio info is the same
if not self.audio_info.is_spec_equal(audio_info):
raise ValueError("Audio info mismatch")
self.audio_buffer.add(pcm_data)
# if status is blank and has voice, set immediate switch to 1
if self.status == "blank" and self.has_voice(self.seg_duration):
self.immediate_switch_to("voice")
elif msg_type == "AgentStartPlay":
logger.debug("Received AgentStartPlay, create new audio buffer")
self.audio_buffer = ByteBuffer()
elif msg_type == "AgentEndPlay":
logger.debug("Received AgentEndPlay, mark audio finished")
self.audio_buffer.mark_finished()
elif msg_type == "ClearAgentAudio":
logger.warning("Received ClearAgentAudio, clear audio buffer")
self.audio_buffer = None
self.audio_info = None
if self.status == "voice":
self.status = "blank"
# self.immediate_switch_to("blank")
except Exception as e:
logger.error("Error decoding message: {}, continue".format(e))
continue
logger.warning("recv loop interrupted")
def start(self):
self.launch_chat_server()
self.recv_thread = threading.Thread(target=self.recv_loop)
self.recv_thread.start()
def has_voice(self, duration) -> bool:
if self.audio_info is None or self.audio_buffer.current_size == 0:
return False
bytes_count = round(duration * self.audio_info.sample_rate) * self.audio_info.channel_count * 2 # S16LE assumed
# if not has enough bytes and maybe has more voice, return False
if self.audio_buffer.current_size < bytes_count and self.audio_buffer.has_more_voice():
logger.warning(f"Not enough bytes and maybe has more voice, content_size: {self.audio_buffer.current_size}, bytes_count: {bytes_count}")
return False
return bytes_count
def get_audio(self, fetch_duration) -> (bytes, AudioInfo):
bytes_count = self.has_voice(fetch_duration)
if bytes_count is False:
return None
pcm_data = self.audio_buffer.get(bytes_count)
# the actual sample count fetched
sample_count = len(pcm_data) // (self.audio_info.channel_count * 2)
logger.debug("Fetched {} bytes audio".format(sample_count))
logger.debug("After fetch, there are {} bytes left".format(self.audio_buffer.current_size))
audio_info = deepcopy(self.audio_info)
audio_info.sample_count = sample_count
return (pcm_data, audio_info)
def stop(self):
self.model_runner = None
if self.chatter_proc is not None:
self.chatter_proc.terminate()
self.chatter_proc.wait()
self.chatter_proc = None
self.w2f_socket.close()
self.f2w_socket.close()
def __del__(self):
self.stop()
class OmniVAReader:
def __init__(
self,
rank: int,
world_size: int,
stream_url: str,
segment_duration: float = 5.0625,
sample_rate: int = 16000,
audio_channels: int = 1,
buffer_size: int = 1,
prev_duration: float = 0.3125,
target_rank: int = 0,
model_runner=None,
huoshan_tts_voice_type=None,
):
self.rank = rank
self.world_size = world_size
self.stream_url = stream_url
self.segment_duration = segment_duration
self.sample_rate = sample_rate
self.audio_channels = audio_channels
self.prev_duration = prev_duration
self.all_seg_sample_count = int(self.segment_duration * self.sample_rate)
self.prev_seg_sample_count = int(self.prev_duration * self.sample_rate)
self.prev_seg_chunk = None
self.target_rank = target_rank % self.world_size
self.flag_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
self.immediate_switch_tensor = torch.tensor([0], dtype=torch.int32).to(device="cuda")
chunk_size = int(self.segment_duration * self.sample_rate) * 2
self.audio_tensor = torch.zeros(chunk_size, dtype=torch.uint8, device="cuda")
self.chat_adapter = None
self.model_runner = model_runner
self.huoshan_tts_voice_type = huoshan_tts_voice_type
assert self.audio_channels == 1, "Only mono audio is supported for OmniVAReader"
logger.info(f"VAReader initialized for stream: {stream_url} target_rank: {self.target_rank}")
logger.info(f"Audio duration per chunk: {segment_duration}s, sample rate: {sample_rate}Hz")
def init_omni_env(self):
self.omni_work_dir = os.getenv("OMNI_WORK_DIR", "/path/of/seko_chatter/")
self.session_id = os.getenv("OMNI_SESSION_ID", "")
self.account = os.getenv("OMNI_ACCOUNT", "")
self.config_files = os.getenv("OMNI_CONFIG_FILES", "").split(",")
self.config_schema_path = os.getenv("OMNI_CONFIG_SCHEMA_PATH", None)
assert os.path.exists(self.omni_work_dir), f"OMNI work directory {self.omni_work_dir} does not exist"
assert self.session_id and self.account, "OMNI_SESSION_ID and OMNI_ACCOUNT are required"
logger.info(
f"OMNI work directory: {self.omni_work_dir}, session_id: {self.session_id}, account: {self.account}, config_files: {self.config_files}, config_schema_path: {self.config_schema_path}"
)
def start(self):
if self.rank == self.target_rank:
self.init_omni_env()
assert self.stream_url.startswith("http"), "Only HTTP stream is supported for OmniVAReader"
self.chat_adapter = ChatAdapter(
omni_work_dir=self.omni_work_dir,
whep_url=self.stream_url,
session_id=self.session_id,
account=self.account,
config_files=self.config_files,
config_schema_path=self.config_schema_path,
seg_duration=self.segment_duration,
model_runner=self.model_runner,
huoshan_tts_voice_type=self.huoshan_tts_voice_type,
)
self.chat_adapter.start()
logger.info(f"OmniVAReader {self.rank}/{self.world_size} started successfully")
else:
logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait only")
if self.world_size > 1:
logger.info(f"OmniVAReader {self.rank}/{self.world_size} wait barrier")
dist.barrier()
logger.info(f"OmniVAReader {self.rank}/{self.world_size} end barrier")
def braodcast_audio_data(self, audio_data):
if self.rank == self.target_rank:
if audio_data is None:
self.flag_tensor.fill_(0)
else:
self.flag_tensor.fill_(1)
self.audio_tensor.copy_(torch.frombuffer(bytearray(audio_data), dtype=torch.uint8))
# logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist.broadcast(self.flag_tensor, src=self.target_rank)
if self.flag_tensor.item() == 0:
return None
dist.broadcast(self.audio_tensor, src=self.target_rank)
if self.rank != self.target_rank:
# logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data = self.audio_tensor.cpu().numpy().tobytes()
return audio_data
def bytes_to_ndarray(self, audio_data):
if audio_data is None:
return None
audio_data = np.frombuffer(audio_data, dtype=np.int16)
audio_data = audio_data.astype(np.float32) / 32768.0
# logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return audio_data
def convert_pcm_s16le_to_mono_resampled(self, audio_data, audio_info):
audio = np.frombuffer(audio_data, dtype=np.int16)
sample_count = audio_info.sample_count
assert len(audio) == sample_count * audio_info.channel_count, f"audio length {len(audio)} != sample_count * channel_count {sample_count * audio_info.channel_count}"
# convert to mono
if audio_info.channel_count > 1:
audio = audio.reshape(-1, audio_info.channel_count).mean(axis=1)
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()}")
if audio_info.sample_rate != self.sample_rate:
sample_count = int(len(audio) * self.sample_rate / audio_info.sample_rate)
audio = resample(audio, sample_count).astype(np.int16)
# logger.info(f"resampled audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
logger.warning(f"valid audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
return audio, sample_count
def prepare_audio_data(self, chat_audio_result):
sample_count = 0
audio = np.array([], dtype=np.int16)
# convert chat audio result to mono and target sample rate
if chat_audio_result is not None:
audio_data, audio_info = chat_audio_result
audio, sample_count = self.convert_pcm_s16le_to_mono_resampled(audio_data, audio_info)
# if is not the first segment, concat with previous segment
if self.prev_seg_chunk is not None:
audio = np.concatenate([self.prev_seg_chunk, audio])
sample_count = len(audio)
assert sample_count <= self.all_seg_sample_count, f"audio length {sample_count} > all_seg_sample_count {self.all_seg_sample_count}"
# pad 0 to the audio to make it the same length as all_seg_sample_count
if sample_count < self.all_seg_sample_count:
pad_count = self.all_seg_sample_count - sample_count
# logger.info(f"pad {pad_count} samples to audio")
audio = np.pad(audio, (0, pad_count), mode="constant", constant_values=0)
sample_count = len(audio)
# update prev seg chunk
self.prev_seg_chunk = audio[-self.prev_seg_sample_count :]
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}, prev seg chunk: {self.prev_seg_chunk.shape}")
return audio.tobytes()
def get_fetch_duration(self):
fetch_duration = self.segment_duration
# after immediate switch, reset prev seg chunk
if self.chat_adapter.reset_prev:
self.prev_seg_chunk = None
self.chat_adapter.reset_prev = False
logger.warning(f"Reset prev seg chunk")
# first segment, fetch segment_duration, else fetch segment_duration - prev_duration
if self.prev_seg_chunk is not None:
fetch_duration -= self.prev_duration
return fetch_duration
def get_audio_segment(self):
audio_data = None
if self.rank == self.target_rank:
try:
fetch_duration = self.get_fetch_duration()
# logger.info(f"Get segment, fetch_duration: {fetch_duration}")
if self.chat_adapter.status == "voice":
audio_result = self.chat_adapter.get_audio(fetch_duration)
audio_data = self.prepare_audio_data(audio_result)
# think all voice segments inferred, naturally switch to blank
if audio_result is None:
logger.info(f"Think all voice segments inferred, naturally switch to blank")
self.chat_adapter.status = "blank"
else:
audio_data = self.prepare_audio_data(None)
except Exception as e:
logger.warning(f"Failed to get voice segment: {e}")
return None
if self.world_size > 1:
audio_data = self.braodcast_audio_data(audio_data)
audio_data = self.bytes_to_ndarray(audio_data)
return audio_data
def get_immediate_switch(self):
if self.rank == self.target_rank:
if self.chat_adapter.immediate_switch == 1:
self.immediate_switch_tensor.fill_(1)
# reset immediate switch
self.chat_adapter.immediate_switch = 0
else:
self.immediate_switch_tensor.fill_(0)
dist.broadcast(self.immediate_switch_tensor, src=self.target_rank)
immediate_switch = self.immediate_switch_tensor.item()
return immediate_switch
def stop(self):
self.model_runner = None
if self.chat_adapter is not None:
self.chat_adapter.stop()
self.chat_adapter = None
logger.warning("OmniVAReader stopped")
def __del__(self):
self.stop()
if __name__ == "__main__":
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
RANK = int(os.environ.get("RANK", 0))
if WORLD_SIZE > 1:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
logger.info(f"Distributed initialized: rank={RANK}, world_size={WORLD_SIZE}")
reader = OmniVAReader(
RANK,
WORLD_SIZE,
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=publish&stream=test_stream_ll&eip=10.120.114.82:8000",
segment_duration=17 / 16,
sample_rate=16000,
audio_channels=1,
prev_duration=1 / 16,
)
reader.start()
fail_count = 0
max_fail_count = 100000000
try:
while True:
audio_data = reader.get_audio_segment(timeout=1)
if audio_data is not None:
logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count = 0
else:
fail_count += 1
if fail_count > max_fail_count:
logger.warning("Failed to get audio chunk, stop reader")
reader.stop()
break
time.sleep(0.95)
finally:
reader.stop()
import os
import queue
import socket
import subprocess
import threading
import time
import traceback
import numpy as np
import torch
import torchaudio as ta
from loguru import logger
def pseudo_random(a, b):
x = str(time.time()).split(".")[1]
y = int(float("0." + x) * 1000000)
return a + (y % (b - a + 1))
class VARecorder:
def __init__(
self,
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
):
self.livestream_url = livestream_url
self.fps = fps
self.sample_rate = sample_rate
self.audio_port = pseudo_random(32000, 40000)
self.video_port = self.audio_port + 1
self.ffmpeg_log_level = os.getenv("FFMPEG_LOG_LEVEL", "error")
logger.info(f"VARecorder audio port: {self.audio_port}, video port: {self.video_port}, ffmpeg_log_level: {self.ffmpeg_log_level}")
self.width = None
self.height = None
self.stoppable_t = None
self.realtime = False
if self.livestream_url.startswith("rtmp://") or self.livestream_url.startswith("http"):
self.realtime = True
# ffmpeg process for mix video and audio data and push to livestream
self.ffmpeg_process = None
# TCP connection objects
self.audio_socket = None
self.video_socket = None
self.audio_conn = None
self.video_conn = None
self.audio_thread = None
self.video_thread = None
# queue for send data to ffmpeg process
self.audio_queue = queue.Queue()
self.video_queue = queue.Queue()
# buffer for stream data
self.audio_samples_per_frame = round(self.sample_rate / self.fps)
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def init_sockets(self):
# TCP socket for send and recv video and audio data
self.video_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.video_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.video_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.video_socket.bind(("127.0.0.1", self.video_port))
self.video_socket.listen(1)
self.audio_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.audio_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.audio_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.audio_socket.bind(("127.0.0.1", self.audio_port))
self.audio_socket.listen(1)
def audio_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to audio socket...")
self.audio_conn, _ = self.audio_socket.accept()
logger.info(f"Audio connection established from {self.audio_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
while True:
try:
if self.audio_queue is None:
break
data = self.audio_queue.get()
if data is None:
logger.info("Audio thread received stop signal")
break
# Convert audio data to 16-bit integer format
audios = torch.clamp(torch.round(data * 32767), -32768, 32767).to(torch.int16)
try:
self.audio_conn.send(audios[None].cpu().numpy().tobytes())
except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Audio connection closed, stopping worker: {type(e).__name__}")
return
fail_time = 0
except (BrokenPipeError, OSError, ConnectionResetError):
logger.info("Audio connection closed during queue processing")
break
except Exception:
logger.error(f"Send audio data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Audio push worker thread failed {fail_time} times, stopping...")
break
except Exception:
logger.error(f"Audio push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Audio push worker thread stopped")
def video_worker(self):
try:
logger.info("Waiting for ffmpeg to connect to video socket...")
self.video_conn, _ = self.video_socket.accept()
logger.info(f"Video connection established from {self.video_conn.getpeername()}")
fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps
while True:
try:
if self.video_queue is None:
break
data = self.video_queue.get()
if data is None:
logger.info("Video thread received stop signal")
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
for i in range(data.shape[0]):
t0 = time.time()
frame = (data[i] * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
try:
self.video_conn.send(frame.tobytes())
except (BrokenPipeError, OSError, ConnectionResetError) as e:
logger.info(f"Video connection closed, stopping worker: {type(e).__name__}")
return
if self.realtime and i < data.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
except (BrokenPipeError, OSError, ConnectionResetError):
logger.info("Video connection closed during queue processing")
break
except Exception:
logger.error(f"Send video data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Video push worker thread failed {fail_time} times, stopping...")
break
except Exception:
logger.error(f"Video push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Video push worker thread stopped")
def start_ffmpeg_process_local(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-fflags",
"nobuffer",
"-analyzeduration",
"0",
"-probesize",
"32",
"-flush_packets",
"1",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-pix_fmt",
"rgb24",
"-color_range",
"pc",
"-colorspace",
"rgb",
"-color_primaries",
"bt709",
"-color_trc",
"iec61966-2-1",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"44100",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"mp4",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-re",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"44100",
"-b:v",
"2M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"flv",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_whip(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"ffmpeg",
"-re",
"-fflags",
"nobuffer",
"-analyzeduration",
"0",
"-probesize",
"32",
"-flush_packets",
"1",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-ch_layout",
"mono",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"48000",
"-c:a",
"libopus",
"-ac",
"2",
"-b:v",
"2M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-threads",
"1",
"-bf",
"0",
"-f",
"whip",
self.livestream_url,
"-y",
"-loglevel",
self.ffmpeg_log_level,
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start(self, width: int, height: int):
self.set_video_size(width, height)
duration = 1.0
frames = int(self.fps * duration)
samples = int(self.sample_rate * (frames / self.fps))
self.pub_livestream(torch.zeros((frames, height, width, 3), dtype=torch.float16), torch.zeros(samples, dtype=torch.float16))
time.sleep(duration)
def set_video_size(self, width: int, height: int):
if self.width is not None and self.height is not None:
assert self.width == width and self.height == height, "Video size already set"
return
self.width = width
self.height = height
self.init_sockets()
if self.livestream_url.startswith("rtmp://"):
self.start_ffmpeg_process_rtmp()
elif self.livestream_url.startswith("http"):
self.start_ffmpeg_process_whip()
else:
self.start_ffmpeg_process_local()
self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start()
self.video_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
# Publish ComfyUI Image tensor and audio tensor to livestream
def pub_livestream(self, images: torch.Tensor, audios: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert C == 3, "Input must be [N, H, W, C] with C=3"
logger.info(f"Publishing video [{N}x{width}x{height}], audio: [{M}]")
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
self.audio_queue.put(audios)
self.video_queue.put(images)
logger.info(f"Published {N} frames and {M} audio samples")
self.stoppable_t = time.time() + M / self.sample_rate + 3
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = images[i:end_frame]
aud = audios[i * self.audio_samples_per_frame : end_frame * self.audio_samples_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.audio_queue.put(aud)
self.video_queue.put(img)
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + aud.shape[0] / self.sample_rate + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
if t > 0:
logger.warning(f"Waiting for {t} seconds to stop ...")
time.sleep(t)
self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues
if self.audio_queue:
self.audio_queue.put(None)
if self.video_queue:
self.video_queue.put(None)
# Wait for threads to finish processing queued data (increased timeout)
queue_timeout = 30 # Increased from 5s to 30s to allow sufficient time for large video frames
if self.audio_thread and self.audio_thread.is_alive():
self.audio_thread.join(timeout=queue_timeout)
if self.audio_thread.is_alive():
logger.error(f"Audio push thread did not stop after {queue_timeout}s")
if self.video_thread and self.video_thread.is_alive():
self.video_thread.join(timeout=queue_timeout)
if self.video_thread.is_alive():
logger.error(f"Video push thread did not stop after {queue_timeout}s")
# Shutdown connections to signal EOF to FFmpeg
# shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
if self.audio_conn:
try:
self.audio_conn.getpeername()
self.audio_conn.shutdown(socket.SHUT_WR)
logger.info("Audio connection shutdown initiated")
except OSError:
# Connection already closed, skip shutdown
pass
if self.video_conn:
try:
self.video_conn.getpeername()
self.video_conn.shutdown(socket.SHUT_WR)
logger.info("Video connection shutdown initiated")
except OSError:
# Connection already closed, skip shutdown
pass
if self.ffmpeg_process:
is_local_file = not self.livestream_url.startswith(("rtmp://", "http"))
# Local MP4 files need time to write moov atom and finalize the container
timeout_seconds = 30 if is_local_file else 10
logger.info(f"Waiting for FFmpeg to finalize file (timeout={timeout_seconds}s, local_file={is_local_file})")
logger.info(f"FFmpeg output: {self.livestream_url}")
try:
returncode = self.ffmpeg_process.wait(timeout=timeout_seconds)
if returncode == 0:
logger.info(f"FFmpeg process exited successfully (exit code: {returncode})")
else:
logger.warning(f"FFmpeg process exited with non-zero code: {returncode}")
except subprocess.TimeoutExpired:
logger.warning(f"FFmpeg process did not exit within {timeout_seconds}s, sending SIGTERM...")
try:
self.ffmpeg_process.terminate() # SIGTERM
returncode = self.ffmpeg_process.wait(timeout=5)
logger.warning(f"FFmpeg process terminated with SIGTERM (exit code: {returncode})")
except subprocess.TimeoutExpired:
logger.error("FFmpeg process still running after SIGTERM, killing with SIGKILL...")
self.ffmpeg_process.kill()
self.ffmpeg_process.wait() # Wait for kill to complete
logger.error("FFmpeg process killed with SIGKILL")
finally:
self.ffmpeg_process = None
if self.audio_conn:
try:
self.audio_conn.close()
except Exception as e:
logger.debug(f"Error closing audio connection: {e}")
finally:
self.audio_conn = None
if self.video_conn:
try:
self.video_conn.close()
except Exception as e:
logger.debug(f"Error closing video connection: {e}")
finally:
self.video_conn = None
if self.audio_socket:
try:
self.audio_socket.close()
except Exception as e:
logger.debug(f"Error closing audio socket: {e}")
finally:
self.audio_socket = None
if self.video_socket:
try:
self.video_socket.close()
except Exception as e:
logger.debug(f"Error closing video socket: {e}")
finally:
self.video_socket = None
if self.audio_queue:
while self.audio_queue.qsize() > 0:
try:
self.audio_queue.get_nowait()
except: # noqa
break
if self.video_queue:
while self.video_queue.qsize() > 0:
try:
self.video_queue.get_nowait()
except: # noqa
break
self.audio_queue = None
self.video_queue = None
logger.info("VARecorder stopped and resources cleaned up")
def __del__(self):
self.stop(wait=False)
def create_simple_video(frames=10, height=480, width=640):
video_data = []
for i in range(frames):
frame = np.zeros((height, width, 3), dtype=np.float32)
stripe_height = height // 8
colors = [
[1.0, 0.0, 0.0], # 红色
[0.0, 1.0, 0.0], # 绿色
[0.0, 0.0, 1.0], # 蓝色
[1.0, 1.0, 0.0], # 黄色
[1.0, 0.0, 1.0], # 洋红
[0.0, 1.0, 1.0], # 青色
[1.0, 1.0, 1.0], # 白色
[0.5, 0.5, 0.5], # 灰色
]
for j, color in enumerate(colors):
start_y = j * stripe_height
end_y = min((j + 1) * stripe_height, height)
frame[start_y:end_y, :] = color
offset = int((i / frames) * width)
frame = np.roll(frame, offset, axis=1)
frame = torch.tensor(frame, dtype=torch.float32)
video_data.append(frame)
return torch.stack(video_data, dim=0)
if __name__ == "__main__":
sample_rate = 16000
fps = 16
width = 640
height = 480
recorder = VARecorder(
# livestream_url="rtmp://localhost/live/test",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url="/path/to/output_video.mp4",
fps=fps,
sample_rate=sample_rate,
)
audio_path = "/path/to/test_b_2min.wav"
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.reshape(-1)
secs = audio_array.shape[0] // sample_rate
interval = 1
for i in range(0, secs, interval):
logger.info(f"{i} / {secs} s")
start = i * sample_rate
end = (i + interval) * sample_rate
cur_audio_array = audio_array[start:end]
logger.info(f"audio: {cur_audio_array.shape} {cur_audio_array.dtype} {cur_audio_array.min()} {cur_audio_array.max()}")
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"images: {images.shape} {images.dtype} {images.min()} {images.max()}")
recorder.pub_livestream(images, cur_audio_array)
time.sleep(interval)
recorder.stop()
import ctypes
import queue
import threading
import time
import traceback
import numpy as np
import torch
import torchaudio as ta
from loguru import logger
from scipy.signal import resample
class X264VARecorder:
def __init__(
self,
whip_shared_path: str,
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
slice_frame: int = 1,
prev_frame: int = 1,
):
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.livestream_url = livestream_url
self.fps = fps
self.sample_rate = sample_rate
self.width = None
self.height = None
self.stoppable_t = None
# only enable whip shared api for whip http livestream
self.whip_shared_path = whip_shared_path
self.whip_shared_lib = None
self.whip_shared_handle = None
assert livestream_url.startswith("http"), "X264VARecorder only support whip http livestream"
self.realtime = True
# queue for send data to whip shared api
self.queue = queue.Queue()
self.worker_thread = None
# buffer for stream data
self.target_sample_rate = 48000
self.target_samples_per_frame = round(self.target_sample_rate / self.fps)
self.target_chunks_per_frame = self.target_samples_per_frame * 2
self.stream_buffer = []
self.stream_buffer_lock = threading.Lock()
self.stop_schedule = False
self.schedule_thread = None
self.slice_frame = slice_frame
self.prev_frame = prev_frame
assert self.slice_frame >= self.prev_frame, "Slice frame must be greater than previous frame"
def worker(self):
try:
fail_time, max_fail_time = 0, 10
packet_secs = 1.0 / self.fps
while True:
try:
if self.queue is None:
break
data = self.queue.get()
if data is None:
logger.info("Worker thread received stop signal")
break
audios, images = data
for i in range(images.shape[0]):
t0 = time.time()
cur_audio = audios[i * self.target_chunks_per_frame : (i + 1) * self.target_chunks_per_frame].flatten()
audio_ptr = cur_audio.ctypes.data_as(ctypes.POINTER(ctypes.c_int16))
self.whip_shared_lib.pushWhipRawAudioFrame(self.whip_shared_handle, audio_ptr, self.target_samples_per_frame)
cur_video = images[i].flatten()
video_ptr = cur_video.ctypes.data_as(ctypes.POINTER(ctypes.c_uint8))
self.whip_shared_lib.pushWhipRawVideoFrame(self.whip_shared_handle, video_ptr, self.width, self.height)
if self.realtime and i < images.shape[0] - 1:
time.sleep(max(0, packet_secs - (time.time() - t0)))
fail_time = 0
except: # noqa
logger.error(f"Send audio data error: {traceback.format_exc()}")
fail_time += 1
if fail_time > max_fail_time:
logger.error(f"Audio push worker thread failed {fail_time} times, stopping...")
break
except: # noqa
logger.error(f"Audio push worker thread error: {traceback.format_exc()}")
finally:
logger.info("Audio push worker thread stopped")
def start_libx264_whip_shared_api(self, width: int, height: int):
self.whip_shared_lib = ctypes.CDLL(self.whip_shared_path)
# define function argtypes and restype
self.whip_shared_lib.initWhipStream.argtypes = [ctypes.c_char_p, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.initWhipStream.restype = ctypes.c_void_p
self.whip_shared_lib.pushWhipRawAudioFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int16), ctypes.c_int]
self.whip_shared_lib.pushWhipRawVideoFrame.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_uint8), ctypes.c_int, ctypes.c_int]
self.whip_shared_lib.destroyWhipStream.argtypes = [ctypes.c_void_p]
whip_url = ctypes.c_char_p(self.livestream_url.encode("utf-8"))
self.whip_shared_handle = ctypes.c_void_p(self.whip_shared_lib.initWhipStream(whip_url, 1, 1, 0, width, height))
logger.info(f"WHIP shared API initialized with handle: {self.whip_shared_handle}")
def convert_data(self, audios, images):
# Convert audio data to 16-bit integer format
audio_datas = torch.clamp(torch.round(audios * 32767), -32768, 32767).to(torch.int16).cpu().numpy().reshape(-1)
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
image_datas = (images * 255).clamp(0, 255).to(torch.uint8).cpu().numpy()
logger.info(f"image_datas: {image_datas.shape} {image_datas.dtype} {image_datas.min()} {image_datas.max()}")
reample_audios = resample(audio_datas, int(len(audio_datas) * 48000 / self.sample_rate))
stereo_audios = np.stack([reample_audios, reample_audios], axis=-1).astype(np.int16).reshape(-1)
return stereo_audios, image_datas
def start(self, width: int, height: int):
self.set_video_size(width, height)
def set_video_size(self, width: int, height: int):
if self.width is not None and self.height is not None:
assert self.width == width and self.height == height, "Video size already set"
return
self.width = width
self.height = height
self.start_libx264_whip_shared_api(width, height)
self.worker_thread = threading.Thread(target=self.worker)
self.worker_thread.start()
if self.realtime:
self.schedule_thread = threading.Thread(target=self.schedule_stream_buffer)
self.schedule_thread.start()
def buffer_stream(self, images: torch.Tensor, audios: torch.Tensor, gen_video: torch.Tensor):
N, height, width, C = images.shape
M = audios.reshape(-1).shape[0]
assert N % self.slice_frame == 0, "Video frames must be divisible by slice_frame"
assert C == 3, "Input must be [N, H, W, C] with C=3"
audio_frames = round(M * self.fps / self.sample_rate)
if audio_frames != N:
logger.warning(f"Video and audio frames mismatch, {N} vs {audio_frames}")
self.set_video_size(width, height)
audio_datas, image_datas = self.convert_data(audios, images)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets = []
for i in range(0, N, self.slice_frame):
end_frame = i + self.slice_frame
img = image_datas[i:end_frame]
aud = audio_datas[i * self.target_chunks_per_frame : end_frame * self.target_chunks_per_frame]
gen = gen_video[:, :, (end_frame - self.prev_frame) : end_frame]
rets.append((img, aud, gen))
with self.stream_buffer_lock:
origin_size = len(self.stream_buffer)
self.stream_buffer.extend(rets)
logger.info(f"Buffered {origin_size} + {len(rets)} = {len(self.stream_buffer)} stream segments")
def get_buffer_stream_size(self):
return len(self.stream_buffer)
def truncate_stream_buffer(self, size: int):
with self.stream_buffer_lock:
self.stream_buffer = self.stream_buffer[:size]
logger.info(f"Truncated stream buffer to {len(self.stream_buffer)} segments")
if len(self.stream_buffer) > 0:
return self.stream_buffer[-1][2] # return the last video tensor
else:
return None
def schedule_stream_buffer(self):
schedule_interval = self.slice_frame / self.fps
logger.info(f"Schedule stream buffer with interval: {schedule_interval} seconds")
t = None
while True:
try:
if self.stop_schedule:
break
img, aud, gen = None, None, None
with self.stream_buffer_lock:
if len(self.stream_buffer) > 0:
img, aud, gen = self.stream_buffer.pop(0)
if t is not None:
wait_secs = schedule_interval - (time.time() - t)
if wait_secs > 0:
time.sleep(wait_secs)
t = time.time()
if img is not None and aud is not None:
self.queue.put((aud, img))
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del gen
self.stoppable_t = time.time() + img.shape[0] / self.fps + 3
else:
logger.warning(f"No stream buffer to schedule")
except Exception:
logger.error(f"Schedule stream buffer error: {traceback.format_exc()}")
break
logger.info("Schedule stream buffer thread stopped")
def stop(self, wait=True):
if wait and self.stoppable_t:
t = self.stoppable_t - time.time()
if t > 0:
logger.warning(f"Waiting for {t} seconds to stop ...")
time.sleep(t)
self.stoppable_t = None
if self.schedule_thread:
self.stop_schedule = True
self.schedule_thread.join(timeout=5)
if self.schedule_thread and self.schedule_thread.is_alive():
logger.error(f"Schedule thread did not stop after 5s")
# Send stop signals to queues
if self.queue:
self.queue.put(None)
# Wait for threads to finish
if self.worker_thread and self.worker_thread.is_alive():
self.worker_thread.join(timeout=5)
if self.worker_thread.is_alive():
logger.warning("Worker thread did not stop gracefully")
# Destroy WHIP shared API
if self.whip_shared_lib and self.whip_shared_handle:
self.whip_shared_lib.destroyWhipStream(self.whip_shared_handle)
self.whip_shared_handle = None
self.whip_shared_lib = None
logger.warning("WHIP shared API destroyed")
def __del__(self):
self.stop()
def create_simple_video(frames=10, height=480, width=640):
video_data = []
for i in range(frames):
frame = np.zeros((height, width, 3), dtype=np.float32)
stripe_height = height // 8
colors = [
[1.0, 0.0, 0.0], # 红色
[0.0, 1.0, 0.0], # 绿色
[0.0, 0.0, 1.0], # 蓝色
[1.0, 1.0, 0.0], # 黄色
[1.0, 0.0, 1.0], # 洋红
[0.0, 1.0, 1.0], # 青色
[1.0, 1.0, 1.0], # 白色
[0.5, 0.5, 0.5], # 灰色
]
for j, color in enumerate(colors):
start_y = j * stripe_height
end_y = min((j + 1) * stripe_height, height)
frame[start_y:end_y, :] = color
offset = int((i / frames) * width)
frame = np.roll(frame, offset, axis=1)
frame = torch.tensor(frame, dtype=torch.float32)
video_data.append(frame)
return torch.stack(video_data, dim=0)
if __name__ == "__main__":
sample_rate = 16000
fps = 16
width = 452
height = 352
recorder = X264VARecorder(
whip_shared_path="/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/0.1.1/go_whxp.so",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll2&eip=10.120.114.82:8000",
fps=fps,
sample_rate=sample_rate,
)
recorder.start(width, height)
# time.sleep(5)
audio_path = "/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav"
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.numpy().reshape(-1)
secs = audio_array.shape[0] // sample_rate
interval = 1
space = 10
i = 0
while i < space:
t0 = time.time()
logger.info(f"space {i} / {space} s")
cur_audio_array = np.zeros(int(interval * sample_rate), dtype=np.float32)
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
recorder.buffer_stream(images, torch.tensor(cur_audio_array, dtype=torch.float32), images)
i += interval
time.sleep(interval - (time.time() - t0))
started = True
i = 0
while i < secs:
t0 = time.time()
start = int(i * sample_rate)
end = int((i + interval) * sample_rate)
cur_audio_array = torch.tensor(audio_array[start:end], dtype=torch.float32)
num_frames = int(interval * fps)
images = create_simple_video(num_frames, height, width)
logger.info(f"{i} / {secs} s")
if started:
logger.warning(f"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!")
started = False
recorder.buffer_stream(images, cur_audio_array, images)
i += interval
time.sleep(interval - (time.time() - t0))
recorder.stop()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment