Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Bw-bestperf
Wan2.1-14B-480P-INT8-Lightx2v
Commits
e2778d0d
Commit
e2778d0d
authored
Feb 05, 2026
by
litzh
Browse files
Initial commit
parents
Pipeline
#3370
canceled with stages
Changes
532
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
5684 additions
and
0 deletions
+5684
-0
lightx2v/deploy/common/audio_separator.py
lightx2v/deploy/common/audio_separator.py
+405
-0
lightx2v/deploy/common/face_detector.py
lightx2v/deploy/common/face_detector.py
+728
-0
lightx2v/deploy/common/pipeline.py
lightx2v/deploy/common/pipeline.py
+167
-0
lightx2v/deploy/common/podcasts.py
lightx2v/deploy/common/podcasts.py
+696
-0
lightx2v/deploy/common/sensetime_voice_clone.py
lightx2v/deploy/common/sensetime_voice_clone.py
+846
-0
lightx2v/deploy/common/utils.py
lightx2v/deploy/common/utils.py
+276
-0
lightx2v/deploy/common/va_controller.py
lightx2v/deploy/common/va_controller.py
+253
-0
lightx2v/deploy/common/va_reader.py
lightx2v/deploy/common/va_reader.py
+278
-0
lightx2v/deploy/common/va_reader_omni.py
lightx2v/deploy/common/va_reader_omni.py
+597
-0
lightx2v/deploy/common/va_recorder.py
lightx2v/deploy/common/va_recorder.py
+695
-0
lightx2v/deploy/common/va_recorder_x264.py
lightx2v/deploy/common/va_recorder_x264.py
+321
-0
lightx2v/deploy/common/video_recorder.py
lightx2v/deploy/common/video_recorder.py
+422
-0
No files found.
Too many changes to show.
To preserve performance only
532 of 532+
files are displayed.
Plain diff
Email patch
lightx2v/deploy/common/audio_separator.py
0 → 100644
View file @
e2778d0d
# -*- coding: utf-8 -*-
"""
Audio Source Separation Module
Separates different voice tracks in audio, supports multi-person audio separation
"""
import
base64
import
io
import
os
import
tempfile
import
traceback
from
collections
import
defaultdict
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torch.serialization
import
torchaudio
from
loguru
import
logger
# Import pyannote.audio for speaker diarization
from
pyannote.audio
import
Audio
,
Pipeline
# Fix for PyTorch 2.6 compatibility: allow pyannote.audio classes in torch.load
# PyTorch 2.6 changed torch.load default to weights_only=True for security
try
:
# Add safe globals for pyannote.audio classes
# This allows torch.load to work with pyannote.audio model files
from
pyannote.audio.core.task
import
Specifications
torch
.
serialization
.
add_safe_globals
([
Specifications
])
except
(
ImportError
,
AttributeError
)
as
e
:
# If pyannote.audio is not installed or class doesn't exist, log warning
# The actual error will be handled when Pipeline.from_pretrained is called
logger
.
debug
(
f
"Could not add pyannote.audio safe globals (may need to use weights_only=False):
{
e
}
"
)
_origin_torch_load
=
torch
.
load
def
our_torch_load
(
checkpoint_file
,
*
args
,
**
kwargs
):
kwargs
[
"weights_only"
]
=
False
return
_origin_torch_load
(
checkpoint_file
,
*
args
,
**
kwargs
)
class
AudioSeparator
:
"""
Audio separator for separating different voice tracks in audio using pyannote.audio
Supports multi-person conversation separation, maintains duration (other speakers' tracks are empty)
"""
def
__init__
(
self
,
model_path
:
str
=
None
,
device
:
str
=
None
,
sample_rate
:
int
=
16000
,
):
"""
Initialize audio separator
Args:
model_path: Model path (if using custom model), default uses pyannote/speaker-diarization-community-1
device: Device ('cpu', 'cuda', etc.), None for auto selection
sample_rate: Target sample rate, default 16000
"""
self
.
sample_rate
=
sample_rate
self
.
device
=
device
if
device
else
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
self
.
_init_pyannote
(
model_path
)
def
_init_pyannote
(
self
,
model_path
:
str
=
None
):
"""Initialize pyannote.audio pipeline"""
try
:
huggingface_token
=
os
.
getenv
(
"HUGGINGFACE_TOKEN"
)
or
os
.
getenv
(
"HF_TOKEN"
)
model_name
=
model_path
or
"pyannote/speaker-diarization-community-1"
# Fix for PyTorch 2.6: use safe_globals context manager to allow pyannote.audio classes
# PyTorch 2.6 changed torch.load default to weights_only=True
try
:
from
pyannote.audio.core.task
import
Specifications
safe_globals_context
=
torch
.
serialization
.
safe_globals
([
Specifications
])
except
(
ImportError
,
AttributeError
):
# If Specifications class is not available, use empty context
safe_globals_context
=
torch
.
serialization
.
safe_globals
([])
try
:
torch
.
load
=
our_torch_load
# Try loading with token if available
if
huggingface_token
:
self
.
pipeline
=
Pipeline
.
from_pretrained
(
model_name
,
token
=
huggingface_token
)
else
:
# Try without token (may work for public models)
self
.
pipeline
=
Pipeline
.
from_pretrained
(
model_name
)
except
Exception
as
e
:
if
"gated"
in
str
(
e
).
lower
()
or
"token"
in
str
(
e
).
lower
():
raise
RuntimeError
(
f
"Model requires authentication. Set HUGGINGFACE_TOKEN or HF_TOKEN environment variable:
{
e
}
"
)
# If safe_globals didn't work, try with weights_only=False as fallback
if
"weights_only"
in
str
(
e
).
lower
()
or
"Unsupported global"
in
str
(
e
):
logger
.
warning
(
f
"PyTorch 2.6 compatibility issue detected, attempting fallback:
{
e
}
"
)
# Note: We can't directly control weights_only in Pipeline.from_pretrained,
# but the safe_globals should have worked. If not, the error will be raised.
raise
RuntimeError
(
f
"Failed to load pyannote model:
{
e
}
"
)
finally
:
torch
.
load
=
_origin_torch_load
# Move pipeline to specified device
if
self
.
device
:
self
.
pipeline
.
to
(
torch
.
device
(
self
.
device
))
# Initialize Audio helper for waveform loading
self
.
pyannote_audio
=
Audio
()
logger
.
info
(
"Initialized pyannote.audio speaker diarization pipeline"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to initialize pyannote:
{
e
}
"
)
raise
RuntimeError
(
f
"Failed to initialize pyannote.audio pipeline:
{
e
}
"
)
def
separate_speakers
(
self
,
audio_path
:
Union
[
str
,
bytes
],
num_speakers
:
Optional
[
int
]
=
None
,
min_speakers
:
int
=
1
,
max_speakers
:
int
=
5
,
)
->
Dict
:
"""
Separate different speakers in audio
Args:
audio_path: Audio file path or bytes data
num_speakers: Specified number of speakers, None for auto detection
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Dict containing:
- speakers: List of speaker audio segments, each containing:
- speaker_id: Speaker ID (0, 1, 2, ...)
- audio: torch.Tensor audio data [channels, samples]
- segments: List of (start_time, end_time) tuples
- sample_rate: Sample rate
"""
try
:
# Load audio
if
isinstance
(
audio_path
,
bytes
):
# 尝试从字节数据推断音频格式
# 检查是否是 WAV 格式(RIFF 头)
is_wav
=
audio_path
[:
4
]
==
b
"RIFF"
and
audio_path
[
8
:
12
]
==
b
"WAVE"
# 检查是否是 MP3 格式(ID3 或 MPEG 头)
is_mp3
=
audio_path
[:
3
]
==
b
"ID3"
or
audio_path
[:
2
]
==
b
"
\xff\xfb
"
or
audio_path
[:
2
]
==
b
"
\xff\xf3
"
# 根据格式选择后缀
if
is_wav
:
suffix
=
".wav"
elif
is_mp3
:
suffix
=
".mp3"
else
:
# 默认尝试 WAV,如果失败会抛出错误
suffix
=
".wav"
with
tempfile
.
NamedTemporaryFile
(
suffix
=
suffix
,
delete
=
False
)
as
tmp_file
:
tmp_file
.
write
(
audio_path
)
tmp_audio_path
=
tmp_file
.
name
try
:
result
=
self
.
_separate_speakers_internal
(
tmp_audio_path
,
num_speakers
,
min_speakers
,
max_speakers
)
finally
:
# 确保临时文件被删除
try
:
os
.
unlink
(
tmp_audio_path
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to delete temp file
{
tmp_audio_path
}
:
{
e
}
"
)
return
result
else
:
return
self
.
_separate_speakers_internal
(
audio_path
,
num_speakers
,
min_speakers
,
max_speakers
)
except
Exception
as
e
:
logger
.
error
(
f
"Speaker separation failed:
{
traceback
.
format_exc
()
}
"
)
raise
RuntimeError
(
f
"Audio separation error:
{
e
}
"
)
def
_separate_speakers_internal
(
self
,
audio_path
:
str
,
num_speakers
:
Optional
[
int
]
=
None
,
min_speakers
:
int
=
1
,
max_speakers
:
int
=
5
,
)
->
Dict
:
"""Internal method: execute speaker separation"""
# Load audio
waveform
,
original_sr
=
torchaudio
.
load
(
audio_path
)
if
original_sr
!=
self
.
sample_rate
:
resampler
=
torchaudio
.
transforms
.
Resample
(
original_sr
,
self
.
sample_rate
)
waveform
=
resampler
(
waveform
)
# Convert to mono if stereo
if
waveform
.
shape
[
0
]
>
1
:
waveform
=
waveform
.
mean
(
dim
=
0
,
keepdim
=
True
)
# Ensure waveform is float32 and normalized (pyannote expects this format)
if
waveform
.
dtype
!=
torch
.
float32
:
waveform
=
waveform
.
float
()
# Ensure waveform is in range [-1, 1] (normalize if needed)
if
waveform
.
abs
().
max
()
>
1.0
:
waveform
=
waveform
/
waveform
.
abs
().
max
()
if
self
.
pipeline
is
None
:
raise
RuntimeError
(
"Pyannote pipeline not initialized"
)
return
self
.
_separate_with_pyannote
(
audio_path
,
waveform
,
num_speakers
,
min_speakers
,
max_speakers
)
def
_separate_with_pyannote
(
self
,
audio_path
:
str
,
waveform
:
torch
.
Tensor
,
num_speakers
:
Optional
[
int
],
min_speakers
:
int
,
max_speakers
:
int
,
)
->
Dict
:
"""Use pyannote.audio for speaker diarization"""
try
:
# Use waveform dict to avoid AudioDecoder dependency issues
# Pipeline can accept either file path or waveform dict
# Using waveform dict is more reliable when torchcodec is not properly installed
audio_input
=
{
"waveform"
:
waveform
,
"sample_rate"
:
self
.
sample_rate
,
}
# Run speaker diarization
output
=
self
.
pipeline
(
audio_input
,
min_speakers
=
min_speakers
if
num_speakers
is
None
else
num_speakers
,
max_speakers
=
max_speakers
if
num_speakers
is
None
else
num_speakers
,
)
# Extract audio segments for each speaker
speakers_dict
=
defaultdict
(
list
)
for
turn
,
speaker
in
output
.
speaker_diarization
:
print
(
f
"Speaker:
{
speaker
}
, Start time:
{
turn
.
start
}
, End time:
{
turn
.
end
}
"
)
start_time
=
turn
.
start
end_time
=
turn
.
end
start_sample
=
int
(
start_time
*
self
.
sample_rate
)
end_sample
=
int
(
end_time
*
self
.
sample_rate
)
# Extract audio segment for this time period
segment_audio
=
waveform
[:,
start_sample
:
end_sample
]
speakers_dict
[
speaker
].
append
((
start_time
,
end_time
,
segment_audio
))
# Generate complete audio for each speaker (other speakers' segments are empty)
speakers
=
[]
audio_duration
=
waveform
.
shape
[
1
]
/
self
.
sample_rate
num_samples
=
waveform
.
shape
[
1
]
for
speaker_id
,
segments
in
speakers_dict
.
items
():
# Create zero-filled audio
speaker_audio
=
torch
.
zeros_like
(
waveform
)
# Fill in this speaker's segments
for
start_time
,
end_time
,
segment_audio
in
segments
:
start_sample
=
int
(
start_time
*
self
.
sample_rate
)
end_sample
=
int
(
end_time
*
self
.
sample_rate
)
# Ensure no out-of-bounds
end_sample
=
min
(
end_sample
,
num_samples
)
segment_len
=
end_sample
-
start_sample
if
segment_len
>
0
and
segment_audio
.
shape
[
1
]
>
0
:
actual_len
=
min
(
segment_len
,
segment_audio
.
shape
[
1
])
speaker_audio
[:,
start_sample
:
start_sample
+
actual_len
]
=
segment_audio
[:,
:
actual_len
]
speakers
.
append
(
{
"speaker_id"
:
speaker_id
,
"audio"
:
speaker_audio
,
"segments"
:
[(
s
[
0
],
s
[
1
])
for
s
in
segments
],
"sample_rate"
:
self
.
sample_rate
,
}
)
logger
.
info
(
f
"Separated audio into
{
len
(
speakers
)
}
speakers using pyannote"
)
return
{
"speakers"
:
speakers
,
"method"
:
"pyannote"
}
except
Exception
as
e
:
logger
.
error
(
f
"Pyannote separation failed:
{
e
}
"
)
raise
RuntimeError
(
f
"Audio separation failed:
{
e
}
"
)
def
save_speaker_audio
(
self
,
speaker_audio
:
torch
.
Tensor
,
output_path
:
str
,
sample_rate
:
int
=
None
):
"""
Save speaker audio to file
Args:
speaker_audio: Audio tensor [channels, samples]
output_path: Output path
sample_rate: Sample rate, if None uses self.sample_rate
"""
sr
=
sample_rate
if
sample_rate
else
self
.
sample_rate
torchaudio
.
save
(
output_path
,
speaker_audio
,
sr
)
logger
.
info
(
f
"Saved speaker audio to
{
output_path
}
"
)
def
speaker_audio_to_base64
(
self
,
speaker_audio
:
torch
.
Tensor
,
sample_rate
:
int
=
None
,
format
:
str
=
"wav"
)
->
str
:
"""
Convert speaker audio tensor to base64 encoded string without saving to file
Args:
speaker_audio: Audio tensor [channels, samples]
sample_rate: Sample rate, if None uses self.sample_rate
format: Audio format (default: "wav")
Returns:
Base64 encoded audio string
"""
sr
=
sample_rate
if
sample_rate
else
self
.
sample_rate
# Use BytesIO to save audio to memory
buffer
=
io
.
BytesIO
()
torchaudio
.
save
(
buffer
,
speaker_audio
,
sr
,
format
=
format
)
# Get the audio bytes
audio_bytes
=
buffer
.
getvalue
()
# Encode to base64
audio_base64
=
base64
.
b64encode
(
audio_bytes
).
decode
(
"utf-8"
)
logger
.
debug
(
f
"Converted speaker audio to base64, size:
{
len
(
audio_bytes
)
}
bytes"
)
return
audio_base64
def
separate_and_save
(
self
,
audio_path
:
Union
[
str
,
bytes
],
output_dir
:
str
,
num_speakers
:
Optional
[
int
]
=
None
,
min_speakers
:
int
=
1
,
max_speakers
:
int
=
5
,
)
->
Dict
:
"""
Separate audio and save to files
Args:
audio_path: Input audio path or bytes data
output_dir: Output directory
num_speakers: Specified number of speakers
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Separation result dictionary, containing output file paths
"""
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
result
=
self
.
separate_speakers
(
audio_path
,
num_speakers
,
min_speakers
,
max_speakers
)
output_paths
=
[]
for
speaker
in
result
[
"speakers"
]:
speaker_id
=
speaker
[
"speaker_id"
]
output_path
=
os
.
path
.
join
(
output_dir
,
f
"
{
speaker_id
}
.wav"
)
self
.
save_speaker_audio
(
speaker
[
"audio"
],
output_path
,
speaker
[
"sample_rate"
])
output_paths
.
append
(
output_path
)
speaker
[
"output_path"
]
=
output_path
result
[
"output_paths"
]
=
output_paths
return
result
def
separate_audio_tracks
(
audio_path
:
str
,
output_dir
:
str
=
None
,
num_speakers
:
int
=
None
,
model_path
:
str
=
None
,
)
->
Dict
:
"""
Convenience function: separate different audio tracks
Args:
audio_path: Audio file path
output_dir: Output directory, if None does not save files
num_speakers: Number of speakers
model_path: Model path (optional)
Returns:
Separation result dictionary
"""
separator
=
AudioSeparator
(
model_path
=
model_path
)
if
output_dir
:
return
separator
.
separate_and_save
(
audio_path
,
output_dir
,
num_speakers
=
num_speakers
)
else
:
return
separator
.
separate_speakers
(
audio_path
,
num_speakers
=
num_speakers
)
if
__name__
==
"__main__"
:
# Test code
import
sys
if
len
(
sys
.
argv
)
<
2
:
print
(
"Usage: python audio_separator.py <audio_path> [output_dir] [num_speakers]"
)
sys
.
exit
(
1
)
audio_path
=
sys
.
argv
[
1
]
output_dir
=
sys
.
argv
[
2
]
if
len
(
sys
.
argv
)
>
2
else
"./separated_audio"
num_speakers
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
None
separator
=
AudioSeparator
()
result
=
separator
.
separate_and_save
(
audio_path
,
output_dir
,
num_speakers
=
num_speakers
)
print
(
f
"Separated audio into
{
len
(
result
[
'speakers'
])
}
speakers:"
)
for
speaker
in
result
[
"speakers"
]:
print
(
f
" Speaker
{
speaker
[
'speaker_id'
]
}
:
{
len
(
speaker
[
'segments'
])
}
segments"
)
if
"output_path"
in
speaker
:
print
(
f
" Saved to:
{
speaker
[
'output_path'
]
}
"
)
lightx2v/deploy/common/face_detector.py
0 → 100644
View file @
e2778d0d
import
io
import
os
import
traceback
from
typing
import
Dict
,
List
,
Union
import
numpy
as
np
import
torch
from
PIL
import
Image
,
ImageDraw
from
loguru
import
logger
from
ultralytics
import
YOLO
# Try to import transformers for Grounding DINO
try
:
from
transformers
import
AutoModelForZeroShotObjectDetection
,
AutoProcessor
TRANSFORMERS_AVAILABLE
=
True
except
ImportError
:
TRANSFORMERS_AVAILABLE
=
False
logger
.
warning
(
"transformers not available, Grounding DINO method will not work"
)
class
FaceDetector
:
"""
Face detection using multiple methods
Supports three detection methods:
1. YOLO World (method='yolo'):
- Open-vocabulary detection
- Supports various face types: human, animal, anime, sketch
- More flexible but slower
- Can detect custom classes via text description
2. Grounding DINO (method='grounding'):
- Open-vocabulary object detection
- Supports various face types via text prompts
- Requires transformers library
- Good balance between accuracy and flexibility
"""
def
__init__
(
self
,
method
:
str
=
"yolo"
,
model_path
:
str
=
None
,
conf_threshold
:
float
=
None
,
device
:
str
=
None
,
custom_classes
:
List
[
str
]
=
None
,
cascade_path
:
str
=
None
,
):
"""
Initialize face detector
Args:
method: Detection method. Options:
- "yolo": Use YOLO World (supports various face types)
- "grounding": Use Grounding DINO (requires transformers)
Default: "yolo"
model_path: YOLO World model path (only used when method="yolo")
If None, uses default YOLO World model
conf_threshold: Confidence threshold (only used when method="yolo")
If None, uses adaptive threshold based on classes
device: Device for YOLO ('cpu', 'cuda', '0', '1', etc.), None for auto
custom_classes: List of custom class names for YOLO World. Default: ["face"]
Examples: ["face"], ["animal face"], ["human face", "animal face"]
"""
self
.
method
=
method
.
lower
()
self
.
device
=
device
if
self
.
method
==
"yolo"
:
# Initialize YOLO World detector
# Set custom classes (default to "face")
if
custom_classes
is
None
:
custom_classes
=
[
"human face"
,
"animal face"
,
"anime face"
,
"sketch face"
]
self
.
custom_classes
=
custom_classes
# Adaptive confidence threshold based on class specificity
if
conf_threshold
is
None
:
if
len
(
custom_classes
)
>
1
:
# Multiple classes: use lower threshold to catch all detections
conf_threshold
=
0.1
elif
len
(
custom_classes
)
==
1
:
class_name
=
custom_classes
[
0
].
lower
()
if
"face"
in
class_name
and
class_name
.
strip
()
==
"face"
:
# Generic "face" class: needs higher threshold but not too high
conf_threshold
=
0.15
else
:
# Specific class like "animal face": can use moderate threshold
conf_threshold
=
0.15
else
:
conf_threshold
=
0.25
self
.
conf_threshold
=
conf_threshold
if
model_path
is
None
:
# Use YOLO World model for open-vocabulary detection
logger
.
info
(
"Loading YOLO World model for face detection"
)
try
:
# Try to load YOLO World small model first (lighter and faster)
self
.
model
=
YOLO
(
"yolov8s-world.pt"
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to load yolov8s-world.pt, trying yolov8m-world.pt:
{
e
}
"
)
try
:
self
.
model
=
YOLO
(
"yolov8m-world.pt"
)
except
Exception
as
e2
:
logger
.
warning
(
f
"Failed to load yolov8m-world.pt, trying yolov8l-world.pt:
{
e2
}
"
)
self
.
model
=
YOLO
(
"yolov8l-world.pt"
)
# Set custom classes for YOLO World
# YOLO World can detect any object described in natural language
self
.
model
.
set_classes
(
self
.
custom_classes
)
else
:
logger
.
info
(
f
"Loading YOLO World model from
{
model_path
}
"
)
self
.
model
=
YOLO
(
model_path
)
logger
.
info
(
f
"Face detector initialized with YOLO World, custom classes:
{
self
.
custom_classes
}
, confidence threshold:
{
self
.
conf_threshold
}
"
)
self
.
face_cascade
=
None
elif
self
.
method
==
"grounding"
:
# Initialize Grounding DINO detector
if
not
TRANSFORMERS_AVAILABLE
:
raise
ImportError
(
"transformers library is required for Grounding DINO. Install it with: pip install transformers torch"
)
# Set up proxy for HuggingFace model download
# Check if proxy is already set, if not try to use common proxy settings
if
not
os
.
environ
.
get
(
"HTTP_PROXY"
)
and
not
os
.
environ
.
get
(
"http_proxy"
):
# Try to use HTTPS_PROXY for HTTP requests as well if available
https_proxy
=
os
.
environ
.
get
(
"HTTPS_PROXY"
)
or
os
.
environ
.
get
(
"https_proxy"
)
if
https_proxy
:
os
.
environ
[
"HTTP_PROXY"
]
=
https_proxy
os
.
environ
[
"http_proxy"
]
=
https_proxy
logger
.
info
(
f
"Using proxy from HTTPS_PROXY:
{
https_proxy
}
"
)
# Log proxy settings
http_proxy
=
os
.
environ
.
get
(
"HTTP_PROXY"
)
or
os
.
environ
.
get
(
"http_proxy"
)
https_proxy
=
os
.
environ
.
get
(
"HTTPS_PROXY"
)
or
os
.
environ
.
get
(
"https_proxy"
)
if
http_proxy
or
https_proxy
:
logger
.
info
(
f
"Using proxy - HTTP:
{
http_proxy
}
, HTTPS:
{
https_proxy
}
"
)
# Set custom classes (default to "face")
if
custom_classes
is
None
:
custom_classes
=
[
"human face"
,
"animal face"
,
"anime face"
,
"sketch face"
]
self
.
custom_classes
=
custom_classes
# Adaptive confidence threshold
if
conf_threshold
is
None
:
if
len
(
custom_classes
)
>
1
:
conf_threshold
=
0.1
else
:
conf_threshold
=
0.3
# Grounding DINO typically needs higher threshold
self
.
conf_threshold
=
conf_threshold
# Load Grounding DINO model
model_id
=
"IDEA-Research/grounding-dino-base"
# or "grounding-dino-tiny" for faster inference
if
model_path
is
not
None
:
model_id
=
model_path
logger
.
info
(
f
"Loading Grounding DINO model:
{
model_id
}
"
)
try
:
# Grounding DINO requires trust_remote_code=True
self
.
processor
=
AutoProcessor
.
from_pretrained
(
model_id
,
trust_remote_code
=
True
)
self
.
model
=
AutoModelForZeroShotObjectDetection
.
from_pretrained
(
model_id
,
trust_remote_code
=
True
)
if
device
:
self
.
model
=
self
.
model
.
to
(
device
)
logger
.
info
(
f
"Face detector initialized with Grounding DINO, custom classes:
{
self
.
custom_classes
}
, confidence threshold:
{
self
.
conf_threshold
}
"
)
except
Exception
as
e
:
error_msg
=
str
(
e
)
if
"connection"
in
error_msg
.
lower
()
or
"proxy"
in
error_msg
.
lower
()
or
"network"
in
error_msg
.
lower
():
logger
.
error
(
f
"Failed to download model. Please check your network connection and proxy settings."
)
logger
.
error
(
f
"Current proxy settings - HTTP_PROXY:
{
http_proxy
}
, HTTPS_PROXY:
{
https_proxy
}
"
)
logger
.
error
(
"You can set proxy with: export http_proxy=... && export https_proxy=..."
)
raise
self
.
face_cascade
=
None
else
:
raise
ValueError
(
f
"Unknown method:
{
method
}
. Must be 'yolo', or 'grounding'"
)
def
detect_faces
(
self
,
image
:
Union
[
str
,
Image
.
Image
,
bytes
,
np
.
ndarray
],
return_image
:
bool
=
False
,
)
->
Dict
:
"""
Detect faces in image
Args:
image: Input image, can be path, PIL Image, bytes or numpy array
return_image: Whether to return annotated image with detection boxes
Returns:
Dict containing:
- faces: List of face detection results, each containing:
- bbox: [x1, y1, x2, y2] bounding box coordinates (absolute pixel coordinates)
- confidence: Confidence score (0.0-1.0)
- class_id: Class ID
- class_name: Class name
- face_type: Type of face detected
- image (optional): PIL Image with detection boxes drawn (if return_image=True)
"""
try
:
if
self
.
method
==
"grounding"
:
return
self
.
_detect_faces_grounding
(
image
,
return_image
)
elif
self
.
method
==
"yolo"
:
return
self
.
_detect_faces_yolo
(
image
,
return_image
)
except
Exception
as
e
:
logger
.
error
(
f
"Face detection failed:
{
traceback
.
format_exc
()
}
"
)
raise
RuntimeError
(
f
"Face detection error:
{
e
}
"
)
def
_detect_faces_yolo
(
self
,
image
:
Union
[
str
,
Image
.
Image
,
bytes
,
np
.
ndarray
],
return_image
:
bool
=
False
,
)
->
Dict
:
"""Detect faces using YOLO World"""
# Load image
if
isinstance
(
image
,
str
):
img
=
Image
.
open
(
image
).
convert
(
"RGB"
)
elif
isinstance
(
image
,
bytes
):
img
=
Image
.
open
(
io
.
BytesIO
(
image
)).
convert
(
"RGB"
)
elif
isinstance
(
image
,
np
.
ndarray
):
img
=
Image
.
fromarray
(
image
).
convert
(
"RGB"
)
elif
isinstance
(
image
,
Image
.
Image
):
img
=
image
.
convert
(
"RGB"
)
else
:
raise
ValueError
(
f
"Unsupported image type:
{
type
(
image
)
}
"
)
# Use YOLO World for open-vocabulary detection
# YOLO World detects objects based on the custom classes set via set_classes()
results
=
self
.
model
.
predict
(
source
=
img
,
conf
=
self
.
conf_threshold
,
device
=
self
.
device
,
verbose
=
False
,
)
faces
=
[]
annotated_img
=
img
.
copy
()
if
return_image
else
None
if
len
(
results
)
>
0
:
result
=
results
[
0
]
boxes
=
result
.
boxes
if
boxes
is
not
None
and
len
(
boxes
)
>
0
:
for
i
in
range
(
len
(
boxes
)):
# Get bounding box coordinates (xyxy format)
bbox
=
boxes
.
xyxy
[
i
].
cpu
().
numpy
().
tolist
()
confidence
=
float
(
boxes
.
conf
[
i
].
cpu
().
numpy
())
class_id
=
int
(
boxes
.
cls
[
i
].
cpu
().
numpy
())
# Get class name from custom classes list
# YOLO World returns class_id that corresponds to index in custom_classes
if
class_id
<
len
(
self
.
custom_classes
):
class_name
=
self
.
custom_classes
[
class_id
]
else
:
class_name
=
result
.
names
.
get
(
class_id
,
"unknown"
)
# Determine face type based on class name
# For "face" class, it can detect all types of faces
if
class_name
.
lower
()
==
"face"
:
face_type
=
"face"
# Generic face type (can be human, animal, anime, etc.)
elif
any
(
keyword
in
class_name
.
lower
()
for
keyword
in
[
"human"
,
"person"
]):
face_type
=
"human"
elif
any
(
keyword
in
class_name
.
lower
()
for
keyword
in
[
"animal"
,
"cat"
,
"dog"
,
"bird"
,
"horse"
,
"sheep"
,
"cow"
,
"elephant"
,
"bear"
,
"zebra"
,
"giraffe"
]):
face_type
=
"animal"
elif
any
(
keyword
in
class_name
.
lower
()
for
keyword
in
[
"anime"
,
"cartoon"
,
"manga"
]):
face_type
=
"anime"
elif
any
(
keyword
in
class_name
.
lower
()
for
keyword
in
[
"sketch"
,
"line"
,
"drawing"
]):
face_type
=
"sketch"
else
:
logger
.
debug
(
f
"Dropped unused detected result:
{
class_name
}
"
)
face_type
=
None
face_info
=
{
"bbox"
:
bbox
,
# [x1, y1, x2, y2] - absolute pixel coordinates
"confidence"
:
confidence
,
"class_id"
:
class_id
,
"class_name"
:
class_name
,
"face_type"
:
face_type
,
}
if
face_type
is
not
None
:
faces
.
append
(
face_info
)
# Draw annotations on image if needed
if
return_image
and
annotated_img
is
not
None
:
draw
=
ImageDraw
.
Draw
(
annotated_img
)
x1
,
y1
,
x2
,
y2
=
bbox
# Draw bounding box
draw
.
rectangle
(
[
x1
,
y1
,
x2
,
y2
],
outline
=
"red"
,
width
=
2
,
)
# Draw label
label
=
f
"
{
class_name
}
{
confidence
:.
2
f
}
"
draw
.
text
((
x1
,
y1
-
15
),
label
,
fill
=
"red"
)
result_dict
=
{
"faces"
:
faces
}
if
return_image
and
annotated_img
is
not
None
:
result_dict
[
"image"
]
=
annotated_img
logger
.
info
(
f
"Detected
{
len
(
faces
)
}
faces using YOLO World"
)
return
result_dict
def
_calculate_iou
(
self
,
bbox1
:
List
[
float
],
bbox2
:
List
[
float
])
->
float
:
"""
Calculate Intersection over Union (IoU) between two bounding boxes
Args:
bbox1: [x1, y1, x2, y2] format
bbox2: [x1, y1, x2, y2] format
Returns:
IoU value between 0 and 1
"""
x1_1
,
y1_1
,
x2_1
,
y2_1
=
bbox1
x1_2
,
y1_2
,
x2_2
,
y2_2
=
bbox2
# Calculate intersection area
inter_x1
=
max
(
x1_1
,
x1_2
)
inter_y1
=
max
(
y1_1
,
y1_2
)
inter_x2
=
min
(
x2_1
,
x2_2
)
inter_y2
=
min
(
y2_1
,
y2_2
)
if
inter_x2
<=
inter_x1
or
inter_y2
<=
inter_y1
:
return
0.0
inter_area
=
(
inter_x2
-
inter_x1
)
*
(
inter_y2
-
inter_y1
)
# Calculate union area
area1
=
(
x2_1
-
x1_1
)
*
(
y2_1
-
y1_1
)
area2
=
(
x2_2
-
x1_2
)
*
(
y2_2
-
y1_2
)
union_area
=
area1
+
area2
-
inter_area
if
union_area
==
0
:
return
0.0
return
inter_area
/
union_area
def
_calculate_bbox_area
(
self
,
bbox
:
List
[
float
])
->
float
:
"""Calculate the area of a bounding box"""
x1
,
y1
,
x2
,
y2
=
bbox
return
(
x2
-
x1
)
*
(
y2
-
y1
)
def
_calculate_containment
(
self
,
bbox_small
:
List
[
float
],
bbox_large
:
List
[
float
])
->
float
:
"""
Calculate how much of bbox_small is contained in bbox_large
Returns the ratio of intersection area to bbox_small area
"""
x1_s
,
y1_s
,
x2_s
,
y2_s
=
bbox_small
x1_l
,
y1_l
,
x2_l
,
y2_l
=
bbox_large
# Calculate intersection
inter_x1
=
max
(
x1_s
,
x1_l
)
inter_y1
=
max
(
y1_s
,
y1_l
)
inter_x2
=
min
(
x2_s
,
x2_l
)
inter_y2
=
min
(
y2_s
,
y2_l
)
if
inter_x2
<=
inter_x1
or
inter_y2
<=
inter_y1
:
return
0.0
inter_area
=
(
inter_x2
-
inter_x1
)
*
(
inter_y2
-
inter_y1
)
small_area
=
(
x2_s
-
x1_s
)
*
(
y2_s
-
y1_s
)
if
small_area
==
0
:
return
0.0
return
inter_area
/
small_area
def
_apply_nms
(
self
,
faces
:
List
[
Dict
],
iou_threshold
:
float
=
0.4
,
containment_threshold
:
float
=
0.6
)
->
List
[
Dict
]:
"""
Apply Non-Maximum Suppression (NMS) to remove duplicate detections.
When detections overlap, keeps the one with larger area (preferring whole objects over parts).
Args:
faces: List of face detection dictionaries
iou_threshold: IoU threshold for considering detections as duplicates
containment_threshold: If a smaller box is contained in a larger box by this ratio, suppress it
Returns:
Filtered list of faces with duplicates removed
"""
if
len
(
faces
)
==
0
:
return
faces
# Sort by area (largest first), then by confidence as tie-breaker
# This ensures we keep the larger detection (whole object) over smaller ones (parts)
for
face
in
faces
:
face
[
"_area"
]
=
self
.
_calculate_bbox_area
(
face
[
"bbox"
])
sorted_faces
=
sorted
(
faces
,
key
=
lambda
x
:
(
x
[
"_area"
],
x
[
"confidence"
]),
reverse
=
True
)
keep
=
[]
suppressed
=
set
()
for
i
,
face
in
enumerate
(
sorted_faces
):
if
i
in
suppressed
:
continue
keep
.
append
(
face
)
bbox_i
=
face
[
"bbox"
]
area_i
=
face
[
"_area"
]
# Suppress overlapping detections (prefer larger area)
for
j
in
range
(
i
+
1
,
len
(
sorted_faces
)):
if
j
in
suppressed
:
continue
bbox_j
=
sorted_faces
[
j
][
"bbox"
]
area_j
=
sorted_faces
[
j
][
"_area"
]
# Check IoU overlap
iou
=
self
.
_calculate_iou
(
bbox_i
,
bbox_j
)
if
iou
>
iou_threshold
:
# If overlapping, suppress the smaller one
suppressed
.
add
(
j
)
continue
# Check if smaller box is mostly contained in larger box
# (e.g., face is contained in whole animal body)
# Since we sorted by area, area_i >= area_j for j > i
if
area_j
<
area_i
:
containment
=
self
.
_calculate_containment
(
bbox_j
,
bbox_i
)
if
containment
>
containment_threshold
:
suppressed
.
add
(
j
)
# Clean up temporary area field
for
face
in
keep
:
face
.
pop
(
"_area"
,
None
)
logger
.
info
(
f
"NMS filtered
{
len
(
faces
)
}
detections to
{
len
(
keep
)
}
(IoU threshold:
{
iou_threshold
}
, containment threshold:
{
containment_threshold
}
, prefer larger area)"
)
return
keep
def
_detect_faces_grounding
(
self
,
image
:
Union
[
str
,
Image
.
Image
,
bytes
,
np
.
ndarray
],
return_image
:
bool
=
False
,
)
->
Dict
:
"""Detect faces using Grounding DINO"""
# Load image
if
isinstance
(
image
,
str
):
img
=
Image
.
open
(
image
).
convert
(
"RGB"
)
elif
isinstance
(
image
,
bytes
):
img
=
Image
.
open
(
io
.
BytesIO
(
image
)).
convert
(
"RGB"
)
elif
isinstance
(
image
,
np
.
ndarray
):
img
=
Image
.
fromarray
(
image
).
convert
(
"RGB"
)
elif
isinstance
(
image
,
Image
.
Image
):
img
=
image
.
convert
(
"RGB"
)
else
:
raise
ValueError
(
f
"Unsupported image type:
{
type
(
image
)
}
"
)
# Prepare text prompt - join custom classes with ". " separator
text_prompt
=
". "
.
join
(
self
.
custom_classes
)
if
not
text_prompt
.
endswith
(
"."
):
text_prompt
+=
"."
# Process image and text
inputs
=
self
.
processor
(
images
=
img
,
text
=
text_prompt
,
return_tensors
=
"pt"
)
if
self
.
device
:
inputs
=
{
k
:
v
.
to
(
self
.
device
)
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
inputs
.
items
()}
# Run inference
with
torch
.
no_grad
():
outputs
=
self
.
model
(
**
inputs
)
# Post-process results
# Note: Grounding DINO uses 'threshold' instead of 'box_threshold'
results
=
self
.
processor
.
post_process_grounded_object_detection
(
outputs
,
input_ids
=
inputs
[
"input_ids"
],
threshold
=
self
.
conf_threshold
,
text_threshold
=
self
.
conf_threshold
,
target_sizes
=
[
img
.
size
[::
-
1
]],
# [height, width]
)
faces
=
[]
annotated_img
=
img
.
copy
()
if
return_image
else
None
if
len
(
results
)
>
0
:
result
=
results
[
0
]
# Get detections
# Use text_labels instead of labels to avoid FutureWarning
boxes
=
result
.
get
(
"boxes"
,
[])
text_labels
=
result
.
get
(
"text_labels"
,
[])
# Fallback to labels if text_labels not available
if
not
text_labels
:
labels
=
result
.
get
(
"labels"
,
[])
# Convert label IDs to class names if needed
text_labels
=
[
self
.
custom_classes
[
label
]
if
isinstance
(
label
,
int
)
and
label
<
len
(
self
.
custom_classes
)
else
str
(
label
)
for
label
in
labels
]
scores
=
result
.
get
(
"scores"
,
[])
for
i
,
(
box
,
label
,
score
)
in
enumerate
(
zip
(
boxes
,
text_labels
,
scores
)):
# Grounding DINO returns boxes as [x1, y1, x2, y2]
if
isinstance
(
box
,
torch
.
Tensor
):
bbox
=
box
.
tolist
()
else
:
bbox
=
list
(
box
)
# Ensure it's [x1, y1, x2, y2] format
if
len
(
bbox
)
==
4
:
bbox
=
[
float
(
bbox
[
0
]),
float
(
bbox
[
1
]),
float
(
bbox
[
2
]),
float
(
bbox
[
3
])]
else
:
# If it's in center format, convert
x_center
,
y_center
,
width
,
height
=
bbox
x1
=
x_center
-
width
/
2
y1
=
y_center
-
height
/
2
x2
=
x_center
+
width
/
2
y2
=
y_center
+
height
/
2
bbox
=
[
float
(
x1
),
float
(
y1
),
float
(
x2
),
float
(
y2
)]
# Get class name from label
# Grounding DINO may return multiple class names concatenated
class_name_raw
=
label
if
isinstance
(
label
,
str
)
else
self
.
custom_classes
[
0
]
# If class_name contains multiple classes, try to extract the most specific one
# Priority: specific classes (animal, anime, sketch) > human > generic face
class_name
=
class_name_raw
if
isinstance
(
class_name_raw
,
str
)
and
len
(
self
.
custom_classes
)
>
1
:
class_name_lower
=
class_name_raw
.
lower
()
# Check for specific classes first
if
any
(
keyword
in
class_name_lower
for
keyword
in
[
"animal"
]):
for
c
in
self
.
custom_classes
:
if
"animal"
in
c
.
lower
():
class_name
=
c
break
elif
any
(
keyword
in
class_name_lower
for
keyword
in
[
"anime"
,
"cartoon"
]):
for
c
in
self
.
custom_classes
:
if
any
(
k
in
c
.
lower
()
for
k
in
[
"anime"
,
"cartoon"
]):
class_name
=
c
break
elif
any
(
keyword
in
class_name_lower
for
keyword
in
[
"sketch"
,
"line"
,
"drawing"
]):
for
c
in
self
.
custom_classes
:
if
any
(
k
in
c
.
lower
()
for
k
in
[
"sketch"
,
"line"
,
"drawing"
]):
class_name
=
c
break
elif
any
(
keyword
in
class_name_lower
for
keyword
in
[
"human"
,
"person"
]):
for
c
in
self
.
custom_classes
:
if
any
(
k
in
c
.
lower
()
for
k
in
[
"human"
,
"person"
]):
class_name
=
c
break
# Determine face type based on class name
if
class_name
.
lower
()
==
"face"
:
face_type
=
"face"
elif
any
(
keyword
in
class_name
.
lower
()
for
keyword
in
[
"human"
,
"person"
]):
face_type
=
"human"
elif
any
(
keyword
in
class_name
.
lower
()
for
keyword
in
[
"animal"
,
"cat"
,
"dog"
,
"bird"
]):
face_type
=
"animal"
elif
any
(
keyword
in
class_name
.
lower
()
for
keyword
in
[
"anime"
,
"cartoon"
,
"manga"
]):
face_type
=
"anime"
elif
any
(
keyword
in
class_name
.
lower
()
for
keyword
in
[
"sketch"
,
"line"
,
"drawing"
]):
face_type
=
"sketch"
else
:
face_type
=
class_name
.
lower
()
face_info
=
{
"bbox"
:
bbox
,
"confidence"
:
float
(
score
),
"class_id"
:
i
,
"class_name"
:
class_name
,
"face_type"
:
face_type
,
}
faces
.
append
(
face_info
)
# Draw annotations if needed
if
return_image
and
annotated_img
is
not
None
:
draw
=
ImageDraw
.
Draw
(
annotated_img
)
x1
,
y1
,
x2
,
y2
=
bbox
draw
.
rectangle
([
x1
,
y1
,
x2
,
y2
],
outline
=
"red"
,
width
=
2
)
label
=
f
"
{
class_name
}
{
score
:.
2
f
}
"
draw
.
text
((
x1
,
y1
-
15
),
label
,
fill
=
"red"
)
# Apply NMS to remove duplicate detections (only when multiple classes are specified)
if
len
(
self
.
custom_classes
)
>
1
:
faces
=
self
.
_apply_nms
(
faces
,
iou_threshold
=
0.4
,
containment_threshold
=
0.6
)
# Re-draw annotations after NMS if needed
if
return_image
and
annotated_img
is
not
None
and
len
(
faces
)
>
0
:
annotated_img
=
img
.
copy
()
draw
=
ImageDraw
.
Draw
(
annotated_img
)
for
face
in
faces
:
x1
,
y1
,
x2
,
y2
=
face
[
"bbox"
]
draw
.
rectangle
([
x1
,
y1
,
x2
,
y2
],
outline
=
"red"
,
width
=
2
)
label
=
f
"
{
face
[
'class_name'
]
}
{
face
[
'confidence'
]:.
2
f
}
"
draw
.
text
((
x1
,
y1
-
15
),
label
,
fill
=
"red"
)
result_dict
=
{
"faces"
:
faces
}
if
return_image
and
annotated_img
is
not
None
:
result_dict
[
"image"
]
=
annotated_img
logger
.
info
(
f
"Detected
{
len
(
faces
)
}
faces using Grounding DINO (after NMS)"
)
return
result_dict
def
detect_faces_from_bytes
(
self
,
image_bytes
:
bytes
,
**
kwargs
)
->
Dict
:
"""
Detect faces from byte data
Args:
image_bytes: Image byte data
**kwargs: Additional parameters passed to detect_faces
Returns:
Detection result dictionary
"""
return
self
.
detect_faces
(
image_bytes
,
**
kwargs
)
def
extract_face_regions
(
self
,
image
:
Union
[
str
,
Image
.
Image
,
bytes
],
expand_ratio
:
float
=
0.1
)
->
List
[
Image
.
Image
]:
"""
Extract detected face regions
Args:
image: Input image
expand_ratio: Bounding box expansion ratio to include more context
Returns:
List of extracted face region images
"""
result
=
self
.
detect_faces
(
image
)
faces
=
result
[
"faces"
]
# Load original image
if
isinstance
(
image
,
str
):
img
=
Image
.
open
(
image
).
convert
(
"RGB"
)
elif
isinstance
(
image
,
bytes
):
img
=
Image
.
open
(
io
.
BytesIO
(
image
)).
convert
(
"RGB"
)
elif
isinstance
(
image
,
Image
.
Image
):
img
=
image
.
convert
(
"RGB"
)
else
:
raise
ValueError
(
f
"Unsupported image type:
{
type
(
image
)
}
"
)
face_regions
=
[]
img_width
,
img_height
=
img
.
size
for
face
in
faces
:
x1
,
y1
,
x2
,
y2
=
face
[
"bbox"
]
# Expand bounding box
width
=
x2
-
x1
height
=
y2
-
y1
expand_x
=
width
*
expand_ratio
expand_y
=
height
*
expand_ratio
x1
=
max
(
0
,
int
(
x1
-
expand_x
))
y1
=
max
(
0
,
int
(
y1
-
expand_y
))
x2
=
min
(
img_width
,
int
(
x2
+
expand_x
))
y2
=
min
(
img_height
,
int
(
y2
+
expand_y
))
# Crop region
face_region
=
img
.
crop
((
x1
,
y1
,
x2
,
y2
))
face_regions
.
append
(
face_region
)
return
face_regions
def
count_faces
(
self
,
image
:
Union
[
str
,
Image
.
Image
,
bytes
])
->
int
:
"""
Count number of faces in image
Args:
image: Input image
Returns:
Number of detected faces
"""
result
=
self
.
detect_faces
(
image
,
return_image
=
False
)
return
len
(
result
[
"faces"
])
def
detect_faces_in_image
(
image_path
:
str
,
method
:
str
=
"grounding"
,
model_path
:
str
=
None
,
conf_threshold
:
float
=
None
,
return_image
:
bool
=
False
,
custom_classes
:
List
[
str
]
=
None
,
)
->
Dict
:
"""
Convenience function: detect faces in image
Args:
image_path: Image path
method: Detection method ("yolo", or "grounding"), default "yolo"
model_path: YOLO World model path (only for method="yolo")
conf_threshold: Confidence threshold (None for adaptive, only for method="yolo")
return_image: Whether to return annotated image
custom_classes: List of custom class names for YOLO (default: ["face"])
Returns:
Detection result dictionary containing:
- faces: List of face detection results with bbox coordinates [x1, y1, x2, y2]
Each face contains: bbox, confidence, class_id, class_name, face_type
- image (optional): Annotated image with detection boxes
Examples:
# Detect faces using YOLO World with default "face" class
result = detect_faces_in_image("image.jpg", method="yolo")
# Detect with YOLO World and custom classes
result = detect_faces_in_image("image.jpg", method="yolo",
custom_classes=["human face", "animal face"])
# Detect with Grounding DINO
result = detect_faces_in_image("image.jpg", method="grounding",
custom_classes=["animal face"])
"""
detector
=
FaceDetector
(
method
=
method
,
model_path
=
model_path
,
conf_threshold
=
conf_threshold
,
custom_classes
=
custom_classes
,
)
return
detector
.
detect_faces
(
image_path
,
return_image
=
return_image
)
if
__name__
==
"__main__"
:
# Test code
import
sys
if
len
(
sys
.
argv
)
<
2
:
print
(
"Usage: python face_detector.py <image_path>"
)
sys
.
exit
(
1
)
image_path
=
sys
.
argv
[
1
]
detector
=
FaceDetector
()
result
=
detector
.
detect_faces
(
image_path
,
return_image
=
True
)
print
(
f
"Detected
{
len
(
result
[
'faces'
])
}
faces:"
)
for
i
,
face
in
enumerate
(
result
[
"faces"
]):
print
(
f
" Face
{
i
+
1
}
:
{
face
}
"
)
output_path
=
"detected_faces.png"
result
[
"image"
].
save
(
output_path
)
print
(
f
"Annotated image saved to:
{
output_path
}
"
)
lightx2v/deploy/common/pipeline.py
0 → 100644
View file @
e2778d0d
import
json
import
sys
from
loguru
import
logger
class
Pipeline
:
def
__init__
(
self
,
pipeline_json_file
):
self
.
pipeline_json_file
=
pipeline_json_file
x
=
json
.
load
(
open
(
pipeline_json_file
))
self
.
data
=
x
[
"data"
]
self
.
meta
=
x
[
"meta"
]
self
.
inputs
=
{}
self
.
outputs
=
{}
self
.
temps
=
{}
self
.
model_lists
=
[]
self
.
types
=
{}
self
.
queues
=
set
()
self
.
model_name_inner_to_outer
=
self
.
meta
.
get
(
"model_name_inner_to_outer"
,
{})
self
.
model_name_outer_to_inner
=
self
.
meta
.
get
(
"model_name_outer_to_inner"
,
{})
self
.
tidy_pipeline
()
def
init_dict
(
self
,
base
,
task
,
model_cls
):
if
task
not
in
base
:
base
[
task
]
=
{}
if
model_cls
not
in
base
[
task
]:
base
[
task
][
model_cls
]
=
{}
# tidy each task item eg, ['t2v', 'wan2.1', 'multi_stage']
def
tidy_task
(
self
,
task
,
model_cls
,
stage
,
v3
):
out2worker
=
{}
out2num
=
{}
cur_inps
=
set
()
cur_temps
=
set
()
cur_types
=
{}
for
worker_name
,
worker_item
in
v3
.
items
():
prevs
=
[]
for
inp
in
worker_item
[
"inputs"
]:
cur_types
[
inp
]
=
self
.
get_type
(
inp
)
if
inp
in
out2worker
:
prevs
.
append
(
out2worker
[
inp
])
out2num
[
inp
]
-=
1
if
out2num
[
inp
]
<=
0
:
cur_temps
.
add
(
inp
)
else
:
cur_inps
.
add
(
inp
)
worker_item
[
"previous"
]
=
prevs
for
out
in
worker_item
[
"outputs"
]:
cur_types
[
out
]
=
self
.
get_type
(
out
)
out2worker
[
out
]
=
worker_name
if
out
not
in
out2num
:
out2num
[
out
]
=
0
out2num
[
out
]
+=
1
if
"queue"
not
in
worker_item
:
worker_item
[
"queue"
]
=
"-"
.
join
([
task
,
model_cls
,
stage
,
worker_name
])
self
.
queues
.
add
(
worker_item
[
"queue"
])
cur_outs
=
[
out
for
out
,
num
in
out2num
.
items
()
if
num
>
0
]
self
.
inputs
[
task
][
model_cls
][
stage
]
=
list
(
cur_inps
)
self
.
outputs
[
task
][
model_cls
][
stage
]
=
cur_outs
self
.
temps
[
task
][
model_cls
][
stage
]
=
list
(
cur_temps
)
self
.
types
[
task
][
model_cls
][
stage
]
=
cur_types
# tidy previous dependence workers and queue name
def
tidy_pipeline
(
self
):
for
task
,
v1
in
self
.
data
.
items
():
for
model_cls
,
v2
in
v1
.
items
():
for
stage
,
v3
in
v2
.
items
():
self
.
init_dict
(
self
.
inputs
,
task
,
model_cls
)
self
.
init_dict
(
self
.
outputs
,
task
,
model_cls
)
self
.
init_dict
(
self
.
temps
,
task
,
model_cls
)
self
.
init_dict
(
self
.
types
,
task
,
model_cls
)
self
.
tidy_task
(
task
,
model_cls
,
stage
,
v3
)
self
.
model_lists
.
append
({
"task"
:
task
,
"model_cls"
:
model_cls
,
"stage"
:
stage
})
logger
.
info
(
f
"pipelines:
{
json
.
dumps
(
self
.
data
,
indent
=
4
)
}
"
)
logger
.
info
(
f
"inputs:
{
self
.
inputs
}
"
)
logger
.
info
(
f
"outputs:
{
self
.
outputs
}
"
)
logger
.
info
(
f
"temps:
{
self
.
temps
}
"
)
logger
.
info
(
f
"types:
{
self
.
types
}
"
)
logger
.
info
(
f
"model_lists:
{
self
.
model_lists
}
"
)
logger
.
info
(
f
"queues:
{
self
.
queues
}
"
)
def
get_item_by_keys
(
self
,
keys
):
item
=
self
.
data
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in
{
self
.
pipeline_json_file
}
!"
)
item
=
item
[
k
]
return
item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage', 'text_encoder']
def
get_worker
(
self
,
keys
):
return
self
.
get_item_by_keys
(
keys
)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_workers
(
self
,
keys
):
return
self
.
get_item_by_keys
(
keys
)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_inputs
(
self
,
keys
):
item
=
self
.
inputs
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in inputs!"
)
item
=
item
[
k
]
return
item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_outputs
(
self
,
keys
):
item
=
self
.
outputs
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in outputs!"
)
item
=
item
[
k
]
return
item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_temps
(
self
,
keys
):
item
=
self
.
temps
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in temps!"
)
item
=
item
[
k
]
return
item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_types
(
self
,
keys
):
item
=
self
.
types
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in types!"
)
item
=
item
[
k
]
return
item
def
check_item_by_keys
(
self
,
keys
):
item
=
self
.
data
for
k
in
keys
:
if
k
not
in
item
:
return
False
item
=
item
[
k
]
return
True
def
get_model_lists
(
self
):
return
self
.
model_lists
def
get_type
(
self
,
name
):
return
self
.
meta
[
"special_types"
].
get
(
name
,
"OBJECT"
)
def
get_monitor_config
(
self
):
return
self
.
meta
[
"monitor"
]
def
get_queues
(
self
):
return
self
.
queues
def
inner_model_name
(
self
,
name
):
return
self
.
model_name_outer_to_inner
.
get
(
name
,
name
)
def
outer_model_name
(
self
,
name
):
return
self
.
model_name_inner_to_outer
.
get
(
name
,
name
)
if
__name__
==
"__main__"
:
pipeline
=
Pipeline
(
sys
.
argv
[
1
])
print
(
pipeline
.
get_workers
([
"t2v"
,
"wan2.1"
,
"multi_stage"
]))
print
(
pipeline
.
get_worker
([
"i2v"
,
"wan2.1"
,
"multi_stage"
,
"dit"
]))
lightx2v/deploy/common/podcasts.py
0 → 100644
View file @
e2778d0d
# -*- coding: utf-8 -*-
import
asyncio
import
io
import
json
import
os
import
struct
import
uuid
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
typing
import
Callable
,
List
,
Optional
import
websockets
from
loguru
import
logger
from
pydub
import
AudioSegment
# Protocol definitions (from podcasts_protocols)
class
MsgType
(
IntEnum
):
"""Message type enumeration"""
Invalid
=
0
FullClientRequest
=
0b1
AudioOnlyClient
=
0b10
FullServerResponse
=
0b1001
AudioOnlyServer
=
0b1011
FrontEndResultServer
=
0b1100
Error
=
0b1111
ServerACK
=
AudioOnlyServer
class
MsgTypeFlagBits
(
IntEnum
):
"""Message type flag bits"""
NoSeq
=
0
PositiveSeq
=
0b1
LastNoSeq
=
0b10
NegativeSeq
=
0b11
WithEvent
=
0b100
class
VersionBits
(
IntEnum
):
"""Version bits"""
Version1
=
1
class
HeaderSizeBits
(
IntEnum
):
"""Header size bits"""
HeaderSize4
=
1
HeaderSize8
=
2
HeaderSize12
=
3
HeaderSize16
=
4
class
SerializationBits
(
IntEnum
):
"""Serialization method bits"""
Raw
=
0
JSON
=
0b1
Thrift
=
0b11
Custom
=
0b1111
class
CompressionBits
(
IntEnum
):
"""Compression method bits"""
None_
=
0
Gzip
=
0b1
Custom
=
0b1111
class
EventType
(
IntEnum
):
"""Event type enumeration"""
None_
=
0
StartConnection
=
1
StartTask
=
1
FinishConnection
=
2
FinishTask
=
2
ConnectionStarted
=
50
TaskStarted
=
50
ConnectionFailed
=
51
TaskFailed
=
51
ConnectionFinished
=
52
TaskFinished
=
52
StartSession
=
100
CancelSession
=
101
FinishSession
=
102
SessionStarted
=
150
SessionCanceled
=
151
SessionFinished
=
152
SessionFailed
=
153
UsageResponse
=
154
ChargeData
=
154
TaskRequest
=
200
UpdateConfig
=
201
AudioMuted
=
250
SayHello
=
300
TTSSentenceStart
=
350
TTSSentenceEnd
=
351
TTSResponse
=
352
TTSEnded
=
359
PodcastRoundStart
=
360
PodcastRoundResponse
=
361
PodcastRoundEnd
=
362
PodcastEnd
=
363
@
dataclass
class
Message
:
"""Message object"""
version
:
VersionBits
=
VersionBits
.
Version1
header_size
:
HeaderSizeBits
=
HeaderSizeBits
.
HeaderSize4
type
:
MsgType
=
MsgType
.
Invalid
flag
:
MsgTypeFlagBits
=
MsgTypeFlagBits
.
NoSeq
serialization
:
SerializationBits
=
SerializationBits
.
JSON
compression
:
CompressionBits
=
CompressionBits
.
None_
event
:
EventType
=
EventType
.
None_
session_id
:
str
=
""
connect_id
:
str
=
""
sequence
:
int
=
0
error_code
:
int
=
0
payload
:
bytes
=
b
""
@
classmethod
def
from_bytes
(
cls
,
data
:
bytes
)
->
"Message"
:
"""Create message object from bytes"""
if
len
(
data
)
<
3
:
raise
ValueError
(
f
"Data too short: expected at least 3 bytes, got
{
len
(
data
)
}
"
)
type_and_flag
=
data
[
1
]
msg_type
=
MsgType
(
type_and_flag
>>
4
)
flag
=
MsgTypeFlagBits
(
type_and_flag
&
0b00001111
)
msg
=
cls
(
type
=
msg_type
,
flag
=
flag
)
msg
.
unmarshal
(
data
)
return
msg
def
marshal
(
self
)
->
bytes
:
"""Serialize message to bytes"""
buffer
=
io
.
BytesIO
()
header
=
[
(
self
.
version
<<
4
)
|
self
.
header_size
,
(
self
.
type
<<
4
)
|
self
.
flag
,
(
self
.
serialization
<<
4
)
|
self
.
compression
,
]
header_size
=
4
*
self
.
header_size
if
padding
:
=
header_size
-
len
(
header
):
header
.
extend
([
0
]
*
padding
)
buffer
.
write
(
bytes
(
header
))
writers
=
self
.
_get_writers
()
for
writer
in
writers
:
writer
(
buffer
)
return
buffer
.
getvalue
()
def
unmarshal
(
self
,
data
:
bytes
)
->
None
:
"""Deserialize message from bytes"""
buffer
=
io
.
BytesIO
(
data
)
version_and_header_size
=
buffer
.
read
(
1
)[
0
]
self
.
version
=
VersionBits
(
version_and_header_size
>>
4
)
self
.
header_size
=
HeaderSizeBits
(
version_and_header_size
&
0b00001111
)
buffer
.
read
(
1
)
serialization_compression
=
buffer
.
read
(
1
)[
0
]
self
.
serialization
=
SerializationBits
(
serialization_compression
>>
4
)
self
.
compression
=
CompressionBits
(
serialization_compression
&
0b00001111
)
header_size
=
4
*
self
.
header_size
read_size
=
3
if
padding_size
:
=
header_size
-
read_size
:
buffer
.
read
(
padding_size
)
readers
=
self
.
_get_readers
()
for
reader
in
readers
:
reader
(
buffer
)
remaining
=
buffer
.
read
()
if
remaining
:
raise
ValueError
(
f
"Unexpected data after message:
{
remaining
}
"
)
def
_get_writers
(
self
)
->
List
[
Callable
[[
io
.
BytesIO
],
None
]]:
"""Get list of writer functions"""
writers
=
[]
if
self
.
flag
==
MsgTypeFlagBits
.
WithEvent
:
writers
.
extend
([
self
.
_write_event
,
self
.
_write_session_id
])
if
self
.
type
in
[
MsgType
.
FullClientRequest
,
MsgType
.
FullServerResponse
,
MsgType
.
FrontEndResultServer
,
MsgType
.
AudioOnlyClient
,
MsgType
.
AudioOnlyServer
]:
if
self
.
flag
in
[
MsgTypeFlagBits
.
PositiveSeq
,
MsgTypeFlagBits
.
NegativeSeq
]:
writers
.
append
(
self
.
_write_sequence
)
elif
self
.
type
==
MsgType
.
Error
:
writers
.
append
(
self
.
_write_error_code
)
else
:
raise
ValueError
(
f
"Unsupported message type:
{
self
.
type
}
"
)
writers
.
append
(
self
.
_write_payload
)
return
writers
def
_get_readers
(
self
)
->
List
[
Callable
[[
io
.
BytesIO
],
None
]]:
"""Get list of reader functions"""
readers
=
[]
if
self
.
type
in
[
MsgType
.
FullClientRequest
,
MsgType
.
FullServerResponse
,
MsgType
.
FrontEndResultServer
,
MsgType
.
AudioOnlyClient
,
MsgType
.
AudioOnlyServer
]:
if
self
.
flag
in
[
MsgTypeFlagBits
.
PositiveSeq
,
MsgTypeFlagBits
.
NegativeSeq
]:
readers
.
append
(
self
.
_read_sequence
)
elif
self
.
type
==
MsgType
.
Error
:
readers
.
append
(
self
.
_read_error_code
)
if
self
.
flag
==
MsgTypeFlagBits
.
WithEvent
:
readers
.
extend
([
self
.
_read_event
,
self
.
_read_session_id
,
self
.
_read_connect_id
])
readers
.
append
(
self
.
_read_payload
)
return
readers
def
_write_event
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
buffer
.
write
(
struct
.
pack
(
">i"
,
self
.
event
))
def
_write_session_id
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
if
self
.
event
in
[
EventType
.
StartConnection
,
EventType
.
FinishConnection
,
EventType
.
ConnectionStarted
,
EventType
.
ConnectionFailed
]:
return
session_id_bytes
=
self
.
session_id
.
encode
(
"utf-8"
)
size
=
len
(
session_id_bytes
)
if
size
>
0xFFFFFFFF
:
raise
ValueError
(
f
"Session ID size (
{
size
}
) exceeds max(uint32)"
)
buffer
.
write
(
struct
.
pack
(
">I"
,
size
))
if
size
>
0
:
buffer
.
write
(
session_id_bytes
)
def
_write_sequence
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
buffer
.
write
(
struct
.
pack
(
">i"
,
self
.
sequence
))
def
_write_error_code
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
buffer
.
write
(
struct
.
pack
(
">I"
,
self
.
error_code
))
def
_write_payload
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
size
=
len
(
self
.
payload
)
if
size
>
0xFFFFFFFF
:
raise
ValueError
(
f
"Payload size (
{
size
}
) exceeds max(uint32)"
)
buffer
.
write
(
struct
.
pack
(
">I"
,
size
))
buffer
.
write
(
self
.
payload
)
def
_read_event
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
event_bytes
=
buffer
.
read
(
4
)
if
event_bytes
:
self
.
event
=
EventType
(
struct
.
unpack
(
">i"
,
event_bytes
)[
0
])
def
_read_session_id
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
if
self
.
event
in
[
EventType
.
StartConnection
,
EventType
.
FinishConnection
,
EventType
.
ConnectionStarted
,
EventType
.
ConnectionFailed
,
EventType
.
ConnectionFinished
]:
return
size_bytes
=
buffer
.
read
(
4
)
if
size_bytes
:
size
=
struct
.
unpack
(
">I"
,
size_bytes
)[
0
]
if
size
>
0
:
session_id_bytes
=
buffer
.
read
(
size
)
if
len
(
session_id_bytes
)
==
size
:
self
.
session_id
=
session_id_bytes
.
decode
(
"utf-8"
)
def
_read_connect_id
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
if
self
.
event
in
[
EventType
.
ConnectionStarted
,
EventType
.
ConnectionFailed
,
EventType
.
ConnectionFinished
]:
size_bytes
=
buffer
.
read
(
4
)
if
size_bytes
:
size
=
struct
.
unpack
(
">I"
,
size_bytes
)[
0
]
if
size
>
0
:
self
.
connect_id
=
buffer
.
read
(
size
).
decode
(
"utf-8"
)
def
_read_sequence
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
sequence_bytes
=
buffer
.
read
(
4
)
if
sequence_bytes
:
self
.
sequence
=
struct
.
unpack
(
">i"
,
sequence_bytes
)[
0
]
def
_read_error_code
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
error_code_bytes
=
buffer
.
read
(
4
)
if
error_code_bytes
:
self
.
error_code
=
struct
.
unpack
(
">I"
,
error_code_bytes
)[
0
]
def
_read_payload
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
size_bytes
=
buffer
.
read
(
4
)
if
size_bytes
:
size
=
struct
.
unpack
(
">I"
,
size_bytes
)[
0
]
if
size
>
0
:
self
.
payload
=
buffer
.
read
(
size
)
async
def
receive_message
(
websocket
:
websockets
.
WebSocketClientProtocol
)
->
Message
:
"""Receive message from websocket"""
try
:
data
=
await
websocket
.
recv
()
if
isinstance
(
data
,
str
):
raise
ValueError
(
f
"Unexpected text message:
{
data
}
"
)
elif
isinstance
(
data
,
bytes
):
msg
=
Message
.
from_bytes
(
data
)
# logger.debug(f"Received: {msg}")
return
msg
else
:
raise
ValueError
(
f
"Unexpected message type:
{
type
(
data
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to receive message:
{
e
}
"
)
raise
async
def
wait_for_event
(
websocket
:
websockets
.
WebSocketClientProtocol
,
msg_type
:
MsgType
,
event_type
:
EventType
)
->
Message
:
"""Wait for specific event"""
while
True
:
msg
=
await
receive_message
(
websocket
)
if
msg
.
type
!=
msg_type
or
msg
.
event
!=
event_type
:
raise
ValueError
(
f
"Unexpected message:
{
msg
}
"
)
if
msg
.
type
==
msg_type
and
msg
.
event
==
event_type
:
return
msg
async
def
start_connection
(
websocket
:
websockets
.
WebSocketClientProtocol
)
->
None
:
"""Start connection"""
msg
=
Message
(
type
=
MsgType
.
FullClientRequest
,
flag
=
MsgTypeFlagBits
.
WithEvent
)
msg
.
event
=
EventType
.
StartConnection
msg
.
payload
=
b
"{}"
logger
.
debug
(
f
"Sending:
{
msg
}
"
)
await
websocket
.
send
(
msg
.
marshal
())
async
def
finish_connection
(
websocket
:
websockets
.
WebSocketClientProtocol
)
->
None
:
"""Finish connection"""
msg
=
Message
(
type
=
MsgType
.
FullClientRequest
,
flag
=
MsgTypeFlagBits
.
WithEvent
)
msg
.
event
=
EventType
.
FinishConnection
msg
.
payload
=
b
"{}"
logger
.
debug
(
f
"Sending:
{
msg
}
"
)
await
websocket
.
send
(
msg
.
marshal
())
async
def
start_session
(
websocket
:
websockets
.
WebSocketClientProtocol
,
payload
:
bytes
,
session_id
:
str
)
->
None
:
"""Start session"""
msg
=
Message
(
type
=
MsgType
.
FullClientRequest
,
flag
=
MsgTypeFlagBits
.
WithEvent
)
msg
.
event
=
EventType
.
StartSession
msg
.
session_id
=
session_id
msg
.
payload
=
payload
logger
.
debug
(
f
"Sending:
{
msg
}
"
)
await
websocket
.
send
(
msg
.
marshal
())
async
def
finish_session
(
websocket
:
websockets
.
WebSocketClientProtocol
,
session_id
:
str
)
->
None
:
"""Finish session"""
msg
=
Message
(
type
=
MsgType
.
FullClientRequest
,
flag
=
MsgTypeFlagBits
.
WithEvent
)
msg
.
event
=
EventType
.
FinishSession
msg
.
session_id
=
session_id
msg
.
payload
=
b
"{}"
logger
.
debug
(
f
"Sending:
{
msg
}
"
)
await
websocket
.
send
(
msg
.
marshal
())
class
PodcastRoundPostProcessor
:
def
__init__
(
self
,
session_id
,
data_manager
):
self
.
session_id
=
session_id
self
.
data_manager
=
data_manager
self
.
temp_merged_audio_name
=
"merged_audio.mp3"
self
.
output_merged_audio_name
=
f
"
{
session_id
}
-merged_audio.mp3"
self
.
subtitle_timestamps
=
[]
# 记录字幕时间戳
self
.
current_audio_duration
=
0.0
# 当前音频时长
self
.
merged_audio
=
None
# 用于存储合并的音频对象
self
.
merged_audio_bytes
=
None
async
def
init
(
self
):
if
self
.
data_manager
:
await
self
.
data_manager
.
create_podcast_temp_session_dir
(
self
.
session_id
)
async
def
postprocess_round
(
self
,
current_round
,
voice
,
audio
,
podcast_texts
):
text
=
""
if
podcast_texts
:
text
=
podcast_texts
[
-
1
].
get
(
"text"
,
""
)
logger
.
debug
(
f
"Processing round:
{
current_round
}
, voice:
{
voice
}
, text:
{
text
}
, audio:
{
len
(
audio
)
}
bytes"
)
new_segment
=
AudioSegment
.
from_mp3
(
io
.
BytesIO
(
bytes
(
audio
)))
round_duration
=
len
(
new_segment
)
/
1000.0
if
self
.
merged_audio
is
None
:
self
.
merged_audio
=
new_segment
else
:
self
.
merged_audio
=
self
.
merged_audio
+
new_segment
# 保存合并后的音频到临时文件(用于前端实时访问)
merged_io
=
io
.
BytesIO
()
self
.
merged_audio
.
export
(
merged_io
,
format
=
"mp3"
)
self
.
merged_audio_bytes
=
merged_io
.
getvalue
()
if
self
.
data_manager
:
await
self
.
data_manager
.
save_podcast_temp_session_file
(
self
.
session_id
,
self
.
temp_merged_audio_name
,
self
.
merged_audio_bytes
)
merged_file_size
=
len
(
self
.
merged_audio_bytes
)
# 记录字幕时间戳
self
.
subtitle_timestamps
.
append
(
{
"start"
:
self
.
current_audio_duration
,
"end"
:
self
.
current_audio_duration
+
round_duration
,
"text"
:
text
,
"speaker"
:
voice
,
}
)
self
.
current_audio_duration
+=
round_duration
logger
.
debug
(
f
"Merged audio updated:
{
merged_file_size
}
bytes, duration:
{
self
.
current_audio_duration
:.
2
f
}
s"
)
return
{
"url"
:
f
"/api/v1/podcast/audio?session_id=
{
self
.
session_id
}
&filename=
{
self
.
temp_merged_audio_name
}
"
,
"size"
:
merged_file_size
,
"duration"
:
self
.
current_audio_duration
,
"round"
:
current_round
,
"text"
:
text
,
"speaker"
:
voice
,
}
async
def
postprocess_final
(
self
):
if
self
.
data_manager
:
await
self
.
data_manager
.
save_podcast_output_file
(
self
.
output_merged_audio_name
,
self
.
merged_audio_bytes
)
return
{
"subtitles"
:
self
.
subtitle_timestamps
,
"audio_name"
:
self
.
output_merged_audio_name
,
}
async
def
cleanup
(
self
):
if
self
.
data_manager
:
await
self
.
data_manager
.
clear_podcast_temp_session_dir
(
self
.
session_id
)
self
.
data_manager
=
None
class
VolcEnginePodcastClient
:
"""
VolcEngine Podcast客户端
支持多种播客类型:
- action=0: 文本转播客
- action=3: NLP文本转播客
- action=4: 提示词生成播客
"""
def
__init__
(
self
):
self
.
endpoint
=
"wss://openspeech.bytedance.com/api/v3/sami/podcasttts"
self
.
appid
=
os
.
getenv
(
"VOLCENGINE_PODCAST_APPID"
)
self
.
access_token
=
os
.
getenv
(
"VOLCENGINE_PODCAST_ACCESS_TOKEN"
)
self
.
app_key
=
"aGjiRDfUWi"
self
.
proxy
=
os
.
getenv
(
"HTTPS_PROXY"
,
None
)
if
self
.
proxy
:
logger
.
info
(
f
"volcengine podcast use proxy:
{
self
.
proxy
}
"
)
async
def
podcast_request
(
self
,
session_id
:
str
,
data_manager
=
None
,
text
:
str
=
""
,
input_url
:
str
=
""
,
prompt_text
:
str
=
""
,
nlp_texts
:
str
=
""
,
action
:
int
=
0
,
resource_id
:
str
=
"volc.service_type.10050"
,
encoding
:
str
=
"mp3"
,
input_id
:
str
=
"test_podcast"
,
speaker_info
:
str
=
'{"random_order":false}'
,
use_head_music
:
bool
=
False
,
use_tail_music
:
bool
=
False
,
only_nlp_text
:
bool
=
False
,
return_audio_url
:
bool
=
False
,
skip_round_audio_save
:
bool
=
False
,
on_round_complete
:
Optional
[
Callable
]
=
None
,
):
"""
执行播客请求
Args:
text: 输入文本 (action=0时使用)
input_url: Web URL或文件URL (action=0时使用)
prompt_text: 提示词文本 (action=4时必须)
nlp_texts: NLP文本 (action=3时必须)
action: 播客类型 (0/3/4)
resource_id: 音频资源ID
encoding: 音频格式 (mp3/wav)
input_id: 唯一输入标识
speaker_info: 播客说话人信息
use_head_music: 是否使用开头音乐
use_tail_music: 是否使用结尾音乐
only_nlp_text: 是否只返回播客文本
return_audio_url: 是否返回音频URL
skip_round_audio_save: 是否跳过单轮音频保存
output_dir: 输出目录
on_round_complete: 轮次完成回调函数
"""
if
not
self
.
appid
or
not
self
.
access_token
:
logger
.
error
(
"APP ID or Access Key is required"
)
return
None
,
None
headers
=
{
"X-Api-App-Id"
:
self
.
appid
,
"X-Api-App-Key"
:
self
.
app_key
,
"X-Api-Access-Key"
:
self
.
access_token
,
"X-Api-Resource-Id"
:
resource_id
,
"X-Api-Connect-Id"
:
str
(
uuid
.
uuid4
()),
}
is_podcast_round_end
=
True
audio_received
=
False
last_round_id
=
-
1
task_id
=
""
websocket
=
None
retry_num
=
5
audio
=
bytearray
()
voice
=
""
current_round
=
0
podcast_texts
=
[]
post_processor
=
PodcastRoundPostProcessor
(
session_id
,
data_manager
)
await
post_processor
.
init
()
try
:
while
retry_num
>
0
:
# 建立WebSocket连接
websocket
=
await
websockets
.
connect
(
self
.
endpoint
,
additional_headers
=
headers
)
logger
.
debug
(
f
"WebSocket connected:
{
websocket
.
response
.
headers
}
"
)
# 构建请求参数
if
input_url
:
req_params
=
{
"input_id"
:
input_id
,
"nlp_texts"
:
json
.
loads
(
nlp_texts
)
if
nlp_texts
else
None
,
"prompt_text"
:
prompt_text
,
"action"
:
action
,
"use_head_music"
:
use_head_music
,
"use_tail_music"
:
use_tail_music
,
"input_info"
:
{
"input_url"
:
input_url
,
"return_audio_url"
:
return_audio_url
,
"only_nlp_text"
:
only_nlp_text
,
},
"speaker_info"
:
json
.
loads
(
speaker_info
)
if
speaker_info
else
None
,
"audio_config"
:
{
"format"
:
encoding
,
"sample_rate"
:
24000
,
"speech_rate"
:
0
},
}
else
:
req_params
=
{
"input_id"
:
input_id
,
"input_text"
:
text
,
"nlp_texts"
:
json
.
loads
(
nlp_texts
)
if
nlp_texts
else
None
,
"prompt_text"
:
prompt_text
,
"action"
:
action
,
"use_head_music"
:
use_head_music
,
"use_tail_music"
:
use_tail_music
,
"input_info"
:
{
"input_url"
:
input_url
,
"return_audio_url"
:
return_audio_url
,
"only_nlp_text"
:
only_nlp_text
,
},
"speaker_info"
:
json
.
loads
(
speaker_info
)
if
speaker_info
else
None
,
"audio_config"
:
{
"format"
:
encoding
,
"sample_rate"
:
24000
,
"speech_rate"
:
0
},
}
logger
.
debug
(
f
"Request params:
{
json
.
dumps
(
req_params
,
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
if
not
is_podcast_round_end
:
req_params
[
"retry_info"
]
=
{
"retry_task_id"
:
task_id
,
"last_finished_round_id"
:
last_round_id
}
# Start connection
await
start_connection
(
websocket
)
await
wait_for_event
(
websocket
,
MsgType
.
FullServerResponse
,
EventType
.
ConnectionStarted
)
session_id
=
str
(
uuid
.
uuid4
())
if
not
task_id
:
task_id
=
session_id
# Start session
await
start_session
(
websocket
,
json
.
dumps
(
req_params
).
encode
(),
session_id
)
await
wait_for_event
(
websocket
,
MsgType
.
FullServerResponse
,
EventType
.
SessionStarted
)
# Finish session
await
finish_session
(
websocket
,
session_id
)
while
True
:
msg
=
await
receive_message
(
websocket
)
# 音频数据块
if
msg
.
type
==
MsgType
.
AudioOnlyServer
and
msg
.
event
==
EventType
.
PodcastRoundResponse
:
if
not
audio_received
and
audio
:
audio_received
=
True
audio
.
extend
(
msg
.
payload
)
# 错误信息
elif
msg
.
type
==
MsgType
.
Error
:
raise
RuntimeError
(
f
"Server error:
{
msg
.
payload
.
decode
()
}
"
)
elif
msg
.
type
==
MsgType
.
FullServerResponse
:
# 播客 round 开始
if
msg
.
event
==
EventType
.
PodcastRoundStart
:
data
=
json
.
loads
(
msg
.
payload
.
decode
())
if
data
.
get
(
"text"
):
filtered_payload
=
{
"text"
:
data
.
get
(
"text"
),
"speaker"
:
data
.
get
(
"speaker"
)}
podcast_texts
.
append
(
filtered_payload
)
voice
=
data
.
get
(
"speaker"
)
current_round
=
data
.
get
(
"round_id"
)
if
current_round
==
-
1
:
voice
=
"head_music"
if
current_round
==
9999
:
voice
=
"tail_music"
is_podcast_round_end
=
False
logger
.
debug
(
f
"New round started:
{
data
}
"
)
# 播客 round 结束
if
msg
.
event
==
EventType
.
PodcastRoundEnd
:
data
=
json
.
loads
(
msg
.
payload
.
decode
())
logger
.
debug
(
f
"Podcast round end:
{
data
}
"
)
if
data
.
get
(
"is_error"
):
break
is_podcast_round_end
=
True
last_round_id
=
current_round
if
audio
:
round_info
=
await
post_processor
.
postprocess_round
(
current_round
,
voice
,
audio
,
podcast_texts
)
if
on_round_complete
:
await
on_round_complete
(
round_info
)
audio
.
clear
()
# 播客结束
if
msg
.
event
==
EventType
.
PodcastEnd
:
data
=
json
.
loads
(
msg
.
payload
.
decode
())
logger
.
info
(
f
"Podcast end:
{
data
}
"
)
# 会话结束
if
msg
.
event
==
EventType
.
SessionFinished
:
break
if
not
audio_received
and
not
only_nlp_text
:
raise
RuntimeError
(
"No audio data received"
)
# 保持连接
await
finish_connection
(
websocket
)
await
wait_for_event
(
websocket
,
MsgType
.
FullServerResponse
,
EventType
.
ConnectionFinished
)
# 播客结束, 保存最终音频文件
if
is_podcast_round_end
:
podcast_info
=
await
post_processor
.
postprocess_final
()
return
podcast_info
else
:
logger
.
error
(
f
"Current podcast not finished, resuming from round
{
last_round_id
}
"
)
retry_num
-=
1
await
asyncio
.
sleep
(
1
)
if
websocket
:
await
websocket
.
close
()
finally
:
await
post_processor
.
cleanup
()
if
websocket
:
await
websocket
.
close
()
return
None
async
def
test
(
args
):
"""
Podcast测试函数
Args:
args: dict, 包含所有podcast参数
"""
client
=
VolcEnginePodcastClient
()
# 设置默认参数
params
=
{
"text"
:
""
,
"input_url"
:
"https://zhuanlan.zhihu.com/p/607822576"
,
"prompt_text"
:
""
,
"nlp_texts"
:
""
,
"action"
:
0
,
"resource_id"
:
"volc.service_type.10050"
,
"encoding"
:
"mp3"
,
"input_id"
:
"test_podcast"
,
"speaker_info"
:
'{"random_order":false}'
,
"use_head_music"
:
False
,
"use_tail_music"
:
False
,
"only_nlp_text"
:
False
,
"return_audio_url"
:
True
,
"skip_round_audio_save"
:
False
,
"output_dir"
:
"output"
,
}
# 覆盖默认参数
if
args
:
params
.
update
(
args
)
await
client
.
podcast_request
(
**
params
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--text"
,
default
=
""
,
help
=
"Input text Use when action in [0]"
)
parser
.
add_argument
(
"--input_url"
,
default
=
""
,
help
=
"Web url or file url Use when action in [0]"
)
parser
.
add_argument
(
"--prompt_text"
,
default
=
""
,
help
=
"Input Prompt Text must not empty when action in [4]"
)
parser
.
add_argument
(
"--nlp_texts"
,
default
=
""
,
help
=
"Input NLP Texts must not empty when action in [3]"
)
parser
.
add_argument
(
"--resource_id"
,
default
=
"volc.service_type.10050"
,
help
=
"Audio Resource ID"
)
parser
.
add_argument
(
"--encoding"
,
default
=
"mp3"
,
choices
=
[
"mp3"
,
"wav"
],
help
=
"Audio format"
)
parser
.
add_argument
(
"--input_id"
,
default
=
"test_podcast"
,
help
=
"Unique input identifier"
)
parser
.
add_argument
(
"--speaker_info"
,
default
=
'{"random_order":false}'
,
help
=
"Podcast Speaker Info"
)
parser
.
add_argument
(
"--use_head_music"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Enable head music"
)
parser
.
add_argument
(
"--use_tail_music"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Enable tail music"
)
parser
.
add_argument
(
"--only_nlp_text"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Enable only podcast text when action in [0, 4]"
)
parser
.
add_argument
(
"--return_audio_url"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Enable return audio url that can download"
)
parser
.
add_argument
(
"--action"
,
default
=
0
,
type
=
int
,
choices
=
[
0
,
3
,
4
],
help
=
"different podcast type"
)
parser
.
add_argument
(
"--skip_round_audio_save"
,
default
=
False
,
action
=
"store_true"
,
help
=
"skip round audio save"
)
parser
.
add_argument
(
"--output_dir"
,
default
=
"output"
,
help
=
"Output directory"
)
args
=
parser
.
parse_args
()
kwargs
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
v
is
not
None
and
not
(
isinstance
(
v
,
bool
)
and
not
v
)}
asyncio
.
run
(
test
(
kwargs
))
lightx2v/deploy/common/sensetime_voice_clone.py
0 → 100644
View file @
e2778d0d
# -*- coding: utf-8 -*-
import
asyncio
import
os
import
struct
import
subprocess
import
sys
import
tempfile
import
time
import
uuid
from
typing
import
Optional
,
Tuple
import
aiohttp
import
numpy
as
np
import
soundfile
as
sf
from
aiohttp
import
ClientWebSocketResponse
# Protobuf imports
from
google.protobuf
import
descriptor
as
_descriptor
from
google.protobuf
import
descriptor_pool
as
_descriptor_pool
from
google.protobuf
import
symbol_database
as
_symbol_database
from
google.protobuf.internal
import
builder
as
_builder
from
loguru
import
logger
# ============================================================================
# Generated protocol buffer code (from tts.proto)
# ============================================================================
_sym_db
=
_symbol_database
.
Default
()
DESCRIPTOR
=
_descriptor_pool
.
Default
().
AddSerializedFile
(
b
'
\n\t
tts.proto
\x12\x03
tts"
\x8a\x01\n\r
SubtitleEntry
\x12\x15\n\r
start_time_ms
\x18\x01
\x01
(
\r\x12\x13\n\x0b\x65
nd_time_ms
\x18\x02
\x01
(
\r\x12\x0f\n\x07
speaker
\x18\x03
\x01
(
\t\x12\r\n\x05
style
\x18\x04
\x01
(
\t\x12\x1f\n\x08
language
\x18\x05
\x01
(
\x0e\x32\r
.tts.Language
\x12\x0c\n\x04
text
\x18\x06
\x01
(
\t
"
\x88\x01\n\n
AudioChunk
\x12\x12\n\n
audio_data
\x18\x01
\x01
(
\x0c\x12\x17\n\x0f\x61
udio_chunk_seq
\x18\x02
\x01
(
\x05\x12\x15\n\r
is_last_chunk
\x18\x03
\x01
(
\x08\x12\x0c\n\x04
text
\x18\x04
\x01
(
\t\x12\x14\n\x0c\x61
udio_format
\x18\x05
\x01
(
\t\x12\x12\n\n
disable_ns
\x18\x06
\x01
(
\x08
"
\x84\x04\n\n
TtsRequest
\x12
-
\n\x0c
message_type
\x18\x01
\x01
(
\x0e\x32\x17
.tts.RequestMessageType
\x12\x0e\n\x06\x61
pp_id
\x18\x02
\x01
(
\t\x12\x15\n\r
app_signature
\x18\x03
\x01
(
\t\x12\x0c\n\x04
text
\x18\x04
\x01
(
\t\x12\x16\n\x0e
text_chunk_seq
\x18\x05
\x01
(
\x05\x12\x1a\n\x12
is_last_text_chunk
\x18\x06
\x01
(
\x08\x12
\n\t
text_type
\x18\x07
\x01
(
\x0e\x32\r
.tts.TextType
\x12\x0f\n\x07
speaker
\x18\x08
\x01
(
\t\x12\x1f\n\x08
language
\x18\t
\x01
(
\x0e\x32\r
.tts.Language
\x12\r\n\x05
style
\x18\n
\x01
(
\t\x12\r\n\x05
speed
\x18\x0b
\x01
(
\x02\x12\x0e\n\x06
volume
\x18\x0c
\x01
(
\x02\x12\r\n\x05
pitch
\x18\r
\x01
(
\x02\x12\x15\n\r
stream_output
\x18\x0e
\x01
(
\x08\x12\x19\n\x11\x61
udio_sample_rate
\x18\x0f
\x01
(
\x05\x12
*
\n\x0e\x61
udio_encoding
\x18\x10
\x01
(
\x0e\x32\x12
.tts.AudioEncoding
\x12\x18\n\x10
output_subtitles
\x18\x11
\x01
(
\x08\x12\x12\n\n
session_id
\x18\x12
\x01
(
\t\x12
%
\n\x0c
upload_audio
\x18\x13
\x01
(
\x0b\x32\x0f
.tts.AudioChunk
\x12\x1a\n\x12
pronunciation_dict
\x18\x14
\x03
(
\t
"
\xe7\x02\n\x0b
TtsResponse
\x12
$
\n\x0b
status_code
\x18\x01
\x01
(
\x0e\x32\x0f
.tts.StatusCode
\x12\x14\n\x0c\x65
rror_detail
\x18\x02
\x01
(
\t\x12\x14\n\x0c
time_cost_ms
\x18\x03
\x01
(
\r\x12
*
\n\x0e\x61
udio_encoding
\x18\x04
\x01
(
\x0e\x32\x12
.tts.AudioEncoding
\x12\x17\n\x0f\x61
udio_chunk_seq
\x18\x05
\x01
(
\x05\x12\x12\n\n
audio_data
\x18\x06
\x01
(
\x0c\x12\x1b\n\x13
is_last_audio_chunk
\x18\x07
\x01
(
\x08\x12\x12\n\n
session_id
\x18\x08
\x01
(
\t\x12
%
\n\t
subtitles
\x18\t
\x03
(
\x0b\x32\x12
.tts.SubtitleEntry
\x12\x0f\n\x07
speaker
\x18\n
\x01
(
\t\x12\x1a\n\x12
request_char_count
\x18\x0b
\x01
(
\r\x12
(
\n\r
error_subcode
\x18\x0c
\x01
(
\x0e\x32\x11
.tts.ErrorSubCode*
\xa9\x01\n\x12
RequestMessageType
\x12\x1c\n\x18\x43
LIENT_SYNTHESIS_REQUEST
\x10\x00\x12\x19\n\x15\x43
LIENT_FINISH_REQUEST
\x10\x01\x12\x1d\n\x19\x43
LIENT_UPLOAD_CLONE_AUDIO
\x10\x02\x12\x1c\n\x18\x43
LIENT_QUERY_CLONE_AUDIO
\x10\x03\x12\x1d\n\x19\x43
LIENT_DELETE_CLONE_AUDIO
\x10\x04
*
\x1f\n\x08
TextType
\x12\t\n\x05
PLAIN
\x10\x00\x12\x08\n\x04
SSML
\x10\x01
*A
\n\x08
Language
\x12\t\n\x05
ZH_CN
\x10\x00\x12\t\n\x05\x45
N_US
\x10\x01\x12\x11\n\r
ZH_CN_SICHUAN
\x10\x02\x12\x0c\n\x08
ZH_CN_HK
\x10\x03
**
\n\r
AudioEncoding
\x12\x07\n\x03
PCM
\x10\x00\x12\x07\n\x03
WAV
\x10\x01\x12\x07\n\x03
MP3
\x10\x02
*
\xa7\x01\n\n
StatusCode
\x12\x0b\n\x07
SUCCESS
\x10\x00\x12\t\n\x05\x45
RROR
\x10\x01\x12\x0b\n\x07
TIMEOUT
\x10\x02\x12\x13\n\x0f
INVALID_REQUEST
\x10\x03\x12\x12\n\x0e
INTERNAL_ERROR
\x10\x04\x12\x18\n\x14
UPLOAD_AUDIO_SUCCESS
\x10\x05\x12\x17\n\x13
QUERY_AUDIO_SUCCESS
\x10\x06\x12\x18\n\x14\x44\x45
LETE_AUDIO_SUCCESS
\x10\x07
*
\xe1\x02\n\x0c\x45
rrorSubCode
\x12\x0c\n\x08\x45
RR_NONE
\x10\x00\x12\x16\n\x12\x45
RR_BASE_FILE_READ
\x10\x65\x12\x17\n\x13\x45
RR_BASE_FILE_WRITE
\x10\x66\x12\x1c\n\x18\x45
RR_BASE_INVALID_SEQ_NUM
\x10
g
\x12\x1e\n\x1a\x45
RR_BASE_SPEAKER_NOT_FOUND
\x10
h
\x12\x14\n\x0f\x45
RR_AC_INTERNAL
\x10\xc9\x01\x12\x16\n\x11\x45
RR_AC_LONG_AUDIO
\x10\xca\x01\x12\x15\n\x10\x45
RR_AC_LONG_TEXT
\x10\xcb\x01\x12\x1f\n\x1a\x45
RR_AC_AUDIO_TEXT_MISMATCH
\x10\xcc\x01\x12
\n\x1b\x45
RR_AC_UNAUTHORIZED_SPEAKER
\x10\xcd\x01\x12\x1b\n\x16\x45
RR_AC_INVALID_SPEAKER
\x10\xce\x01\x12\x17\n\x12\x45
RR_AC_SHORT_AUDIO
\x10\xcf\x01\x12\x16\n\x11\x45
RR_AC_SHORT_TEXT
\x10\xd0\x01\x62\x06
proto3'
)
_globals
=
globals
()
_builder
.
BuildMessageAndEnumDescriptors
(
DESCRIPTOR
,
_globals
)
_builder
.
BuildTopDescriptorsAndMessages
(
DESCRIPTOR
,
"tts_pb2"
,
_globals
)
if
not
_descriptor
.
_USE_C_DESCRIPTORS
:
DESCRIPTOR
.
_options
=
None
_globals
[
"_REQUESTMESSAGETYPE"
].
_serialized_start
=
1180
_globals
[
"_REQUESTMESSAGETYPE"
].
_serialized_end
=
1349
_globals
[
"_TEXTTYPE"
].
_serialized_start
=
1351
_globals
[
"_TEXTTYPE"
].
_serialized_end
=
1382
_globals
[
"_LANGUAGE"
].
_serialized_start
=
1384
_globals
[
"_LANGUAGE"
].
_serialized_end
=
1449
_globals
[
"_AUDIOENCODING"
].
_serialized_start
=
1451
_globals
[
"_AUDIOENCODING"
].
_serialized_end
=
1493
_globals
[
"_STATUSCODE"
].
_serialized_start
=
1496
_globals
[
"_STATUSCODE"
].
_serialized_end
=
1663
_globals
[
"_ERRORSUBCODE"
].
_serialized_start
=
1666
_globals
[
"_ERRORSUBCODE"
].
_serialized_end
=
2019
_globals
[
"_SUBTITLEENTRY"
].
_serialized_start
=
19
_globals
[
"_SUBTITLEENTRY"
].
_serialized_end
=
157
_globals
[
"_AUDIOCHUNK"
].
_serialized_start
=
160
_globals
[
"_AUDIOCHUNK"
].
_serialized_end
=
296
_globals
[
"_TTSREQUEST"
].
_serialized_start
=
299
_globals
[
"_TTSREQUEST"
].
_serialized_end
=
815
_globals
[
"_TTSRESPONSE"
].
_serialized_start
=
818
_globals
[
"_TTSRESPONSE"
].
_serialized_end
=
1177
# Import protobuf classes for easier access
# These are created by the protobuf builder above and added to _globals
# ============================================================================
# Get protobuf classes from _globals (they are created by the builder)
SubtitleEntry
=
_globals
.
get
(
"SubtitleEntry"
)
AudioChunk
=
_globals
.
get
(
"AudioChunk"
)
TtsRequest
=
_globals
.
get
(
"TtsRequest"
)
TtsResponse
=
_globals
.
get
(
"TtsResponse"
)
RequestMessageType
=
_globals
.
get
(
"RequestMessageType"
)
TextType
=
_globals
.
get
(
"TextType"
)
Language
=
_globals
.
get
(
"Language"
)
AudioEncoding
=
_globals
.
get
(
"AudioEncoding"
)
StatusCode
=
_globals
.
get
(
"StatusCode"
)
ErrorSubCode
=
_globals
.
get
(
"ErrorSubCode"
)
# Verify that all required classes are available
if
not
all
([
SubtitleEntry
,
AudioChunk
,
TtsRequest
,
TtsResponse
,
RequestMessageType
,
TextType
,
Language
,
AudioEncoding
,
StatusCode
,
ErrorSubCode
]):
raise
RuntimeError
(
"Failed to load protobuf classes. Please check protobuf installation."
)
# ============================================================================
# Configuration parameters
RECEIVE_TIMEOUT
=
30
# Receive timeout (seconds)
# Language mapping
lang_id2str_mapping
=
{
Language
.
ZH_CN
:
"ZH_CN"
,
Language
.
ZH_CN_SICHUAN
:
"ZH_CN_SICHUAN"
,
Language
.
ZH_CN_HK
:
"ZH_CN_HK"
,
Language
.
EN_US
:
"EN_US"
,
}
lang_str2id_mapping
=
{
v
:
k
for
k
,
v
in
lang_id2str_mapping
.
items
()}
# Audio encoding mapping
codec_id2str_mapping
=
{
AudioEncoding
.
PCM
:
"pcm"
,
AudioEncoding
.
WAV
:
"wav"
,
AudioEncoding
.
MP3
:
"mp3"
,
}
codec_str2id_mapping
=
{
v
:
k
for
k
,
v
in
codec_id2str_mapping
.
items
()}
def
parse_response
(
protocol_type
:
int
,
data
:
bytes
)
->
TtsResponse
:
try
:
response
=
TtsResponse
()
response
.
ParseFromString
(
data
)
return
response
except
Exception
as
e
:
raise
ValueError
(
f
"Failed to parse response:
{
str
(
e
)
}
"
)
def
create_synthesis_request
(
message_type
,
text
:
str
,
text_chunk_seq
:
int
=
0
,
is_last_text_chunk
:
bool
=
False
,
app_id
:
str
=
""
,
app_signature
:
str
=
""
,
text_type
:
TextType
=
TextType
.
PLAIN
,
speaker
:
str
=
"default"
,
language
:
Language
=
Language
.
ZH_CN
,
style
:
str
=
""
,
speed
:
float
=
1
,
volume
:
float
=
0
,
pitch
:
float
=
0
,
stream_output
:
bool
=
True
,
audio_sample_rate
:
int
=
24000
,
audio_encoding
:
AudioEncoding
=
AudioEncoding
.
PCM
,
output_subtitles
:
bool
=
False
,
session_id
:
str
=
""
,
upload_data
:
Optional
[
AudioChunk
]
=
None
,
)
->
TtsRequest
:
request
=
TtsRequest
()
request
.
message_type
=
message_type
request
.
app_id
=
app_id
request
.
text
=
text
request
.
text_chunk_seq
=
text_chunk_seq
request
.
is_last_text_chunk
=
is_last_text_chunk
request
.
text_type
=
text_type
request
.
speaker
=
speaker
request
.
language
=
language
request
.
style
=
style
request
.
speed
=
speed
request
.
volume
=
volume
request
.
pitch
=
pitch
request
.
stream_output
=
stream_output
request
.
audio_sample_rate
=
audio_sample_rate
request
.
audio_encoding
=
audio_encoding
request
.
output_subtitles
=
output_subtitles
request
.
session_id
=
session_id
if
upload_data
is
not
None
:
request
.
upload_audio
.
CopyFrom
(
upload_data
)
return
request
def
serialize_request
(
request
:
TtsRequest
)
->
bytes
:
request_bytes
=
request
.
SerializeToString
()
request_length
=
struct
.
pack
(
"!I"
,
len
(
request_bytes
))
full_request
=
b
"
\x01
"
+
request_length
+
request_bytes
return
full_request
async
def
receive_full_message
(
websocket
:
ClientWebSocketResponse
)
->
Tuple
[
int
,
bytes
]:
try
:
# Receive data
message
=
await
asyncio
.
wait_for
(
websocket
.
receive_bytes
(),
timeout
=
RECEIVE_TIMEOUT
)
if
len
(
message
)
<
5
:
raise
ValueError
(
"Invalid response: too short"
)
protocol_type
=
message
[
0
]
if
protocol_type
!=
0x01
:
raise
ValueError
(
"Unsupported protocol type"
)
protocol_length
=
struct
.
unpack
(
"!I"
,
message
[
1
:
5
])[
0
]
data
=
message
[
5
:]
if
len
(
data
)
!=
protocol_length
:
logger
.
info
(
f
"Length error
{
protocol_length
}
, got
{
len
(
data
)
}
"
)
# If data is incomplete, continue receiving
while
len
(
data
)
<
protocol_length
:
try
:
chunk
=
await
asyncio
.
wait_for
(
websocket
.
receive_bytes
(),
timeout
=
RECEIVE_TIMEOUT
)
if
not
chunk
:
raise
ValueError
(
"Got disconnected or empty data"
)
data
+=
chunk
logger
.
info
(
f
"Received additional
{
len
(
chunk
)
}
bytes, total
{
len
(
data
)
}
/
{
protocol_length
}
"
)
except
asyncio
.
TimeoutError
:
raise
ValueError
(
f
"Timeout while receiving message. Got
{
len
(
data
)
}
/
{
protocol_length
}
bytes"
)
return
protocol_type
,
data
except
asyncio
.
TimeoutError
:
raise
ValueError
(
f
"Response timed out after
{
RECEIVE_TIMEOUT
}
seconds"
)
except
aiohttp
.
WSServerHandshakeError
as
e
:
# WebSocket handshake error, may contain error information
error_msg
=
f
"WebSocket handshake error:
{
str
(
e
)
}
"
if
hasattr
(
e
,
"message"
)
and
e
.
message
:
error_msg
=
e
.
message
raise
ValueError
(
error_msg
)
except
Exception
as
e
:
error_str
=
str
(
e
)
# Check if it's a WebSocket close message error
if
"1009"
in
error_str
:
raise
ValueError
(
"Audio file too large or format not supported. Please use WAV/MP3 audio file (max size limit)."
)
elif
"1000"
in
error_str
or
"WSMsgType"
in
error_str
:
# WebSocket close message, try to extract error information
if
"1009"
in
error_str
:
raise
ValueError
(
"Message too large. Audio file may be too big or in unsupported format."
)
else
:
raise
ValueError
(
f
"WebSocket connection closed:
{
error_str
}
"
)
raise
ValueError
(
f
"Error receiving data:
{
str
(
e
)
}
"
)
class
SenseTimeTTSClient
:
"""
SenseTime TTS Client
Parameter ranges:
- speed: 0.5~2.0 (1.0 is normal speed)
- volume: -12~12 dB (0 is normal volume)
- pitch: -24~24 halftone (0 is normal pitch)
"""
def
__init__
(
self
,
url
=
None
,
app_id
=
None
,
apikey
=
None
):
self
.
url
=
url
or
os
.
getenv
(
"SENSETIME_TTS_URL"
)
self
.
app_id
=
app_id
or
os
.
getenv
(
"SENSETIME_APP_ID"
)
self
.
apikey
=
apikey
or
os
.
getenv
(
"SENSETIME_APIKEY"
)
if
not
self
.
apikey
:
raise
ValueError
(
"SENSETIME_APIKEY is not set"
)
if
not
self
.
app_id
:
raise
ValueError
(
"SENSETIME_APP_ID is not set"
)
if
not
self
.
url
:
raise
ValueError
(
"SENSETIME_TTS_URL is not set"
)
async
def
_receive_loop
(
self
,
websocket
,
session_id
,
params
,
result_dict
):
"""Continuously receive server responses in a loop"""
is_running
=
True
data
=
b
""
seq
=
-
1
subtitles
=
[]
first_latency
=
None
try
:
while
is_running
:
try
:
ptype
,
data_bytes
=
await
receive_full_message
(
websocket
)
response
=
parse_response
(
ptype
,
data_bytes
)
if
response
.
status_code
==
StatusCode
.
SUCCESS
:
chunk_seq
=
response
.
audio_chunk_seq
is_last_chunk
=
response
.
is_last_audio_chunk
stream
=
params
.
get
(
"stream_output"
,
True
)
# Check sequence number
valid
=
chunk_seq
==
seq
+
1
seq
=
chunk_seq
if
not
valid
:
logger
.
warning
(
f
"Session
{
session_id
}
Invalid seq"
)
is_running
=
False
break
if
chunk_seq
==
0
:
start_time
=
result_dict
.
get
(
"start_time"
)
if
start_time
is
not
None
:
first_latency
=
(
time
.
time
()
-
start_time
)
*
1000
logger
.
info
(
f
"Session
{
session_id
}
stream(
{
int
(
stream
)
}
) Got first package, cost(ms):
{
first_latency
:.
3
f
}
"
)
if
response
.
audio_data
:
data
+=
response
.
audio_data
logger
.
info
(
f
"Audio seq:
{
chunk_seq
}
,is_last:
{
is_last_chunk
}
data length:
{
len
(
response
.
audio_data
)
}
bytes"
)
if
response
.
subtitles
:
for
subtitle
in
response
.
subtitles
:
start_time_ms
=
subtitle
.
start_time_ms
end_time_ms
=
subtitle
.
end_time_ms
fmt_sub
=
f
"
{
subtitle
.
text
}
(
{
start_time_ms
}
-
{
end_time_ms
}
ms)"
subtitles
.
append
(
fmt_sub
)
if
response
.
is_last_audio_chunk
:
start_time
=
result_dict
.
get
(
"start_time"
)
whole_cost
=
time
.
time
()
-
start_time
if
start_time
else
0
if
len
(
data
)
>
0
:
sample_rate
=
params
.
get
(
"sample_rate"
,
24000
)
duration
=
len
(
data
)
/
2
/
sample_rate
rtf
=
whole_cost
/
duration
if
duration
>
0
else
0
if
len
(
subtitles
)
>
0
:
joint_sub
=
"
\t
"
.
join
(
subtitles
)
logger
.
info
(
f
"Session
{
session_id
}
subtile:
{
joint_sub
}
"
)
out_info
=
f
"spk
{
params
.
get
(
'speaker'
,
'default'
)
}
"
out_info
+=
f
"stream
{
int
(
stream
)
}
"
if
first_latency
is
not
None
:
out_info
+=
f
"latency
{
first_latency
:.
3
f
}
ms "
out_info
+=
f
"cost
{
whole_cost
:.
3
f
}
secs "
if
params
.
get
(
"audio_format"
)
==
"pcm"
:
out_info
+=
f
"duration
{
duration
:.
3
f
}
secs "
out_info
+=
f
"RTF
{
rtf
:.
3
f
}
"
logger
.
info
(
f
"Session
{
session_id
}
done,
{
out_info
}
"
)
result_dict
[
"audio_data"
]
=
data
result_dict
[
"subtitles"
]
=
subtitles
result_dict
[
"success"
]
=
True
is_running
=
False
elif
response
.
status_code
==
StatusCode
.
INTERNAL_ERROR
:
error_msg
=
response
.
error_detail
if
response
.
error_detail
else
"Internal error"
logger
.
error
(
f
"INTERNAL_ERROR in response:
{
error_msg
}
"
)
result_dict
[
"error"
]
=
error_msg
result_dict
[
"success"
]
=
False
is_running
=
False
break
elif
response
.
status_code
==
StatusCode
.
ERROR
:
error_msg
=
response
.
error_detail
if
response
.
error_detail
else
"Unknown error"
logger
.
error
(
f
"ERROR in response:
{
error_msg
}
"
)
result_dict
[
"error"
]
=
error_msg
result_dict
[
"success"
]
=
False
is_running
=
False
break
elif
response
.
status_code
==
StatusCode
.
UPLOAD_AUDIO_SUCCESS
:
if
response
.
speaker
==
""
:
logger
.
error
(
"ERROR: Got none speaker for UPLOAD_AUDIO_SUCCESS"
)
result_dict
[
"error"
]
=
"Got none speaker for UPLOAD_AUDIO_SUCCESS"
else
:
logger
.
info
(
f
"OK, Got speaker id
{
response
.
speaker
}
session id
{
response
.
session_id
}
"
)
result_dict
[
"speaker"
]
=
response
.
speaker
result_dict
[
"session_id"
]
=
response
.
session_id
result_dict
[
"success"
]
=
True
is_running
=
False
break
elif
response
.
status_code
==
StatusCode
.
QUERY_AUDIO_SUCCESS
:
logger
.
info
(
f
"Query speaker
{
response
.
speaker
}
successful"
)
result_dict
[
"speaker"
]
=
response
.
speaker
result_dict
[
"success"
]
=
True
is_running
=
False
break
elif
response
.
status_code
==
StatusCode
.
DELETE_AUDIO_SUCCESS
:
logger
.
info
(
f
"Delete speaker
{
response
.
speaker
}
successful"
)
result_dict
[
"success"
]
=
True
is_running
=
False
break
else
:
# Handle other error status codes, return error details directly
error_msg
=
response
.
error_detail
if
response
.
error_detail
else
"Unknown error"
logger
.
error
(
f
"Error in response:
{
error_msg
}
"
)
result_dict
[
"error"
]
=
error_msg
result_dict
[
"success"
]
=
False
is_running
=
False
break
except
asyncio
.
CancelledError
:
logger
.
info
(
"Receive loop cancelled"
)
is_running
=
False
break
except
Exception
as
e
:
logger
.
error
(
f
"Error in receive loop:
{
e
}
"
)
result_dict
[
"error"
]
=
str
(
e
)
break
except
Exception
as
e
:
logger
.
error
(
f
"Receive loop terminated:
{
e
}
"
)
result_dict
[
"error"
]
=
str
(
e
)
logger
.
info
(
"Exit receive loop."
)
async
def
tts_request
(
self
,
text
,
speaker
=
"M20"
,
style
=
"正常"
,
speed
=
1.0
,
volume
=
0
,
pitch
=
0
,
language
=
"ZH_CN"
,
output
=
"tts_output.wav"
,
sample_rate
=
24000
,
audio_format
=
"wav"
,
stream_output
=
True
,
output_subtitles
=
False
,
):
"""
Execute TTS request
Args:
text: Text to convert
speaker: Speaker, common values include "M20", "F12", "zhili", "nvguo59", or ID returned by audioclone
style: Speaker style, common values include "正常" (normal), "高兴" (happy), "愤怒" (angry), etc.
speed: Speech rate (0.5~2.0, 1.0 is normal speed)
volume: Volume (-12~12 dB, 0 is normal volume)
pitch: Pitch (-24~24 halftone, 0 is normal pitch)
language: Language, options: "ZH_CN", "ZH_CN_SICHUAN", "ZH_CN_HK", "EN_US"
output: Output file path
sample_rate: Sample rate, options: 8000, 16000, 24000, 32000, 48000
audio_format: Audio format, options: "pcm", "wav", "mp3"
stream_output: Whether to stream output
output_subtitles: Whether to output subtitles
"""
# Validate parameter ranges
if
not
(
0.5
<=
speed
<=
2.0
):
logger
.
warning
(
f
"speed
{
speed
}
is out of valid range [0.5, 2.0], using default value 1.0"
)
speed
=
1.0
if
not
(
-
12
<=
volume
<=
12
):
logger
.
warning
(
f
"volume
{
volume
}
is out of valid range [-12, 12], using default value 0"
)
volume
=
0
if
not
(
-
24
<=
pitch
<=
24
):
logger
.
warning
(
f
"pitch
{
pitch
}
is out of valid range [-24, 24], using default value 0"
)
pitch
=
0
if
language
not
in
lang_str2id_mapping
:
logger
.
warning
(
f
"language
{
language
}
is invalid, using default value ZH_CN"
)
language
=
"ZH_CN"
if
audio_format
not
in
codec_str2id_mapping
:
logger
.
warning
(
f
"audio_format
{
audio_format
}
is invalid, using default value pcm"
)
audio_format
=
"pcm"
logger
.
info
(
f
"Connecting to
{
self
.
url
}
..."
)
headers
=
{
"apikey"
:
self
.
apikey
}
if
self
.
url
.
startswith
(
"wss:"
)
else
None
result_dict
=
{
"success"
:
False
,
"audio_data"
:
None
,
"subtitles"
:
[],
"error"
:
None
}
try
:
async
with
aiohttp
.
ClientSession
(
headers
=
headers
)
as
session
:
async
with
session
.
ws_connect
(
self
.
url
)
as
websocket
:
logger
.
info
(
"WebSocket connection established"
)
session_id
=
str
(
uuid
.
uuid4
())
params
=
{
"speaker"
:
speaker
,
"style"
:
style
,
"speed"
:
speed
,
"volume"
:
volume
,
"pitch"
:
pitch
,
"language"
:
language
,
"sample_rate"
:
sample_rate
,
"audio_format"
:
audio_format
,
"stream_output"
:
stream_output
,
"output_subtitles"
:
output_subtitles
,
}
# Set start time (before sending request)
start_time
=
time
.
time
()
result_dict
[
"start_time"
]
=
start_time
# Start receive loop
receive_task
=
asyncio
.
create_task
(
self
.
_receive_loop
(
websocket
,
session_id
,
params
,
result_dict
))
# Simulate streaming: send character by character
for
i
,
chunk
in
enumerate
(
text
):
if
not
receive_task
.
done
():
is_last
=
i
==
len
(
text
)
-
1
request
=
create_synthesis_request
(
message_type
=
RequestMessageType
.
CLIENT_SYNTHESIS_REQUEST
,
app_id
=
self
.
app_id
,
text
=
chunk
,
text_chunk_seq
=
i
,
is_last_text_chunk
=
is_last
,
session_id
=
session_id
,
speaker
=
speaker
,
style
=
style
,
speed
=
speed
,
output_subtitles
=
output_subtitles
,
audio_sample_rate
=
sample_rate
,
language
=
lang_str2id_mapping
[
language
],
volume
=
volume
,
audio_encoding
=
codec_str2id_mapping
[
audio_format
],
stream_output
=
stream_output
,
pitch
=
pitch
,
)
full_request
=
serialize_request
(
request
)
await
websocket
.
send_bytes
(
full_request
)
# Wait for receive task to complete
await
receive_task
if
result_dict
[
"success"
]
and
result_dict
[
"audio_data"
]:
audio_data
=
result_dict
[
"audio_data"
]
# Save audio file
if
audio_format
==
"pcm"
:
if
not
output
.
endswith
(
".wav"
):
output
+=
".wav"
audio_np
=
np
.
frombuffer
(
audio_data
,
dtype
=
np
.
int16
)
sf
.
write
(
output
,
audio_np
,
samplerate
=
sample_rate
,
subtype
=
"PCM_16"
)
else
:
if
not
output
.
endswith
(
f
".
{
audio_format
}
"
):
output
+=
f
".
{
audio_format
}
"
with
open
(
output
,
"wb"
)
as
fp
:
fp
.
write
(
audio_data
)
logger
.
info
(
f
"audio saved to
{
output
}
, audio size:
{
len
(
audio_data
)
/
1024
:.
2
f
}
KB"
)
os
.
chmod
(
output
,
0o644
)
return
True
else
:
error_msg
=
result_dict
.
get
(
"error"
,
"Unknown error"
)
logger
.
warning
(
f
"SenseTimeTTSClient tts request failed:
{
error_msg
}
"
)
return
False
except
Exception
as
e
:
logger
.
warning
(
f
"SenseTimeTTSClient tts request failed:
{
e
}
"
)
return
False
async
def
upload_audio_clone
(
self
,
audio_path
,
audio_text
,
disable_ns
=
False
,
):
"""
Upload audio for voice cloning
Args:
audio_path: Audio file path
audio_text: Text corresponding to the audio
disable_ns: Whether to disable audio noise reduction processing
Returns:
tuple: (success: bool, result: str)
- success: True indicates success, False indicates failure
- result: Returns speaker_id on success, error message string on failure
"""
logger
.
info
(
f
"Connecting to
{
self
.
url
}
..."
)
headers
=
{
"apikey"
:
self
.
apikey
}
if
self
.
url
.
startswith
(
"wss:"
)
else
None
result_dict
=
{
"success"
:
False
,
"speaker"
:
None
,
"session_id"
:
None
,
"error"
:
None
}
try
:
async
with
aiohttp
.
ClientSession
(
headers
=
headers
)
as
session
:
async
with
session
.
ws_connect
(
self
.
url
)
as
websocket
:
logger
.
info
(
"WebSocket connection established"
)
session_id
=
str
(
uuid
.
uuid4
())
# Start receive loop
receive_task
=
asyncio
.
create_task
(
self
.
_receive_loop
(
websocket
,
session_id
,
{},
result_dict
))
# Read and send audio
# Check file format, if it's a video file (e.g., MP4), extract audio first
tmp_audio_path
=
None
original_audio_path
=
audio_path
try
:
file_ext
=
os
.
path
.
splitext
(
audio_path
)[
1
].
lower
()
if
file_ext
in
[
".mp4"
,
".mov"
,
".avi"
,
".mkv"
,
".flv"
]:
# Video file, need to extract audio first
with
tempfile
.
NamedTemporaryFile
(
suffix
=
".wav"
,
delete
=
False
)
as
tmp_audio
:
tmp_audio_path
=
tmp_audio
.
name
try
:
# Use ffmpeg to extract audio
cmd
=
[
"ffmpeg"
,
"-i"
,
audio_path
,
"-vn"
,
"-acodec"
,
"pcm_s16le"
,
"-ar"
,
"16000"
,
"-ac"
,
"1"
,
"-y"
,
tmp_audio_path
]
proc
=
await
asyncio
.
create_subprocess_exec
(
*
cmd
,
stderr
=
asyncio
.
subprocess
.
PIPE
)
try
:
_
,
stderr
=
await
asyncio
.
wait_for
(
proc
.
communicate
(),
timeout
=
60
)
except
asyncio
.
TimeoutError
:
proc
.
kill
()
await
proc
.
wait
()
raise
ValueError
(
"Audio extraction timeout. Video file may be too large."
)
if
proc
.
returncode
!=
0
:
raise
ValueError
(
f
"Failed to extract audio from video:
{
stderr
.
decode
(
errors
=
'ignore'
)
}
"
)
logger
.
info
(
f
"Extracted audio from video file to
{
tmp_audio_path
}
"
)
audio_path
=
tmp_audio_path
except
subprocess
.
TimeoutError
:
raise
ValueError
(
"Audio extraction timeout. Video file may be too large."
)
except
FileNotFoundError
:
raise
ValueError
(
"ffmpeg not found. Please install ffmpeg to process video files."
)
except
Exception
as
e
:
raise
ValueError
(
f
"Failed to extract audio:
{
str
(
e
)
}
"
)
with
open
(
audio_path
,
"rb"
)
as
fp
:
audio_bytes
=
fp
.
read
()
# Check file size (recommended not to exceed 10MB)
if
len
(
audio_bytes
)
>
10
*
1024
*
1024
:
logger
.
warning
(
f
"Audio file size (
{
len
(
audio_bytes
)
/
1024
/
1024
:.
2
f
}
MB) may be too large"
)
audio_chunk
=
AudioChunk
()
audio_chunk
.
audio_data
=
audio_bytes
audio_chunk
.
audio_chunk_seq
=
0
audio_chunk
.
is_last_chunk
=
1
audio_chunk
.
text
=
audio_text
audio_chunk
.
disable_ns
=
disable_ns
finally
:
# Clean up temporary files
if
tmp_audio_path
and
os
.
path
.
exists
(
tmp_audio_path
):
try
:
os
.
unlink
(
tmp_audio_path
)
except
Exception
:
pass
request
=
create_synthesis_request
(
message_type
=
RequestMessageType
.
CLIENT_UPLOAD_CLONE_AUDIO
,
app_id
=
self
.
app_id
,
text
=
""
,
session_id
=
session_id
,
upload_data
=
audio_chunk
,
)
full_request
=
serialize_request
(
request
)
await
websocket
.
send_bytes
(
full_request
)
logger
.
info
(
f
"Sent audio chunk for cloning"
)
# Wait for receive task to complete
await
receive_task
if
result_dict
[
"success"
]:
speaker_id
=
result_dict
.
get
(
"speaker"
)
logger
.
info
(
f
"SenseTimeTTSClient upload audio clone successful, speaker:
{
speaker_id
}
"
)
return
True
,
speaker_id
else
:
# Return error message string directly
error_msg
=
result_dict
.
get
(
"error"
,
"Unknown error"
)
logger
.
warning
(
f
"SenseTimeTTSClient upload audio clone failed:
{
error_msg
}
"
)
return
False
,
error_msg
except
Exception
as
e
:
error_msg
=
str
(
e
)
logger
.
warning
(
f
"SenseTimeTTSClient upload audio clone failed:
{
error_msg
}
"
)
return
False
,
error_msg
async
def
query_speaker
(
self
,
speaker
):
"""
Query if the specified speaker exists
Args:
speaker: speaker ID
"""
logger
.
info
(
f
"Connecting to
{
self
.
url
}
..."
)
headers
=
{
"apikey"
:
self
.
apikey
}
if
self
.
url
.
startswith
(
"wss:"
)
else
None
result_dict
=
{
"success"
:
False
,
"speaker"
:
None
,
"error"
:
None
}
try
:
async
with
aiohttp
.
ClientSession
(
headers
=
headers
)
as
session
:
async
with
session
.
ws_connect
(
self
.
url
)
as
websocket
:
logger
.
info
(
"WebSocket connection established"
)
session_id
=
str
(
uuid
.
uuid4
())
# Start receive loop
receive_task
=
asyncio
.
create_task
(
self
.
_receive_loop
(
websocket
,
session_id
,
{},
result_dict
))
# Send query request
request
=
create_synthesis_request
(
message_type
=
RequestMessageType
.
CLIENT_QUERY_CLONE_AUDIO
,
app_id
=
self
.
app_id
,
text
=
""
,
session_id
=
session_id
,
speaker
=
speaker
,
)
full_request
=
serialize_request
(
request
)
await
websocket
.
send_bytes
(
full_request
)
logger
.
info
(
f
"Sent query for speaker
{
speaker
}
"
)
# Wait for receive task to complete
await
receive_task
if
result_dict
[
"success"
]:
logger
.
info
(
f
"SenseTimeTTSClient query speaker successful"
)
return
True
else
:
error_msg
=
result_dict
.
get
(
"error"
,
"Unknown error"
)
logger
.
warning
(
f
"SenseTimeTTSClient query speaker failed:
{
error_msg
}
"
)
return
False
except
Exception
as
e
:
logger
.
warning
(
f
"SenseTimeTTSClient query speaker failed:
{
e
}
"
)
return
False
async
def
delete_speaker
(
self
,
speaker
):
"""
Delete the specified speaker
Args:
speaker: speaker ID
"""
logger
.
info
(
f
"Connecting to
{
self
.
url
}
..."
)
headers
=
{
"apikey"
:
self
.
apikey
}
if
self
.
url
.
startswith
(
"wss:"
)
else
None
result_dict
=
{
"success"
:
False
,
"error"
:
None
}
try
:
async
with
aiohttp
.
ClientSession
(
headers
=
headers
)
as
session
:
async
with
session
.
ws_connect
(
self
.
url
)
as
websocket
:
logger
.
info
(
"WebSocket connection established"
)
session_id
=
str
(
uuid
.
uuid4
())
# Start receive loop
receive_task
=
asyncio
.
create_task
(
self
.
_receive_loop
(
websocket
,
session_id
,
{},
result_dict
))
# Send delete request
request
=
create_synthesis_request
(
message_type
=
RequestMessageType
.
CLIENT_DELETE_CLONE_AUDIO
,
app_id
=
self
.
app_id
,
text
=
""
,
session_id
=
session_id
,
speaker
=
speaker
,
)
full_request
=
serialize_request
(
request
)
await
websocket
.
send_bytes
(
full_request
)
logger
.
info
(
f
"Sent delete request for speaker
{
speaker
}
"
)
# Wait for receive task to complete
await
receive_task
if
result_dict
[
"success"
]:
logger
.
info
(
f
"SenseTimeTTSClient delete speaker successful"
)
return
True
else
:
error_msg
=
result_dict
.
get
(
"error"
,
"Unknown error"
)
logger
.
warning
(
f
"SenseTimeTTSClient delete speaker failed:
{
error_msg
}
"
)
return
False
except
Exception
as
e
:
logger
.
warning
(
f
"SenseTimeTTSClient delete speaker failed:
{
e
}
"
)
return
False
async
def
test
(
args
):
"""
TTS test function
Args:
args: list, e.g. [text, speaker, style, speed, volume, pitch, language, output, sample_rate, audio_format, stream_output, output_subtitles]
Provide as many as needed, from left to right.
Parameter ranges:
- speed: 0.5~2.0 (1.0 is normal speed)
- volume: -12~12 dB (0 is normal volume)
- pitch: -24~24 halftone (0 is normal pitch)
"""
client
=
SenseTimeTTSClient
()
# Set default parameters
params
=
{
"text"
:
"今天天气真不错,阳光明媚,微风轻拂,让人心情愉悦。"
,
"speaker"
:
"M20"
,
"style"
:
"正常"
,
"speed"
:
1.0
,
"volume"
:
0
,
"pitch"
:
0
,
"language"
:
"ZH_CN"
,
"output"
:
"tts_output.wav"
,
"sample_rate"
:
24000
,
"audio_format"
:
"pcm"
,
"stream_output"
:
True
,
"output_subtitles"
:
False
,
}
keys
=
list
(
params
.
keys
())
# Override default parameters
for
i
,
arg
in
enumerate
(
args
):
if
i
<
len
(
keys
):
# Type conversion
if
keys
[
i
]
in
[
"sample_rate"
]:
params
[
keys
[
i
]]
=
int
(
arg
)
elif
keys
[
i
]
in
[
"stream_output"
,
"output_subtitles"
]:
# Support multiple boolean inputs
params
[
keys
[
i
]]
=
str
(
arg
).
lower
()
in
(
"1"
,
"true"
,
"yes"
,
"on"
)
elif
keys
[
i
]
in
[
"speed"
,
"volume"
,
"pitch"
]:
params
[
keys
[
i
]]
=
float
(
arg
)
else
:
params
[
keys
[
i
]]
=
arg
await
client
.
tts_request
(
params
[
"text"
],
params
[
"speaker"
],
params
[
"style"
],
params
[
"speed"
],
params
[
"volume"
],
params
[
"pitch"
],
params
[
"language"
],
params
[
"output"
],
params
[
"sample_rate"
],
params
[
"audio_format"
],
params
[
"stream_output"
],
params
[
"output_subtitles"
],
)
async
def
test_audio_clone
(
args
):
"""
Voice cloning test function
Args:
args: list, e.g. [audio_path, audio_text, disable_ns]
Provide as many as needed, from left to right.
Parameters:
- audio_path: Audio file path (required)
- audio_text: Text corresponding to the audio (required)
- disable_ns: Whether to disable audio noise reduction processing, default False (optional, supports "1", "true", "yes", "on" for True)
"""
client
=
SenseTimeTTSClient
()
# Set default parameters
params
=
{
"audio_path"
:
""
,
"audio_text"
:
""
,
"disable_ns"
:
False
,
}
keys
=
list
(
params
.
keys
())
# Override default parameters
for
i
,
arg
in
enumerate
(
args
):
if
i
<
len
(
keys
):
# Type conversion
if
keys
[
i
]
==
"disable_ns"
:
# Support multiple boolean inputs
params
[
keys
[
i
]]
=
str
(
arg
).
lower
()
in
(
"1"
,
"true"
,
"yes"
,
"on"
)
else
:
params
[
keys
[
i
]]
=
arg
# Validate required parameters
if
not
params
[
"audio_path"
]:
logger
.
error
(
"audio_path is required for audio clone test"
)
return
if
not
params
[
"audio_text"
]:
logger
.
error
(
"audio_text is required for audio clone test"
)
return
# Check if file exists
if
not
os
.
path
.
exists
(
params
[
"audio_path"
]):
logger
.
error
(
f
"Audio file not found:
{
params
[
'audio_path'
]
}
"
)
return
success
,
result
=
await
client
.
upload_audio_clone
(
params
[
"audio_path"
],
params
[
"audio_text"
],
params
[
"disable_ns"
],
)
if
success
:
logger
.
info
(
f
"Audio clone successful! Speaker ID:
{
result
}
"
)
else
:
logger
.
warning
(
f
"Audio clone failed:
{
result
}
"
)
if
__name__
==
"__main__"
:
# Support two test modes: regular TTS test and voice cloning test
if
len
(
sys
.
argv
)
>
1
and
sys
.
argv
[
1
]
==
"clone"
:
# Voice cloning test mode: python sensetime_tts.py clone [audio_path] [audio_text] [disable_ns]
asyncio
.
run
(
test_audio_clone
(
sys
.
argv
[
2
:]))
else
:
# Regular TTS test mode: python sensetime_tts.py [text] [speaker] ...
asyncio
.
run
(
test
(
sys
.
argv
[
1
:]))
lightx2v/deploy/common/utils.py
0 → 100644
View file @
e2778d0d
import
asyncio
import
base64
import
io
import
os
import
subprocess
import
tempfile
import
time
import
traceback
from
datetime
import
datetime
import
httpx
import
torchaudio
from
PIL
import
Image
from
loguru
import
logger
FMT
=
"%Y-%m-%d %H:%M:%S"
def
current_time
():
return
datetime
.
now
().
timestamp
()
def
time2str
(
t
):
d
=
datetime
.
fromtimestamp
(
t
)
return
d
.
strftime
(
FMT
)
def
str2time
(
s
):
d
=
datetime
.
strptime
(
s
,
FMT
)
return
d
.
timestamp
()
def
try_catch
(
func
):
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
func
(
*
args
,
**
kwargs
)
except
Exception
:
logger
.
error
(
f
"Error in
{
func
.
__name__
}
:"
)
traceback
.
print_exc
()
return
None
return
wrapper
def
class_try_catch
(
func
):
def
wrapper
(
self
,
*
args
,
**
kwargs
):
try
:
return
func
(
self
,
*
args
,
**
kwargs
)
except
Exception
:
logger
.
error
(
f
"Error in
{
self
.
__class__
.
__name__
}
.
{
func
.
__name__
}
:"
)
traceback
.
print_exc
()
return
None
return
wrapper
def
class_try_catch_async
(
func
):
async
def
wrapper
(
self
,
*
args
,
**
kwargs
):
try
:
return
await
func
(
self
,
*
args
,
**
kwargs
)
except
Exception
:
logger
.
error
(
f
"Error in
{
self
.
__class__
.
__name__
}
.
{
func
.
__name__
}
:"
)
traceback
.
print_exc
()
return
None
return
wrapper
def
data_name
(
x
,
task_id
):
if
x
==
"input_image"
or
x
.
startswith
(
"input_image/"
):
x
=
x
+
".png"
elif
x
==
"input_video"
:
x
=
x
+
".mp4"
elif
x
==
"input_last_frame"
:
x
=
x
+
".png"
elif
x
==
"output_video"
:
x
=
x
+
".mp4"
elif
x
==
"output_image"
:
x
=
x
+
".png"
return
f
"
{
task_id
}
-
{
x
}
"
async
def
fetch_resource
(
url
,
timeout
):
logger
.
info
(
f
"Begin to download resource from url:
{
url
}
"
)
t0
=
time
.
time
()
async
with
httpx
.
AsyncClient
()
as
client
:
async
with
client
.
stream
(
"GET"
,
url
,
timeout
=
timeout
)
as
response
:
response
.
raise_for_status
()
ans_bytes
=
[]
async
for
chunk
in
response
.
aiter_bytes
(
chunk_size
=
1024
*
1024
):
ans_bytes
.
append
(
chunk
)
if
len
(
ans_bytes
)
>
128
:
raise
Exception
(
f
"url
{
url
}
recv data is too big"
)
content
=
b
""
.
join
(
ans_bytes
)
logger
.
info
(
f
"Download url
{
url
}
resource cost time:
{
time
.
time
()
-
t0
}
seconds"
)
return
content
# check, resize, read rotate meta info
def
format_image_data
(
data
,
max_size
=
1280
):
image
=
Image
.
open
(
io
.
BytesIO
(
data
)).
convert
(
"RGB"
)
exif
=
image
.
getexif
()
changed
=
False
w
,
h
=
image
.
size
assert
w
>
0
and
h
>
0
,
"image is empty"
logger
.
info
(
f
"load image:
{
w
}
x
{
h
}
, exif:
{
exif
}
"
)
if
w
>
max_size
or
h
>
max_size
:
ratio
=
max_size
/
max
(
w
,
h
)
w
=
int
(
w
*
ratio
)
h
=
int
(
h
*
ratio
)
image
=
image
.
resize
((
w
,
h
))
logger
.
info
(
f
"resize image to:
{
image
.
size
}
"
)
changed
=
True
orientation_key
=
274
if
orientation_key
and
orientation_key
in
exif
:
orientation
=
exif
[
orientation_key
]
if
orientation
==
2
:
image
=
image
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
elif
orientation
==
3
:
image
=
image
.
rotate
(
180
,
expand
=
True
)
elif
orientation
==
4
:
image
=
image
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
elif
orientation
==
5
:
image
=
image
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
).
rotate
(
90
,
expand
=
True
)
elif
orientation
==
6
:
image
=
image
.
rotate
(
270
,
expand
=
True
)
elif
orientation
==
7
:
image
=
image
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
).
rotate
(
270
,
expand
=
True
)
elif
orientation
==
8
:
image
=
image
.
rotate
(
90
,
expand
=
True
)
# reset orientation to 1
if
orientation
!=
1
:
logger
.
info
(
f
"reset orientation from
{
orientation
}
to 1"
)
exif
[
orientation_key
]
=
1
changed
=
True
if
not
changed
:
return
data
output
=
io
.
BytesIO
()
image
.
save
(
output
,
format
=
image
.
format
or
"JPEG"
,
exif
=
exif
.
tobytes
())
return
output
.
getvalue
()
def
media_to_audio
(
data
,
max_duration
=
None
,
sample_rate
=
44100
,
channels
=
2
,
output_format
=
"wav"
):
with
tempfile
.
NamedTemporaryFile
()
as
fin
:
fin
.
write
(
data
)
fin
.
flush
()
ds
=
[
"-t"
,
str
(
max_duration
)]
if
max_duration
is
not
None
else
[]
fmts
=
[
"mp3"
,
"libmp3lame"
]
if
output_format
==
"mp3"
else
[
"wav"
,
"pcm_s16le"
]
cmd
=
[
"ffmpeg"
,
"-i"
,
fin
.
name
,
*
ds
,
"-f"
,
fmts
[
0
],
"-acodec"
,
fmts
[
1
],
"-ar"
,
str
(
sample_rate
),
"-ac"
,
str
(
channels
),
"pipe:1"
]
logger
.
info
(
f
"media_to_audio cmd:
{
cmd
}
"
)
p
=
subprocess
.
run
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
assert
p
.
returncode
==
0
,
f
"media to
{
output_format
}
failed:
{
p
.
stderr
.
decode
()
}
"
return
p
.
stdout
def
format_audio_data
(
data
,
max_duration
=
None
):
if
len
(
data
)
<
4
:
raise
ValueError
(
"Audio file too short"
)
data
=
media_to_audio
(
data
,
max_duration
)
waveform
,
sample_rate
=
torchaudio
.
load
(
io
.
BytesIO
(
data
),
num_frames
=
10
)
logger
.
info
(
f
"load audio:
{
waveform
.
size
()
}
,
{
sample_rate
}
"
)
assert
waveform
.
numel
()
>
0
,
"audio is empty"
assert
sample_rate
>
0
,
"audio sample rate is not valid"
return
data
async
def
preload_data
(
inp
,
inp_type
,
typ
,
val
):
try
:
if
typ
==
"url"
:
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"5"
))
data
=
await
fetch_resource
(
val
,
timeout
=
timeout
)
elif
typ
==
"base64"
:
# Check if this is multiple base64 images (for i2i tasks)
# Frontend now sends a list of base64 strings: ["base64string1", "base64string2", ...]
if
isinstance
(
val
,
list
):
data
=
{}
for
idx
,
encoded
in
enumerate
(
val
):
if
encoded
.
startswith
(
"data:image"
):
_
,
encoded
=
encoded
.
split
(
","
,
1
)
decoded
=
await
asyncio
.
to_thread
(
base64
.
b64decode
,
encoded
)
data
[
f
"
{
inp
}
_
{
idx
+
1
}
"
]
=
decoded
else
:
data
=
await
asyncio
.
to_thread
(
base64
.
b64decode
,
val
)
# For multi-person audio directory, val should be a dict with file structure
elif
typ
==
"directory"
:
data
=
{}
for
fname
,
b64_data
in
val
.
items
():
data
[
fname
]
=
await
asyncio
.
to_thread
(
base64
.
b64decode
,
b64_data
)
return
{
"type"
:
"directory"
,
"data"
:
data
}
elif
typ
==
"stream"
:
# no bytes data need to be saved by data_manager
data
=
None
else
:
raise
ValueError
(
f
"cannot read
{
inp
}
[
{
inp_type
}
] which type is
{
typ
}
!"
)
# check if valid image bytes
if
inp_type
==
"IMAGE"
:
if
isinstance
(
data
,
dict
):
for
key
,
value
in
data
.
items
():
data
[
key
]
=
await
asyncio
.
to_thread
(
format_image_data
,
value
)
return
{
"type"
:
"directory"
,
"data"
:
data
}
else
:
data
=
await
asyncio
.
to_thread
(
format_image_data
,
data
)
elif
inp_type
==
"AUDIO"
:
if
typ
!=
"stream"
and
typ
!=
"directory"
:
data
=
await
asyncio
.
to_thread
(
format_audio_data
,
data
)
elif
inp_type
==
"VIDEO"
:
# Video data doesn't need special formatting, just validate it's not empty
if
len
(
data
)
==
0
:
raise
ValueError
(
"Video file is empty"
)
logger
.
info
(
f
"load video:
{
len
(
data
)
}
bytes"
)
else
:
raise
Exception
(
f
"cannot parse inp_type=
{
inp_type
}
data"
)
return
data
except
Exception
as
e
:
raise
ValueError
(
f
"Failed to read
{
inp
}
, type=
{
typ
}
, val=
{
val
[:
100
]
}
:
{
e
}
!"
)
async
def
load_inputs
(
params
,
raw_inputs
,
types
):
inputs_data
=
{}
for
inp
in
raw_inputs
:
item
=
params
.
pop
(
inp
)
bytes_data
=
await
preload_data
(
inp
,
types
[
inp
],
item
[
"type"
],
item
[
"data"
])
# Handle multi-person audio directory, multiple images (for i2i tasks)
if
bytes_data
is
not
None
and
isinstance
(
bytes_data
,
dict
)
and
bytes_data
.
get
(
"type"
)
==
"directory"
:
fs
=
[]
for
fname
,
fdata
in
bytes_data
[
"data"
].
items
():
inputs_data
[
f
"
{
inp
}
/
{
fname
}
"
]
=
fdata
fs
.
append
(
f
"
{
inp
}
/
{
fname
}
"
)
if
"extra_inputs"
not
in
params
:
params
[
"extra_inputs"
]
=
{}
params
[
"extra_inputs"
][
inp
]
=
fs
elif
bytes_data
is
not
None
:
inputs_data
[
inp
]
=
bytes_data
else
:
params
[
inp
]
=
item
return
inputs_data
def
check_params
(
params
,
raw_inputs
,
raw_outputs
,
types
):
stream_audio
=
os
.
getenv
(
"STREAM_AUDIO"
,
"0"
)
==
"1"
stream_video
=
os
.
getenv
(
"STREAM_VIDEO"
,
"0"
)
==
"1"
for
x
in
raw_inputs
+
raw_outputs
:
if
x
in
params
and
"type"
in
params
[
x
]:
if
params
[
x
][
"type"
]
==
"stream"
:
if
types
[
x
]
==
"AUDIO"
:
assert
stream_audio
,
"stream audio is not supported, please set env STREAM_AUDIO=1"
elif
types
[
x
]
==
"VIDEO"
:
assert
stream_video
,
"stream video is not supported, please set env STREAM_VIDEO=1"
elif
params
[
x
][
"type"
]
==
"directory"
:
# Multi-person audio directory is only supported for AUDIO type
assert
types
[
x
]
==
"AUDIO"
,
f
"directory type is only supported for AUDIO input, got
{
types
[
x
]
}
"
if
__name__
==
"__main__"
:
# https://github.com/recurser/exif-orientation-examples
exif_dir
=
"/data/nvme0/liuliang1/exif-orientation-examples"
out_dir
=
"/data/nvme0/liuliang1/exif-orientation-examples/outs"
os
.
makedirs
(
out_dir
,
exist_ok
=
True
)
for
base_name
in
[
"Landscape"
,
"Portrait"
]:
for
i
in
range
(
9
):
fin_name
=
os
.
path
.
join
(
exif_dir
,
f
"
{
base_name
}
_
{
i
}
.jpg"
)
fout_name
=
os
.
path
.
join
(
out_dir
,
f
"
{
base_name
}
_
{
i
}
_formatted.jpg"
)
logger
.
info
(
f
"format image:
{
fin_name
}
->
{
fout_name
}
"
)
with
open
(
fin_name
,
"rb"
)
as
f
:
data
=
f
.
read
()
data
=
format_image_data
(
data
)
with
open
(
fout_name
,
"wb"
)
as
f
:
f
.
write
(
data
)
lightx2v/deploy/common/va_controller.py
0 → 100644
View file @
e2778d0d
import
math
import
os
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.models.runners.vsr.vsr_wrapper
import
compute_scaled_and_target_dims
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
NextControl
:
def
__init__
(
self
,
action
:
str
,
data
:
any
=
None
):
# action: blank_to_voice, data: prev_video tensor
# action: wait, data: None
# action: fetch, data: None
# action: switch_image, data: image_path
# action: perform_action, data: action prompt
self
.
action
=
action
self
.
data
=
data
class
VAController
:
def
__init__
(
self
,
model_runner
):
self
.
reader
=
None
self
.
recorder
=
None
self
.
rank
=
0
self
.
world_size
=
1
if
dist
.
is_initialized
():
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
target_reader_rank
=
int
(
os
.
getenv
(
"READER_RANK"
,
"0"
))
%
self
.
world_size
self
.
target_recorder_rank
=
int
(
os
.
getenv
(
"RECORDER_RANK"
,
"0"
))
%
self
.
world_size
self
.
init_base
(
model_runner
.
config
,
model_runner
.
input_info
,
model_runner
.
vfi_model
is
not
None
,
model_runner
.
vsr_model
is
not
None
)
self
.
init_recorder
()
self
.
init_reader
(
model_runner
)
def
init_base
(
self
,
config
,
input_info
,
has_vfi_model
,
has_vsr_model
):
if
"stream_config"
in
input_info
.
__dataclass_fields__
:
self
.
stream_config
=
input_info
.
stream_config
logger
.
info
(
f
"VAController init base with stream config:
{
self
.
stream_config
}
"
)
self
.
audio_path
=
input_info
.
audio_path
self
.
output_video_path
=
input_info
.
save_result_path
if
isinstance
(
self
.
output_video_path
,
dict
):
self
.
output_video_path
=
self
.
output_video_path
[
"data"
]
self
.
audio_sr
=
config
.
get
(
"audio_sr"
,
16000
)
self
.
target_fps
=
config
.
get
(
"target_fps"
,
16
)
self
.
max_num_frames
=
config
.
get
(
"target_video_length"
,
81
)
self
.
prev_frame_length
=
config
.
get
(
"prev_frame_length"
,
5
)
self
.
record_fps
=
config
.
get
(
"target_fps"
,
16
)
if
"video_frame_interpolation"
in
config
and
has_vfi_model
:
self
.
record_fps
=
config
[
"video_frame_interpolation"
][
"target_fps"
]
self
.
record_fps
=
config
.
get
(
"record_fps"
,
self
.
record_fps
)
self
.
tgt_h
=
input_info
.
target_shape
[
0
]
self
.
tgt_w
=
input_info
.
target_shape
[
1
]
self
.
record_h
,
self
.
record_w
=
self
.
tgt_h
,
self
.
tgt_w
if
"video_super_resolution"
in
config
and
has_vsr_model
:
_
,
_
,
self
.
record_w
,
self
.
record_h
=
compute_scaled_and_target_dims
(
self
.
record_w
,
self
.
record_h
,
scale
=
config
[
"video_super_resolution"
][
"scale"
],
multiple
=
128
,
)
# how many frames to publish stream as a batch
self
.
slice_frame
=
config
.
get
(
"slice_frame"
,
self
.
prev_frame_length
)
# estimate the max infer seconds, for immediate switch with local omni
slice_interval
=
self
.
slice_frame
/
self
.
record_fps
est_max_infer_secs
=
config
.
get
(
"est_max_infer_secs"
,
0.6
)
est_max_switch_image_secs
=
config
.
get
(
"est_max_switch_image_secs"
,
0
)
est_max_switch_action_secs
=
config
.
get
(
"est_max_switch_action_secs"
,
0
)
self
.
est_infer_end_idx
=
math
.
ceil
(
est_max_infer_secs
/
slice_interval
)
self
.
est_switch_image_end_idx
=
math
.
ceil
(
est_max_switch_image_secs
/
slice_interval
)
self
.
est_switch_action_end_idx
=
math
.
ceil
(
est_max_switch_action_secs
/
slice_interval
)
max_end_idx
=
max
(
self
.
est_infer_end_idx
,
self
.
est_switch_image_end_idx
,
self
.
est_switch_action_end_idx
)
self
.
min_stay_queue_num
=
max_end_idx
*
2
+
1
def
init_recorder
(
self
):
if
not
self
.
output_video_path
or
self
.
rank
!=
self
.
target_recorder_rank
:
return
logger
.
info
(
f
"Rank
{
self
.
rank
}
init recorder with:
{
self
.
output_video_path
}
"
)
whip_shared_path
=
os
.
getenv
(
"WHIP_SHARED_LIB"
,
None
)
if
whip_shared_path
and
self
.
output_video_path
.
startswith
(
"http"
):
from
lightx2v.deploy.common.va_recorder_x264
import
X264VARecorder
self
.
recorder
=
X264VARecorder
(
whip_shared_path
=
whip_shared_path
,
livestream_url
=
self
.
output_video_path
,
fps
=
self
.
record_fps
,
sample_rate
=
self
.
audio_sr
,
slice_frame
=
self
.
slice_frame
,
prev_frame
=
self
.
prev_frame_length
,
)
else
:
from
lightx2v.deploy.common.va_recorder
import
VARecorder
self
.
recorder
=
VARecorder
(
livestream_url
=
self
.
output_video_path
,
fps
=
self
.
record_fps
,
sample_rate
=
self
.
audio_sr
,
slice_frame
=
self
.
slice_frame
,
prev_frame
=
self
.
prev_frame_length
,
stream_config
=
self
.
stream_config
,
)
def
init_reader
(
self
,
model_runner
=
None
):
if
not
isinstance
(
self
.
audio_path
,
dict
):
return
assert
self
.
audio_path
[
"type"
]
==
"stream"
,
f
"unexcept audio_path:
{
self
.
audio_path
}
"
segment_duration
=
self
.
max_num_frames
/
self
.
target_fps
prev_duration
=
self
.
prev_frame_length
/
self
.
target_fps
omni_work_dir
=
os
.
getenv
(
"OMNI_WORK_DIR"
,
None
)
if
omni_work_dir
:
from
lightx2v.deploy.common.va_reader_omni
import
OmniVAReader
self
.
reader
=
OmniVAReader
(
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
stream_url
=
self
.
audio_path
[
"data"
],
sample_rate
=
self
.
audio_sr
,
segment_duration
=
segment_duration
,
prev_duration
=
prev_duration
,
target_rank
=
self
.
target_reader_rank
,
model_runner
=
model_runner
,
huoshan_tts_voice_type
=
self
.
audio_path
.
get
(
"huoshan_tts_voice_type"
,
None
),
stream_config
=
self
.
stream_config
,
va_recorder
=
self
.
recorder
,
)
else
:
from
lightx2v.deploy.common.va_reader
import
VAReader
self
.
reader
=
VAReader
(
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
stream_url
=
self
.
audio_path
[
"data"
],
sample_rate
=
self
.
audio_sr
,
segment_duration
=
segment_duration
,
prev_duration
=
prev_duration
,
target_rank
=
self
.
target_reader_rank
,
)
def
start
(
self
):
self
.
reader
.
start
()
if
self
.
rank
==
self
.
target_recorder_rank
:
assert
self
.
recorder
is
not
None
,
f
"recorder is required for stream audio input for rank
{
self
.
rank
}
"
self
.
recorder
.
start
(
self
.
record_w
,
self
.
record_h
)
if
self
.
world_size
>
1
:
dist
.
barrier
()
def
next_control
(
self
):
from
lightx2v.deploy.common.va_reader_omni
import
OmniVAReader
if
isinstance
(
self
.
reader
,
OmniVAReader
):
action_control
=
self
.
omni_reader_action_control
()
if
action_control
is
not
None
:
return
action_control
image_control
=
self
.
omni_reader_image_control
()
if
image_control
is
not
None
:
return
image_control
return
self
.
omni_reader_next_control
()
return
NextControl
(
action
=
"fetch"
)
def
before_control
(
self
):
from
lightx2v.deploy.common.va_reader_omni
import
OmniVAReader
if
isinstance
(
self
.
reader
,
OmniVAReader
):
self
.
len_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
self
.
flag_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
self
.
prev_tensor
=
torch
.
zeros
((
1
,
3
,
self
.
prev_frame_length
,
self
.
tgt_h
,
self
.
tgt_w
),
dtype
=
torch
.
float
,
device
=
AI_DEVICE
)
def
omni_reader_next_control
(
self
):
immediate_switch
=
self
.
reader
.
get_immediate_switch
()
if
immediate_switch
==
1
:
# truncate the stream buffer to keep the max infer time length
# and broadcast the prev video tensor to all ranks
if
self
.
rank
==
self
.
target_recorder_rank
:
logger
.
warning
(
f
"runner recv immediate switch, truncate stream buffer"
)
video_tensor
=
self
.
recorder
.
truncate_stream_buffer
(
self
.
est_infer_end_idx
)
if
video_tensor
is
not
None
:
self
.
flag_tensor
.
fill_
(
1
)
self
.
prev_tensor
.
copy_
(
video_tensor
)
else
:
self
.
flag_tensor
.
fill_
(
0
)
dist
.
broadcast
(
self
.
flag_tensor
,
src
=
self
.
target_recorder_rank
)
if
self
.
flag_tensor
.
item
()
==
1
:
dist
.
broadcast
(
self
.
prev_tensor
,
src
=
self
.
target_recorder_rank
)
return
NextControl
(
action
=
"blank_to_voice"
,
data
=
self
.
prev_tensor
)
else
:
# get the length of stream buffer, broadcast to all ranks
if
self
.
rank
==
self
.
target_recorder_rank
:
stream_buffer_length
=
self
.
recorder
.
get_buffer_stream_size
()
self
.
len_tensor
.
copy_
(
stream_buffer_length
)
dist
.
broadcast
(
self
.
len_tensor
,
src
=
self
.
target_recorder_rank
)
buffer_length
=
self
.
len_tensor
.
item
()
# stream buffer is enough, skip infer
if
buffer_length
>=
self
.
min_stay_queue_num
:
return
NextControl
(
action
=
"wait"
)
return
NextControl
(
action
=
"fetch"
)
def
omni_reader_image_control
(
self
):
image_switch
=
self
.
reader
.
get_image_switch
()
if
not
isinstance
(
image_switch
,
str
)
or
len
(
image_switch
)
==
0
:
return
None
if
not
os
.
path
.
exists
(
image_switch
):
logger
.
warning
(
f
"Switch image path
{
image_switch
}
does not exist"
)
return
None
# truncate the stream buffer to keep the max infer time length
if
self
.
rank
==
self
.
target_recorder_rank
:
logger
.
warning
(
f
"runner recv image switch, truncate stream buffer"
)
self
.
recorder
.
truncate_stream_buffer
(
self
.
est_switch_image_end_idx
)
return
NextControl
(
action
=
"switch_image"
,
data
=
image_switch
)
def
omni_reader_action_control
(
self
):
action_switch
=
self
.
reader
.
get_action_switch
()
if
not
isinstance
(
action_switch
,
str
)
or
len
(
action_switch
)
==
0
:
return
None
# truncate the stream buffer to keep the max infer time length
if
self
.
rank
==
self
.
target_recorder_rank
:
logger
.
warning
(
f
"runner recv action switch, truncate stream buffer"
)
self
.
recorder
.
truncate_stream_buffer
(
self
.
est_switch_action_end_idx
)
return
NextControl
(
action
=
"perform_action"
,
data
=
action_switch
)
def
pub_livestream
(
self
,
images
:
torch
.
Tensor
,
audios
:
torch
.
Tensor
,
gen_video
:
torch
.
Tensor
,
valid_duration
=
1e9
):
if
self
.
recorder
.
realtime
:
self
.
recorder
.
buffer_stream
(
images
,
audios
,
gen_video
,
valid_duration
=
valid_duration
)
else
:
self
.
recorder
.
pub_livestream
(
images
,
audios
)
def
clear
(
self
):
self
.
len_tensor
=
None
self
.
flag_tensor
=
None
self
.
prev_tensor
=
None
if
self
.
reader
is
not
None
:
try
:
self
.
reader
.
stop
()
except
Exception
as
e
:
logger
.
warning
(
f
"Error stopping reader:
{
e
}
"
)
self
.
reader
=
None
if
self
.
recorder
is
not
None
:
try
:
self
.
recorder
.
stop
()
except
Exception
as
e
:
logger
.
warning
(
f
"Error stopping recorder:
{
e
}
"
)
self
.
recorder
=
None
def
__del__
(
self
):
self
.
clear
()
lightx2v/deploy/common/va_reader.py
0 → 100644
View file @
e2778d0d
import
os
import
queue
import
signal
import
subprocess
import
threading
import
time
import
traceback
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
class
VAReader
:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
stream_url
:
str
,
segment_duration
:
float
=
5.0
,
sample_rate
:
int
=
16000
,
audio_channels
:
int
=
1
,
buffer_size
:
int
=
1
,
prev_duration
:
float
=
0.3125
,
target_rank
:
int
=
0
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
stream_url
=
stream_url
self
.
segment_duration
=
segment_duration
self
.
sample_rate
=
sample_rate
self
.
audio_channels
=
audio_channels
self
.
prev_duration
=
prev_duration
# int16 = 2 bytes
self
.
chunk_size
=
int
(
self
.
segment_duration
*
self
.
sample_rate
)
*
2
self
.
prev_size
=
int
(
self
.
prev_duration
*
self
.
sample_rate
)
*
2
self
.
prev_chunk
=
None
self
.
buffer_size
=
buffer_size
self
.
audio_queue
=
queue
.
Queue
(
maxsize
=
self
.
buffer_size
)
self
.
audio_thread
=
None
self
.
ffmpeg_process
=
None
self
.
bytes_buffer
=
bytearray
()
self
.
target_rank
=
target_rank
%
self
.
world_size
self
.
flag_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
).
to
(
device
=
"cuda"
)
self
.
audio_tensor
=
torch
.
zeros
(
self
.
chunk_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
logger
.
info
(
f
"VAReader initialized for stream:
{
stream_url
}
target_rank:
{
self
.
target_rank
}
"
)
logger
.
info
(
f
"Audio duration per chunk:
{
segment_duration
}
s, sample rate:
{
sample_rate
}
Hz"
)
def
start
(
self
):
if
self
.
rank
==
self
.
target_rank
:
if
self
.
stream_url
.
startswith
(
"rtmp://"
):
self
.
start_ffmpeg_process_rtmp
()
elif
self
.
stream_url
.
startswith
(
"http"
):
self
.
start_ffmpeg_process_whep
()
else
:
raise
Exception
(
f
"Unsupported stream URL:
{
self
.
stream_url
}
"
)
self
.
audio_thread
=
threading
.
Thread
(
target
=
self
.
audio_worker
,
daemon
=
True
)
self
.
audio_thread
.
start
()
logger
.
info
(
f
"VAReader
{
self
.
rank
}
/
{
self
.
world_size
}
started successfully"
)
else
:
logger
.
info
(
f
"VAReader
{
self
.
rank
}
/
{
self
.
world_size
}
wait only"
)
if
self
.
world_size
>
1
:
logger
.
info
(
f
"VAReader
{
self
.
rank
}
/
{
self
.
world_size
}
wait barrier"
)
dist
.
barrier
()
logger
.
info
(
f
"VAReader
{
self
.
rank
}
/
{
self
.
world_size
}
end barrier"
)
def
start_ffmpeg_process_rtmp
(
self
):
"""Start ffmpeg process read audio from stream"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-i"
,
self
.
stream_url
,
"-vn"
,
# "-acodec",
# "pcm_s16le",
"-ar"
,
str
(
self
.
sample_rate
),
"-ac"
,
str
(
self
.
audio_channels
),
"-f"
,
"s16le"
,
"-"
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
bufsize
=
0
)
logger
.
info
(
f
"FFmpeg audio pull process started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg process:
{
e
}
"
)
raise
def
start_ffmpeg_process_whep
(
self
):
"""Start gstream process read audio from stream"""
ffmpeg_cmd
=
[
"gst-launch-1.0"
,
"-q"
,
"whepsrc"
,
f
"whep-endpoint=
{
self
.
stream_url
}
"
,
"video-caps=none"
,
"!rtpopusdepay"
,
"!opusdec"
,
"plc=false"
,
"!audioconvert"
,
"!audioresample"
,
f
"!audio/x-raw,format=S16LE,channels=
{
self
.
audio_channels
}
,rate=
{
self
.
sample_rate
}
"
,
"!fdsink"
,
"fd=1"
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
,
stdout
=
subprocess
.
PIPE
,
# stderr=subprocess.PIPE,
bufsize
=
0
,
)
logger
.
info
(
f
"FFmpeg audio pull process started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg process:
{
e
}
"
)
raise
def
audio_worker
(
self
):
logger
.
info
(
"Audio pull worker thread started"
)
try
:
while
True
:
if
not
self
.
ffmpeg_process
or
self
.
ffmpeg_process
.
poll
()
is
not
None
:
logger
.
warning
(
"FFmpeg process exited, audio worker thread stopped"
)
break
self
.
fetch_audio_data
()
time
.
sleep
(
0.01
)
except
:
# noqa
logger
.
error
(
f
"Audio pull worker error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
warning
(
"Audio pull worker thread stopped"
)
def
fetch_audio_data
(
self
):
"""Fetch audio data from ffmpeg process"""
try
:
audio_bytes
=
self
.
ffmpeg_process
.
stdout
.
read
(
self
.
chunk_size
)
if
not
audio_bytes
:
return
self
.
bytes_buffer
.
extend
(
audio_bytes
)
# logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes")
if
len
(
self
.
bytes_buffer
)
>=
self
.
chunk_size
:
audio_data
=
self
.
bytes_buffer
[:
self
.
chunk_size
]
self
.
bytes_buffer
=
self
.
bytes_buffer
[
self
.
chunk_size
:]
# first chunk, read original 81 frames
# for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames
if
self
.
prev_chunk
is
None
:
logger
.
info
(
f
"change chunk_size: from
{
self
.
chunk_size
}
to
{
self
.
chunk_size
-
self
.
prev_size
}
"
)
self
.
chunk_size
-=
self
.
prev_size
else
:
audio_data
=
self
.
prev_chunk
+
audio_data
self
.
prev_chunk
=
audio_data
[
-
self
.
prev_size
:]
try
:
self
.
audio_queue
.
put_nowait
(
audio_data
)
except
queue
.
Full
:
logger
.
warning
(
f
"Audio queue full:
{
self
.
audio_queue
.
qsize
()
}
, discarded oldest chunk"
)
self
.
audio_queue
.
get_nowait
()
self
.
audio_queue
.
put_nowait
(
audio_data
)
logger
.
info
(
f
"Put audio data:
{
len
(
audio_data
)
}
bytes, audio_queue:
{
self
.
audio_queue
.
qsize
()
}
, chunk_size:
{
self
.
chunk_size
}
"
)
except
:
# noqa
logger
.
error
(
f
"Fetch audio data error:
{
traceback
.
format_exc
()
}
"
)
def
braodcast_audio_data
(
self
,
audio_data
):
if
self
.
rank
==
self
.
target_rank
:
if
audio_data
is
None
:
self
.
flag_tensor
.
fill_
(
0
)
else
:
self
.
flag_tensor
.
fill_
(
1
)
self
.
audio_tensor
.
copy_
(
torch
.
frombuffer
(
bytearray
(
audio_data
),
dtype
=
torch
.
uint8
))
logger
.
info
(
f
"rank
{
self
.
rank
}
send audio_tensor:
{
self
.
audio_tensor
.
shape
}
"
)
dist
.
broadcast
(
self
.
flag_tensor
,
src
=
self
.
target_rank
)
if
self
.
flag_tensor
.
item
()
==
0
:
return
None
dist
.
broadcast
(
self
.
audio_tensor
,
src
=
self
.
target_rank
)
if
self
.
rank
!=
self
.
target_rank
:
logger
.
info
(
f
"rank
{
self
.
rank
}
recv audio_tensor:
{
self
.
audio_tensor
.
shape
}
"
)
audio_data
=
self
.
audio_tensor
.
cpu
().
numpy
().
tobytes
()
return
audio_data
def
bytes_to_ndarray
(
self
,
audio_data
):
if
audio_data
is
None
:
return
None
audio_data
=
np
.
frombuffer
(
audio_data
,
dtype
=
np
.
int16
)
audio_data
=
audio_data
.
astype
(
np
.
float32
)
/
32768.0
logger
.
info
(
f
"Got segment audio rank=
{
self
.
rank
}
:
{
audio_data
.
shape
}
{
audio_data
.
dtype
}
{
audio_data
.
min
()
}
{
audio_data
.
max
()
}
"
)
return
audio_data
def
get_audio_segment
(
self
,
timeout
:
float
=
1.0
,
fetch_duration
:
float
=
None
,
prev_duration
:
float
=
None
):
if
fetch_duration
is
not
None
and
self
.
segment_duration
!=
fetch_duration
:
logger
.
warning
(
f
"ignore fetch_duration,
{
fetch_duration
}
!=
{
self
.
segment_duration
}
"
)
if
prev_duration
is
not
None
and
self
.
prev_duration
!=
prev_duration
:
raise
ValueError
(
f
"prev_duration
{
prev_duration
}
!=
{
self
.
prev_duration
}
"
)
audio_data
=
None
if
self
.
rank
==
self
.
target_rank
:
try
:
audio_data
=
self
.
audio_queue
.
get
(
timeout
=
timeout
)
except
:
# noqa
logger
.
warning
(
f
"Failed to get audio segment:
{
traceback
.
format_exc
()
}
"
)
if
self
.
world_size
>
1
:
audio_data
=
self
.
braodcast_audio_data
(
audio_data
)
audio_data
=
self
.
bytes_to_ndarray
(
audio_data
)
return
audio_data
,
self
.
segment_duration
def
stop
(
self
):
# Stop ffmpeg process
if
self
.
ffmpeg_process
:
self
.
ffmpeg_process
.
send_signal
(
signal
.
SIGINT
)
try
:
self
.
ffmpeg_process
.
wait
(
timeout
=
5
)
except
subprocess
.
TimeoutExpired
:
self
.
ffmpeg_process
.
kill
()
logger
.
warning
(
"FFmpeg reader process stopped"
)
# Wait for threads to finish
if
self
.
audio_thread
and
self
.
audio_thread
.
is_alive
():
self
.
audio_thread
.
join
(
timeout
=
5
)
if
self
.
audio_thread
.
is_alive
():
logger
.
error
(
"Audio pull thread did not stop gracefully"
)
while
self
.
audio_queue
and
self
.
audio_queue
.
qsize
()
>
0
:
self
.
audio_queue
.
get_nowait
()
self
.
audio_queue
=
None
logger
.
warning
(
"Audio pull queue cleaned"
)
def
__del__
(
self
):
self
.
stop
()
if
__name__
==
"__main__"
:
WORLD_SIZE
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
RANK
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
if
WORLD_SIZE
>
1
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
logger
.
info
(
f
"Distributed initialized: rank=
{
RANK
}
, world_size=
{
WORLD_SIZE
}
"
)
reader
=
VAReader
(
RANK
,
WORLD_SIZE
,
# "rtmp://localhost/live/test_audio",
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000"
,
segment_duration
=
1.0
,
sample_rate
=
16000
,
audio_channels
=
1
,
prev_duration
=
1
/
16
,
)
reader
.
start
()
fail_count
=
0
max_fail_count
=
2
try
:
while
True
:
audio_data
=
reader
.
get_audio_segment
(
timeout
=
2
)
if
audio_data
is
not
None
:
# logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count
=
0
else
:
fail_count
+=
1
if
fail_count
>
max_fail_count
:
logger
.
warning
(
"Failed to get audio chunk, stop reader"
)
reader
.
stop
()
break
time
.
sleep
(
0.95
)
finally
:
reader
.
stop
()
lightx2v/deploy/common/va_reader_omni.py
0 → 100644
View file @
e2778d0d
import
datetime
import
json
import
os
import
random
import
subprocess
import
threading
import
time
import
traceback
from
collections
import
deque
from
copy
import
deepcopy
import
jsonschema
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
zmq
from
loguru
import
logger
try
:
from
bson
import
BSON
except
ImportError
:
BSON
=
None
logger
.
warning
(
"BSON is not installed"
)
from
scipy.signal
import
resample
class
AudioInfo
:
def
__init__
(
self
,
info
:
dict
):
self
.
sample_count
=
info
[
"sample_count"
]
self
.
sample_rate
=
info
[
"sample_rate"
]
self
.
channel_count
=
info
[
"channel_count"
]
self
.
sample_fmt
=
info
[
"sample_fmt"
]
self
.
pts
=
info
[
"pts"
]
def
is_spec_equal
(
self
,
other
:
"AudioInfo"
)
->
bool
:
return
self
.
sample_fmt
==
other
.
sample_fmt
and
self
.
sample_rate
==
other
.
sample_rate
and
self
.
channel_count
==
other
.
channel_count
def
duration
(
self
)
->
datetime
.
timedelta
:
return
datetime
.
timedelta
(
seconds
=
self
.
sample_count
/
self
.
sample_rate
)
def
__str__
(
self
):
return
"AudioInfo(sample_count={}, sample_rate={}, channel_count={}, sample_fmt={}, pts={})"
.
format
(
self
.
sample_count
,
self
.
sample_rate
,
self
.
channel_count
,
self
.
sample_fmt
,
self
.
pts
)
class
ByteBuffer
:
def
__init__
(
self
):
self
.
buffer
=
deque
()
self
.
current_size
=
0
# is the audio belonging to current turn finished
self
.
audio_finished
=
False
def
add
(
self
,
byte_data
:
bytes
):
self
.
buffer
.
append
(
byte_data
)
self
.
current_size
+=
len
(
byte_data
)
def
get
(
self
,
size
=
1024
):
data
=
bytearray
()
while
size
>
0
and
len
(
self
.
buffer
)
>
0
:
chunk
=
self
.
buffer
.
popleft
()
if
len
(
chunk
)
<=
size
:
# 如果当前数据小于size,则将当前数据全部添加到data中
data
.
extend
(
chunk
)
self
.
current_size
-=
len
(
chunk
)
size
-=
len
(
chunk
)
else
:
# 如果当前数据大于size,则将当前数据的一部分添加到data中,剩余部分留在缓冲区
data
.
extend
(
chunk
[:
size
])
self
.
buffer
.
appendleft
(
chunk
[
size
:])
# 剩余部分留在缓冲区
self
.
current_size
-=
size
size
=
0
return
bytes
(
data
)
def
mark_finished
(
self
):
self
.
audio_finished
=
True
def
has_more_voice
(
self
):
return
not
self
.
audio_finished
def
__len__
(
self
):
return
self
.
current_size
class
ChatAdapter
:
def
__init__
(
self
,
omni_work_dir
:
str
,
whep_url
:
str
,
session_id
:
str
,
account
:
str
,
config_files
:
list
[
str
],
config_schema_path
:
str
,
seg_duration
:
float
,
model_runner
,
huoshan_tts_voice_type
,
stream_config
:
dict
,
):
assert
os
.
path
.
exists
(
omni_work_dir
),
f
"OMNI work directory
{
omni_work_dir
}
does not exist"
self
.
omni_work_dir
=
omni_work_dir
self
.
stream_config
=
stream_config
self
.
context
=
zmq
.
Context
()
self
.
w2f_socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
w2f_url
=
ChatAdapter
.
select_and_bind
(
self
.
w2f_socket
)
self
.
f2w_socket
=
self
.
context
.
socket
(
zmq
.
PUSH
)
self
.
f2w_url
=
ChatAdapter
.
select_and_bind
(
self
.
f2w_socket
)
self
.
recv_thread
=
None
self
.
audio_buffer
=
ByteBuffer
()
self
.
audio_info
=
None
self
.
chat_server_cmd
=
[
os
.
path
.
join
(
self
.
omni_work_dir
,
"bin"
,
"seko-chatter"
),
"--session-id"
,
session_id
,
"--account"
,
account
,
"--whep-server-url"
,
whep_url
,
"--w2f-endpoint"
,
self
.
w2f_url
,
"--f2w-endpoint"
,
self
.
f2w_url
,
"--config-files"
,
*
config_files
,
]
override_config
=
{}
if
huoshan_tts_voice_type
is
not
None
:
logger
.
info
(
f
"Use Huoshan TTS voice type:
{
huoshan_tts_voice_type
}
"
)
override_config
[
"TTS"
]
=
{
"default_voice_info"
:
{
"voice_type"
:
huoshan_tts_voice_type
,
"provider"
:
"huoshan_stream_tts"
,
}
}
system_prompt
=
stream_config
.
get
(
"system_prompt"
,
""
)
if
system_prompt
:
override_config
[
"model"
]
=
{
"system_prompt"
:
system_prompt
}
logger
.
info
(
f
"Omni use custom system prompt:
{
system_prompt
}
"
)
with
open
(
config_schema_path
,
"r"
)
as
f
:
schema
=
json
.
load
(
f
)
jsonschema
.
validate
(
instance
=
override_config
,
schema
=
schema
)
if
override_config
is
not
None
:
self
.
chat_server_cmd
.
extend
([
"--override-config"
,
json
.
dumps
(
override_config
)])
self
.
chatter_proc
=
None
self
.
seg_duration
=
seg_duration
self
.
reset_prev
=
False
self
.
status
=
"blank"
self
.
immediate_switch
=
0
self
.
image_switch
=
""
self
.
action_switch
=
""
self
.
model_runner
=
model_runner
def
launch_chat_server
(
self
):
env
=
{
"RUST_LOG"
:
"info,duplex_server=debug,backend_5o=debug"
,
"LD_LIBRARY_PATH"
:
os
.
environ
.
get
(
"LD_LIBRARY_PATH"
,
""
)
+
":"
+
os
.
path
.
join
(
self
.
omni_work_dir
,
"lib/"
),
"PATH"
:
os
.
environ
[
"PATH"
]
+
":"
+
os
.
path
.
join
(
self
.
omni_work_dir
,
"bin/"
),
}
self
.
chatter_proc
=
subprocess
.
Popen
(
self
.
chat_server_cmd
,
env
=
env
,
cwd
=
self
.
omni_work_dir
)
@
staticmethod
def
select_and_bind
(
socket
:
zmq
.
Socket
)
->
str
:
# randomly select a port between 1024 and 6553
retry_count
=
20
err
=
None
while
retry_count
>
0
:
try
:
port
=
random
.
randint
(
1024
,
65535
)
# port = 5555
url
=
f
"tcp://localhost:
{
port
}
"
socket
.
bind
(
url
)
return
url
except
zmq
.
error
.
ZMQError
as
e
:
retry_count
-=
1
err
=
e
raise
err
# immediate switch to status, discard prev_bytes, set immediate_switch to 1
def
immediate_switch_to
(
self
,
status
):
logger
.
warning
(
f
"VA reader immediate switch to
{
status
}
"
)
self
.
reset_prev
=
True
self
.
status
=
status
self
.
immediate_switch
=
1
# only no action switch can be paused immediately
if
self
.
model_runner
is
not
None
and
self
.
model_runner
.
can_pause
:
self
.
model_runner
.
pause_signal
=
True
logger
.
warning
(
f
"Model runner pause signal set to True"
)
def
set_image_switch
(
self
,
image_path
):
logger
.
warning
(
f
"Setting image switch:
{
image_path
}
"
)
self
.
image_switch
=
image_path
# only blank status and no action switch can be paused immediately
if
self
.
model_runner
is
not
None
and
self
.
model_runner
.
can_pause
:
self
.
model_runner
.
pause_signal
=
True
logger
.
warning
(
f
"Model runner set pause signal for image switch & blank status"
)
def
set_action_switch
(
self
,
prompt
):
logger
.
warning
(
f
"Setting action switch:
{
prompt
}
"
)
self
.
action_switch
=
prompt
# only blank status can be paused immediately
if
self
.
model_runner
is
not
None
and
self
.
model_runner
.
can_pause
:
self
.
model_runner
.
pause_signal
=
True
logger
.
warning
(
f
"Model runner set pause signal for action switch & blank status"
)
def
recv_loop
(
self
):
while
True
:
try
:
message
=
self
.
w2f_socket
.
recv
()
except
Exception
:
logger
.
error
(
f
"Error receiving message:
{
traceback
.
format_exc
()
}
"
)
break
try
:
message
=
BSON
.
decode
(
message
)
msg_type
=
message
[
"type"
]
logger
.
debug
(
"Received message type: {}"
.
format
(
msg_type
))
if
msg_type
==
"AgentAudio"
:
audio
=
message
[
"audio"
]
if
audio
[
"type"
]
!=
"Pcm"
:
logger
.
error
(
"Unsupported audio type: {}"
.
format
(
audio
[
"type"
]))
continue
pcm_data
=
audio
[
"data"
]
audio_info
=
AudioInfo
(
audio
[
"info"
])
logger
.
debug
(
"Received audio with duration: {}"
.
format
(
audio_info
.
duration
()))
if
self
.
audio_info
is
None
:
self
.
audio_info
=
audio_info
else
:
# check if the audio info is the same
if
not
self
.
audio_info
.
is_spec_equal
(
audio_info
):
raise
ValueError
(
"Audio info mismatch"
)
self
.
audio_buffer
.
add
(
pcm_data
)
# if status is blank and has voice, set immediate switch to 1
if
self
.
status
==
"blank"
and
self
.
has_voice
(
self
.
seg_duration
):
self
.
immediate_switch_to
(
"voice"
)
elif
msg_type
==
"AgentStartPlay"
:
logger
.
debug
(
"Received AgentStartPlay, create new audio buffer"
)
self
.
audio_buffer
=
ByteBuffer
()
elif
msg_type
==
"AgentEndPlay"
:
logger
.
debug
(
"Received AgentEndPlay, mark audio finished"
)
self
.
audio_buffer
.
mark_finished
()
elif
msg_type
==
"ClearAgentAudio"
:
logger
.
warning
(
"Received ClearAgentAudio, clear audio buffer"
)
self
.
audio_buffer
=
None
self
.
audio_info
=
None
if
self
.
status
==
"voice"
:
self
.
status
=
"blank"
# self.immediate_switch_to("blank")
except
Exception
as
e
:
logger
.
error
(
"Error decoding message: {}, continue"
.
format
(
e
))
continue
logger
.
warning
(
"recv loop interrupted"
)
def
start
(
self
):
self
.
launch_chat_server
()
self
.
recv_thread
=
threading
.
Thread
(
target
=
self
.
recv_loop
)
self
.
recv_thread
.
start
()
def
has_voice
(
self
,
duration
)
->
bool
:
if
self
.
audio_info
is
None
or
self
.
audio_buffer
.
current_size
==
0
:
return
False
bytes_count
=
round
(
duration
*
self
.
audio_info
.
sample_rate
)
*
self
.
audio_info
.
channel_count
*
2
# S16LE assumed
# if not has enough bytes and maybe has more voice, return False
if
self
.
audio_buffer
.
current_size
<
bytes_count
and
self
.
audio_buffer
.
has_more_voice
():
logger
.
warning
(
f
"Not enough bytes and maybe has more voice, content_size:
{
self
.
audio_buffer
.
current_size
}
, bytes_count:
{
bytes_count
}
"
)
return
False
return
bytes_count
def
get_audio
(
self
,
fetch_duration
)
->
(
bytes
,
AudioInfo
):
bytes_count
=
self
.
has_voice
(
fetch_duration
)
if
bytes_count
is
False
or
self
.
audio_info
is
None
:
return
None
pcm_data
=
self
.
audio_buffer
.
get
(
bytes_count
)
# the actual sample count fetched
sample_count
=
len
(
pcm_data
)
//
(
self
.
audio_info
.
channel_count
*
2
)
logger
.
debug
(
"Fetched {} bytes audio"
.
format
(
sample_count
))
logger
.
debug
(
"After fetch, there are {} bytes left"
.
format
(
self
.
audio_buffer
.
current_size
))
audio_info
=
deepcopy
(
self
.
audio_info
)
audio_info
.
sample_count
=
sample_count
return
(
pcm_data
,
audio_info
)
def
stop
(
self
):
self
.
model_runner
=
None
if
self
.
chatter_proc
is
not
None
:
self
.
chatter_proc
.
terminate
()
self
.
chatter_proc
.
wait
()
self
.
chatter_proc
=
None
self
.
w2f_socket
.
close
()
self
.
f2w_socket
.
close
()
def
__del__
(
self
):
self
.
stop
()
class
OmniVAReader
:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
stream_url
:
str
,
segment_duration
:
float
=
5.0625
,
sample_rate
:
int
=
16000
,
audio_channels
:
int
=
1
,
buffer_size
:
int
=
1
,
prev_duration
:
float
=
0.3125
,
target_rank
:
int
=
0
,
model_runner
=
None
,
huoshan_tts_voice_type
=
None
,
stream_config
:
dict
=
{},
**
kwargs
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
stream_url
=
stream_url
self
.
segment_duration
=
segment_duration
self
.
sample_rate
=
sample_rate
self
.
audio_channels
=
audio_channels
self
.
prev_duration
=
prev_duration
self
.
all_seg_sample_count
=
int
(
self
.
segment_duration
*
self
.
sample_rate
)
self
.
prev_seg_sample_count
=
int
(
self
.
prev_duration
*
self
.
sample_rate
)
self
.
prev_seg_chunk
=
None
self
.
target_rank
=
target_rank
%
self
.
world_size
self
.
flag_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
).
to
(
device
=
"cuda"
)
self
.
valid_duration_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
float32
).
to
(
device
=
"cuda"
)
self
.
immediate_switch_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
).
to
(
device
=
"cuda"
)
chunk_size
=
int
(
self
.
segment_duration
*
self
.
sample_rate
)
*
2
self
.
audio_tensor
=
torch
.
zeros
(
chunk_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
chat_adapter
=
None
self
.
model_runner
=
model_runner
self
.
huoshan_tts_voice_type
=
huoshan_tts_voice_type
self
.
stream_config
=
stream_config
assert
self
.
audio_channels
==
1
,
"Only mono audio is supported for OmniVAReader"
logger
.
info
(
f
"VAReader initialized for stream:
{
stream_url
}
target_rank:
{
self
.
target_rank
}
"
)
logger
.
info
(
f
"Audio duration per chunk:
{
segment_duration
}
s, sample rate:
{
sample_rate
}
Hz"
)
def
init_omni_env
(
self
):
self
.
omni_work_dir
=
os
.
getenv
(
"OMNI_WORK_DIR"
,
"/path/of/seko_chatter/"
)
self
.
session_id
=
os
.
getenv
(
"OMNI_SESSION_ID"
,
""
)
self
.
account
=
os
.
getenv
(
"OMNI_ACCOUNT"
,
""
)
self
.
config_files
=
os
.
getenv
(
"OMNI_CONFIG_FILES"
,
""
).
split
(
","
)
self
.
config_schema_path
=
os
.
getenv
(
"OMNI_CONFIG_SCHEMA_PATH"
,
None
)
assert
os
.
path
.
exists
(
self
.
omni_work_dir
),
f
"OMNI work directory
{
self
.
omni_work_dir
}
does not exist"
assert
self
.
session_id
and
self
.
account
,
"OMNI_SESSION_ID and OMNI_ACCOUNT are required"
logger
.
info
(
f
"OMNI work directory:
{
self
.
omni_work_dir
}
, session_id:
{
self
.
session_id
}
, account:
{
self
.
account
}
, config_files:
{
self
.
config_files
}
, config_schema_path:
{
self
.
config_schema_path
}
"
)
def
start
(
self
):
if
self
.
rank
==
self
.
target_rank
:
self
.
init_omni_env
()
assert
self
.
stream_url
.
startswith
(
"http"
),
"Only HTTP stream is supported for OmniVAReader"
self
.
chat_adapter
=
ChatAdapter
(
omni_work_dir
=
self
.
omni_work_dir
,
whep_url
=
self
.
stream_url
,
session_id
=
self
.
session_id
,
account
=
self
.
account
,
config_files
=
self
.
config_files
,
config_schema_path
=
self
.
config_schema_path
,
seg_duration
=
self
.
segment_duration
,
model_runner
=
self
.
model_runner
,
huoshan_tts_voice_type
=
self
.
huoshan_tts_voice_type
,
stream_config
=
self
.
stream_config
,
)
self
.
chat_adapter
.
start
()
logger
.
info
(
f
"OmniVAReader
{
self
.
rank
}
/
{
self
.
world_size
}
started successfully"
)
else
:
logger
.
info
(
f
"OmniVAReader
{
self
.
rank
}
/
{
self
.
world_size
}
wait only"
)
if
self
.
world_size
>
1
:
logger
.
info
(
f
"OmniVAReader
{
self
.
rank
}
/
{
self
.
world_size
}
wait barrier"
)
dist
.
barrier
()
logger
.
info
(
f
"OmniVAReader
{
self
.
rank
}
/
{
self
.
world_size
}
end barrier"
)
def
braodcast_audio_data
(
self
,
audio_data
):
if
self
.
rank
==
self
.
target_rank
:
if
audio_data
is
None
:
self
.
flag_tensor
.
fill_
(
0
)
else
:
self
.
flag_tensor
.
fill_
(
1
)
self
.
audio_tensor
.
copy_
(
torch
.
frombuffer
(
bytearray
(
audio_data
),
dtype
=
torch
.
uint8
))
# logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist
.
broadcast
(
self
.
flag_tensor
,
src
=
self
.
target_rank
)
if
self
.
flag_tensor
.
item
()
==
0
:
return
None
dist
.
broadcast
(
self
.
audio_tensor
,
src
=
self
.
target_rank
)
if
self
.
rank
!=
self
.
target_rank
:
# logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data
=
self
.
audio_tensor
.
cpu
().
numpy
().
tobytes
()
return
audio_data
def
braodcast_valid_duration
(
self
,
valid_duration
):
if
self
.
rank
==
self
.
target_rank
:
self
.
valid_duration_tensor
.
fill_
(
valid_duration
)
dist
.
broadcast
(
self
.
valid_duration_tensor
,
src
=
self
.
target_rank
)
return
self
.
valid_duration_tensor
.
item
()
def
bytes_to_ndarray
(
self
,
audio_data
):
if
audio_data
is
None
:
return
None
audio_data
=
np
.
frombuffer
(
audio_data
,
dtype
=
np
.
int16
)
audio_data
=
audio_data
.
astype
(
np
.
float32
)
/
32768.0
# logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return
audio_data
def
convert_pcm_s16le_to_mono_resampled
(
self
,
audio_data
,
audio_info
):
audio
=
np
.
frombuffer
(
audio_data
,
dtype
=
np
.
int16
)
sample_count
=
audio_info
.
sample_count
assert
len
(
audio
)
==
sample_count
*
audio_info
.
channel_count
,
f
"audio length
{
len
(
audio
)
}
!= sample_count * channel_count
{
sample_count
*
audio_info
.
channel_count
}
"
# convert to mono
if
audio_info
.
channel_count
>
1
:
audio
=
audio
.
reshape
(
-
1
,
audio_info
.
channel_count
).
mean
(
axis
=
1
)
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()}")
if
audio_info
.
sample_rate
!=
self
.
sample_rate
:
sample_count
=
int
(
len
(
audio
)
*
self
.
sample_rate
/
audio_info
.
sample_rate
)
audio
=
resample
(
audio
,
sample_count
).
astype
(
np
.
int16
)
# logger.info(f"resampled audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
logger
.
warning
(
f
"valid audio:
{
audio
.
shape
}
{
audio
.
dtype
}
{
audio
.
min
()
}
{
audio
.
max
()
}
{
sample_count
}
"
)
return
audio
,
sample_count
def
prepare_audio_data
(
self
,
chat_audio_result
):
sample_count
=
0
audio
=
np
.
array
([],
dtype
=
np
.
int16
)
# convert chat audio result to mono and target sample rate
if
chat_audio_result
is
not
None
:
audio_data
,
audio_info
=
chat_audio_result
audio
,
sample_count
=
self
.
convert_pcm_s16le_to_mono_resampled
(
audio_data
,
audio_info
)
valid_duration
=
sample_count
/
self
.
sample_rate
# if is not the first segment, concat with previous segment
if
self
.
prev_seg_chunk
is
not
None
:
audio
=
np
.
concatenate
([
self
.
prev_seg_chunk
,
audio
])
sample_count
=
len
(
audio
)
assert
sample_count
<=
self
.
all_seg_sample_count
,
f
"audio length
{
sample_count
}
> all_seg_sample_count
{
self
.
all_seg_sample_count
}
"
# pad 0 to the audio to make it the same length as all_seg_sample_count
if
sample_count
<
self
.
all_seg_sample_count
:
pad_count
=
self
.
all_seg_sample_count
-
sample_count
# logger.info(f"pad {pad_count} samples to audio")
audio
=
np
.
pad
(
audio
,
(
0
,
pad_count
),
mode
=
"constant"
,
constant_values
=
0
)
sample_count
=
len
(
audio
)
# update prev seg chunk
self
.
prev_seg_chunk
=
audio
[
-
self
.
prev_seg_sample_count
:]
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}, prev seg chunk: {self.prev_seg_chunk.shape}")
return
audio
.
tobytes
(),
valid_duration
def
get_fetch_duration
(
self
):
fetch_duration
=
self
.
segment_duration
# after immediate switch, reset prev seg chunk
if
self
.
chat_adapter
is
not
None
and
self
.
chat_adapter
.
reset_prev
:
self
.
prev_seg_chunk
=
None
self
.
chat_adapter
.
reset_prev
=
False
logger
.
warning
(
f
"Reset prev seg chunk"
)
# first segment, fetch segment_duration, else fetch segment_duration - prev_duration
if
self
.
prev_seg_chunk
is
not
None
:
fetch_duration
-=
self
.
prev_duration
return
fetch_duration
def
change_segment_duration
(
self
,
segment_duration
):
if
segment_duration
is
None
or
self
.
segment_duration
==
segment_duration
:
return
if
self
.
rank
==
self
.
target_rank
:
logger
.
warning
(
f
"segment duration changed:
{
self
.
segment_duration
}
->
{
segment_duration
}
"
)
self
.
segment_duration
=
segment_duration
self
.
all_seg_sample_count
=
int
(
self
.
segment_duration
*
self
.
sample_rate
)
chunk_size
=
int
(
self
.
segment_duration
*
self
.
sample_rate
)
*
2
self
.
audio_tensor
=
torch
.
zeros
(
chunk_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
if
self
.
chat_adapter
is
not
None
:
self
.
chat_adapter
.
seg_duration
=
segment_duration
def
get_audio_segment
(
self
,
fetch_duration
:
float
=
None
,
prev_duration
:
float
=
None
):
audio_data
=
None
valid_duration
=
0
if
prev_duration
is
not
None
and
self
.
prev_duration
!=
prev_duration
:
raise
ValueError
(
f
"prev_duration
{
prev_duration
}
!=
{
self
.
prev_duration
}
"
)
self
.
change_segment_duration
(
fetch_duration
)
if
self
.
rank
==
self
.
target_rank
:
try
:
fetch_duration
=
self
.
get_fetch_duration
()
# logger.info(f"Get segment, fetch_duration: {fetch_duration}")
if
self
.
chat_adapter
.
status
==
"voice"
:
audio_result
=
self
.
chat_adapter
.
get_audio
(
fetch_duration
)
audio_data
,
valid_duration
=
self
.
prepare_audio_data
(
audio_result
)
# think all voice segments inferred, naturally switch to blank
if
audio_result
is
None
:
logger
.
info
(
f
"Think all voice segments inferred, naturally switch to blank"
)
self
.
chat_adapter
.
status
=
"blank"
else
:
audio_data
,
valid_duration
=
self
.
prepare_audio_data
(
None
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to get voice segment:
{
e
}
"
)
return
None
,
0
if
self
.
world_size
>
1
:
audio_data
=
self
.
braodcast_audio_data
(
audio_data
)
valid_duration
=
self
.
braodcast_valid_duration
(
valid_duration
)
audio_data
=
self
.
bytes_to_ndarray
(
audio_data
)
return
audio_data
,
valid_duration
def
get_immediate_switch
(
self
):
if
self
.
rank
==
self
.
target_rank
:
if
self
.
chat_adapter
is
not
None
and
self
.
chat_adapter
.
immediate_switch
==
1
:
self
.
immediate_switch_tensor
.
fill_
(
1
)
# reset immediate switch
self
.
chat_adapter
.
immediate_switch
=
0
else
:
self
.
immediate_switch_tensor
.
fill_
(
0
)
if
self
.
world_size
>
1
:
dist
.
broadcast
(
self
.
immediate_switch_tensor
,
src
=
self
.
target_rank
)
return
self
.
immediate_switch_tensor
.
item
()
def
get_image_switch
(
self
):
data
=
""
if
self
.
chat_adapter
is
None
else
self
.
chat_adapter
.
image_switch
image_switch
=
self
.
broadcast_data
(
data
)
# reset image switch
if
self
.
chat_adapter
is
not
None
:
self
.
chat_adapter
.
image_switch
=
""
return
image_switch
def
get_action_switch
(
self
):
data
=
""
if
self
.
chat_adapter
is
None
else
self
.
chat_adapter
.
action_switch
action_switch
=
self
.
broadcast_data
(
data
)
# reset action switch
if
self
.
chat_adapter
is
not
None
:
self
.
chat_adapter
.
action_switch
=
""
return
action_switch
def
broadcast_data
(
self
,
data
):
if
self
.
world_size
<=
1
:
return
data
if
self
.
rank
==
self
.
target_rank
:
val
=
json
.
dumps
(
data
,
ensure_ascii
=
False
).
encode
(
"utf-8"
)
T
=
torch
.
frombuffer
(
bytearray
(
val
),
dtype
=
torch
.
uint8
).
to
(
device
=
"cuda"
)
S
=
torch
.
tensor
([
T
.
shape
[
0
]],
dtype
=
torch
.
int32
).
to
(
device
=
"cuda"
)
else
:
S
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
dist
.
broadcast
(
S
,
src
=
self
.
target_rank
)
if
self
.
rank
!=
self
.
target_rank
:
T
=
torch
.
zeros
(
S
.
item
(),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
dist
.
broadcast
(
T
,
src
=
self
.
target_rank
)
if
self
.
rank
!=
self
.
target_rank
:
val
=
T
.
cpu
().
numpy
().
tobytes
()
data
=
json
.
loads
(
val
.
decode
(
"utf-8"
))
return
data
def
stop
(
self
):
self
.
model_runner
=
None
if
self
.
chat_adapter
is
not
None
:
self
.
chat_adapter
.
stop
()
self
.
chat_adapter
=
None
logger
.
warning
(
"OmniVAReader stopped"
)
def
__del__
(
self
):
self
.
stop
()
if
__name__
==
"__main__"
:
WORLD_SIZE
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
RANK
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
if
WORLD_SIZE
>
1
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
logger
.
info
(
f
"Distributed initialized: rank=
{
RANK
}
, world_size=
{
WORLD_SIZE
}
"
)
reader
=
OmniVAReader
(
RANK
,
WORLD_SIZE
,
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=publish&stream=test_stream_ll&eip=10.120.114.82:8000"
,
segment_duration
=
17
/
16
,
sample_rate
=
16000
,
audio_channels
=
1
,
prev_duration
=
1
/
16
,
)
reader
.
start
()
fail_count
=
0
max_fail_count
=
100000000
try
:
while
True
:
audio_data
=
reader
.
get_audio_segment
(
timeout
=
1
)
if
audio_data
is
not
None
:
logger
.
info
(
f
"Got audio chunk, shape:
{
audio_data
.
shape
}
, range: [
{
audio_data
.
min
()
}
,
{
audio_data
.
max
()
}
]"
)
fail_count
=
0
else
:
fail_count
+=
1
if
fail_count
>
max_fail_count
:
logger
.
warning
(
"Failed to get audio chunk, stop reader"
)
reader
.
stop
()
break
time
.
sleep
(
0.95
)
finally
:
reader
.
stop
()
lightx2v/deploy/common/va_recorder.py
0 → 100644
View file @
e2778d0d
import
math
import
os
import
queue
import
socket
import
subprocess
import
threading
import
time
import
traceback
import
numpy
as
np
import
torch
import
torchaudio
as
ta
from
loguru
import
logger
def
pseudo_random
(
a
,
b
):
x
=
str
(
time
.
time
()).
split
(
"."
)[
1
]
y
=
int
(
float
(
"0."
+
x
)
*
1000000
)
return
a
+
(
y
%
(
b
-
a
+
1
))
class
VARecorder
:
def
__init__
(
self
,
livestream_url
:
str
,
fps
:
float
=
16.0
,
sample_rate
:
int
=
16000
,
slice_frame
:
int
=
1
,
prev_frame
:
int
=
1
,
stream_config
:
dict
=
{},
):
self
.
livestream_url
=
livestream_url
self
.
stream_config
=
stream_config
self
.
fps
=
fps
self
.
sample_rate
=
sample_rate
self
.
audio_port
=
pseudo_random
(
32000
,
40000
)
self
.
video_port
=
self
.
audio_port
+
1
self
.
ffmpeg_log_level
=
os
.
getenv
(
"FFMPEG_LOG_LEVEL"
,
"error"
)
logger
.
info
(
f
"VARecorder audio port:
{
self
.
audio_port
}
, video port:
{
self
.
video_port
}
, ffmpeg_log_level:
{
self
.
ffmpeg_log_level
}
"
)
self
.
width
=
None
self
.
height
=
None
self
.
stoppable_t
=
None
self
.
realtime
=
False
if
self
.
livestream_url
.
startswith
(
"rtmp://"
)
or
self
.
livestream_url
.
startswith
(
"http"
):
self
.
realtime
=
True
# ffmpeg process for mix video and audio data and push to livestream
self
.
ffmpeg_process
=
None
# TCP connection objects
self
.
audio_socket
=
None
self
.
video_socket
=
None
self
.
audio_conn
=
None
self
.
video_conn
=
None
self
.
audio_thread
=
None
self
.
video_thread
=
None
# queue for send data to ffmpeg process
self
.
audio_queue
=
queue
.
Queue
()
self
.
video_queue
=
queue
.
Queue
()
# buffer for stream data
self
.
audio_samples_per_frame
=
round
(
self
.
sample_rate
/
self
.
fps
)
self
.
stream_buffer
=
[]
self
.
stream_buffer_lock
=
threading
.
Lock
()
self
.
stop_schedule
=
False
self
.
schedule_thread
=
None
self
.
slice_frame
=
slice_frame
self
.
prev_frame
=
prev_frame
assert
self
.
slice_frame
>=
self
.
prev_frame
,
"Slice frame must be greater than previous frame"
def
init_sockets
(
self
):
# TCP socket for send and recv video and audio data
self
.
video_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
video_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
self
.
video_socket
.
setsockopt
(
socket
.
IPPROTO_TCP
,
socket
.
TCP_NODELAY
,
1
)
self
.
video_socket
.
bind
((
"127.0.0.1"
,
self
.
video_port
))
self
.
video_socket
.
listen
(
1
)
self
.
audio_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
audio_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
self
.
audio_socket
.
setsockopt
(
socket
.
IPPROTO_TCP
,
socket
.
TCP_NODELAY
,
1
)
self
.
audio_socket
.
bind
((
"127.0.0.1"
,
self
.
audio_port
))
self
.
audio_socket
.
listen
(
1
)
def
audio_worker
(
self
):
try
:
logger
.
info
(
"Waiting for ffmpeg to connect to audio socket..."
)
self
.
audio_conn
,
_
=
self
.
audio_socket
.
accept
()
logger
.
info
(
f
"Audio connection established from
{
self
.
audio_conn
.
getpeername
()
}
"
)
fail_time
,
max_fail_time
=
0
,
10
while
True
:
try
:
if
self
.
audio_queue
is
None
:
break
data
=
self
.
audio_queue
.
get
()
if
data
is
None
:
logger
.
info
(
"Audio thread received stop signal"
)
break
# Convert audio data to 16-bit integer format
audios
=
torch
.
clamp
(
torch
.
round
(
data
*
32767
),
-
32768
,
32767
).
to
(
torch
.
int16
)
try
:
self
.
audio_conn
.
send
(
audios
[
None
].
cpu
().
numpy
().
tobytes
())
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
)
as
e
:
logger
.
info
(
f
"Audio connection closed, stopping worker:
{
type
(
e
).
__name__
}
"
)
return
fail_time
=
0
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
):
logger
.
info
(
"Audio connection closed during queue processing"
)
break
except
Exception
:
logger
.
error
(
f
"Send audio data error:
{
traceback
.
format_exc
()
}
"
)
fail_time
+=
1
if
fail_time
>
max_fail_time
:
logger
.
error
(
f
"Audio push worker thread failed
{
fail_time
}
times, stopping..."
)
break
except
Exception
:
logger
.
error
(
f
"Audio push worker thread error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
info
(
"Audio push worker thread stopped"
)
def
video_worker
(
self
):
try
:
logger
.
info
(
"Waiting for ffmpeg to connect to video socket..."
)
self
.
video_conn
,
_
=
self
.
video_socket
.
accept
()
logger
.
info
(
f
"Video connection established from
{
self
.
video_conn
.
getpeername
()
}
"
)
fail_time
,
max_fail_time
=
0
,
10
packet_secs
=
1.0
/
self
.
fps
while
True
:
try
:
if
self
.
video_queue
is
None
:
break
data
=
self
.
video_queue
.
get
()
if
data
is
None
:
logger
.
info
(
"Video thread received stop signal"
)
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
for
i
in
range
(
data
.
shape
[
0
]):
t0
=
time
.
time
()
frame
=
(
data
[
i
]
*
255
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
try
:
self
.
video_conn
.
send
(
frame
.
tobytes
())
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
)
as
e
:
logger
.
info
(
f
"Video connection closed, stopping worker:
{
type
(
e
).
__name__
}
"
)
return
if
self
.
realtime
and
i
<
data
.
shape
[
0
]
-
1
:
time
.
sleep
(
max
(
0
,
packet_secs
-
(
time
.
time
()
-
t0
)))
fail_time
=
0
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
):
logger
.
info
(
"Video connection closed during queue processing"
)
break
except
Exception
:
logger
.
error
(
f
"Send video data error:
{
traceback
.
format_exc
()
}
"
)
fail_time
+=
1
if
fail_time
>
max_fail_time
:
logger
.
error
(
f
"Video push worker thread failed
{
fail_time
}
times, stopping..."
)
break
except
Exception
:
logger
.
error
(
f
"Video push worker thread error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
info
(
"Video push worker thread stopped"
)
def
start_ffmpeg_process_local
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-fflags"
,
"nobuffer"
,
"-analyzeduration"
,
"0"
,
"-probesize"
,
"32"
,
"-flush_packets"
,
"1"
,
"-f"
,
"s16le"
,
"-ar"
,
str
(
self
.
sample_rate
),
"-ac"
,
"1"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
audio_port
}
"
,
"-f"
,
"rawvideo"
,
"-pix_fmt"
,
"rgb24"
,
"-color_range"
,
"pc"
,
"-colorspace"
,
"rgb"
,
"-color_primaries"
,
"bt709"
,
"-color_trc"
,
"iec61966-2-1"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-ar"
,
"44100"
,
"-b:v"
,
"4M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-f"
,
"mp4"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start_ffmpeg_process_rtmp
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-re"
,
"-f"
,
"s16le"
,
"-ar"
,
str
(
self
.
sample_rate
),
"-ac"
,
"1"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
audio_port
}
"
,
"-f"
,
"rawvideo"
,
"-re"
,
"-pix_fmt"
,
"rgb24"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-ar"
,
"44100"
,
"-b:v"
,
"2M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-f"
,
"flv"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start_ffmpeg_process_whip
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-re"
,
"-fflags"
,
"nobuffer"
,
"-analyzeduration"
,
"0"
,
"-probesize"
,
"32"
,
"-flush_packets"
,
"1"
,
"-f"
,
"s16le"
,
"-ar"
,
str
(
self
.
sample_rate
),
"-ac"
,
"1"
,
"-ch_layout"
,
"mono"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
audio_port
}
"
,
"-f"
,
"rawvideo"
,
"-re"
,
"-pix_fmt"
,
"rgb24"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-ar"
,
"48000"
,
"-c:a"
,
"libopus"
,
"-ac"
,
"2"
,
"-b:v"
,
"2M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-threads"
,
"1"
,
"-bf"
,
"0"
,
"-f"
,
"whip"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start
(
self
,
width
:
int
,
height
:
int
):
self
.
set_video_size
(
width
,
height
)
duration
=
1.0
frames
=
int
(
self
.
fps
*
duration
)
samples
=
int
(
self
.
sample_rate
*
(
frames
/
self
.
fps
))
tensor
=
torch
.
zeros
((
frames
,
height
,
width
,
3
),
dtype
=
torch
.
float16
)
self
.
pub_livestream
(
tensor
,
torch
.
zeros
(
samples
,
dtype
=
torch
.
float16
))
time
.
sleep
(
duration
)
def
config_video_padding
(
self
):
pass
def
padding_video_frames
(
self
,
frames
:
torch
.
Tensor
):
return
frames
def
try_init_sockets
(
self
,
max_try
=
10
):
for
i
in
range
(
max_try
):
try
:
self
.
init_sockets
()
return
True
except
OSError
:
self
.
audio_port
=
pseudo_random
(
32000
,
40000
)
self
.
video_port
=
self
.
audio_port
+
1
logger
.
warning
(
f
"Failed to initialize sockets
{
i
+
1
}
/
{
max_try
}
:
{
traceback
.
format_exc
()
}
"
)
logger
.
warning
(
f
"change port to
{
self
.
audio_port
}
and
{
self
.
video_port
}
, retry ..."
)
def
set_video_size
(
self
,
width
:
int
,
height
:
int
):
if
self
.
width
is
not
None
and
self
.
height
is
not
None
:
assert
self
.
width
==
width
and
self
.
height
==
height
,
"Video size already set"
return
self
.
width
=
width
self
.
height
=
height
self
.
config_video_padding
()
self
.
try_init_sockets
()
if
self
.
livestream_url
.
startswith
(
"rtmp://"
):
self
.
start_ffmpeg_process_rtmp
()
elif
self
.
livestream_url
.
startswith
(
"http"
):
self
.
start_ffmpeg_process_whip
()
else
:
self
.
start_ffmpeg_process_local
()
self
.
audio_thread
=
threading
.
Thread
(
target
=
self
.
audio_worker
)
self
.
video_thread
=
threading
.
Thread
(
target
=
self
.
video_worker
)
self
.
audio_thread
.
start
()
self
.
video_thread
.
start
()
if
self
.
realtime
:
self
.
schedule_thread
=
threading
.
Thread
(
target
=
self
.
schedule_stream_buffer
)
self
.
schedule_thread
.
start
()
# Publish ComfyUI Image tensor and audio tensor to livestream
def
pub_livestream
(
self
,
images
:
torch
.
Tensor
,
audios
:
torch
.
Tensor
):
N
,
height
,
width
,
C
=
images
.
shape
M
=
audios
.
reshape
(
-
1
).
shape
[
0
]
assert
C
==
3
,
"Input must be [N, H, W, C] with C=3"
logger
.
info
(
f
"Publishing video [
{
N
}
x
{
width
}
x
{
height
}
], audio: [
{
M
}
]"
)
audio_frames
=
round
(
M
*
self
.
fps
/
self
.
sample_rate
)
if
audio_frames
!=
N
:
logger
.
warning
(
f
"Video and audio frames mismatch,
{
N
}
vs
{
audio_frames
}
"
)
self
.
set_video_size
(
width
,
height
)
self
.
audio_queue
.
put
(
audios
)
self
.
video_queue
.
put
(
self
.
padding_video_frames
(
images
))
logger
.
info
(
f
"Published
{
N
}
frames and
{
M
}
audio samples"
)
self
.
stoppable_t
=
time
.
time
()
+
M
/
self
.
sample_rate
+
3
def
buffer_stream
(
self
,
images
:
torch
.
Tensor
,
audios
:
torch
.
Tensor
,
gen_video
:
torch
.
Tensor
,
valid_duration
=
1e9
):
N
,
height
,
width
,
C
=
images
.
shape
M
=
audios
.
reshape
(
-
1
).
shape
[
0
]
assert
N
%
self
.
slice_frame
==
0
,
"Video frames must be divisible by slice_frame"
assert
C
==
3
,
"Input must be [N, H, W, C] with C=3"
audio_frames
=
round
(
M
*
self
.
fps
/
self
.
sample_rate
)
if
audio_frames
!=
N
:
logger
.
warning
(
f
"Video and audio frames mismatch,
{
N
}
vs
{
audio_frames
}
"
)
self
.
set_video_size
(
width
,
height
)
valid_frames
=
math
.
ceil
(
valid_duration
*
self
.
fps
)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets
=
[]
for
i
in
range
(
0
,
N
,
self
.
slice_frame
):
end_frame
=
i
+
self
.
slice_frame
can_truncate
=
valid_frames
<
end_frame
img
=
self
.
padding_video_frames
(
images
[
i
:
end_frame
])
aud
=
audios
[
i
*
self
.
audio_samples_per_frame
:
end_frame
*
self
.
audio_samples_per_frame
]
gen
=
gen_video
[:,
:,
(
end_frame
-
self
.
prev_frame
)
:
end_frame
]
rets
.
append
([
img
,
aud
,
gen
,
can_truncate
])
with
self
.
stream_buffer_lock
:
origin_size
=
len
(
self
.
stream_buffer
)
self
.
stream_buffer
.
extend
(
rets
)
logger
.
info
(
f
"Buffered
{
origin_size
}
+
{
len
(
rets
)
}
=
{
len
(
self
.
stream_buffer
)
}
stream segments, valid_frames:
{
valid_frames
}
"
)
def
get_buffer_stream_size
(
self
):
return
len
(
self
.
stream_buffer
)
def
truncate_stream_buffer
(
self
,
size
:
int
,
check_can_truncate
:
bool
=
True
):
with
self
.
stream_buffer_lock
:
# find the first frame that cannot not be truncated
idx
=
len
(
self
.
stream_buffer
)
-
1
while
check_can_truncate
and
idx
>=
size
and
idx
>=
0
:
if
not
self
.
stream_buffer
[
idx
][
3
]:
logger
.
warning
(
f
"can not truncate frame:
{
idx
}
, trucecate size:
{
size
}
->
{
idx
+
1
}
"
)
size
=
idx
+
1
break
idx
-=
1
self
.
stream_buffer
=
self
.
stream_buffer
[:
size
]
logger
.
info
(
f
"Truncated stream buffer to
{
len
(
self
.
stream_buffer
)
}
segments"
)
if
len
(
self
.
stream_buffer
)
>
0
:
# after truncate, set the last segment can not be truncated
self
.
stream_buffer
[
-
1
][
3
]
=
False
return
self
.
stream_buffer
[
-
1
][
2
]
# return the last video tensor
else
:
return
None
def
schedule_stream_buffer
(
self
):
schedule_interval
=
self
.
slice_frame
/
self
.
fps
logger
.
info
(
f
"Schedule stream buffer with interval:
{
schedule_interval
}
seconds"
)
t
=
None
fail_time
=
0
while
True
:
try
:
if
self
.
stop_schedule
:
break
img
,
aud
,
gen
=
None
,
None
,
None
with
self
.
stream_buffer_lock
:
if
len
(
self
.
stream_buffer
)
>
0
:
img
,
aud
,
gen
,
_
=
self
.
stream_buffer
.
pop
(
0
)
if
t
is
not
None
:
wait_secs
=
schedule_interval
-
(
time
.
time
()
-
t
)
if
wait_secs
>
0
:
time
.
sleep
(
wait_secs
)
t
=
time
.
time
()
if
img
is
not
None
and
aud
is
not
None
:
fail_time
=
0
self
.
audio_queue
.
put
(
aud
)
self
.
video_queue
.
put
(
img
)
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del
gen
self
.
stoppable_t
=
time
.
time
()
+
aud
.
shape
[
0
]
/
self
.
sample_rate
+
3
else
:
fail_time
+=
1
if
fail_time
%
10
==
0
:
logger
.
warning
(
f
"No stream buffer to schedule:
{
fail_time
}
times"
)
except
Exception
:
logger
.
error
(
f
"Schedule stream buffer error:
{
traceback
.
format_exc
()
}
"
)
break
logger
.
info
(
"Schedule stream buffer thread stopped"
)
def
stop
(
self
,
wait
=
True
):
if
wait
and
self
.
stoppable_t
:
t
=
self
.
stoppable_t
-
time
.
time
()
if
t
>
0
:
logger
.
warning
(
f
"Waiting for
{
t
}
seconds to stop ..."
)
time
.
sleep
(
t
)
self
.
stoppable_t
=
None
if
self
.
schedule_thread
:
self
.
stop_schedule
=
True
self
.
schedule_thread
.
join
(
timeout
=
5
)
if
self
.
schedule_thread
and
self
.
schedule_thread
.
is_alive
():
logger
.
error
(
f
"Schedule thread did not stop after 5s"
)
# Send stop signals to queues
if
self
.
audio_queue
:
self
.
audio_queue
.
put
(
None
)
if
self
.
video_queue
:
self
.
video_queue
.
put
(
None
)
# Wait for threads to finish processing queued data (increased timeout)
queue_timeout
=
30
# Increased from 5s to 30s to allow sufficient time for large video frames
if
self
.
audio_thread
and
self
.
audio_thread
.
is_alive
():
self
.
audio_thread
.
join
(
timeout
=
queue_timeout
)
if
self
.
audio_thread
.
is_alive
():
logger
.
error
(
f
"Audio push thread did not stop after
{
queue_timeout
}
s"
)
if
self
.
video_thread
and
self
.
video_thread
.
is_alive
():
self
.
video_thread
.
join
(
timeout
=
queue_timeout
)
if
self
.
video_thread
.
is_alive
():
logger
.
error
(
f
"Video push thread did not stop after
{
queue_timeout
}
s"
)
# Shutdown connections to signal EOF to FFmpeg
# shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
if
self
.
audio_conn
:
try
:
self
.
audio_conn
.
getpeername
()
self
.
audio_conn
.
shutdown
(
socket
.
SHUT_WR
)
logger
.
info
(
"Audio connection shutdown initiated"
)
except
OSError
:
# Connection already closed, skip shutdown
pass
if
self
.
video_conn
:
try
:
self
.
video_conn
.
getpeername
()
self
.
video_conn
.
shutdown
(
socket
.
SHUT_WR
)
logger
.
info
(
"Video connection shutdown initiated"
)
except
OSError
:
# Connection already closed, skip shutdown
pass
if
self
.
ffmpeg_process
:
is_local_file
=
not
self
.
livestream_url
.
startswith
((
"rtmp://"
,
"http"
))
# Local MP4 files need time to write moov atom and finalize the container
timeout_seconds
=
30
if
is_local_file
else
10
logger
.
info
(
f
"Waiting for FFmpeg to finalize file (timeout=
{
timeout_seconds
}
s, local_file=
{
is_local_file
}
)"
)
logger
.
info
(
f
"FFmpeg output:
{
self
.
livestream_url
}
"
)
try
:
returncode
=
self
.
ffmpeg_process
.
wait
(
timeout
=
timeout_seconds
)
if
returncode
==
0
:
logger
.
info
(
f
"FFmpeg process exited successfully (exit code:
{
returncode
}
)"
)
else
:
logger
.
warning
(
f
"FFmpeg process exited with non-zero code:
{
returncode
}
"
)
except
subprocess
.
TimeoutExpired
:
logger
.
warning
(
f
"FFmpeg process did not exit within
{
timeout_seconds
}
s, sending SIGTERM..."
)
try
:
self
.
ffmpeg_process
.
terminate
()
# SIGTERM
returncode
=
self
.
ffmpeg_process
.
wait
(
timeout
=
5
)
logger
.
warning
(
f
"FFmpeg process terminated with SIGTERM (exit code:
{
returncode
}
)"
)
except
subprocess
.
TimeoutExpired
:
logger
.
error
(
"FFmpeg process still running after SIGTERM, killing with SIGKILL..."
)
self
.
ffmpeg_process
.
kill
()
self
.
ffmpeg_process
.
wait
()
# Wait for kill to complete
logger
.
error
(
"FFmpeg process killed with SIGKILL"
)
finally
:
self
.
ffmpeg_process
=
None
if
self
.
audio_conn
:
try
:
self
.
audio_conn
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing audio connection:
{
e
}
"
)
finally
:
self
.
audio_conn
=
None
if
self
.
video_conn
:
try
:
self
.
video_conn
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing video connection:
{
e
}
"
)
finally
:
self
.
video_conn
=
None
if
self
.
audio_socket
:
try
:
self
.
audio_socket
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing audio socket:
{
e
}
"
)
finally
:
self
.
audio_socket
=
None
if
self
.
video_socket
:
try
:
self
.
video_socket
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing video socket:
{
e
}
"
)
finally
:
self
.
video_socket
=
None
if
self
.
audio_queue
:
while
self
.
audio_queue
.
qsize
()
>
0
:
try
:
self
.
audio_queue
.
get_nowait
()
except
:
# noqa
break
if
self
.
video_queue
:
while
self
.
video_queue
.
qsize
()
>
0
:
try
:
self
.
video_queue
.
get_nowait
()
except
:
# noqa
break
self
.
audio_queue
=
None
self
.
video_queue
=
None
logger
.
info
(
"VARecorder stopped and resources cleaned up"
)
def
__del__
(
self
):
self
.
stop
(
wait
=
False
)
def
create_simple_video
(
frames
=
10
,
height
=
480
,
width
=
640
):
video_data
=
[]
for
i
in
range
(
frames
):
frame
=
np
.
zeros
((
height
,
width
,
3
),
dtype
=
np
.
float32
)
stripe_height
=
height
//
8
colors
=
[
[
1.0
,
0.0
,
0.0
],
# 红色
[
0.0
,
1.0
,
0.0
],
# 绿色
[
0.0
,
0.0
,
1.0
],
# 蓝色
[
1.0
,
1.0
,
0.0
],
# 黄色
[
1.0
,
0.0
,
1.0
],
# 洋红
[
0.0
,
1.0
,
1.0
],
# 青色
[
1.0
,
1.0
,
1.0
],
# 白色
[
0.5
,
0.5
,
0.5
],
# 灰色
]
for
j
,
color
in
enumerate
(
colors
):
start_y
=
j
*
stripe_height
end_y
=
min
((
j
+
1
)
*
stripe_height
,
height
)
frame
[
start_y
:
end_y
,
:]
=
color
offset
=
int
((
i
/
frames
)
*
width
)
frame
=
np
.
roll
(
frame
,
offset
,
axis
=
1
)
frame
=
torch
.
tensor
(
frame
,
dtype
=
torch
.
float32
)
video_data
.
append
(
frame
)
return
torch
.
stack
(
video_data
,
dim
=
0
)
if
__name__
==
"__main__"
:
sample_rate
=
16000
fps
=
16
width
=
640
height
=
480
recorder
=
VARecorder
(
# livestream_url="rtmp://localhost/live/test",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url
=
"/path/to/output_video.mp4"
,
fps
=
fps
,
sample_rate
=
sample_rate
,
)
audio_path
=
"/path/to/test_b_2min.wav"
audio_array
,
ori_sr
=
ta
.
load
(
audio_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
16000
)
audio_array
=
audio_array
.
reshape
(
-
1
)
secs
=
audio_array
.
shape
[
0
]
//
sample_rate
interval
=
1
for
i
in
range
(
0
,
secs
,
interval
):
logger
.
info
(
f
"
{
i
}
/
{
secs
}
s"
)
start
=
i
*
sample_rate
end
=
(
i
+
interval
)
*
sample_rate
cur_audio_array
=
audio_array
[
start
:
end
]
logger
.
info
(
f
"audio:
{
cur_audio_array
.
shape
}
{
cur_audio_array
.
dtype
}
{
cur_audio_array
.
min
()
}
{
cur_audio_array
.
max
()
}
"
)
num_frames
=
int
(
interval
*
fps
)
images
=
create_simple_video
(
num_frames
,
height
,
width
)
logger
.
info
(
f
"images:
{
images
.
shape
}
{
images
.
dtype
}
{
images
.
min
()
}
{
images
.
max
()
}
"
)
recorder
.
pub_livestream
(
images
,
cur_audio_array
)
time
.
sleep
(
interval
)
recorder
.
stop
()
lightx2v/deploy/common/va_recorder_x264.py
0 → 100644
View file @
e2778d0d
import
ctypes
import
queue
import
threading
import
time
import
traceback
import
numpy
as
np
import
torch
import
torchaudio
as
ta
from
loguru
import
logger
from
scipy.signal
import
resample
class
X264VARecorder
:
def
__init__
(
self
,
whip_shared_path
:
str
,
livestream_url
:
str
,
fps
:
float
=
16.0
,
sample_rate
:
int
=
16000
,
slice_frame
:
int
=
1
,
prev_frame
:
int
=
1
,
):
assert
livestream_url
.
startswith
(
"http"
),
"X264VARecorder only support whip http livestream"
self
.
livestream_url
=
livestream_url
self
.
fps
=
fps
self
.
sample_rate
=
sample_rate
self
.
width
=
None
self
.
height
=
None
self
.
stoppable_t
=
None
# only enable whip shared api for whip http livestream
self
.
whip_shared_path
=
whip_shared_path
self
.
whip_shared_lib
=
None
self
.
whip_shared_handle
=
None
assert
livestream_url
.
startswith
(
"http"
),
"X264VARecorder only support whip http livestream"
self
.
realtime
=
True
# queue for send data to whip shared api
self
.
queue
=
queue
.
Queue
()
self
.
worker_thread
=
None
# buffer for stream data
self
.
target_sample_rate
=
48000
self
.
target_samples_per_frame
=
round
(
self
.
target_sample_rate
/
self
.
fps
)
self
.
target_chunks_per_frame
=
self
.
target_samples_per_frame
*
2
self
.
stream_buffer
=
[]
self
.
stream_buffer_lock
=
threading
.
Lock
()
self
.
stop_schedule
=
False
self
.
schedule_thread
=
None
self
.
slice_frame
=
slice_frame
self
.
prev_frame
=
prev_frame
assert
self
.
slice_frame
>=
self
.
prev_frame
,
"Slice frame must be greater than previous frame"
def
worker
(
self
):
try
:
fail_time
,
max_fail_time
=
0
,
10
packet_secs
=
1.0
/
self
.
fps
while
True
:
try
:
if
self
.
queue
is
None
:
break
data
=
self
.
queue
.
get
()
if
data
is
None
:
logger
.
info
(
"Worker thread received stop signal"
)
break
audios
,
images
=
data
for
i
in
range
(
images
.
shape
[
0
]):
t0
=
time
.
time
()
cur_audio
=
audios
[
i
*
self
.
target_chunks_per_frame
:
(
i
+
1
)
*
self
.
target_chunks_per_frame
].
flatten
()
audio_ptr
=
cur_audio
.
ctypes
.
data_as
(
ctypes
.
POINTER
(
ctypes
.
c_int16
))
self
.
whip_shared_lib
.
pushWhipRawAudioFrame
(
self
.
whip_shared_handle
,
audio_ptr
,
self
.
target_samples_per_frame
)
cur_video
=
images
[
i
].
flatten
()
video_ptr
=
cur_video
.
ctypes
.
data_as
(
ctypes
.
POINTER
(
ctypes
.
c_uint8
))
self
.
whip_shared_lib
.
pushWhipRawVideoFrame
(
self
.
whip_shared_handle
,
video_ptr
,
self
.
width
,
self
.
height
)
if
self
.
realtime
and
i
<
images
.
shape
[
0
]
-
1
:
time
.
sleep
(
max
(
0
,
packet_secs
-
(
time
.
time
()
-
t0
)))
fail_time
=
0
except
:
# noqa
logger
.
error
(
f
"Send audio data error:
{
traceback
.
format_exc
()
}
"
)
fail_time
+=
1
if
fail_time
>
max_fail_time
:
logger
.
error
(
f
"Audio push worker thread failed
{
fail_time
}
times, stopping..."
)
break
except
:
# noqa
logger
.
error
(
f
"Audio push worker thread error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
info
(
"Audio push worker thread stopped"
)
def
start_libx264_whip_shared_api
(
self
,
width
:
int
,
height
:
int
):
self
.
whip_shared_lib
=
ctypes
.
CDLL
(
self
.
whip_shared_path
)
# define function argtypes and restype
self
.
whip_shared_lib
.
initWhipStream
.
argtypes
=
[
ctypes
.
c_char_p
,
ctypes
.
c_int
,
ctypes
.
c_int
,
ctypes
.
c_int
,
ctypes
.
c_int
,
ctypes
.
c_int
]
self
.
whip_shared_lib
.
initWhipStream
.
restype
=
ctypes
.
c_void_p
self
.
whip_shared_lib
.
pushWhipRawAudioFrame
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
POINTER
(
ctypes
.
c_int16
),
ctypes
.
c_int
]
self
.
whip_shared_lib
.
pushWhipRawVideoFrame
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
POINTER
(
ctypes
.
c_uint8
),
ctypes
.
c_int
,
ctypes
.
c_int
]
self
.
whip_shared_lib
.
destroyWhipStream
.
argtypes
=
[
ctypes
.
c_void_p
]
whip_url
=
ctypes
.
c_char_p
(
self
.
livestream_url
.
encode
(
"utf-8"
))
self
.
whip_shared_handle
=
ctypes
.
c_void_p
(
self
.
whip_shared_lib
.
initWhipStream
(
whip_url
,
1
,
1
,
0
,
width
,
height
))
logger
.
info
(
f
"WHIP shared API initialized with handle:
{
self
.
whip_shared_handle
}
"
)
def
convert_data
(
self
,
audios
,
images
):
# Convert audio data to 16-bit integer format
audio_datas
=
torch
.
clamp
(
torch
.
round
(
audios
*
32767
),
-
32768
,
32767
).
to
(
torch
.
int16
).
cpu
().
numpy
().
reshape
(
-
1
)
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
image_datas
=
(
images
*
255
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
logger
.
info
(
f
"image_datas:
{
image_datas
.
shape
}
{
image_datas
.
dtype
}
{
image_datas
.
min
()
}
{
image_datas
.
max
()
}
"
)
reample_audios
=
resample
(
audio_datas
,
int
(
len
(
audio_datas
)
*
48000
/
self
.
sample_rate
))
stereo_audios
=
np
.
stack
([
reample_audios
,
reample_audios
],
axis
=-
1
).
astype
(
np
.
int16
).
reshape
(
-
1
)
return
stereo_audios
,
image_datas
def
start
(
self
,
width
:
int
,
height
:
int
):
self
.
set_video_size
(
width
,
height
)
def
set_video_size
(
self
,
width
:
int
,
height
:
int
):
if
self
.
width
is
not
None
and
self
.
height
is
not
None
:
assert
self
.
width
==
width
and
self
.
height
==
height
,
"Video size already set"
return
self
.
width
=
width
self
.
height
=
height
self
.
start_libx264_whip_shared_api
(
width
,
height
)
self
.
worker_thread
=
threading
.
Thread
(
target
=
self
.
worker
)
self
.
worker_thread
.
start
()
if
self
.
realtime
:
self
.
schedule_thread
=
threading
.
Thread
(
target
=
self
.
schedule_stream_buffer
)
self
.
schedule_thread
.
start
()
def
buffer_stream
(
self
,
images
:
torch
.
Tensor
,
audios
:
torch
.
Tensor
,
gen_video
:
torch
.
Tensor
,
valid_duration
=
1e9
):
N
,
height
,
width
,
C
=
images
.
shape
M
=
audios
.
reshape
(
-
1
).
shape
[
0
]
assert
N
%
self
.
slice_frame
==
0
,
"Video frames must be divisible by slice_frame"
assert
C
==
3
,
"Input must be [N, H, W, C] with C=3"
audio_frames
=
round
(
M
*
self
.
fps
/
self
.
sample_rate
)
if
audio_frames
!=
N
:
logger
.
warning
(
f
"Video and audio frames mismatch,
{
N
}
vs
{
audio_frames
}
"
)
self
.
set_video_size
(
width
,
height
)
audio_datas
,
image_datas
=
self
.
convert_data
(
audios
,
images
)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets
=
[]
for
i
in
range
(
0
,
N
,
self
.
slice_frame
):
end_frame
=
i
+
self
.
slice_frame
img
=
image_datas
[
i
:
end_frame
]
aud
=
audio_datas
[
i
*
self
.
target_chunks_per_frame
:
end_frame
*
self
.
target_chunks_per_frame
]
gen
=
gen_video
[:,
:,
(
end_frame
-
self
.
prev_frame
)
:
end_frame
]
rets
.
append
((
img
,
aud
,
gen
))
with
self
.
stream_buffer_lock
:
origin_size
=
len
(
self
.
stream_buffer
)
self
.
stream_buffer
.
extend
(
rets
)
logger
.
info
(
f
"Buffered
{
origin_size
}
+
{
len
(
rets
)
}
=
{
len
(
self
.
stream_buffer
)
}
stream segments"
)
def
get_buffer_stream_size
(
self
):
return
len
(
self
.
stream_buffer
)
def
truncate_stream_buffer
(
self
,
size
:
int
):
with
self
.
stream_buffer_lock
:
self
.
stream_buffer
=
self
.
stream_buffer
[:
size
]
logger
.
info
(
f
"Truncated stream buffer to
{
len
(
self
.
stream_buffer
)
}
segments"
)
if
len
(
self
.
stream_buffer
)
>
0
:
return
self
.
stream_buffer
[
-
1
][
2
]
# return the last video tensor
else
:
return
None
def
schedule_stream_buffer
(
self
):
schedule_interval
=
self
.
slice_frame
/
self
.
fps
logger
.
info
(
f
"Schedule stream buffer with interval:
{
schedule_interval
}
seconds"
)
t
=
None
while
True
:
try
:
if
self
.
stop_schedule
:
break
img
,
aud
,
gen
=
None
,
None
,
None
with
self
.
stream_buffer_lock
:
if
len
(
self
.
stream_buffer
)
>
0
:
img
,
aud
,
gen
=
self
.
stream_buffer
.
pop
(
0
)
if
t
is
not
None
:
wait_secs
=
schedule_interval
-
(
time
.
time
()
-
t
)
if
wait_secs
>
0
:
time
.
sleep
(
wait_secs
)
t
=
time
.
time
()
if
img
is
not
None
and
aud
is
not
None
:
self
.
queue
.
put
((
aud
,
img
))
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del
gen
self
.
stoppable_t
=
time
.
time
()
+
img
.
shape
[
0
]
/
self
.
fps
+
3
else
:
logger
.
warning
(
f
"No stream buffer to schedule"
)
except
Exception
:
logger
.
error
(
f
"Schedule stream buffer error:
{
traceback
.
format_exc
()
}
"
)
break
logger
.
info
(
"Schedule stream buffer thread stopped"
)
def
stop
(
self
,
wait
=
True
):
if
wait
and
self
.
stoppable_t
:
t
=
self
.
stoppable_t
-
time
.
time
()
if
t
>
0
:
logger
.
warning
(
f
"Waiting for
{
t
}
seconds to stop ..."
)
time
.
sleep
(
t
)
self
.
stoppable_t
=
None
if
self
.
schedule_thread
:
self
.
stop_schedule
=
True
self
.
schedule_thread
.
join
(
timeout
=
5
)
if
self
.
schedule_thread
and
self
.
schedule_thread
.
is_alive
():
logger
.
error
(
f
"Schedule thread did not stop after 5s"
)
# Send stop signals to queues
if
self
.
queue
:
self
.
queue
.
put
(
None
)
# Wait for threads to finish
if
self
.
worker_thread
and
self
.
worker_thread
.
is_alive
():
self
.
worker_thread
.
join
(
timeout
=
5
)
if
self
.
worker_thread
.
is_alive
():
logger
.
warning
(
"Worker thread did not stop gracefully"
)
# Destroy WHIP shared API
if
self
.
whip_shared_lib
and
self
.
whip_shared_handle
:
self
.
whip_shared_lib
.
destroyWhipStream
(
self
.
whip_shared_handle
)
self
.
whip_shared_handle
=
None
self
.
whip_shared_lib
=
None
logger
.
warning
(
"WHIP shared API destroyed"
)
def
__del__
(
self
):
self
.
stop
()
def
create_simple_video
(
frames
=
10
,
height
=
480
,
width
=
640
):
video_data
=
[]
for
i
in
range
(
frames
):
frame
=
np
.
zeros
((
height
,
width
,
3
),
dtype
=
np
.
float32
)
stripe_height
=
height
//
8
colors
=
[
[
1.0
,
0.0
,
0.0
],
# 红色
[
0.0
,
1.0
,
0.0
],
# 绿色
[
0.0
,
0.0
,
1.0
],
# 蓝色
[
1.0
,
1.0
,
0.0
],
# 黄色
[
1.0
,
0.0
,
1.0
],
# 洋红
[
0.0
,
1.0
,
1.0
],
# 青色
[
1.0
,
1.0
,
1.0
],
# 白色
[
0.5
,
0.5
,
0.5
],
# 灰色
]
for
j
,
color
in
enumerate
(
colors
):
start_y
=
j
*
stripe_height
end_y
=
min
((
j
+
1
)
*
stripe_height
,
height
)
frame
[
start_y
:
end_y
,
:]
=
color
offset
=
int
((
i
/
frames
)
*
width
)
frame
=
np
.
roll
(
frame
,
offset
,
axis
=
1
)
frame
=
torch
.
tensor
(
frame
,
dtype
=
torch
.
float32
)
video_data
.
append
(
frame
)
return
torch
.
stack
(
video_data
,
dim
=
0
)
if
__name__
==
"__main__"
:
sample_rate
=
16000
fps
=
16
width
=
452
height
=
352
recorder
=
X264VARecorder
(
whip_shared_path
=
"/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/0.1.1/go_whxp.so"
,
livestream_url
=
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll2&eip=10.120.114.82:8000"
,
fps
=
fps
,
sample_rate
=
sample_rate
,
)
recorder
.
start
(
width
,
height
)
# time.sleep(5)
audio_path
=
"/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav"
audio_array
,
ori_sr
=
ta
.
load
(
audio_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
16000
)
audio_array
=
audio_array
.
numpy
().
reshape
(
-
1
)
secs
=
audio_array
.
shape
[
0
]
//
sample_rate
interval
=
1
space
=
10
i
=
0
while
i
<
space
:
t0
=
time
.
time
()
logger
.
info
(
f
"space
{
i
}
/
{
space
}
s"
)
cur_audio_array
=
np
.
zeros
(
int
(
interval
*
sample_rate
),
dtype
=
np
.
float32
)
num_frames
=
int
(
interval
*
fps
)
images
=
create_simple_video
(
num_frames
,
height
,
width
)
recorder
.
buffer_stream
(
images
,
torch
.
tensor
(
cur_audio_array
,
dtype
=
torch
.
float32
),
images
)
i
+=
interval
time
.
sleep
(
interval
-
(
time
.
time
()
-
t0
))
started
=
True
i
=
0
while
i
<
secs
:
t0
=
time
.
time
()
start
=
int
(
i
*
sample_rate
)
end
=
int
((
i
+
interval
)
*
sample_rate
)
cur_audio_array
=
torch
.
tensor
(
audio_array
[
start
:
end
],
dtype
=
torch
.
float32
)
num_frames
=
int
(
interval
*
fps
)
images
=
create_simple_video
(
num_frames
,
height
,
width
)
logger
.
info
(
f
"
{
i
}
/
{
secs
}
s"
)
if
started
:
logger
.
warning
(
f
"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!"
)
started
=
False
recorder
.
buffer_stream
(
images
,
cur_audio_array
,
images
)
i
+=
interval
time
.
sleep
(
interval
-
(
time
.
time
()
-
t0
))
recorder
.
stop
()
lightx2v/deploy/common/video_recorder.py
0 → 100644
View file @
e2778d0d
import
os
import
queue
import
socket
import
subprocess
import
threading
import
time
import
traceback
import
numpy
as
np
import
torch
from
loguru
import
logger
def
pseudo_random
(
a
,
b
):
x
=
str
(
time
.
time
()).
split
(
"."
)[
1
]
y
=
int
(
float
(
"0."
+
x
)
*
1000000
)
return
a
+
(
y
%
(
b
-
a
+
1
))
class
VideoRecorder
:
def
__init__
(
self
,
livestream_url
:
str
,
fps
:
float
=
16.0
,
):
self
.
livestream_url
=
livestream_url
self
.
fps
=
fps
self
.
video_port
=
pseudo_random
(
32000
,
40000
)
self
.
ffmpeg_log_level
=
os
.
getenv
(
"FFMPEG_LOG_LEVEL"
,
"error"
)
logger
.
info
(
f
"VideoRecorder video port:
{
self
.
video_port
}
, ffmpeg_log_level:
{
self
.
ffmpeg_log_level
}
"
)
self
.
width
=
None
self
.
height
=
None
self
.
stoppable_t
=
None
self
.
realtime
=
True
# ffmpeg process for video data and push to livestream
self
.
ffmpeg_process
=
None
# TCP connection objects
self
.
video_socket
=
None
self
.
video_conn
=
None
self
.
video_thread
=
None
# queue for send data to ffmpeg process
self
.
video_queue
=
queue
.
Queue
()
def
init_sockets
(
self
):
# TCP socket for send and recv video data
self
.
video_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
video_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
self
.
video_socket
.
setsockopt
(
socket
.
IPPROTO_TCP
,
socket
.
TCP_NODELAY
,
1
)
self
.
video_socket
.
bind
((
"127.0.0.1"
,
self
.
video_port
))
self
.
video_socket
.
listen
(
1
)
def
video_worker
(
self
):
try
:
logger
.
info
(
"Waiting for ffmpeg to connect to video socket..."
)
self
.
video_conn
,
_
=
self
.
video_socket
.
accept
()
logger
.
info
(
f
"Video connection established from
{
self
.
video_conn
.
getpeername
()
}
"
)
fail_time
,
max_fail_time
=
0
,
10
packet_secs
=
1.0
/
self
.
fps
while
True
:
try
:
if
self
.
video_queue
is
None
:
break
data
=
self
.
video_queue
.
get
()
if
data
is
None
:
logger
.
info
(
"Video thread received stop signal"
)
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
for
i
in
range
(
data
.
shape
[
0
]):
t0
=
time
.
time
()
frame
=
(
data
[
i
]
*
255
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
try
:
self
.
video_conn
.
send
(
frame
.
tobytes
())
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
)
as
e
:
logger
.
info
(
f
"Video connection closed, stopping worker:
{
type
(
e
).
__name__
}
"
)
return
if
self
.
realtime
:
time
.
sleep
(
max
(
0
,
packet_secs
-
(
time
.
time
()
-
t0
)))
fail_time
=
0
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
):
logger
.
info
(
"Video connection closed during queue processing"
)
break
except
Exception
:
logger
.
error
(
f
"Send video data error:
{
traceback
.
format_exc
()
}
"
)
fail_time
+=
1
if
fail_time
>
max_fail_time
:
logger
.
error
(
f
"Video push worker thread failed
{
fail_time
}
times, stopping..."
)
break
except
Exception
:
logger
.
error
(
f
"Video push worker thread error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
info
(
"Video push worker thread stopped"
)
def
start_ffmpeg_process_local
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-fflags"
,
"nobuffer"
,
"-analyzeduration"
,
"0"
,
"-probesize"
,
"32"
,
"-flush_packets"
,
"1"
,
"-f"
,
"rawvideo"
,
"-pix_fmt"
,
"rgb24"
,
"-color_range"
,
"pc"
,
"-colorspace"
,
"rgb"
,
"-color_primaries"
,
"bt709"
,
"-color_trc"
,
"iec61966-2-1"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-b:v"
,
"4M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-f"
,
"mp4"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start_ffmpeg_process_rtmp
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-f"
,
"rawvideo"
,
"-re"
,
"-pix_fmt"
,
"rgb24"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-b:v"
,
"2M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-f"
,
"flv"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start_ffmpeg_process_whip
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-re"
,
"-fflags"
,
"nobuffer"
,
"-analyzeduration"
,
"0"
,
"-probesize"
,
"32"
,
"-flush_packets"
,
"1"
,
"-f"
,
"rawvideo"
,
"-re"
,
"-pix_fmt"
,
"rgb24"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-b:v"
,
"2M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-threads"
,
"1"
,
"-bf"
,
"0"
,
"-f"
,
"whip"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start
(
self
,
width
:
int
,
height
:
int
):
self
.
set_video_size
(
width
,
height
)
duration
=
1.0
self
.
pub_video
(
torch
.
zeros
((
int
(
self
.
fps
*
duration
),
height
,
width
,
3
),
dtype
=
torch
.
float16
))
time
.
sleep
(
duration
)
def
set_video_size
(
self
,
width
:
int
,
height
:
int
):
if
self
.
width
is
not
None
and
self
.
height
is
not
None
:
assert
self
.
width
==
width
and
self
.
height
==
height
,
"Video size already set"
return
self
.
width
=
width
self
.
height
=
height
self
.
init_sockets
()
if
self
.
livestream_url
.
startswith
(
"rtmp://"
):
self
.
start_ffmpeg_process_rtmp
()
elif
self
.
livestream_url
.
startswith
(
"http"
):
self
.
start_ffmpeg_process_whip
()
else
:
self
.
start_ffmpeg_process_local
()
self
.
realtime
=
False
self
.
video_thread
=
threading
.
Thread
(
target
=
self
.
video_worker
)
self
.
video_thread
.
start
()
# Publish ComfyUI Image tensor to livestream
def
pub_video
(
self
,
images
:
torch
.
Tensor
):
N
,
height
,
width
,
C
=
images
.
shape
assert
C
==
3
,
"Input must be [N, H, W, C] with C=3"
logger
.
info
(
f
"Publishing video [
{
N
}
x
{
width
}
x
{
height
}
]"
)
self
.
set_video_size
(
width
,
height
)
self
.
video_queue
.
put
(
images
)
logger
.
info
(
f
"Published
{
N
}
frames"
)
self
.
stoppable_t
=
time
.
time
()
+
N
/
self
.
fps
+
3
def
stop
(
self
,
wait
=
True
):
if
wait
and
self
.
stoppable_t
:
t
=
self
.
stoppable_t
-
time
.
time
()
if
t
>
0
:
logger
.
warning
(
f
"Waiting for
{
t
}
seconds to stop ..."
)
time
.
sleep
(
t
)
self
.
stoppable_t
=
None
# Send stop signals to queues
if
self
.
video_queue
:
self
.
video_queue
.
put
(
None
)
# Wait for threads to finish processing queued data (increased timeout)
queue_timeout
=
30
# Increased from 5s to 30s to allow sufficient time for large video frames
if
self
.
video_thread
and
self
.
video_thread
.
is_alive
():
self
.
video_thread
.
join
(
timeout
=
queue_timeout
)
if
self
.
video_thread
.
is_alive
():
logger
.
error
(
f
"Video push thread did not stop after
{
queue_timeout
}
s"
)
# Shutdown connections to signal EOF to FFmpeg
# shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
if
self
.
video_conn
:
try
:
self
.
video_conn
.
getpeername
()
self
.
video_conn
.
shutdown
(
socket
.
SHUT_WR
)
logger
.
info
(
"Video connection shutdown initiated"
)
except
OSError
:
# Connection already closed, skip shutdown
pass
if
self
.
ffmpeg_process
:
is_local_file
=
not
self
.
livestream_url
.
startswith
((
"rtmp://"
,
"http"
))
# Local MP4 files need time to write moov atom and finalize the container
timeout_seconds
=
30
if
is_local_file
else
10
logger
.
info
(
f
"Waiting for FFmpeg to finalize file (timeout=
{
timeout_seconds
}
s, local_file=
{
is_local_file
}
)"
)
logger
.
info
(
f
"FFmpeg output:
{
self
.
livestream_url
}
"
)
try
:
returncode
=
self
.
ffmpeg_process
.
wait
(
timeout
=
timeout_seconds
)
if
returncode
==
0
:
logger
.
info
(
f
"FFmpeg process exited successfully (exit code:
{
returncode
}
)"
)
else
:
logger
.
warning
(
f
"FFmpeg process exited with non-zero code:
{
returncode
}
"
)
except
subprocess
.
TimeoutExpired
:
logger
.
warning
(
f
"FFmpeg process did not exit within
{
timeout_seconds
}
s, sending SIGTERM..."
)
try
:
self
.
ffmpeg_process
.
terminate
()
# SIGTERM
returncode
=
self
.
ffmpeg_process
.
wait
(
timeout
=
5
)
logger
.
warning
(
f
"FFmpeg process terminated with SIGTERM (exit code:
{
returncode
}
)"
)
except
subprocess
.
TimeoutExpired
:
logger
.
error
(
"FFmpeg process still running after SIGTERM, killing with SIGKILL..."
)
self
.
ffmpeg_process
.
kill
()
self
.
ffmpeg_process
.
wait
()
# Wait for kill to complete
logger
.
error
(
"FFmpeg process killed with SIGKILL"
)
finally
:
self
.
ffmpeg_process
=
None
if
self
.
video_conn
:
try
:
self
.
video_conn
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing video connection:
{
e
}
"
)
finally
:
self
.
video_conn
=
None
if
self
.
video_socket
:
try
:
self
.
video_socket
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing video socket:
{
e
}
"
)
finally
:
self
.
video_socket
=
None
if
self
.
video_queue
:
while
self
.
video_queue
.
qsize
()
>
0
:
try
:
self
.
video_queue
.
get_nowait
()
except
:
# noqa
break
self
.
video_queue
=
None
logger
.
info
(
"VideoRecorder stopped and resources cleaned up"
)
def
__del__
(
self
):
self
.
stop
(
wait
=
False
)
def
create_simple_video
(
frames
=
10
,
height
=
480
,
width
=
640
):
video_data
=
[]
for
i
in
range
(
frames
):
frame
=
np
.
zeros
((
height
,
width
,
3
),
dtype
=
np
.
float32
)
stripe_height
=
height
//
8
colors
=
[
[
1.0
,
0.0
,
0.0
],
# 红色
[
0.0
,
1.0
,
0.0
],
# 绿色
[
0.0
,
0.0
,
1.0
],
# 蓝色
[
1.0
,
1.0
,
0.0
],
# 黄色
[
1.0
,
0.0
,
1.0
],
# 洋红
[
0.0
,
1.0
,
1.0
],
# 青色
[
1.0
,
1.0
,
1.0
],
# 白色
[
0.5
,
0.5
,
0.5
],
# 灰色
]
for
j
,
color
in
enumerate
(
colors
):
start_y
=
j
*
stripe_height
end_y
=
min
((
j
+
1
)
*
stripe_height
,
height
)
frame
[
start_y
:
end_y
,
:]
=
color
offset
=
int
((
i
/
frames
)
*
width
)
frame
=
np
.
roll
(
frame
,
offset
,
axis
=
1
)
frame
=
torch
.
tensor
(
frame
,
dtype
=
torch
.
float32
)
video_data
.
append
(
frame
)
return
torch
.
stack
(
video_data
,
dim
=
0
)
if
__name__
==
"__main__"
:
fps
=
16
width
=
640
height
=
480
recorder
=
VideoRecorder
(
# livestream_url="rtmp://localhost/live/test",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url
=
"/path/to/output_video.mp4"
,
fps
=
fps
,
)
secs
=
10
# 10秒视频
interval
=
1
for
i
in
range
(
0
,
secs
,
interval
):
logger
.
info
(
f
"
{
i
}
/
{
secs
}
s"
)
num_frames
=
int
(
interval
*
fps
)
images
=
create_simple_video
(
num_frames
,
height
,
width
)
logger
.
info
(
f
"images:
{
images
.
shape
}
{
images
.
dtype
}
{
images
.
min
()
}
{
images
.
max
()
}
"
)
recorder
.
pub_video
(
images
)
time
.
sleep
(
interval
)
recorder
.
stop
()
Prev
1
…
23
24
25
26
27
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment