Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
LightX2V
Commits
a1ebc651
Commit
a1ebc651
authored
Dec 11, 2025
by
xuwx1
Browse files
updata lightx2v
parent
5a4db490
Pipeline
#3149
canceled with stages
Changes
428
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5295 additions
and
0 deletions
+5295
-0
lightx2v/common/ops/norm/__init__.py
lightx2v/common/ops/norm/__init__.py
+2
-0
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+220
-0
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+204
-0
lightx2v/common/ops/norm/triton_ops.py
lightx2v/common/ops/norm/triton_ops.py
+900
-0
lightx2v/common/ops/tensor/__init__.py
lightx2v/common/ops/tensor/__init__.py
+1
-0
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+110
-0
lightx2v/common/transformer_infer/transformer_infer.py
lightx2v/common/transformer_infer/transformer_infer.py
+46
-0
lightx2v/deploy/__init__.py
lightx2v/deploy/__init__.py
+0
-0
lightx2v/deploy/common/__init__.py
lightx2v/deploy/common/__init__.py
+0
-0
lightx2v/deploy/common/aliyun.py
lightx2v/deploy/common/aliyun.py
+81
-0
lightx2v/deploy/common/audio_separator.py
lightx2v/deploy/common/audio_separator.py
+376
-0
lightx2v/deploy/common/face_detector.py
lightx2v/deploy/common/face_detector.py
+277
-0
lightx2v/deploy/common/pipeline.py
lightx2v/deploy/common/pipeline.py
+167
-0
lightx2v/deploy/common/podcasts.py
lightx2v/deploy/common/podcasts.py
+696
-0
lightx2v/deploy/common/utils.py
lightx2v/deploy/common/utils.py
+253
-0
lightx2v/deploy/common/va_controller.py
lightx2v/deploy/common/va_controller.py
+202
-0
lightx2v/deploy/common/va_reader.py
lightx2v/deploy/common/va_reader.py
+274
-0
lightx2v/deploy/common/va_reader_omni.py
lightx2v/deploy/common/va_reader_omni.py
+508
-0
lightx2v/deploy/common/va_recorder.py
lightx2v/deploy/common/va_recorder.py
+657
-0
lightx2v/deploy/common/va_recorder_x264.py
lightx2v/deploy/common/va_recorder_x264.py
+321
-0
No files found.
Too many changes to show.
To preserve performance only
428 of 428+
files are displayed.
Plain diff
Email patch
lightx2v/common/ops/norm/__init__.py
0 → 100644
View file @
a1ebc651
from
.layer_norm_weight
import
*
from
.rms_norm_weight
import
*
lightx2v/common/ops/norm/layer_norm_weight.py
0 → 100644
View file @
a1ebc651
import
os
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
pathlib
import
Path
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.triton_ops
import
norm_infer
class
LNWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
eps
=
eps
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
config
=
{}
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffers
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffers
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
and
self
.
weight_name
is
not
None
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_tensor
=
weight_dict
[
self
.
weight_name
]
self
.
pin_weight
=
self
.
_create_cpu_pin_tensor
(
weight_tensor
)
bias_tensor
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
self
.
pin_bias
=
self
.
_create_cpu_pin_tensor
(
bias_tensor
)
if
bias_tensor
is
not
None
else
None
self
.
bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
else
:
self
.
weight
=
None
self
.
bias
=
None
def
_get_tensor
(
self
,
name
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
name
is
None
:
return
None
if
self
.
lazy_load
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
name
]
return
tensor
def
_create_cpu_pin_tensor
(
self
,
tensor
):
if
tensor
is
None
:
return
None
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_cuda_buffers
(
self
,
weight_dict
):
weight_tensor
=
self
.
_get_tensor
(
self
.
weight_name
,
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
if
weight_tensor
is
not
None
:
self
.
weight_cuda_buffer
=
weight_tensor
.
to
(
AI_DEVICE
)
bias_tensor
=
self
.
_get_tensor
(
self
.
bias_name
,
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
if
bias_tensor
is
not
None
:
self
.
bias_cuda_buffer
=
bias_tensor
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
weight_tensor
=
self
.
_get_tensor
(
self
.
weight_name
,
use_infer_dtype
=
True
)
if
weight_tensor
is
not
None
:
self
.
pin_weight
=
self
.
_create_cpu_pin_tensor
(
weight_tensor
)
else
:
self
.
weight
=
None
bias_tensor
=
self
.
_get_tensor
(
self
.
bias_name
,
use_infer_dtype
=
True
)
if
bias_tensor
is
not
None
:
self
.
pin_bias
=
self
.
_create_cpu_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
)
and
self
.
pin_weight
is
not
None
:
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
else
:
self
.
weight
=
None
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
else
:
self
.
bias
=
None
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
)
and
self
.
pin_weight
is
not
None
:
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
copy_
(
self
.
bias
,
non_blocking
=
non_blocking
).
cpu
()
elif
hasattr
(
self
,
"weight"
)
and
self
.
weight
is
not
None
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
if
self
.
weight_name
is
not
None
:
destination
[
self
.
weight_name
]
=
self
.
pin_weight
if
hasattr
(
self
,
"pin_weight"
)
else
self
.
weight
if
self
.
bias_name
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
return
destination
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
weight_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
if
weight_name
not
in
destination
:
self
.
weight
=
None
return
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
else
:
self
.
weight
=
None
if
self
.
bias_name
is
not
None
:
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias
=
self
.
bias_cuda_buffer
.
copy_
(
destination
[
bias_name
],
non_blocking
=
True
)
else
:
self
.
bias
=
None
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
weight_name
is
not
None
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
@
LN_WEIGHT_REGISTER
(
"Default"
)
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
input_tensor
=
torch
.
nn
.
functional
.
layer_norm
(
input_tensor
.
float
(),
(
input_tensor
.
shape
[
-
1
],),
self
.
weight
,
self
.
bias
,
self
.
eps
,
).
to
(
self
.
infer_dtype
)
else
:
input_tensor
=
torch
.
nn
.
functional
.
layer_norm
(
input_tensor
,
(
input_tensor
.
shape
[
-
1
],),
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
input_tensor
@
LN_WEIGHT_REGISTER
(
"Triton"
)
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
input_tensor
=
norm_infer
(
input_tensor
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
input_tensor
lightx2v/common/ops/norm/rms_norm_weight.py
0 → 100644
View file @
a1ebc651
import
os
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
pathlib
import
Path
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
import
sgl_kernel
except
ImportError
:
sgl_kernel
=
None
class
RMSWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
self
.
weight_name
=
weight_name
self
.
eps
=
eps
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffer
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffer
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_tensor
=
weight_dict
[
self
.
weight_name
]
self
.
pin_weight
=
self
.
_create_cpu_pin_weight
(
weight_tensor
)
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
_get_weight_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
self
.
weight_name
]
return
tensor
def
_create_cpu_pin_weight
(
self
,
tensor
):
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_cuda_buffer
(
self
,
weight_dict
):
weight_tensor
=
self
.
_get_weight_tensor
(
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
self
.
weight_cuda_buffer
=
weight_tensor
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffer
(
self
):
weight_tensor
=
self
.
_get_weight_tensor
(
use_infer_dtype
=
True
)
self
.
pin_weight
=
self
.
_create_cpu_pin_weight
(
weight_tensor
)
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
pin_weight
if
hasattr
(
self
,
"pin_weight"
)
else
self
.
weight
return
destination
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
if
weight_name
not
in
destination
:
self
.
weight
=
None
return
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
@
RMS_WEIGHT_REGISTER
(
"Default"
)
class
RMSWeight
(
RMSWeightTemplate
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
apply
(
self
,
input_tensor
):
if
GET_SENSITIVE_DTYPE
()
!=
GET_DTYPE
():
input_tensor
=
self
.
_norm
(
input_tensor
).
type_as
(
input_tensor
)
*
self
.
weight
else
:
input_tensor
=
self
.
_norm
(
input_tensor
.
float
()).
type_as
(
input_tensor
)
*
self
.
weight
return
input_tensor
@
RMS_WEIGHT_REGISTER
(
"sgl-kernel"
)
class
RMSWeightSgl
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
,
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
if
sgl_kernel
is
not
None
and
self
.
sensitive_layer_dtype
==
self
.
infer_dtype
:
input_tensor
=
input_tensor
.
contiguous
()
orig_shape
=
input_tensor
.
shape
input_tensor
=
input_tensor
.
view
(
-
1
,
orig_shape
[
-
1
])
input_tensor
=
sgl_kernel
.
rmsnorm
(
input_tensor
,
self
.
weight
,
self
.
eps
).
view
(
orig_shape
)
else
:
# sgl_kernel is not available or dtype!=torch.bfloat16/float16, fallback to default implementation
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
float
().
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
).
to
(
self
.
infer_dtype
)
input_tensor
=
(
input_tensor
*
self
.
weight
).
to
(
self
.
infer_dtype
)
else
:
input_tensor
=
input_tensor
*
torch
.
rsqrt
(
input_tensor
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
input_tensor
=
input_tensor
*
self
.
weight
return
input_tensor
@
RMS_WEIGHT_REGISTER
(
"fp32_variance"
)
class
RMSWeightFP32
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
input_dtype
=
input_tensor
.
dtype
variance
=
input_tensor
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
input_tensor
*
torch
.
rsqrt
(
variance
+
self
.
eps
)
if
self
.
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
if
self
.
weight
is
not
None
:
hidden_states
=
hidden_states
*
self
.
weight
hidden_states
=
hidden_states
.
to
(
input_dtype
)
return
hidden_states
@
RMS_WEIGHT_REGISTER
(
"self_forcing"
)
class
RMSWeightSF
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
apply
(
self
,
x
):
return
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
*
self
.
weight
lightx2v/common/ops/norm/triton_ops.py
0 → 100644
View file @
a1ebc651
# Copied and adapted from: https://github.com/hao-ai-lab/FastVideo & https://github.com/sgl-project/sglang
# TODO: for temporary usage, expecting a refactor
from
typing
import
Optional
import
torch
import
triton
# type: ignore
import
triton.language
as
tl
# type: ignore
from
torch
import
Tensor
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_N"
:
64
},
num_warps
=
2
),
triton
.
Config
({
"BLOCK_N"
:
128
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_N"
:
256
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_N"
:
512
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_N"
:
1024
},
num_warps
=
8
),
],
key
=
[
"inner_dim"
],
)
@
triton
.
jit
def
_fused_scale_shift_4d_kernel
(
output_ptr
,
normalized_ptr
,
scale_ptr
,
shift_ptr
,
rows
,
inner_dim
,
seq_len
,
num_frames
,
frame_seqlen
,
BLOCK_N
:
tl
.
constexpr
,
):
pid_row
=
tl
.
program_id
(
0
)
pid_col
=
tl
.
program_id
(
1
)
col_offsets
=
pid_col
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
mask
=
col_offsets
<
inner_dim
# Pointers for normalized and output
row_base
=
pid_row
*
inner_dim
norm_ptrs
=
normalized_ptr
+
row_base
+
col_offsets
out_ptrs
=
output_ptr
+
row_base
+
col_offsets
# Pointers for scale and shift for 4D
b_idx
=
pid_row
//
seq_len
t_idx
=
pid_row
%
seq_len
frame_idx_in_batch
=
t_idx
//
frame_seqlen
scale_row_idx
=
b_idx
*
num_frames
+
frame_idx_in_batch
scale_ptrs
=
scale_ptr
+
scale_row_idx
*
inner_dim
+
col_offsets
shift_ptrs
=
shift_ptr
+
scale_row_idx
*
inner_dim
+
col_offsets
normalized
=
tl
.
load
(
norm_ptrs
,
mask
=
mask
,
other
=
0.0
)
scale
=
tl
.
load
(
scale_ptrs
,
mask
=
mask
,
other
=
0.0
)
shift
=
tl
.
load
(
shift_ptrs
,
mask
=
mask
,
other
=
0.0
)
one
=
tl
.
full
([
BLOCK_N
],
1.0
,
dtype
=
scale
.
dtype
)
output
=
normalized
*
(
one
+
scale
)
+
shift
tl
.
store
(
out_ptrs
,
output
,
mask
=
mask
)
@
triton
.
jit
def
fuse_scale_shift_kernel_blc_opt
(
x_ptr
,
shift_ptr
,
scale_ptr
,
y_ptr
,
B
,
L
,
C
,
stride_x_b
,
stride_x_l
,
stride_x_c
,
stride_s_b
,
stride_s_l
,
stride_s_c
,
stride_sc_b
,
stride_sc_l
,
stride_sc_c
,
SCALE_IS_SCALAR
:
tl
.
constexpr
,
SHIFT_IS_SCALAR
:
tl
.
constexpr
,
BLOCK_L
:
tl
.
constexpr
,
BLOCK_C
:
tl
.
constexpr
,
):
pid_l
=
tl
.
program_id
(
0
)
pid_c
=
tl
.
program_id
(
1
)
pid_b
=
tl
.
program_id
(
2
)
l_offsets
=
pid_l
*
BLOCK_L
+
tl
.
arange
(
0
,
BLOCK_L
)
c_offsets
=
pid_c
*
BLOCK_C
+
tl
.
arange
(
0
,
BLOCK_C
)
mask_l
=
l_offsets
<
L
mask_c
=
c_offsets
<
C
mask
=
mask_l
[:,
None
]
&
mask_c
[
None
,
:]
x_off
=
pid_b
*
stride_x_b
+
l_offsets
[:,
None
]
*
stride_x_l
+
c_offsets
[
None
,
:]
*
stride_x_c
x
=
tl
.
load
(
x_ptr
+
x_off
,
mask
=
mask
,
other
=
0
)
if
SHIFT_IS_SCALAR
:
shift_val
=
tl
.
load
(
shift_ptr
)
shift
=
tl
.
full
((
BLOCK_L
,
BLOCK_C
),
shift_val
,
dtype
=
shift_val
.
dtype
)
else
:
s_off
=
pid_b
*
stride_s_b
+
l_offsets
[:,
None
]
*
stride_s_l
+
c_offsets
[
None
,
:]
*
stride_s_c
shift
=
tl
.
load
(
shift_ptr
+
s_off
,
mask
=
mask
,
other
=
0
)
if
SCALE_IS_SCALAR
:
scale_val
=
tl
.
load
(
scale_ptr
)
scale
=
tl
.
full
((
BLOCK_L
,
BLOCK_C
),
scale_val
,
dtype
=
scale_val
.
dtype
)
else
:
sc_off
=
pid_b
*
stride_sc_b
+
l_offsets
[:,
None
]
*
stride_sc_l
+
c_offsets
[
None
,
:]
*
stride_sc_c
scale
=
tl
.
load
(
scale_ptr
+
sc_off
,
mask
=
mask
,
other
=
0
)
y
=
x
*
(
1
+
scale
)
+
shift
tl
.
store
(
y_ptr
+
x_off
,
y
,
mask
=
mask
)
def
fuse_scale_shift_kernel
(
x
:
torch
.
Tensor
,
scale
:
torch
.
Tensor
,
shift
:
torch
.
Tensor
,
block_l
:
int
=
128
,
block_c
:
int
=
128
,
):
assert
x
.
is_cuda
and
scale
.
is_cuda
assert
x
.
is_contiguous
()
B
,
L
,
C
=
x
.
shape
output
=
torch
.
empty_like
(
x
)
if
scale
.
dim
()
==
4
:
# scale/shift: [B, F, 1, C]
rows
=
B
*
L
x_2d
=
x
.
view
(
rows
,
C
)
output_2d
=
output
.
view
(
rows
,
C
)
grid
=
lambda
META
:
(
rows
,
triton
.
cdiv
(
C
,
META
[
"BLOCK_N"
]))
# noqa
num_frames
=
scale
.
shape
[
1
]
assert
L
%
num_frames
==
0
,
"seq_len must be divisible by num_frames for 4D scale/shift"
frame_seqlen
=
L
//
num_frames
# Compact [B, F, C] without the singleton dim into [B*F, C]
scale_reshaped
=
scale
.
squeeze
(
2
).
reshape
(
-
1
,
C
).
contiguous
()
shift_reshaped
=
shift
.
squeeze
(
2
).
reshape
(
-
1
,
C
).
contiguous
()
_fused_scale_shift_4d_kernel
[
grid
](
output_2d
,
x_2d
,
scale_reshaped
,
shift_reshaped
,
rows
,
C
,
L
,
num_frames
,
frame_seqlen
,
)
else
:
# 2D: [B, C] or [1, C] -> treat as [B, 1, C] and broadcast over L
# 3D: [B, L, C] (or broadcastable variants like [B, 1, C], [1, L, C], [1, 1, C])
# Also support scalar (0D or 1-element)
if
scale
.
dim
()
==
0
or
(
scale
.
dim
()
==
1
and
scale
.
numel
()
==
1
):
scale_blc
=
scale
.
reshape
(
1
)
elif
scale
.
dim
()
==
2
:
scale_blc
=
scale
[:,
None
,
:]
elif
scale
.
dim
()
==
3
:
scale_blc
=
scale
else
:
raise
ValueError
(
"scale must be 0D/1D(1)/2D/3D or 4D"
)
if
shift
.
dim
()
==
0
or
(
shift
.
dim
()
==
1
and
shift
.
numel
()
==
1
):
shift_blc
=
shift
.
reshape
(
1
)
elif
shift
.
dim
()
==
2
:
shift_blc
=
shift
[:,
None
,
:]
elif
shift
.
dim
()
==
3
:
shift_blc
=
shift
else
:
# broadcast later via expand if possible
shift_blc
=
shift
need_scale_scalar
=
scale_blc
.
dim
()
==
1
and
scale_blc
.
numel
()
==
1
need_shift_scalar
=
shift_blc
.
dim
()
==
1
and
shift_blc
.
numel
()
==
1
if
not
need_scale_scalar
:
scale_exp
=
scale_blc
.
expand
(
B
,
L
,
C
)
s_sb
,
s_sl
,
s_sc
=
scale_exp
.
stride
()
else
:
s_sb
=
s_sl
=
s_sc
=
0
if
not
need_shift_scalar
:
shift_exp
=
shift_blc
.
expand
(
B
,
L
,
C
)
sh_sb
,
sh_sl
,
sh_sc
=
shift_exp
.
stride
()
else
:
sh_sb
=
sh_sl
=
sh_sc
=
0
# If both scalars and both zero, copy fast-path
if
need_scale_scalar
and
need_shift_scalar
:
if
(
scale_blc
.
abs
().
max
()
==
0
)
and
(
shift_blc
.
abs
().
max
()
==
0
):
output
.
copy_
(
x
)
return
output
grid
=
(
triton
.
cdiv
(
L
,
block_l
),
triton
.
cdiv
(
C
,
block_c
),
B
)
fuse_scale_shift_kernel_blc_opt
[
grid
](
x
,
shift_blc
if
need_shift_scalar
else
shift_exp
,
scale_blc
if
need_scale_scalar
else
scale_exp
,
output
,
B
,
L
,
C
,
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
sh_sb
,
sh_sl
,
sh_sc
,
s_sb
,
s_sl
,
s_sc
,
SCALE_IS_SCALAR
=
need_scale_scalar
,
SHIFT_IS_SCALAR
=
need_shift_scalar
,
BLOCK_L
=
block_l
,
BLOCK_C
=
block_c
,
num_warps
=
4
,
num_stages
=
2
,
)
return
output
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
"BLOCK_HS_HALF"
:
32
},
num_warps
=
2
),
triton
.
Config
({
"BLOCK_HS_HALF"
:
64
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_HS_HALF"
:
128
},
num_warps
=
4
),
triton
.
Config
({
"BLOCK_HS_HALF"
:
256
},
num_warps
=
8
),
],
key
=
[
"head_size"
,
"interleaved"
],
)
@
triton
.
jit
def
_rotary_embedding_kernel
(
output_ptr
,
x_ptr
,
cos_ptr
,
sin_ptr
,
num_heads
,
head_size
,
num_tokens
,
stride_x_row
,
stride_cos_row
,
stride_sin_row
,
interleaved
:
tl
.
constexpr
,
BLOCK_HS_HALF
:
tl
.
constexpr
,
):
row_idx
=
tl
.
program_id
(
0
)
token_idx
=
(
row_idx
//
num_heads
)
%
num_tokens
x_row_ptr
=
x_ptr
+
row_idx
*
stride_x_row
cos_row_ptr
=
cos_ptr
+
token_idx
*
stride_cos_row
sin_row_ptr
=
sin_ptr
+
token_idx
*
stride_sin_row
output_row_ptr
=
output_ptr
+
row_idx
*
stride_x_row
# half size for x1 and x2
head_size_half
=
head_size
//
2
for
block_start
in
range
(
0
,
head_size_half
,
BLOCK_HS_HALF
):
offsets_half
=
block_start
+
tl
.
arange
(
0
,
BLOCK_HS_HALF
)
mask
=
offsets_half
<
head_size_half
cos_vals
=
tl
.
load
(
cos_row_ptr
+
offsets_half
,
mask
=
mask
,
other
=
0.0
)
sin_vals
=
tl
.
load
(
sin_row_ptr
+
offsets_half
,
mask
=
mask
,
other
=
0.0
)
offsets_x1
=
2
*
offsets_half
offsets_x2
=
2
*
offsets_half
+
1
x1_vals
=
tl
.
load
(
x_row_ptr
+
offsets_x1
,
mask
=
mask
,
other
=
0.0
)
x2_vals
=
tl
.
load
(
x_row_ptr
+
offsets_x2
,
mask
=
mask
,
other
=
0.0
)
x1_fp32
=
x1_vals
.
to
(
tl
.
float32
)
x2_fp32
=
x2_vals
.
to
(
tl
.
float32
)
cos_fp32
=
cos_vals
.
to
(
tl
.
float32
)
sin_fp32
=
sin_vals
.
to
(
tl
.
float32
)
o1_vals
=
tl
.
fma
(
-
x2_fp32
,
sin_fp32
,
x1_fp32
*
cos_fp32
)
o2_vals
=
tl
.
fma
(
x1_fp32
,
sin_fp32
,
x2_fp32
*
cos_fp32
)
tl
.
store
(
output_row_ptr
+
offsets_x1
,
o1_vals
.
to
(
x1_vals
.
dtype
),
mask
=
mask
)
tl
.
store
(
output_row_ptr
+
offsets_x2
,
o2_vals
.
to
(
x2_vals
.
dtype
),
mask
=
mask
)
def
apply_rotary_embedding
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
interleaved
:
bool
=
False
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
x
)
if
x
.
dim
()
>
3
:
bsz
,
num_tokens
,
num_heads
,
head_size
=
x
.
shape
else
:
num_tokens
,
num_heads
,
head_size
=
x
.
shape
bsz
=
1
assert
head_size
%
2
==
0
,
"head_size must be divisible by 2"
x_reshaped
=
x
.
view
(
-
1
,
head_size
)
output_reshaped
=
output
.
view
(
-
1
,
head_size
)
# num_tokens per head, 1 token per block
grid
=
(
bsz
*
num_tokens
*
num_heads
,)
if
interleaved
and
cos
.
shape
[
-
1
]
==
head_size
:
cos
=
cos
[...,
::
2
].
contiguous
()
sin
=
sin
[...,
::
2
].
contiguous
()
else
:
cos
=
cos
.
contiguous
()
sin
=
sin
.
contiguous
()
_rotary_embedding_kernel
[
grid
](
output_reshaped
,
x_reshaped
,
cos
,
sin
,
num_heads
,
head_size
,
num_tokens
,
x_reshaped
.
stride
(
0
),
cos
.
stride
(
0
),
sin
.
stride
(
0
),
interleaved
,
)
return
output
# RMSNorm-fp32
def
maybe_contiguous_lastdim
(
x
):
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
def
maybe_contiguous
(
x
):
return
x
.
contiguous
()
if
x
is
not
None
else
None
def
triton_autotune_configs
():
if
not
torch
.
cuda
.
is_available
():
return
[]
# Return configs with a valid warp count for the current device
configs
=
[]
# Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
max_threads_per_block
=
1024
# Default to warp size 32 if not defined by device
warp_size
=
getattr
(
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()),
"warp_size"
,
32
)
# Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
return
[
triton
.
Config
({},
num_warps
=
warp_count
)
for
warp_count
in
[
1
,
2
,
4
,
8
,
16
,
32
]
if
warp_count
*
warp_size
<=
max_threads_per_block
]
# return [triton.Config({}, num_warps=8)]
# Copied from flash-attn
@
triton
.
autotune
(
configs
=
triton_autotune_configs
(),
key
=
[
"N"
,
"HAS_RESIDUAL"
,
"STORE_RESIDUAL_OUT"
,
"IS_RMS_NORM"
,
"HAS_BIAS"
,
"HAS_WEIGHT"
,
"HAS_X1"
,
"HAS_W1"
,
"HAS_B1"
,
],
)
# torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
# @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
# @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
# @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
# @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
# @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
@
triton
.
jit
def
_layer_norm_fwd_1pass_kernel
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
B
,
# pointer to the biases
RESIDUAL
,
# pointer to the residual
X1
,
W1
,
B1
,
Y1
,
RESIDUAL_OUT
,
# pointer to the residual
ROWSCALE
,
SEEDS
,
# Dropout seeds for each row
DROPOUT_MASK
,
DROPOUT_MASK1
,
Mean
,
# pointer to the mean
Rstd
,
# pointer to the 1/std
stride_x_row
,
# how much to increase the pointer when moving by 1 row
stride_y_row
,
stride_res_row
,
stride_res_out_row
,
stride_x1_row
,
stride_y1_row
,
M
,
# number of rows in X
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
dropout_p
,
# Dropout probability
zero_centered_weight
,
# If true, add 1.0 to the weight
IS_RMS_NORM
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
HAS_RESIDUAL
:
tl
.
constexpr
,
STORE_RESIDUAL_OUT
:
tl
.
constexpr
,
HAS_WEIGHT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
HAS_DROPOUT
:
tl
.
constexpr
,
STORE_DROPOUT_MASK
:
tl
.
constexpr
,
HAS_ROWSCALE
:
tl
.
constexpr
,
HAS_X1
:
tl
.
constexpr
,
HAS_W1
:
tl
.
constexpr
,
HAS_B1
:
tl
.
constexpr
,
):
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
X
+=
row
*
stride_x_row
Y
+=
row
*
stride_y_row
if
HAS_RESIDUAL
:
RESIDUAL
+=
row
*
stride_res_row
if
STORE_RESIDUAL_OUT
:
RESIDUAL_OUT
+=
row
*
stride_res_out_row
if
HAS_X1
:
X1
+=
row
*
stride_x1_row
if
HAS_W1
:
Y1
+=
row
*
stride_y1_row
# Compute mean and variance
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
row
).
to
(
tl
.
float32
)
x
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
x
=
tl
.
where
(
keep_mask
,
x
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK
+
row
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
if
HAS_X1
:
x1
=
tl
.
load
(
X1
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
HAS_ROWSCALE
:
rowscale
=
tl
.
load
(
ROWSCALE
+
M
+
row
).
to
(
tl
.
float32
)
x1
*=
rowscale
if
HAS_DROPOUT
:
# Compute dropout mask
# 7 rounds is good enough, and reduces register pressure
keep_mask
=
tl
.
rand
(
tl
.
load
(
SEEDS
+
M
+
row
).
to
(
tl
.
uint32
),
cols
,
n_rounds
=
7
)
>
dropout_p
x1
=
tl
.
where
(
keep_mask
,
x1
/
(
1.0
-
dropout_p
),
0.0
)
if
STORE_DROPOUT_MASK
:
tl
.
store
(
DROPOUT_MASK1
+
row
*
N
+
cols
,
keep_mask
,
mask
=
cols
<
N
)
x
+=
x1
if
HAS_RESIDUAL
:
residual
=
tl
.
load
(
RESIDUAL
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
x
+=
residual
if
STORE_RESIDUAL_OUT
:
tl
.
store
(
RESIDUAL_OUT
+
cols
,
x
,
mask
=
cols
<
N
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
tl
.
store
(
Mean
+
row
,
mean
)
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
tl
.
store
(
Rstd
+
row
,
rstd
)
# Normalize and apply linear transformation
mask
=
cols
<
N
if
HAS_WEIGHT
:
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
zero_centered_weight
:
w
+=
1.0
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
if
HAS_WEIGHT
:
y
=
x_hat
*
w
+
b
if
HAS_BIAS
else
x_hat
*
w
else
:
y
=
x_hat
+
b
if
HAS_BIAS
else
x_hat
# Write output
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
if
HAS_W1
:
w1
=
tl
.
load
(
W1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
if
zero_centered_weight
:
w1
+=
1.0
if
HAS_B1
:
b1
=
tl
.
load
(
B1
+
cols
,
mask
=
mask
).
to
(
tl
.
float32
)
y1
=
x_hat
*
w1
+
b1
if
HAS_B1
else
x_hat
*
w1
tl
.
store
(
Y1
+
cols
,
y1
,
mask
=
mask
)
def
_layer_norm_fwd
(
x
:
Tensor
,
weight
:
Tensor
,
bias
:
Tensor
,
eps
:
float
,
residual
:
Optional
[
Tensor
]
=
None
,
x1
:
Optional
[
Tensor
]
=
None
,
weight1
:
Optional
[
Tensor
]
=
None
,
bias1
:
Optional
[
Tensor
]
=
None
,
dropout_p
:
float
=
0.0
,
rowscale
:
Optional
[
Tensor
]
=
None
,
out_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
residual_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
zero_centered_weight
:
bool
=
False
,
is_rms_norm
:
bool
=
False
,
return_dropout_mask
:
bool
=
False
,
out
:
Optional
[
Tensor
]
=
None
,
residual_out
:
Optional
[
Tensor
]
=
None
,
)
->
(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
):
# Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
# and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
# so that _layer_norm_fwd_impl doesn't have to return them.
if
out
is
None
:
out
=
torch
.
empty_like
(
x
,
dtype
=
x
.
dtype
if
out_dtype
is
None
else
out_dtype
)
if
residual
is
not
None
:
residual_dtype
=
residual
.
dtype
if
residual_out
is
None
and
(
residual
is
not
None
or
(
residual_dtype
is
not
None
and
residual_dtype
!=
x
.
dtype
)
or
dropout_p
>
0.0
or
rowscale
is
not
None
or
x1
is
not
None
):
residual_out
=
torch
.
empty_like
(
x
,
dtype
=
residual_dtype
if
residual_dtype
is
not
None
else
x
.
dtype
)
else
:
residual_out
=
None
y1
,
mean
,
rstd
,
seeds
,
dropout_mask
,
dropout_mask1
=
_layer_norm_fwd_impl
(
x
,
weight
,
bias
,
eps
,
out
,
residual
=
residual
,
x1
=
x1
,
weight1
=
weight1
,
bias1
=
bias1
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
zero_centered_weight
=
zero_centered_weight
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
return_dropout_mask
,
residual_out
=
residual_out
,
)
# residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
if
residual_out
is
None
:
residual_out
=
x
return
out
,
y1
,
mean
,
rstd
,
residual_out
,
seeds
,
dropout_mask
,
dropout_mask1
# [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
# since we're returning a tuple of tensors
def
_layer_norm_fwd_impl
(
x
:
Tensor
,
weight
:
Optional
[
Tensor
],
bias
:
Tensor
,
eps
:
float
,
out
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
x1
:
Optional
[
Tensor
]
=
None
,
weight1
:
Optional
[
Tensor
]
=
None
,
bias1
:
Optional
[
Tensor
]
=
None
,
dropout_p
:
float
=
0.0
,
rowscale
:
Optional
[
Tensor
]
=
None
,
zero_centered_weight
:
bool
=
False
,
is_rms_norm
:
bool
=
False
,
return_dropout_mask
:
bool
=
False
,
residual_out
:
Optional
[
Tensor
]
=
None
,
)
->
(
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
,
Tensor
):
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
if
residual
is
not
None
:
assert
residual
.
stride
(
-
1
)
==
1
assert
residual
.
shape
==
(
M
,
N
)
if
weight
is
not
None
:
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
shape
==
(
N
,)
if
x1
is
not
None
:
assert
x1
.
shape
==
x
.
shape
assert
rowscale
is
None
assert
x1
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
assert
weight1
.
shape
==
(
N
,)
assert
weight1
.
stride
(
-
1
)
==
1
if
bias1
is
not
None
:
assert
bias1
.
shape
==
(
N
,)
assert
bias1
.
stride
(
-
1
)
==
1
if
rowscale
is
not
None
:
assert
rowscale
.
is_contiguous
()
assert
rowscale
.
shape
==
(
M
,)
assert
out
.
shape
==
x
.
shape
assert
out
.
stride
(
-
1
)
==
1
if
residual_out
is
not
None
:
assert
residual_out
.
shape
==
x
.
shape
assert
residual_out
.
stride
(
-
1
)
==
1
if
weight1
is
not
None
:
y1
=
torch
.
empty_like
(
out
)
assert
y1
.
stride
(
-
1
)
==
1
else
:
y1
=
None
mean
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
not
is_rms_norm
else
None
rstd
=
torch
.
empty
((
M
,),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
if
dropout_p
>
0.0
:
seeds
=
torch
.
randint
(
2
**
32
,
(
M
if
x1
is
None
else
2
*
M
,),
device
=
x
.
device
,
dtype
=
torch
.
int64
)
else
:
seeds
=
None
if
return_dropout_mask
and
dropout_p
>
0.0
:
dropout_mask
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
torch
.
bool
)
if
x1
is
not
None
:
dropout_mask1
=
torch
.
empty
(
M
,
N
,
device
=
x
.
device
,
dtype
=
torch
.
bool
)
else
:
dropout_mask1
=
None
else
:
dropout_mask
,
dropout_mask1
=
None
,
None
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
with
torch
.
cuda
.
device
(
x
.
device
.
index
):
torch
.
library
.
wrap_triton
(
_layer_norm_fwd_1pass_kernel
)[(
M
,)](
x
,
out
,
weight
if
weight
is
not
None
else
x
,
# unused when HAS_WEIGHT == False
bias
,
residual
,
x1
,
weight1
,
bias1
,
y1
,
residual_out
,
rowscale
,
seeds
,
dropout_mask
,
dropout_mask1
,
mean
,
rstd
,
x
.
stride
(
0
),
out
.
stride
(
0
),
residual
.
stride
(
0
)
if
residual
is
not
None
else
0
,
residual_out
.
stride
(
0
)
if
residual_out
is
not
None
else
0
,
x1
.
stride
(
0
)
if
x1
is
not
None
else
0
,
y1
.
stride
(
0
)
if
y1
is
not
None
else
0
,
M
,
N
,
eps
,
dropout_p
,
# Passing bool make torch inductor very unhappy since it then tries to compare to int_max
int
(
zero_centered_weight
),
is_rms_norm
,
BLOCK_N
,
residual
is
not
None
,
residual_out
is
not
None
,
weight
is
not
None
,
bias
is
not
None
,
dropout_p
>
0.0
,
dropout_mask
is
not
None
,
rowscale
is
not
None
,
HAS_X1
=
x1
is
not
None
,
HAS_W1
=
weight1
is
not
None
,
HAS_B1
=
bias1
is
not
None
,
)
return
y1
,
mean
,
rstd
,
seeds
,
dropout_mask
,
dropout_mask1
class
LayerNormFn
:
@
staticmethod
def
forward
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
zero_centered_weight
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
out_dtype
=
None
,
out
=
None
,
residual_out
=
None
,
):
x_shape_og
=
x
.
shape
# reshape input data into 2D tensor
x
=
maybe_contiguous_lastdim
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]))
if
residual
is
not
None
:
assert
residual
.
shape
==
x_shape_og
residual
=
maybe_contiguous_lastdim
(
residual
.
reshape
(
-
1
,
residual
.
shape
[
-
1
]))
if
x1
is
not
None
:
assert
x1
.
shape
==
x_shape_og
assert
rowscale
is
None
,
"rowscale is not supported with parallel LayerNorm"
x1
=
maybe_contiguous_lastdim
(
x1
.
reshape
(
-
1
,
x1
.
shape
[
-
1
]))
# weight can be None when elementwise_affine=False for LayerNorm
if
weight
is
not
None
:
weight
=
weight
.
contiguous
()
bias
=
maybe_contiguous
(
bias
)
weight1
=
maybe_contiguous
(
weight1
)
bias1
=
maybe_contiguous
(
bias1
)
if
rowscale
is
not
None
:
rowscale
=
rowscale
.
reshape
(
-
1
).
contiguous
()
residual_dtype
=
residual
.
dtype
if
residual
is
not
None
else
(
torch
.
float32
if
residual_in_fp32
else
None
)
if
out
is
not
None
:
out
=
out
.
reshape
(
-
1
,
out
.
shape
[
-
1
])
if
residual_out
is
not
None
:
residual_out
=
residual_out
.
reshape
(
-
1
,
residual_out
.
shape
[
-
1
])
y
,
y1
,
mean
,
rstd
,
residual_out
,
seeds
,
dropout_mask
,
dropout_mask1
=
_layer_norm_fwd
(
x
,
weight
,
bias
,
eps
,
residual
,
x1
,
weight1
,
bias1
,
dropout_p
=
dropout_p
,
rowscale
=
rowscale
,
out_dtype
=
out_dtype
,
residual_dtype
=
residual_dtype
,
zero_centered_weight
=
zero_centered_weight
,
is_rms_norm
=
is_rms_norm
,
return_dropout_mask
=
return_dropout_mask
,
out
=
out
,
residual_out
=
residual_out
,
)
y
=
y
.
reshape
(
x_shape_og
)
return
y
def
layer_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
zero_centered_weight
=
False
,
is_rms_norm
=
False
,
return_dropout_mask
=
False
,
out_dtype
=
None
,
out
=
None
,
residual_out
=
None
,
):
return
LayerNormFn
.
forward
(
x
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
prenorm
,
residual_in_fp32
,
zero_centered_weight
,
is_rms_norm
,
return_dropout_mask
,
out_dtype
,
out
,
residual_out
,
)
@
triton
.
jit
def
_norm_infer_kernel
(
X
,
Y
,
W
,
B
,
stride_x_row
,
stride_y_row
,
M
,
N
,
eps
,
IS_RMS_NORM
:
tl
.
constexpr
,
HAS_WEIGHT
:
tl
.
constexpr
,
HAS_BIAS
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
row
=
tl
.
program_id
(
0
)
X
+=
row
*
stride_x_row
Y
+=
row
*
stride_y_row
if
HAS_WEIGHT
:
W
+=
0
if
HAS_BIAS
:
B
+=
0
cols
=
tl
.
arange
(
0
,
BLOCK_N
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
if
not
IS_RMS_NORM
:
mean
=
tl
.
sum
(
x
,
axis
=
0
)
/
N
xbar
=
tl
.
where
(
cols
<
N
,
x
-
mean
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
else
:
xbar
=
tl
.
where
(
cols
<
N
,
x
,
0.0
)
var
=
tl
.
sum
(
xbar
*
xbar
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
x_hat
=
(
x
-
mean
)
*
rstd
if
not
IS_RMS_NORM
else
x
*
rstd
if
HAS_WEIGHT
:
w
=
tl
.
load
(
W
+
cols
,
mask
=
cols
<
N
,
other
=
1.0
).
to
(
tl
.
float32
)
y
=
x_hat
*
w
else
:
y
=
x_hat
if
HAS_BIAS
:
b
=
tl
.
load
(
B
+
cols
,
mask
=
cols
<
N
,
other
=
0.0
).
to
(
tl
.
float32
)
y
+=
b
tl
.
store
(
Y
+
cols
,
y
,
mask
=
cols
<
N
)
def
norm_infer
(
x
:
Tensor
,
weight
:
Optional
[
Tensor
],
bias
:
Optional
[
Tensor
],
eps
:
float
,
is_rms_norm
:
bool
=
False
,
out
:
Optional
[
Tensor
]
=
None
,
):
M
,
N
=
x
.
shape
assert
x
.
stride
(
-
1
)
==
1
if
weight
is
not
None
:
assert
weight
.
shape
==
(
N
,)
assert
weight
.
stride
(
-
1
)
==
1
if
bias
is
not
None
:
assert
bias
.
shape
==
(
N
,)
assert
bias
.
stride
(
-
1
)
==
1
if
out
is
None
:
out
=
torch
.
empty_like
(
x
)
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_N
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
N
))
if
N
>
BLOCK_N
:
raise
RuntimeError
(
"This layer norm doesn't support feature dim >= 64KB."
)
num_warps
=
min
(
max
(
BLOCK_N
//
256
,
1
),
8
)
_norm_infer_kernel
[(
M
,)](
x
,
out
,
weight
if
weight
is
not
None
else
x
,
# dummy when HAS_WEIGHT=False
bias
if
bias
is
not
None
else
x
,
# dummy when HAS_BIAS=False
x
.
stride
(
0
),
out
.
stride
(
0
),
M
,
N
,
eps
,
IS_RMS_NORM
=
is_rms_norm
,
HAS_WEIGHT
=
weight
is
not
None
,
HAS_BIAS
=
bias
is
not
None
,
BLOCK_N
=
BLOCK_N
,
num_warps
=
num_warps
,
)
return
out
def
rms_norm_fn
(
x
,
weight
,
bias
,
residual
=
None
,
x1
=
None
,
weight1
=
None
,
bias1
=
None
,
eps
=
1e-6
,
dropout_p
=
0.0
,
rowscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
zero_centered_weight
=
False
,
return_dropout_mask
=
False
,
out_dtype
=
None
,
out
=
None
,
residual_out
=
None
,
):
return
LayerNormFn
.
forward
(
x
,
weight
,
bias
,
residual
,
x1
,
weight1
,
bias1
,
eps
,
dropout_p
,
rowscale
,
prenorm
,
residual_in_fp32
,
zero_centered_weight
,
True
,
return_dropout_mask
,
out_dtype
,
out
,
residual_out
,
)
lightx2v/common/ops/tensor/__init__.py
0 → 100644
View file @
a1ebc651
from
.tensor
import
DefaultTensor
lightx2v/common/ops/tensor/tensor.py
0 → 100644
View file @
a1ebc651
import
os
import
re
from
pathlib
import
Path
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
TENSOR_REGISTER
(
"Default"
)
class
DefaultTensor
:
def
__init__
(
self
,
tensor_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
tensor_name
=
tensor_name
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffer
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffer
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
device
=
weight_dict
[
self
.
tensor_name
].
device
if
device
.
type
==
"cpu"
:
tensor
=
weight_dict
[
self
.
tensor_name
]
self
.
pin_tensor
=
self
.
_create_cpu_pin_tensor
(
tensor
)
del
weight_dict
[
self
.
tensor_name
]
else
:
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
def
_get_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
tensor_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
self
.
tensor_name
]
return
tensor
def
_create_cpu_pin_tensor
(
self
,
tensor
):
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_cuda_buffer
(
self
,
weight_dict
):
tensor
=
self
.
_get_tensor
(
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
self
.
tensor_cuda_buffer
=
tensor
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffer
(
self
):
tensor
=
self
.
_get_tensor
(
use_infer_dtype
=
True
)
self
.
pin_tensor
=
self
.
_create_cpu_pin_tensor
(
tensor
)
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
tensor
=
self
.
pin_tensor
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_tensor"
):
self
.
tensor
=
self
.
pin_tensor
.
copy_
(
self
.
tensor
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
tensor
=
self
.
tensor
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
tensor_name
]
=
self
.
pin_tensor
if
hasattr
(
self
,
"pin_tensor"
)
else
self
.
tensor
return
destination
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
if
tensor_name
not
in
destination
:
self
.
tensor
=
None
return
self
.
tensor
=
self
.
tensor_cuda_buffer
.
copy_
(
destination
[
tensor_name
],
non_blocking
=
True
)
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
if
Path
(
self
.
lazy_load_file
).
is_file
():
lazy_load_file_path
=
self
.
lazy_load_file
else
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
del
tensor
lightx2v/common/transformer_infer/transformer_infer.py
0 → 100644
View file @
a1ebc651
import
math
from
abc
import
ABC
,
abstractmethod
class
BaseTransformerInfer
(
ABC
):
@
abstractmethod
def
infer
(
self
):
pass
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
.
transformer_infer
=
self
class
BaseTaylorCachingTransformerInfer
(
BaseTransformerInfer
):
@
abstractmethod
def
infer_calculating
(
self
):
pass
@
abstractmethod
def
infer_using_cache
(
self
):
pass
@
abstractmethod
def
get_taylor_step_diff
(
self
):
pass
# 1. when fully calcualted, stored in cache
def
derivative_approximation
(
self
,
block_cache
,
module_name
,
out
):
if
module_name
not
in
block_cache
:
block_cache
[
module_name
]
=
{
0
:
out
}
else
:
step_diff
=
self
.
get_taylor_step_diff
()
previous_out
=
block_cache
[
module_name
][
0
]
block_cache
[
module_name
][
0
]
=
out
block_cache
[
module_name
][
1
]
=
(
out
-
previous_out
)
/
step_diff
def
taylor_formula
(
self
,
tensor_dict
):
x
=
self
.
get_taylor_step_diff
()
output
=
0
for
i
in
range
(
len
(
tensor_dict
)):
output
+=
(
1
/
math
.
factorial
(
i
))
*
tensor_dict
[
i
]
*
(
x
**
i
)
return
output
lightx2v/deploy/__init__.py
0 → 100644
View file @
a1ebc651
lightx2v/deploy/common/__init__.py
0 → 100644
View file @
a1ebc651
lightx2v/deploy/common/aliyun.py
0 → 100644
View file @
a1ebc651
import
asyncio
import
json
import
os
import
sys
from
alibabacloud_dypnsapi20170525
import
models
as
dypnsapi_models
from
alibabacloud_dypnsapi20170525.client
import
Client
from
alibabacloud_tea_openapi
import
models
as
openapi_models
from
alibabacloud_tea_util
import
models
as
util_models
from
loguru
import
logger
class
AlibabaCloudClient
:
def
__init__
(
self
):
config
=
openapi_models
.
Config
(
access_key_id
=
os
.
getenv
(
"ALIBABA_CLOUD_ACCESS_KEY_ID"
),
access_key_secret
=
os
.
getenv
(
"ALIBABA_CLOUD_ACCESS_KEY_SECRET"
),
https_proxy
=
os
.
getenv
(
"auth_https_proxy"
,
None
),
)
self
.
client
=
Client
(
config
)
self
.
runtime
=
util_models
.
RuntimeOptions
()
def
check_ok
(
self
,
res
,
prefix
):
logger
.
info
(
f
"
{
prefix
}
:
{
res
}
"
)
if
not
isinstance
(
res
,
dict
)
or
"statusCode"
not
in
res
or
res
[
"statusCode"
]
!=
200
:
logger
.
warning
(
f
"
{
prefix
}
: error response:
{
res
}
"
)
return
False
if
"body"
not
in
res
or
"Code"
not
in
res
[
"body"
]
or
"Success"
not
in
res
[
"body"
]:
logger
.
warning
(
f
"
{
prefix
}
: error body:
{
res
}
"
)
return
False
if
res
[
"body"
][
"Code"
]
!=
"OK"
or
res
[
"body"
][
"Success"
]
is
not
True
:
logger
.
warning
(
f
"
{
prefix
}
: sms error:
{
res
}
"
)
return
False
return
True
async
def
send_sms
(
self
,
phone_number
):
try
:
req
=
dypnsapi_models
.
SendSmsVerifyCodeRequest
(
phone_number
=
phone_number
,
sign_name
=
"速通互联验证服务"
,
template_code
=
"100001"
,
template_param
=
json
.
dumps
({
"code"
:
"##code##"
,
"min"
:
"5"
}),
valid_time
=
300
,
)
res
=
await
self
.
client
.
send_sms_verify_code_with_options_async
(
req
,
self
.
runtime
)
ok
=
self
.
check_ok
(
res
.
to_map
(),
"AlibabaCloudClient send sms"
)
logger
.
info
(
f
"AlibabaCloudClient send sms for
{
phone_number
}
:
{
ok
}
"
)
return
ok
except
Exception
as
e
:
logger
.
warning
(
f
"AlibabaCloudClient send sms for
{
phone_number
}
:
{
e
}
"
)
return
False
async
def
check_sms
(
self
,
phone_number
,
verify_code
):
try
:
req
=
dypnsapi_models
.
CheckSmsVerifyCodeRequest
(
phone_number
=
phone_number
,
verify_code
=
verify_code
,
)
res
=
await
self
.
client
.
check_sms_verify_code_with_options_async
(
req
,
self
.
runtime
)
ok
=
self
.
check_ok
(
res
.
to_map
(),
"AlibabaCloudClient check sms"
)
logger
.
info
(
f
"AlibabaCloudClient check sms for
{
phone_number
}
with
{
verify_code
}
:
{
ok
}
"
)
return
ok
except
Exception
as
e
:
logger
.
warning
(
f
"AlibabaCloudClient check sms for
{
phone_number
}
with
{
verify_code
}
:
{
e
}
"
)
return
False
async
def
test
(
args
):
assert
len
(
args
)
in
[
1
,
2
],
"Usage: python aliyun_sms.py <phone_number> [verify_code]"
phone_number
=
args
[
0
]
client
=
AlibabaCloudClient
()
if
len
(
args
)
==
1
:
await
client
.
send_sms
(
phone_number
)
else
:
await
client
.
check_sms
(
phone_number
,
args
[
1
])
if
__name__
==
"__main__"
:
asyncio
.
run
(
test
(
sys
.
argv
[
1
:]))
lightx2v/deploy/common/audio_separator.py
0 → 100644
View file @
a1ebc651
# -*- coding: utf-8 -*-
"""
Audio Source Separation Module
Separates different voice tracks in audio, supports multi-person audio separation
"""
import
base64
import
io
import
os
import
tempfile
import
traceback
from
collections
import
defaultdict
from
typing
import
Dict
,
Optional
,
Union
import
torch
import
torchaudio
from
loguru
import
logger
# Import pyannote.audio for speaker diarization
from
pyannote.audio
import
Audio
,
Pipeline
_origin_torch_load
=
torch
.
load
def
our_torch_load
(
checkpoint_file
,
*
args
,
**
kwargs
):
kwargs
[
"weights_only"
]
=
False
return
_origin_torch_load
(
checkpoint_file
,
*
args
,
**
kwargs
)
class
AudioSeparator
:
"""
Audio separator for separating different voice tracks in audio using pyannote.audio
Supports multi-person conversation separation, maintains duration (other speakers' tracks are empty)
"""
def
__init__
(
self
,
model_path
:
str
=
None
,
device
:
str
=
None
,
sample_rate
:
int
=
16000
,
):
"""
Initialize audio separator
Args:
model_path: Model path (if using custom model), default uses pyannote/speaker-diarization-community-1
device: Device ('cpu', 'cuda', etc.), None for auto selection
sample_rate: Target sample rate, default 16000
"""
self
.
sample_rate
=
sample_rate
self
.
device
=
device
if
device
else
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
self
.
_init_pyannote
(
model_path
)
def
_init_pyannote
(
self
,
model_path
:
str
=
None
):
"""Initialize pyannote.audio pipeline"""
try
:
huggingface_token
=
os
.
getenv
(
"HUGGINGFACE_TOKEN"
)
or
os
.
getenv
(
"HF_TOKEN"
)
model_name
=
model_path
or
"pyannote/speaker-diarization-community-1"
try
:
torch
.
load
=
our_torch_load
# Try loading with token if available
if
huggingface_token
:
self
.
pipeline
=
Pipeline
.
from_pretrained
(
model_name
,
token
=
huggingface_token
)
else
:
# Try without token (may work for public models)
self
.
pipeline
=
Pipeline
.
from_pretrained
(
model_name
)
except
Exception
as
e
:
if
"gated"
in
str
(
e
).
lower
()
or
"token"
in
str
(
e
).
lower
():
raise
RuntimeError
(
f
"Model requires authentication. Set HUGGINGFACE_TOKEN or HF_TOKEN environment variable:
{
e
}
"
)
raise
RuntimeError
(
f
"Failed to load pyannote model:
{
e
}
"
)
finally
:
torch
.
load
=
_origin_torch_load
# Move pipeline to specified device
if
self
.
device
:
self
.
pipeline
.
to
(
torch
.
device
(
self
.
device
))
# Initialize Audio helper for waveform loading
self
.
pyannote_audio
=
Audio
()
logger
.
info
(
"Initialized pyannote.audio speaker diarization pipeline"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to initialize pyannote:
{
e
}
"
)
raise
RuntimeError
(
f
"Failed to initialize pyannote.audio pipeline:
{
e
}
"
)
def
separate_speakers
(
self
,
audio_path
:
Union
[
str
,
bytes
],
num_speakers
:
Optional
[
int
]
=
None
,
min_speakers
:
int
=
1
,
max_speakers
:
int
=
5
,
)
->
Dict
:
"""
Separate different speakers in audio
Args:
audio_path: Audio file path or bytes data
num_speakers: Specified number of speakers, None for auto detection
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Dict containing:
- speakers: List of speaker audio segments, each containing:
- speaker_id: Speaker ID (0, 1, 2, ...)
- audio: torch.Tensor audio data [channels, samples]
- segments: List of (start_time, end_time) tuples
- sample_rate: Sample rate
"""
try
:
# Load audio
if
isinstance
(
audio_path
,
bytes
):
# 尝试从字节数据推断音频格式
# 检查是否是 WAV 格式(RIFF 头)
is_wav
=
audio_path
[:
4
]
==
b
"RIFF"
and
audio_path
[
8
:
12
]
==
b
"WAVE"
# 检查是否是 MP3 格式(ID3 或 MPEG 头)
is_mp3
=
audio_path
[:
3
]
==
b
"ID3"
or
audio_path
[:
2
]
==
b
"
\xff\xfb
"
or
audio_path
[:
2
]
==
b
"
\xff\xf3
"
# 根据格式选择后缀
if
is_wav
:
suffix
=
".wav"
elif
is_mp3
:
suffix
=
".mp3"
else
:
# 默认尝试 WAV,如果失败会抛出错误
suffix
=
".wav"
with
tempfile
.
NamedTemporaryFile
(
suffix
=
suffix
,
delete
=
False
)
as
tmp_file
:
tmp_file
.
write
(
audio_path
)
tmp_audio_path
=
tmp_file
.
name
try
:
result
=
self
.
_separate_speakers_internal
(
tmp_audio_path
,
num_speakers
,
min_speakers
,
max_speakers
)
finally
:
# 确保临时文件被删除
try
:
os
.
unlink
(
tmp_audio_path
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to delete temp file
{
tmp_audio_path
}
:
{
e
}
"
)
return
result
else
:
return
self
.
_separate_speakers_internal
(
audio_path
,
num_speakers
,
min_speakers
,
max_speakers
)
except
Exception
as
e
:
logger
.
error
(
f
"Speaker separation failed:
{
traceback
.
format_exc
()
}
"
)
raise
RuntimeError
(
f
"Audio separation error:
{
e
}
"
)
def
_separate_speakers_internal
(
self
,
audio_path
:
str
,
num_speakers
:
Optional
[
int
]
=
None
,
min_speakers
:
int
=
1
,
max_speakers
:
int
=
5
,
)
->
Dict
:
"""Internal method: execute speaker separation"""
# Load audio
waveform
,
original_sr
=
torchaudio
.
load
(
audio_path
)
if
original_sr
!=
self
.
sample_rate
:
resampler
=
torchaudio
.
transforms
.
Resample
(
original_sr
,
self
.
sample_rate
)
waveform
=
resampler
(
waveform
)
# Convert to mono if stereo
if
waveform
.
shape
[
0
]
>
1
:
waveform
=
waveform
.
mean
(
dim
=
0
,
keepdim
=
True
)
# Ensure waveform is float32 and normalized (pyannote expects this format)
if
waveform
.
dtype
!=
torch
.
float32
:
waveform
=
waveform
.
float
()
# Ensure waveform is in range [-1, 1] (normalize if needed)
if
waveform
.
abs
().
max
()
>
1.0
:
waveform
=
waveform
/
waveform
.
abs
().
max
()
if
self
.
pipeline
is
None
:
raise
RuntimeError
(
"Pyannote pipeline not initialized"
)
return
self
.
_separate_with_pyannote
(
audio_path
,
waveform
,
num_speakers
,
min_speakers
,
max_speakers
)
def
_separate_with_pyannote
(
self
,
audio_path
:
str
,
waveform
:
torch
.
Tensor
,
num_speakers
:
Optional
[
int
],
min_speakers
:
int
,
max_speakers
:
int
,
)
->
Dict
:
"""Use pyannote.audio for speaker diarization"""
try
:
# Use waveform dict to avoid AudioDecoder dependency issues
# Pipeline can accept either file path or waveform dict
# Using waveform dict is more reliable when torchcodec is not properly installed
audio_input
=
{
"waveform"
:
waveform
,
"sample_rate"
:
self
.
sample_rate
,
}
# Run speaker diarization
output
=
self
.
pipeline
(
audio_input
,
min_speakers
=
min_speakers
if
num_speakers
is
None
else
num_speakers
,
max_speakers
=
max_speakers
if
num_speakers
is
None
else
num_speakers
,
)
# Extract audio segments for each speaker
speakers_dict
=
defaultdict
(
list
)
for
turn
,
speaker
in
output
.
speaker_diarization
:
print
(
f
"Speaker:
{
speaker
}
, Start time:
{
turn
.
start
}
, End time:
{
turn
.
end
}
"
)
start_time
=
turn
.
start
end_time
=
turn
.
end
start_sample
=
int
(
start_time
*
self
.
sample_rate
)
end_sample
=
int
(
end_time
*
self
.
sample_rate
)
# Extract audio segment for this time period
segment_audio
=
waveform
[:,
start_sample
:
end_sample
]
speakers_dict
[
speaker
].
append
((
start_time
,
end_time
,
segment_audio
))
# Generate complete audio for each speaker (other speakers' segments are empty)
speakers
=
[]
audio_duration
=
waveform
.
shape
[
1
]
/
self
.
sample_rate
num_samples
=
waveform
.
shape
[
1
]
for
speaker_id
,
segments
in
speakers_dict
.
items
():
# Create zero-filled audio
speaker_audio
=
torch
.
zeros_like
(
waveform
)
# Fill in this speaker's segments
for
start_time
,
end_time
,
segment_audio
in
segments
:
start_sample
=
int
(
start_time
*
self
.
sample_rate
)
end_sample
=
int
(
end_time
*
self
.
sample_rate
)
# Ensure no out-of-bounds
end_sample
=
min
(
end_sample
,
num_samples
)
segment_len
=
end_sample
-
start_sample
if
segment_len
>
0
and
segment_audio
.
shape
[
1
]
>
0
:
actual_len
=
min
(
segment_len
,
segment_audio
.
shape
[
1
])
speaker_audio
[:,
start_sample
:
start_sample
+
actual_len
]
=
segment_audio
[:,
:
actual_len
]
speakers
.
append
(
{
"speaker_id"
:
speaker_id
,
"audio"
:
speaker_audio
,
"segments"
:
[(
s
[
0
],
s
[
1
])
for
s
in
segments
],
"sample_rate"
:
self
.
sample_rate
,
}
)
logger
.
info
(
f
"Separated audio into
{
len
(
speakers
)
}
speakers using pyannote"
)
return
{
"speakers"
:
speakers
,
"method"
:
"pyannote"
}
except
Exception
as
e
:
logger
.
error
(
f
"Pyannote separation failed:
{
e
}
"
)
raise
RuntimeError
(
f
"Audio separation failed:
{
e
}
"
)
def
save_speaker_audio
(
self
,
speaker_audio
:
torch
.
Tensor
,
output_path
:
str
,
sample_rate
:
int
=
None
):
"""
Save speaker audio to file
Args:
speaker_audio: Audio tensor [channels, samples]
output_path: Output path
sample_rate: Sample rate, if None uses self.sample_rate
"""
sr
=
sample_rate
if
sample_rate
else
self
.
sample_rate
torchaudio
.
save
(
output_path
,
speaker_audio
,
sr
)
logger
.
info
(
f
"Saved speaker audio to
{
output_path
}
"
)
def
speaker_audio_to_base64
(
self
,
speaker_audio
:
torch
.
Tensor
,
sample_rate
:
int
=
None
,
format
:
str
=
"wav"
)
->
str
:
"""
Convert speaker audio tensor to base64 encoded string without saving to file
Args:
speaker_audio: Audio tensor [channels, samples]
sample_rate: Sample rate, if None uses self.sample_rate
format: Audio format (default: "wav")
Returns:
Base64 encoded audio string
"""
sr
=
sample_rate
if
sample_rate
else
self
.
sample_rate
# Use BytesIO to save audio to memory
buffer
=
io
.
BytesIO
()
torchaudio
.
save
(
buffer
,
speaker_audio
,
sr
,
format
=
format
)
# Get the audio bytes
audio_bytes
=
buffer
.
getvalue
()
# Encode to base64
audio_base64
=
base64
.
b64encode
(
audio_bytes
).
decode
(
"utf-8"
)
logger
.
debug
(
f
"Converted speaker audio to base64, size:
{
len
(
audio_bytes
)
}
bytes"
)
return
audio_base64
def
separate_and_save
(
self
,
audio_path
:
Union
[
str
,
bytes
],
output_dir
:
str
,
num_speakers
:
Optional
[
int
]
=
None
,
min_speakers
:
int
=
1
,
max_speakers
:
int
=
5
,
)
->
Dict
:
"""
Separate audio and save to files
Args:
audio_path: Input audio path or bytes data
output_dir: Output directory
num_speakers: Specified number of speakers
min_speakers: Minimum number of speakers
max_speakers: Maximum number of speakers
Returns:
Separation result dictionary, containing output file paths
"""
os
.
makedirs
(
output_dir
,
exist_ok
=
True
)
result
=
self
.
separate_speakers
(
audio_path
,
num_speakers
,
min_speakers
,
max_speakers
)
output_paths
=
[]
for
speaker
in
result
[
"speakers"
]:
speaker_id
=
speaker
[
"speaker_id"
]
output_path
=
os
.
path
.
join
(
output_dir
,
f
"
{
speaker_id
}
.wav"
)
self
.
save_speaker_audio
(
speaker
[
"audio"
],
output_path
,
speaker
[
"sample_rate"
])
output_paths
.
append
(
output_path
)
speaker
[
"output_path"
]
=
output_path
result
[
"output_paths"
]
=
output_paths
return
result
def
separate_audio_tracks
(
audio_path
:
str
,
output_dir
:
str
=
None
,
num_speakers
:
int
=
None
,
model_path
:
str
=
None
,
)
->
Dict
:
"""
Convenience function: separate different audio tracks
Args:
audio_path: Audio file path
output_dir: Output directory, if None does not save files
num_speakers: Number of speakers
model_path: Model path (optional)
Returns:
Separation result dictionary
"""
separator
=
AudioSeparator
(
model_path
=
model_path
)
if
output_dir
:
return
separator
.
separate_and_save
(
audio_path
,
output_dir
,
num_speakers
=
num_speakers
)
else
:
return
separator
.
separate_speakers
(
audio_path
,
num_speakers
=
num_speakers
)
if
__name__
==
"__main__"
:
# Test code
import
sys
if
len
(
sys
.
argv
)
<
2
:
print
(
"Usage: python audio_separator.py <audio_path> [output_dir] [num_speakers]"
)
sys
.
exit
(
1
)
audio_path
=
sys
.
argv
[
1
]
output_dir
=
sys
.
argv
[
2
]
if
len
(
sys
.
argv
)
>
2
else
"./separated_audio"
num_speakers
=
int
(
sys
.
argv
[
3
])
if
len
(
sys
.
argv
)
>
3
else
None
separator
=
AudioSeparator
()
result
=
separator
.
separate_and_save
(
audio_path
,
output_dir
,
num_speakers
=
num_speakers
)
print
(
f
"Separated audio into
{
len
(
result
[
'speakers'
])
}
speakers:"
)
for
speaker
in
result
[
"speakers"
]:
print
(
f
" Speaker
{
speaker
[
'speaker_id'
]
}
:
{
len
(
speaker
[
'segments'
])
}
segments"
)
if
"output_path"
in
speaker
:
print
(
f
" Saved to:
{
speaker
[
'output_path'
]
}
"
)
lightx2v/deploy/common/face_detector.py
0 → 100644
View file @
a1ebc651
# -*- coding: utf-8 -*-
"""
Face Detection Module using YOLO
Supports detecting faces in images, including human faces, animal faces, anime faces, sketches, etc.
"""
import
io
import
traceback
from
typing
import
Dict
,
List
,
Union
import
numpy
as
np
from
PIL
import
Image
,
ImageDraw
from
loguru
import
logger
from
ultralytics
import
YOLO
class
FaceDetector
:
"""
Face detection using YOLO models
Supports detecting: human faces, animal faces, anime faces, sketch faces, etc.
"""
def
__init__
(
self
,
model_path
:
str
=
None
,
conf_threshold
:
float
=
0.25
,
device
:
str
=
None
):
"""
Initialize face detector
Args:
model_path: YOLO model path, if None uses default pretrained model
conf_threshold: Confidence threshold, default 0.25
device: Device ('cpu', 'cuda', '0', '1', etc.), None for auto selection
"""
self
.
conf_threshold
=
conf_threshold
self
.
device
=
device
if
model_path
is
None
:
# Use YOLO11 pretrained model, can detect COCO dataset classes (including person)
# Or use dedicated face detection model
logger
.
info
(
"Loading default YOLO11n model for face detection"
)
try
:
self
.
model
=
YOLO
(
"yolo11n.pt"
)
# Lightweight model
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to load default model, trying yolov8n:
{
e
}
"
)
self
.
model
=
YOLO
(
"yolov8n.pt"
)
else
:
logger
.
info
(
f
"Loading YOLO model from
{
model_path
}
"
)
self
.
model
=
YOLO
(
model_path
)
# Person class ID in COCO dataset is 0
# YOLO can detect person, for more precise face detection, recommend using dedicated face detection models
# Such as YOLOv8-face or RetinaFace, can be specified via model_path parameter
# First use YOLO to detect person region, then can further detect faces within
self
.
target_classes
=
{
"person"
:
0
,
# Face (by detecting person class)
# Can be extended to detect animal faces (cat, dog, etc.) and other classes
}
def
detect_faces
(
self
,
image
:
Union
[
str
,
Image
.
Image
,
bytes
,
np
.
ndarray
],
return_image
:
bool
=
False
,
)
->
Dict
:
"""
Detect faces in image
Args:
image: Input image, can be path, PIL Image, bytes or numpy array
return_image: Whether to return annotated image with detection boxes
return_boxes: Whether to return detection box information
Returns:
Dict containing:
- faces: List of face detection results, each containing:
- bbox: [x1, y1, x2, y2] bounding box coordinates (absolute pixel coordinates)
- confidence: Confidence score (0.0-1.0)
- class_id: Class ID
- class_name: Class name
- image (optional): PIL Image with detection boxes drawn (if return_image=True)
"""
try
:
# Load image
if
isinstance
(
image
,
str
):
img
=
Image
.
open
(
image
).
convert
(
"RGB"
)
elif
isinstance
(
image
,
bytes
):
img
=
Image
.
open
(
io
.
BytesIO
(
image
)).
convert
(
"RGB"
)
elif
isinstance
(
image
,
np
.
ndarray
):
img
=
Image
.
fromarray
(
image
).
convert
(
"RGB"
)
elif
isinstance
(
image
,
Image
.
Image
):
img
=
image
.
convert
(
"RGB"
)
else
:
raise
ValueError
(
f
"Unsupported image type:
{
type
(
image
)
}
"
)
# Use YOLO for detection
# Note: YOLO by default detects person, we focus on person detection
# For more precise face detection, can train or use dedicated face detection models
results
=
self
.
model
.
predict
(
source
=
img
,
conf
=
self
.
conf_threshold
,
device
=
self
.
device
,
verbose
=
False
,
)
faces
=
[]
annotated_img
=
img
.
copy
()
if
return_image
else
None
if
len
(
results
)
>
0
:
result
=
results
[
0
]
boxes
=
result
.
boxes
if
boxes
is
not
None
and
len
(
boxes
)
>
0
:
for
i
in
range
(
len
(
boxes
)):
# Get bounding box coordinates (xyxy format)
bbox
=
boxes
.
xyxy
[
i
].
cpu
().
numpy
().
tolist
()
confidence
=
float
(
boxes
.
conf
[
i
].
cpu
().
numpy
())
class_id
=
int
(
boxes
.
cls
[
i
].
cpu
().
numpy
())
# Get class name
class_name
=
result
.
names
.
get
(
class_id
,
"unknown"
)
# Process target classes (person, etc.)
# For person, the entire body box contains face region
# For more precise face detection, can:
# 1. Use dedicated face detection models (RetinaFace, YOLOv8-face)
# 2. Further use face detection model within current person box
# 3. Use specifically trained multi-class detection models (faces, animal faces, anime faces, etc.)
if
class_id
in
self
.
target_classes
.
values
():
face_info
=
{
"bbox"
:
bbox
,
# [x1, y1, x2, y2] - absolute pixel coordinates
"confidence"
:
confidence
,
"class_id"
:
class_id
,
"class_name"
:
class_name
,
}
faces
.
append
(
face_info
)
# Draw annotations on image if needed
if
return_image
and
annotated_img
is
not
None
:
draw
=
ImageDraw
.
Draw
(
annotated_img
)
x1
,
y1
,
x2
,
y2
=
bbox
# Draw bounding box
draw
.
rectangle
(
[
x1
,
y1
,
x2
,
y2
],
outline
=
"red"
,
width
=
2
,
)
# Draw label
label
=
f
"
{
class_name
}
{
confidence
:.
2
f
}
"
draw
.
text
((
x1
,
y1
-
15
),
label
,
fill
=
"red"
)
result_dict
=
{
"faces"
:
faces
}
if
return_image
and
annotated_img
is
not
None
:
result_dict
[
"image"
]
=
annotated_img
logger
.
info
(
f
"Detected
{
len
(
faces
)
}
faces in image"
)
return
result_dict
except
Exception
as
e
:
logger
.
error
(
f
"Face detection failed:
{
traceback
.
format_exc
()
}
"
)
raise
RuntimeError
(
f
"Face detection error:
{
e
}
"
)
def
detect_faces_from_bytes
(
self
,
image_bytes
:
bytes
,
**
kwargs
)
->
Dict
:
"""
Detect faces from byte data
Args:
image_bytes: Image byte data
**kwargs: Additional parameters passed to detect_faces
Returns:
Detection result dictionary
"""
return
self
.
detect_faces
(
image_bytes
,
**
kwargs
)
def
extract_face_regions
(
self
,
image
:
Union
[
str
,
Image
.
Image
,
bytes
],
expand_ratio
:
float
=
0.1
)
->
List
[
Image
.
Image
]:
"""
Extract detected face regions
Args:
image: Input image
expand_ratio: Bounding box expansion ratio to include more context
Returns:
List of extracted face region images
"""
result
=
self
.
detect_faces
(
image
)
faces
=
result
[
"faces"
]
# Load original image
if
isinstance
(
image
,
str
):
img
=
Image
.
open
(
image
).
convert
(
"RGB"
)
elif
isinstance
(
image
,
bytes
):
img
=
Image
.
open
(
io
.
BytesIO
(
image
)).
convert
(
"RGB"
)
elif
isinstance
(
image
,
Image
.
Image
):
img
=
image
.
convert
(
"RGB"
)
else
:
raise
ValueError
(
f
"Unsupported image type:
{
type
(
image
)
}
"
)
face_regions
=
[]
img_width
,
img_height
=
img
.
size
for
face
in
faces
:
x1
,
y1
,
x2
,
y2
=
face
[
"bbox"
]
# Expand bounding box
width
=
x2
-
x1
height
=
y2
-
y1
expand_x
=
width
*
expand_ratio
expand_y
=
height
*
expand_ratio
x1
=
max
(
0
,
int
(
x1
-
expand_x
))
y1
=
max
(
0
,
int
(
y1
-
expand_y
))
x2
=
min
(
img_width
,
int
(
x2
+
expand_x
))
y2
=
min
(
img_height
,
int
(
y2
+
expand_y
))
# Crop region
face_region
=
img
.
crop
((
x1
,
y1
,
x2
,
y2
))
face_regions
.
append
(
face_region
)
return
face_regions
def
count_faces
(
self
,
image
:
Union
[
str
,
Image
.
Image
,
bytes
])
->
int
:
"""
Count number of faces in image
Args:
image: Input image
Returns:
Number of detected faces
"""
result
=
self
.
detect_faces
(
image
,
return_image
=
False
)
return
len
(
result
[
"faces"
])
def
detect_faces_in_image
(
image_path
:
str
,
model_path
:
str
=
None
,
conf_threshold
:
float
=
0.25
,
return_image
:
bool
=
False
,
)
->
Dict
:
"""
Convenience function: detect faces in image
Args:
image_path: Image path
model_path: YOLO model path
conf_threshold: Confidence threshold
return_image: Whether to return annotated image
Returns:
Detection result dictionary containing:
- faces: List of face detection results with bbox coordinates [x1, y1, x2, y2]
- image (optional): Annotated image with detection boxes
"""
detector
=
FaceDetector
(
model_path
=
model_path
,
conf_threshold
=
conf_threshold
)
return
detector
.
detect_faces
(
image_path
,
return_image
=
return_image
)
if
__name__
==
"__main__"
:
# Test code
import
sys
if
len
(
sys
.
argv
)
<
2
:
print
(
"Usage: python face_detector.py <image_path>"
)
sys
.
exit
(
1
)
image_path
=
sys
.
argv
[
1
]
detector
=
FaceDetector
()
result
=
detector
.
detect_faces
(
image_path
,
return_image
=
True
)
print
(
f
"Detected
{
len
(
result
[
'faces'
])
}
faces:"
)
for
i
,
face
in
enumerate
(
result
[
"faces"
]):
print
(
f
" Face
{
i
+
1
}
:
{
face
}
"
)
output_path
=
"detected_faces.png"
result
[
"image"
].
save
(
output_path
)
print
(
f
"Annotated image saved to:
{
output_path
}
"
)
lightx2v/deploy/common/pipeline.py
0 → 100644
View file @
a1ebc651
import
json
import
sys
from
loguru
import
logger
class
Pipeline
:
def
__init__
(
self
,
pipeline_json_file
):
self
.
pipeline_json_file
=
pipeline_json_file
x
=
json
.
load
(
open
(
pipeline_json_file
))
self
.
data
=
x
[
"data"
]
self
.
meta
=
x
[
"meta"
]
self
.
inputs
=
{}
self
.
outputs
=
{}
self
.
temps
=
{}
self
.
model_lists
=
[]
self
.
types
=
{}
self
.
queues
=
set
()
self
.
model_name_inner_to_outer
=
self
.
meta
.
get
(
"model_name_inner_to_outer"
,
{})
self
.
model_name_outer_to_inner
=
self
.
meta
.
get
(
"model_name_outer_to_inner"
,
{})
self
.
tidy_pipeline
()
def
init_dict
(
self
,
base
,
task
,
model_cls
):
if
task
not
in
base
:
base
[
task
]
=
{}
if
model_cls
not
in
base
[
task
]:
base
[
task
][
model_cls
]
=
{}
# tidy each task item eg, ['t2v', 'wan2.1', 'multi_stage']
def
tidy_task
(
self
,
task
,
model_cls
,
stage
,
v3
):
out2worker
=
{}
out2num
=
{}
cur_inps
=
set
()
cur_temps
=
set
()
cur_types
=
{}
for
worker_name
,
worker_item
in
v3
.
items
():
prevs
=
[]
for
inp
in
worker_item
[
"inputs"
]:
cur_types
[
inp
]
=
self
.
get_type
(
inp
)
if
inp
in
out2worker
:
prevs
.
append
(
out2worker
[
inp
])
out2num
[
inp
]
-=
1
if
out2num
[
inp
]
<=
0
:
cur_temps
.
add
(
inp
)
else
:
cur_inps
.
add
(
inp
)
worker_item
[
"previous"
]
=
prevs
for
out
in
worker_item
[
"outputs"
]:
cur_types
[
out
]
=
self
.
get_type
(
out
)
out2worker
[
out
]
=
worker_name
if
out
not
in
out2num
:
out2num
[
out
]
=
0
out2num
[
out
]
+=
1
if
"queue"
not
in
worker_item
:
worker_item
[
"queue"
]
=
"-"
.
join
([
task
,
model_cls
,
stage
,
worker_name
])
self
.
queues
.
add
(
worker_item
[
"queue"
])
cur_outs
=
[
out
for
out
,
num
in
out2num
.
items
()
if
num
>
0
]
self
.
inputs
[
task
][
model_cls
][
stage
]
=
list
(
cur_inps
)
self
.
outputs
[
task
][
model_cls
][
stage
]
=
cur_outs
self
.
temps
[
task
][
model_cls
][
stage
]
=
list
(
cur_temps
)
self
.
types
[
task
][
model_cls
][
stage
]
=
cur_types
# tidy previous dependence workers and queue name
def
tidy_pipeline
(
self
):
for
task
,
v1
in
self
.
data
.
items
():
for
model_cls
,
v2
in
v1
.
items
():
for
stage
,
v3
in
v2
.
items
():
self
.
init_dict
(
self
.
inputs
,
task
,
model_cls
)
self
.
init_dict
(
self
.
outputs
,
task
,
model_cls
)
self
.
init_dict
(
self
.
temps
,
task
,
model_cls
)
self
.
init_dict
(
self
.
types
,
task
,
model_cls
)
self
.
tidy_task
(
task
,
model_cls
,
stage
,
v3
)
self
.
model_lists
.
append
({
"task"
:
task
,
"model_cls"
:
model_cls
,
"stage"
:
stage
})
logger
.
info
(
f
"pipelines:
{
json
.
dumps
(
self
.
data
,
indent
=
4
)
}
"
)
logger
.
info
(
f
"inputs:
{
self
.
inputs
}
"
)
logger
.
info
(
f
"outputs:
{
self
.
outputs
}
"
)
logger
.
info
(
f
"temps:
{
self
.
temps
}
"
)
logger
.
info
(
f
"types:
{
self
.
types
}
"
)
logger
.
info
(
f
"model_lists:
{
self
.
model_lists
}
"
)
logger
.
info
(
f
"queues:
{
self
.
queues
}
"
)
def
get_item_by_keys
(
self
,
keys
):
item
=
self
.
data
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in
{
self
.
pipeline_json_file
}
!"
)
item
=
item
[
k
]
return
item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage', 'text_encoder']
def
get_worker
(
self
,
keys
):
return
self
.
get_item_by_keys
(
keys
)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_workers
(
self
,
keys
):
return
self
.
get_item_by_keys
(
keys
)
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_inputs
(
self
,
keys
):
item
=
self
.
inputs
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in inputs!"
)
item
=
item
[
k
]
return
item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_outputs
(
self
,
keys
):
item
=
self
.
outputs
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in outputs!"
)
item
=
item
[
k
]
return
item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_temps
(
self
,
keys
):
item
=
self
.
temps
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in temps!"
)
item
=
item
[
k
]
return
item
# eg. keys: ['t2v', 'wan2.1', 'multi_stage']
def
get_types
(
self
,
keys
):
item
=
self
.
types
for
k
in
keys
:
if
k
not
in
item
:
raise
Exception
(
f
"
{
keys
}
are not in types!"
)
item
=
item
[
k
]
return
item
def
check_item_by_keys
(
self
,
keys
):
item
=
self
.
data
for
k
in
keys
:
if
k
not
in
item
:
return
False
item
=
item
[
k
]
return
True
def
get_model_lists
(
self
):
return
self
.
model_lists
def
get_type
(
self
,
name
):
return
self
.
meta
[
"special_types"
].
get
(
name
,
"OBJECT"
)
def
get_monitor_config
(
self
):
return
self
.
meta
[
"monitor"
]
def
get_queues
(
self
):
return
self
.
queues
def
inner_model_name
(
self
,
name
):
return
self
.
model_name_outer_to_inner
.
get
(
name
,
name
)
def
outer_model_name
(
self
,
name
):
return
self
.
model_name_inner_to_outer
.
get
(
name
,
name
)
if
__name__
==
"__main__"
:
pipeline
=
Pipeline
(
sys
.
argv
[
1
])
print
(
pipeline
.
get_workers
([
"t2v"
,
"wan2.1"
,
"multi_stage"
]))
print
(
pipeline
.
get_worker
([
"i2v"
,
"wan2.1"
,
"multi_stage"
,
"dit"
]))
lightx2v/deploy/common/podcasts.py
0 → 100644
View file @
a1ebc651
# -*- coding: utf-8 -*-
import
asyncio
import
io
import
json
import
os
import
struct
import
uuid
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
typing
import
Callable
,
List
,
Optional
import
websockets
from
loguru
import
logger
from
pydub
import
AudioSegment
# Protocol definitions (from podcasts_protocols)
class
MsgType
(
IntEnum
):
"""Message type enumeration"""
Invalid
=
0
FullClientRequest
=
0b1
AudioOnlyClient
=
0b10
FullServerResponse
=
0b1001
AudioOnlyServer
=
0b1011
FrontEndResultServer
=
0b1100
Error
=
0b1111
ServerACK
=
AudioOnlyServer
class
MsgTypeFlagBits
(
IntEnum
):
"""Message type flag bits"""
NoSeq
=
0
PositiveSeq
=
0b1
LastNoSeq
=
0b10
NegativeSeq
=
0b11
WithEvent
=
0b100
class
VersionBits
(
IntEnum
):
"""Version bits"""
Version1
=
1
class
HeaderSizeBits
(
IntEnum
):
"""Header size bits"""
HeaderSize4
=
1
HeaderSize8
=
2
HeaderSize12
=
3
HeaderSize16
=
4
class
SerializationBits
(
IntEnum
):
"""Serialization method bits"""
Raw
=
0
JSON
=
0b1
Thrift
=
0b11
Custom
=
0b1111
class
CompressionBits
(
IntEnum
):
"""Compression method bits"""
None_
=
0
Gzip
=
0b1
Custom
=
0b1111
class
EventType
(
IntEnum
):
"""Event type enumeration"""
None_
=
0
StartConnection
=
1
StartTask
=
1
FinishConnection
=
2
FinishTask
=
2
ConnectionStarted
=
50
TaskStarted
=
50
ConnectionFailed
=
51
TaskFailed
=
51
ConnectionFinished
=
52
TaskFinished
=
52
StartSession
=
100
CancelSession
=
101
FinishSession
=
102
SessionStarted
=
150
SessionCanceled
=
151
SessionFinished
=
152
SessionFailed
=
153
UsageResponse
=
154
ChargeData
=
154
TaskRequest
=
200
UpdateConfig
=
201
AudioMuted
=
250
SayHello
=
300
TTSSentenceStart
=
350
TTSSentenceEnd
=
351
TTSResponse
=
352
TTSEnded
=
359
PodcastRoundStart
=
360
PodcastRoundResponse
=
361
PodcastRoundEnd
=
362
PodcastEnd
=
363
@
dataclass
class
Message
:
"""Message object"""
version
:
VersionBits
=
VersionBits
.
Version1
header_size
:
HeaderSizeBits
=
HeaderSizeBits
.
HeaderSize4
type
:
MsgType
=
MsgType
.
Invalid
flag
:
MsgTypeFlagBits
=
MsgTypeFlagBits
.
NoSeq
serialization
:
SerializationBits
=
SerializationBits
.
JSON
compression
:
CompressionBits
=
CompressionBits
.
None_
event
:
EventType
=
EventType
.
None_
session_id
:
str
=
""
connect_id
:
str
=
""
sequence
:
int
=
0
error_code
:
int
=
0
payload
:
bytes
=
b
""
@
classmethod
def
from_bytes
(
cls
,
data
:
bytes
)
->
"Message"
:
"""Create message object from bytes"""
if
len
(
data
)
<
3
:
raise
ValueError
(
f
"Data too short: expected at least 3 bytes, got
{
len
(
data
)
}
"
)
type_and_flag
=
data
[
1
]
msg_type
=
MsgType
(
type_and_flag
>>
4
)
flag
=
MsgTypeFlagBits
(
type_and_flag
&
0b00001111
)
msg
=
cls
(
type
=
msg_type
,
flag
=
flag
)
msg
.
unmarshal
(
data
)
return
msg
def
marshal
(
self
)
->
bytes
:
"""Serialize message to bytes"""
buffer
=
io
.
BytesIO
()
header
=
[
(
self
.
version
<<
4
)
|
self
.
header_size
,
(
self
.
type
<<
4
)
|
self
.
flag
,
(
self
.
serialization
<<
4
)
|
self
.
compression
,
]
header_size
=
4
*
self
.
header_size
if
padding
:
=
header_size
-
len
(
header
):
header
.
extend
([
0
]
*
padding
)
buffer
.
write
(
bytes
(
header
))
writers
=
self
.
_get_writers
()
for
writer
in
writers
:
writer
(
buffer
)
return
buffer
.
getvalue
()
def
unmarshal
(
self
,
data
:
bytes
)
->
None
:
"""Deserialize message from bytes"""
buffer
=
io
.
BytesIO
(
data
)
version_and_header_size
=
buffer
.
read
(
1
)[
0
]
self
.
version
=
VersionBits
(
version_and_header_size
>>
4
)
self
.
header_size
=
HeaderSizeBits
(
version_and_header_size
&
0b00001111
)
buffer
.
read
(
1
)
serialization_compression
=
buffer
.
read
(
1
)[
0
]
self
.
serialization
=
SerializationBits
(
serialization_compression
>>
4
)
self
.
compression
=
CompressionBits
(
serialization_compression
&
0b00001111
)
header_size
=
4
*
self
.
header_size
read_size
=
3
if
padding_size
:
=
header_size
-
read_size
:
buffer
.
read
(
padding_size
)
readers
=
self
.
_get_readers
()
for
reader
in
readers
:
reader
(
buffer
)
remaining
=
buffer
.
read
()
if
remaining
:
raise
ValueError
(
f
"Unexpected data after message:
{
remaining
}
"
)
def
_get_writers
(
self
)
->
List
[
Callable
[[
io
.
BytesIO
],
None
]]:
"""Get list of writer functions"""
writers
=
[]
if
self
.
flag
==
MsgTypeFlagBits
.
WithEvent
:
writers
.
extend
([
self
.
_write_event
,
self
.
_write_session_id
])
if
self
.
type
in
[
MsgType
.
FullClientRequest
,
MsgType
.
FullServerResponse
,
MsgType
.
FrontEndResultServer
,
MsgType
.
AudioOnlyClient
,
MsgType
.
AudioOnlyServer
]:
if
self
.
flag
in
[
MsgTypeFlagBits
.
PositiveSeq
,
MsgTypeFlagBits
.
NegativeSeq
]:
writers
.
append
(
self
.
_write_sequence
)
elif
self
.
type
==
MsgType
.
Error
:
writers
.
append
(
self
.
_write_error_code
)
else
:
raise
ValueError
(
f
"Unsupported message type:
{
self
.
type
}
"
)
writers
.
append
(
self
.
_write_payload
)
return
writers
def
_get_readers
(
self
)
->
List
[
Callable
[[
io
.
BytesIO
],
None
]]:
"""Get list of reader functions"""
readers
=
[]
if
self
.
type
in
[
MsgType
.
FullClientRequest
,
MsgType
.
FullServerResponse
,
MsgType
.
FrontEndResultServer
,
MsgType
.
AudioOnlyClient
,
MsgType
.
AudioOnlyServer
]:
if
self
.
flag
in
[
MsgTypeFlagBits
.
PositiveSeq
,
MsgTypeFlagBits
.
NegativeSeq
]:
readers
.
append
(
self
.
_read_sequence
)
elif
self
.
type
==
MsgType
.
Error
:
readers
.
append
(
self
.
_read_error_code
)
if
self
.
flag
==
MsgTypeFlagBits
.
WithEvent
:
readers
.
extend
([
self
.
_read_event
,
self
.
_read_session_id
,
self
.
_read_connect_id
])
readers
.
append
(
self
.
_read_payload
)
return
readers
def
_write_event
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
buffer
.
write
(
struct
.
pack
(
">i"
,
self
.
event
))
def
_write_session_id
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
if
self
.
event
in
[
EventType
.
StartConnection
,
EventType
.
FinishConnection
,
EventType
.
ConnectionStarted
,
EventType
.
ConnectionFailed
]:
return
session_id_bytes
=
self
.
session_id
.
encode
(
"utf-8"
)
size
=
len
(
session_id_bytes
)
if
size
>
0xFFFFFFFF
:
raise
ValueError
(
f
"Session ID size (
{
size
}
) exceeds max(uint32)"
)
buffer
.
write
(
struct
.
pack
(
">I"
,
size
))
if
size
>
0
:
buffer
.
write
(
session_id_bytes
)
def
_write_sequence
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
buffer
.
write
(
struct
.
pack
(
">i"
,
self
.
sequence
))
def
_write_error_code
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
buffer
.
write
(
struct
.
pack
(
">I"
,
self
.
error_code
))
def
_write_payload
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
size
=
len
(
self
.
payload
)
if
size
>
0xFFFFFFFF
:
raise
ValueError
(
f
"Payload size (
{
size
}
) exceeds max(uint32)"
)
buffer
.
write
(
struct
.
pack
(
">I"
,
size
))
buffer
.
write
(
self
.
payload
)
def
_read_event
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
event_bytes
=
buffer
.
read
(
4
)
if
event_bytes
:
self
.
event
=
EventType
(
struct
.
unpack
(
">i"
,
event_bytes
)[
0
])
def
_read_session_id
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
if
self
.
event
in
[
EventType
.
StartConnection
,
EventType
.
FinishConnection
,
EventType
.
ConnectionStarted
,
EventType
.
ConnectionFailed
,
EventType
.
ConnectionFinished
]:
return
size_bytes
=
buffer
.
read
(
4
)
if
size_bytes
:
size
=
struct
.
unpack
(
">I"
,
size_bytes
)[
0
]
if
size
>
0
:
session_id_bytes
=
buffer
.
read
(
size
)
if
len
(
session_id_bytes
)
==
size
:
self
.
session_id
=
session_id_bytes
.
decode
(
"utf-8"
)
def
_read_connect_id
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
if
self
.
event
in
[
EventType
.
ConnectionStarted
,
EventType
.
ConnectionFailed
,
EventType
.
ConnectionFinished
]:
size_bytes
=
buffer
.
read
(
4
)
if
size_bytes
:
size
=
struct
.
unpack
(
">I"
,
size_bytes
)[
0
]
if
size
>
0
:
self
.
connect_id
=
buffer
.
read
(
size
).
decode
(
"utf-8"
)
def
_read_sequence
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
sequence_bytes
=
buffer
.
read
(
4
)
if
sequence_bytes
:
self
.
sequence
=
struct
.
unpack
(
">i"
,
sequence_bytes
)[
0
]
def
_read_error_code
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
error_code_bytes
=
buffer
.
read
(
4
)
if
error_code_bytes
:
self
.
error_code
=
struct
.
unpack
(
">I"
,
error_code_bytes
)[
0
]
def
_read_payload
(
self
,
buffer
:
io
.
BytesIO
)
->
None
:
size_bytes
=
buffer
.
read
(
4
)
if
size_bytes
:
size
=
struct
.
unpack
(
">I"
,
size_bytes
)[
0
]
if
size
>
0
:
self
.
payload
=
buffer
.
read
(
size
)
async
def
receive_message
(
websocket
:
websockets
.
WebSocketClientProtocol
)
->
Message
:
"""Receive message from websocket"""
try
:
data
=
await
websocket
.
recv
()
if
isinstance
(
data
,
str
):
raise
ValueError
(
f
"Unexpected text message:
{
data
}
"
)
elif
isinstance
(
data
,
bytes
):
msg
=
Message
.
from_bytes
(
data
)
# logger.debug(f"Received: {msg}")
return
msg
else
:
raise
ValueError
(
f
"Unexpected message type:
{
type
(
data
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to receive message:
{
e
}
"
)
raise
async
def
wait_for_event
(
websocket
:
websockets
.
WebSocketClientProtocol
,
msg_type
:
MsgType
,
event_type
:
EventType
)
->
Message
:
"""Wait for specific event"""
while
True
:
msg
=
await
receive_message
(
websocket
)
if
msg
.
type
!=
msg_type
or
msg
.
event
!=
event_type
:
raise
ValueError
(
f
"Unexpected message:
{
msg
}
"
)
if
msg
.
type
==
msg_type
and
msg
.
event
==
event_type
:
return
msg
async
def
start_connection
(
websocket
:
websockets
.
WebSocketClientProtocol
)
->
None
:
"""Start connection"""
msg
=
Message
(
type
=
MsgType
.
FullClientRequest
,
flag
=
MsgTypeFlagBits
.
WithEvent
)
msg
.
event
=
EventType
.
StartConnection
msg
.
payload
=
b
"{}"
logger
.
debug
(
f
"Sending:
{
msg
}
"
)
await
websocket
.
send
(
msg
.
marshal
())
async
def
finish_connection
(
websocket
:
websockets
.
WebSocketClientProtocol
)
->
None
:
"""Finish connection"""
msg
=
Message
(
type
=
MsgType
.
FullClientRequest
,
flag
=
MsgTypeFlagBits
.
WithEvent
)
msg
.
event
=
EventType
.
FinishConnection
msg
.
payload
=
b
"{}"
logger
.
debug
(
f
"Sending:
{
msg
}
"
)
await
websocket
.
send
(
msg
.
marshal
())
async
def
start_session
(
websocket
:
websockets
.
WebSocketClientProtocol
,
payload
:
bytes
,
session_id
:
str
)
->
None
:
"""Start session"""
msg
=
Message
(
type
=
MsgType
.
FullClientRequest
,
flag
=
MsgTypeFlagBits
.
WithEvent
)
msg
.
event
=
EventType
.
StartSession
msg
.
session_id
=
session_id
msg
.
payload
=
payload
logger
.
debug
(
f
"Sending:
{
msg
}
"
)
await
websocket
.
send
(
msg
.
marshal
())
async
def
finish_session
(
websocket
:
websockets
.
WebSocketClientProtocol
,
session_id
:
str
)
->
None
:
"""Finish session"""
msg
=
Message
(
type
=
MsgType
.
FullClientRequest
,
flag
=
MsgTypeFlagBits
.
WithEvent
)
msg
.
event
=
EventType
.
FinishSession
msg
.
session_id
=
session_id
msg
.
payload
=
b
"{}"
logger
.
debug
(
f
"Sending:
{
msg
}
"
)
await
websocket
.
send
(
msg
.
marshal
())
class
PodcastRoundPostProcessor
:
def
__init__
(
self
,
session_id
,
data_manager
):
self
.
session_id
=
session_id
self
.
data_manager
=
data_manager
self
.
temp_merged_audio_name
=
"merged_audio.mp3"
self
.
output_merged_audio_name
=
f
"
{
session_id
}
-merged_audio.mp3"
self
.
subtitle_timestamps
=
[]
# 记录字幕时间戳
self
.
current_audio_duration
=
0.0
# 当前音频时长
self
.
merged_audio
=
None
# 用于存储合并的音频对象
self
.
merged_audio_bytes
=
None
async
def
init
(
self
):
if
self
.
data_manager
:
await
self
.
data_manager
.
create_podcast_temp_session_dir
(
self
.
session_id
)
async
def
postprocess_round
(
self
,
current_round
,
voice
,
audio
,
podcast_texts
):
text
=
""
if
podcast_texts
:
text
=
podcast_texts
[
-
1
].
get
(
"text"
,
""
)
logger
.
debug
(
f
"Processing round:
{
current_round
}
, voice:
{
voice
}
, text:
{
text
}
, audio:
{
len
(
audio
)
}
bytes"
)
new_segment
=
AudioSegment
.
from_mp3
(
io
.
BytesIO
(
bytes
(
audio
)))
round_duration
=
len
(
new_segment
)
/
1000.0
if
self
.
merged_audio
is
None
:
self
.
merged_audio
=
new_segment
else
:
self
.
merged_audio
=
self
.
merged_audio
+
new_segment
# 保存合并后的音频到临时文件(用于前端实时访问)
merged_io
=
io
.
BytesIO
()
self
.
merged_audio
.
export
(
merged_io
,
format
=
"mp3"
)
self
.
merged_audio_bytes
=
merged_io
.
getvalue
()
if
self
.
data_manager
:
await
self
.
data_manager
.
save_podcast_temp_session_file
(
self
.
session_id
,
self
.
temp_merged_audio_name
,
self
.
merged_audio_bytes
)
merged_file_size
=
len
(
self
.
merged_audio_bytes
)
# 记录字幕时间戳
self
.
subtitle_timestamps
.
append
(
{
"start"
:
self
.
current_audio_duration
,
"end"
:
self
.
current_audio_duration
+
round_duration
,
"text"
:
text
,
"speaker"
:
voice
,
}
)
self
.
current_audio_duration
+=
round_duration
logger
.
debug
(
f
"Merged audio updated:
{
merged_file_size
}
bytes, duration:
{
self
.
current_audio_duration
:.
2
f
}
s"
)
return
{
"url"
:
f
"/api/v1/podcast/audio?session_id=
{
self
.
session_id
}
&filename=
{
self
.
temp_merged_audio_name
}
"
,
"size"
:
merged_file_size
,
"duration"
:
self
.
current_audio_duration
,
"round"
:
current_round
,
"text"
:
text
,
"speaker"
:
voice
,
}
async
def
postprocess_final
(
self
):
if
self
.
data_manager
:
await
self
.
data_manager
.
save_podcast_output_file
(
self
.
output_merged_audio_name
,
self
.
merged_audio_bytes
)
return
{
"subtitles"
:
self
.
subtitle_timestamps
,
"audio_name"
:
self
.
output_merged_audio_name
,
}
async
def
cleanup
(
self
):
if
self
.
data_manager
:
await
self
.
data_manager
.
clear_podcast_temp_session_dir
(
self
.
session_id
)
self
.
data_manager
=
None
class
VolcEnginePodcastClient
:
"""
VolcEngine Podcast客户端
支持多种播客类型:
- action=0: 文本转播客
- action=3: NLP文本转播客
- action=4: 提示词生成播客
"""
def
__init__
(
self
):
self
.
endpoint
=
"wss://openspeech.bytedance.com/api/v3/sami/podcasttts"
self
.
appid
=
os
.
getenv
(
"VOLCENGINE_PODCAST_APPID"
)
self
.
access_token
=
os
.
getenv
(
"VOLCENGINE_PODCAST_ACCESS_TOKEN"
)
self
.
app_key
=
"aGjiRDfUWi"
self
.
proxy
=
os
.
getenv
(
"HTTPS_PROXY"
,
None
)
if
self
.
proxy
:
logger
.
info
(
f
"volcengine podcast use proxy:
{
self
.
proxy
}
"
)
async
def
podcast_request
(
self
,
session_id
:
str
,
data_manager
=
None
,
text
:
str
=
""
,
input_url
:
str
=
""
,
prompt_text
:
str
=
""
,
nlp_texts
:
str
=
""
,
action
:
int
=
0
,
resource_id
:
str
=
"volc.service_type.10050"
,
encoding
:
str
=
"mp3"
,
input_id
:
str
=
"test_podcast"
,
speaker_info
:
str
=
'{"random_order":false}'
,
use_head_music
:
bool
=
False
,
use_tail_music
:
bool
=
False
,
only_nlp_text
:
bool
=
False
,
return_audio_url
:
bool
=
False
,
skip_round_audio_save
:
bool
=
False
,
on_round_complete
:
Optional
[
Callable
]
=
None
,
):
"""
执行播客请求
Args:
text: 输入文本 (action=0时使用)
input_url: Web URL或文件URL (action=0时使用)
prompt_text: 提示词文本 (action=4时必须)
nlp_texts: NLP文本 (action=3时必须)
action: 播客类型 (0/3/4)
resource_id: 音频资源ID
encoding: 音频格式 (mp3/wav)
input_id: 唯一输入标识
speaker_info: 播客说话人信息
use_head_music: 是否使用开头音乐
use_tail_music: 是否使用结尾音乐
only_nlp_text: 是否只返回播客文本
return_audio_url: 是否返回音频URL
skip_round_audio_save: 是否跳过单轮音频保存
output_dir: 输出目录
on_round_complete: 轮次完成回调函数
"""
if
not
self
.
appid
or
not
self
.
access_token
:
logger
.
error
(
"APP ID or Access Key is required"
)
return
None
,
None
headers
=
{
"X-Api-App-Id"
:
self
.
appid
,
"X-Api-App-Key"
:
self
.
app_key
,
"X-Api-Access-Key"
:
self
.
access_token
,
"X-Api-Resource-Id"
:
resource_id
,
"X-Api-Connect-Id"
:
str
(
uuid
.
uuid4
()),
}
is_podcast_round_end
=
True
audio_received
=
False
last_round_id
=
-
1
task_id
=
""
websocket
=
None
retry_num
=
5
audio
=
bytearray
()
voice
=
""
current_round
=
0
podcast_texts
=
[]
post_processor
=
PodcastRoundPostProcessor
(
session_id
,
data_manager
)
await
post_processor
.
init
()
try
:
while
retry_num
>
0
:
# 建立WebSocket连接
websocket
=
await
websockets
.
connect
(
self
.
endpoint
,
additional_headers
=
headers
)
logger
.
debug
(
f
"WebSocket connected:
{
websocket
.
response
.
headers
}
"
)
# 构建请求参数
if
input_url
:
req_params
=
{
"input_id"
:
input_id
,
"nlp_texts"
:
json
.
loads
(
nlp_texts
)
if
nlp_texts
else
None
,
"prompt_text"
:
prompt_text
,
"action"
:
action
,
"use_head_music"
:
use_head_music
,
"use_tail_music"
:
use_tail_music
,
"input_info"
:
{
"input_url"
:
input_url
,
"return_audio_url"
:
return_audio_url
,
"only_nlp_text"
:
only_nlp_text
,
},
"speaker_info"
:
json
.
loads
(
speaker_info
)
if
speaker_info
else
None
,
"audio_config"
:
{
"format"
:
encoding
,
"sample_rate"
:
24000
,
"speech_rate"
:
0
},
}
else
:
req_params
=
{
"input_id"
:
input_id
,
"input_text"
:
text
,
"nlp_texts"
:
json
.
loads
(
nlp_texts
)
if
nlp_texts
else
None
,
"prompt_text"
:
prompt_text
,
"action"
:
action
,
"use_head_music"
:
use_head_music
,
"use_tail_music"
:
use_tail_music
,
"input_info"
:
{
"input_url"
:
input_url
,
"return_audio_url"
:
return_audio_url
,
"only_nlp_text"
:
only_nlp_text
,
},
"speaker_info"
:
json
.
loads
(
speaker_info
)
if
speaker_info
else
None
,
"audio_config"
:
{
"format"
:
encoding
,
"sample_rate"
:
24000
,
"speech_rate"
:
0
},
}
logger
.
debug
(
f
"Request params:
{
json
.
dumps
(
req_params
,
indent
=
2
,
ensure_ascii
=
False
)
}
"
)
if
not
is_podcast_round_end
:
req_params
[
"retry_info"
]
=
{
"retry_task_id"
:
task_id
,
"last_finished_round_id"
:
last_round_id
}
# Start connection
await
start_connection
(
websocket
)
await
wait_for_event
(
websocket
,
MsgType
.
FullServerResponse
,
EventType
.
ConnectionStarted
)
session_id
=
str
(
uuid
.
uuid4
())
if
not
task_id
:
task_id
=
session_id
# Start session
await
start_session
(
websocket
,
json
.
dumps
(
req_params
).
encode
(),
session_id
)
await
wait_for_event
(
websocket
,
MsgType
.
FullServerResponse
,
EventType
.
SessionStarted
)
# Finish session
await
finish_session
(
websocket
,
session_id
)
while
True
:
msg
=
await
receive_message
(
websocket
)
# 音频数据块
if
msg
.
type
==
MsgType
.
AudioOnlyServer
and
msg
.
event
==
EventType
.
PodcastRoundResponse
:
if
not
audio_received
and
audio
:
audio_received
=
True
audio
.
extend
(
msg
.
payload
)
# 错误信息
elif
msg
.
type
==
MsgType
.
Error
:
raise
RuntimeError
(
f
"Server error:
{
msg
.
payload
.
decode
()
}
"
)
elif
msg
.
type
==
MsgType
.
FullServerResponse
:
# 播客 round 开始
if
msg
.
event
==
EventType
.
PodcastRoundStart
:
data
=
json
.
loads
(
msg
.
payload
.
decode
())
if
data
.
get
(
"text"
):
filtered_payload
=
{
"text"
:
data
.
get
(
"text"
),
"speaker"
:
data
.
get
(
"speaker"
)}
podcast_texts
.
append
(
filtered_payload
)
voice
=
data
.
get
(
"speaker"
)
current_round
=
data
.
get
(
"round_id"
)
if
current_round
==
-
1
:
voice
=
"head_music"
if
current_round
==
9999
:
voice
=
"tail_music"
is_podcast_round_end
=
False
logger
.
debug
(
f
"New round started:
{
data
}
"
)
# 播客 round 结束
if
msg
.
event
==
EventType
.
PodcastRoundEnd
:
data
=
json
.
loads
(
msg
.
payload
.
decode
())
logger
.
debug
(
f
"Podcast round end:
{
data
}
"
)
if
data
.
get
(
"is_error"
):
break
is_podcast_round_end
=
True
last_round_id
=
current_round
if
audio
:
round_info
=
await
post_processor
.
postprocess_round
(
current_round
,
voice
,
audio
,
podcast_texts
)
if
on_round_complete
:
await
on_round_complete
(
round_info
)
audio
.
clear
()
# 播客结束
if
msg
.
event
==
EventType
.
PodcastEnd
:
data
=
json
.
loads
(
msg
.
payload
.
decode
())
logger
.
info
(
f
"Podcast end:
{
data
}
"
)
# 会话结束
if
msg
.
event
==
EventType
.
SessionFinished
:
break
if
not
audio_received
and
not
only_nlp_text
:
raise
RuntimeError
(
"No audio data received"
)
# 保持连接
await
finish_connection
(
websocket
)
await
wait_for_event
(
websocket
,
MsgType
.
FullServerResponse
,
EventType
.
ConnectionFinished
)
# 播客结束, 保存最终音频文件
if
is_podcast_round_end
:
podcast_info
=
await
post_processor
.
postprocess_final
()
return
podcast_info
else
:
logger
.
error
(
f
"Current podcast not finished, resuming from round
{
last_round_id
}
"
)
retry_num
-=
1
await
asyncio
.
sleep
(
1
)
if
websocket
:
await
websocket
.
close
()
finally
:
await
post_processor
.
cleanup
()
if
websocket
:
await
websocket
.
close
()
return
None
async
def
test
(
args
):
"""
Podcast测试函数
Args:
args: dict, 包含所有podcast参数
"""
client
=
VolcEnginePodcastClient
()
# 设置默认参数
params
=
{
"text"
:
""
,
"input_url"
:
"https://zhuanlan.zhihu.com/p/607822576"
,
"prompt_text"
:
""
,
"nlp_texts"
:
""
,
"action"
:
0
,
"resource_id"
:
"volc.service_type.10050"
,
"encoding"
:
"mp3"
,
"input_id"
:
"test_podcast"
,
"speaker_info"
:
'{"random_order":false}'
,
"use_head_music"
:
False
,
"use_tail_music"
:
False
,
"only_nlp_text"
:
False
,
"return_audio_url"
:
True
,
"skip_round_audio_save"
:
False
,
"output_dir"
:
"output"
,
}
# 覆盖默认参数
if
args
:
params
.
update
(
args
)
await
client
.
podcast_request
(
**
params
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--text"
,
default
=
""
,
help
=
"Input text Use when action in [0]"
)
parser
.
add_argument
(
"--input_url"
,
default
=
""
,
help
=
"Web url or file url Use when action in [0]"
)
parser
.
add_argument
(
"--prompt_text"
,
default
=
""
,
help
=
"Input Prompt Text must not empty when action in [4]"
)
parser
.
add_argument
(
"--nlp_texts"
,
default
=
""
,
help
=
"Input NLP Texts must not empty when action in [3]"
)
parser
.
add_argument
(
"--resource_id"
,
default
=
"volc.service_type.10050"
,
help
=
"Audio Resource ID"
)
parser
.
add_argument
(
"--encoding"
,
default
=
"mp3"
,
choices
=
[
"mp3"
,
"wav"
],
help
=
"Audio format"
)
parser
.
add_argument
(
"--input_id"
,
default
=
"test_podcast"
,
help
=
"Unique input identifier"
)
parser
.
add_argument
(
"--speaker_info"
,
default
=
'{"random_order":false}'
,
help
=
"Podcast Speaker Info"
)
parser
.
add_argument
(
"--use_head_music"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Enable head music"
)
parser
.
add_argument
(
"--use_tail_music"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Enable tail music"
)
parser
.
add_argument
(
"--only_nlp_text"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Enable only podcast text when action in [0, 4]"
)
parser
.
add_argument
(
"--return_audio_url"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Enable return audio url that can download"
)
parser
.
add_argument
(
"--action"
,
default
=
0
,
type
=
int
,
choices
=
[
0
,
3
,
4
],
help
=
"different podcast type"
)
parser
.
add_argument
(
"--skip_round_audio_save"
,
default
=
False
,
action
=
"store_true"
,
help
=
"skip round audio save"
)
parser
.
add_argument
(
"--output_dir"
,
default
=
"output"
,
help
=
"Output directory"
)
args
=
parser
.
parse_args
()
kwargs
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
v
is
not
None
and
not
(
isinstance
(
v
,
bool
)
and
not
v
)}
asyncio
.
run
(
test
(
kwargs
))
lightx2v/deploy/common/utils.py
0 → 100644
View file @
a1ebc651
import
asyncio
import
base64
import
io
import
os
import
subprocess
import
tempfile
import
time
import
traceback
from
datetime
import
datetime
import
httpx
import
torchaudio
from
PIL
import
Image
from
loguru
import
logger
FMT
=
"%Y-%m-%d %H:%M:%S"
def
current_time
():
return
datetime
.
now
().
timestamp
()
def
time2str
(
t
):
d
=
datetime
.
fromtimestamp
(
t
)
return
d
.
strftime
(
FMT
)
def
str2time
(
s
):
d
=
datetime
.
strptime
(
s
,
FMT
)
return
d
.
timestamp
()
def
try_catch
(
func
):
def
wrapper
(
*
args
,
**
kwargs
):
try
:
return
func
(
*
args
,
**
kwargs
)
except
Exception
:
logger
.
error
(
f
"Error in
{
func
.
__name__
}
:"
)
traceback
.
print_exc
()
return
None
return
wrapper
def
class_try_catch
(
func
):
def
wrapper
(
self
,
*
args
,
**
kwargs
):
try
:
return
func
(
self
,
*
args
,
**
kwargs
)
except
Exception
:
logger
.
error
(
f
"Error in
{
self
.
__class__
.
__name__
}
.
{
func
.
__name__
}
:"
)
traceback
.
print_exc
()
return
None
return
wrapper
def
class_try_catch_async
(
func
):
async
def
wrapper
(
self
,
*
args
,
**
kwargs
):
try
:
return
await
func
(
self
,
*
args
,
**
kwargs
)
except
Exception
:
logger
.
error
(
f
"Error in
{
self
.
__class__
.
__name__
}
.
{
func
.
__name__
}
:"
)
traceback
.
print_exc
()
return
None
return
wrapper
def
data_name
(
x
,
task_id
):
if
x
==
"input_image"
:
x
=
x
+
".png"
elif
x
==
"input_video"
:
x
=
x
+
".mp4"
elif
x
==
"output_video"
:
x
=
x
+
".mp4"
return
f
"
{
task_id
}
-
{
x
}
"
async
def
fetch_resource
(
url
,
timeout
):
logger
.
info
(
f
"Begin to download resource from url:
{
url
}
"
)
t0
=
time
.
time
()
async
with
httpx
.
AsyncClient
()
as
client
:
async
with
client
.
stream
(
"GET"
,
url
,
timeout
=
timeout
)
as
response
:
response
.
raise_for_status
()
ans_bytes
=
[]
async
for
chunk
in
response
.
aiter_bytes
(
chunk_size
=
1024
*
1024
):
ans_bytes
.
append
(
chunk
)
if
len
(
ans_bytes
)
>
128
:
raise
Exception
(
f
"url
{
url
}
recv data is too big"
)
content
=
b
""
.
join
(
ans_bytes
)
logger
.
info
(
f
"Download url
{
url
}
resource cost time:
{
time
.
time
()
-
t0
}
seconds"
)
return
content
# check, resize, read rotate meta info
def
format_image_data
(
data
,
max_size
=
1280
):
image
=
Image
.
open
(
io
.
BytesIO
(
data
)).
convert
(
"RGB"
)
exif
=
image
.
getexif
()
changed
=
False
w
,
h
=
image
.
size
assert
w
>
0
and
h
>
0
,
"image is empty"
logger
.
info
(
f
"load image:
{
w
}
x
{
h
}
, exif:
{
exif
}
"
)
if
w
>
max_size
or
h
>
max_size
:
ratio
=
max_size
/
max
(
w
,
h
)
w
=
int
(
w
*
ratio
)
h
=
int
(
h
*
ratio
)
image
=
image
.
resize
((
w
,
h
))
logger
.
info
(
f
"resize image to:
{
image
.
size
}
"
)
changed
=
True
orientation_key
=
274
if
orientation_key
and
orientation_key
in
exif
:
orientation
=
exif
[
orientation_key
]
if
orientation
==
2
:
image
=
image
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
elif
orientation
==
3
:
image
=
image
.
rotate
(
180
,
expand
=
True
)
elif
orientation
==
4
:
image
=
image
.
transpose
(
Image
.
FLIP_TOP_BOTTOM
)
elif
orientation
==
5
:
image
=
image
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
).
rotate
(
90
,
expand
=
True
)
elif
orientation
==
6
:
image
=
image
.
rotate
(
270
,
expand
=
True
)
elif
orientation
==
7
:
image
=
image
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
).
rotate
(
270
,
expand
=
True
)
elif
orientation
==
8
:
image
=
image
.
rotate
(
90
,
expand
=
True
)
# reset orientation to 1
if
orientation
!=
1
:
logger
.
info
(
f
"reset orientation from
{
orientation
}
to 1"
)
exif
[
orientation_key
]
=
1
changed
=
True
if
not
changed
:
return
data
output
=
io
.
BytesIO
()
image
.
save
(
output
,
format
=
image
.
format
or
"JPEG"
,
exif
=
exif
.
tobytes
())
return
output
.
getvalue
()
def
media_to_wav
(
data
):
with
tempfile
.
NamedTemporaryFile
()
as
fin
:
fin
.
write
(
data
)
fin
.
flush
()
cmd
=
[
"ffmpeg"
,
"-i"
,
fin
.
name
,
"-f"
,
"wav"
,
"-acodec"
,
"pcm_s16le"
,
"-ar"
,
"44100"
,
"-ac"
,
"2"
,
"pipe:1"
]
p
=
subprocess
.
run
(
cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
assert
p
.
returncode
==
0
,
f
"media to wav failed:
{
p
.
stderr
.
decode
()
}
"
return
p
.
stdout
def
format_audio_data
(
data
):
if
len
(
data
)
<
4
:
raise
ValueError
(
"Audio file too short"
)
data
=
media_to_wav
(
data
)
waveform
,
sample_rate
=
torchaudio
.
load
(
io
.
BytesIO
(
data
),
num_frames
=
10
)
logger
.
info
(
f
"load audio:
{
waveform
.
size
()
}
,
{
sample_rate
}
"
)
assert
waveform
.
numel
()
>
0
,
"audio is empty"
assert
sample_rate
>
0
,
"audio sample rate is not valid"
return
data
async
def
preload_data
(
inp
,
inp_type
,
typ
,
val
):
try
:
if
typ
==
"url"
:
timeout
=
int
(
os
.
getenv
(
"REQUEST_TIMEOUT"
,
"5"
))
data
=
await
fetch_resource
(
val
,
timeout
=
timeout
)
elif
typ
==
"base64"
:
# Decode base64 in background thread to avoid blocking event loop
data
=
await
asyncio
.
to_thread
(
base64
.
b64decode
,
val
)
# For multi-person audio directory, val should be a dict with file structure
elif
typ
==
"directory"
:
data
=
{}
for
fname
,
b64_data
in
val
.
items
():
data
[
fname
]
=
await
asyncio
.
to_thread
(
base64
.
b64decode
,
b64_data
)
return
{
"type"
:
"directory"
,
"data"
:
data
}
elif
typ
==
"stream"
:
# no bytes data need to be saved by data_manager
data
=
None
else
:
raise
ValueError
(
f
"cannot read
{
inp
}
[
{
inp_type
}
] which type is
{
typ
}
!"
)
# check if valid image bytes
if
inp_type
==
"IMAGE"
:
data
=
await
asyncio
.
to_thread
(
format_image_data
,
data
)
elif
inp_type
==
"AUDIO"
:
if
typ
!=
"stream"
and
typ
!=
"directory"
:
data
=
await
asyncio
.
to_thread
(
format_audio_data
,
data
)
elif
inp_type
==
"VIDEO"
:
# Video data doesn't need special formatting, just validate it's not empty
if
len
(
data
)
==
0
:
raise
ValueError
(
"Video file is empty"
)
logger
.
info
(
f
"load video:
{
len
(
data
)
}
bytes"
)
else
:
raise
Exception
(
f
"cannot parse inp_type=
{
inp_type
}
data"
)
return
data
except
Exception
as
e
:
raise
ValueError
(
f
"Failed to read
{
inp
}
, type=
{
typ
}
, val=
{
val
[:
100
]
}
:
{
e
}
!"
)
async
def
load_inputs
(
params
,
raw_inputs
,
types
):
inputs_data
=
{}
for
inp
in
raw_inputs
:
item
=
params
.
pop
(
inp
)
bytes_data
=
await
preload_data
(
inp
,
types
[
inp
],
item
[
"type"
],
item
[
"data"
])
# Handle multi-person audio directory
if
bytes_data
is
not
None
and
isinstance
(
bytes_data
,
dict
)
and
bytes_data
.
get
(
"type"
)
==
"directory"
:
fs
=
[]
for
fname
,
fdata
in
bytes_data
[
"data"
].
items
():
inputs_data
[
f
"
{
inp
}
/
{
fname
}
"
]
=
fdata
fs
.
append
(
f
"
{
inp
}
/
{
fname
}
"
)
params
[
"extra_inputs"
]
=
{
inp
:
fs
}
elif
bytes_data
is
not
None
:
inputs_data
[
inp
]
=
bytes_data
else
:
params
[
inp
]
=
item
return
inputs_data
def
check_params
(
params
,
raw_inputs
,
raw_outputs
,
types
):
stream_audio
=
os
.
getenv
(
"STREAM_AUDIO"
,
"0"
)
==
"1"
stream_video
=
os
.
getenv
(
"STREAM_VIDEO"
,
"0"
)
==
"1"
for
x
in
raw_inputs
+
raw_outputs
:
if
x
in
params
and
"type"
in
params
[
x
]:
if
params
[
x
][
"type"
]
==
"stream"
:
if
types
[
x
]
==
"AUDIO"
:
assert
stream_audio
,
"stream audio is not supported, please set env STREAM_AUDIO=1"
elif
types
[
x
]
==
"VIDEO"
:
assert
stream_video
,
"stream video is not supported, please set env STREAM_VIDEO=1"
elif
params
[
x
][
"type"
]
==
"directory"
:
# Multi-person audio directory is only supported for AUDIO type
assert
types
[
x
]
==
"AUDIO"
,
f
"directory type is only supported for AUDIO input, got
{
types
[
x
]
}
"
if
__name__
==
"__main__"
:
# https://github.com/recurser/exif-orientation-examples
exif_dir
=
"/data/nvme0/liuliang1/exif-orientation-examples"
out_dir
=
"/data/nvme0/liuliang1/exif-orientation-examples/outs"
os
.
makedirs
(
out_dir
,
exist_ok
=
True
)
for
base_name
in
[
"Landscape"
,
"Portrait"
]:
for
i
in
range
(
9
):
fin_name
=
os
.
path
.
join
(
exif_dir
,
f
"
{
base_name
}
_
{
i
}
.jpg"
)
fout_name
=
os
.
path
.
join
(
out_dir
,
f
"
{
base_name
}
_
{
i
}
_formatted.jpg"
)
logger
.
info
(
f
"format image:
{
fin_name
}
->
{
fout_name
}
"
)
with
open
(
fin_name
,
"rb"
)
as
f
:
data
=
f
.
read
()
data
=
format_image_data
(
data
)
with
open
(
fout_name
,
"wb"
)
as
f
:
f
.
write
(
data
)
lightx2v/deploy/common/va_controller.py
0 → 100644
View file @
a1ebc651
import
math
import
os
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.models.runners.vsr.vsr_wrapper
import
compute_scaled_and_target_dims
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
NextControl
:
def
__init__
(
self
,
action
:
str
,
data
:
any
=
None
):
# action: switch, data: prev_video tensor
# action: wait, data: None
# action: fetch, data: None
self
.
action
=
action
self
.
data
=
data
class
VAController
:
def
__init__
(
self
,
model_runner
):
self
.
reader
=
None
self
.
recorder
=
None
self
.
rank
=
0
self
.
world_size
=
1
if
dist
.
is_initialized
():
self
.
rank
=
dist
.
get_rank
()
self
.
world_size
=
dist
.
get_world_size
()
self
.
target_reader_rank
=
int
(
os
.
getenv
(
"READER_RANK"
,
"0"
))
%
self
.
world_size
self
.
target_recorder_rank
=
int
(
os
.
getenv
(
"RECORDER_RANK"
,
"0"
))
%
self
.
world_size
self
.
init_base
(
model_runner
.
config
,
model_runner
.
input_info
,
model_runner
.
vfi_model
is
not
None
,
model_runner
.
vsr_model
is
not
None
)
self
.
init_recorder
()
self
.
init_reader
(
model_runner
)
def
init_base
(
self
,
config
,
input_info
,
has_vfi_model
,
has_vsr_model
):
self
.
audio_path
=
input_info
.
audio_path
self
.
output_video_path
=
input_info
.
save_result_path
if
isinstance
(
self
.
output_video_path
,
dict
):
self
.
output_video_path
=
self
.
output_video_path
[
"data"
]
self
.
audio_sr
=
config
.
get
(
"audio_sr"
,
16000
)
self
.
target_fps
=
config
.
get
(
"target_fps"
,
16
)
self
.
max_num_frames
=
config
.
get
(
"target_video_length"
,
81
)
self
.
prev_frame_length
=
config
.
get
(
"prev_frame_length"
,
5
)
self
.
record_fps
=
config
.
get
(
"target_fps"
,
16
)
if
"video_frame_interpolation"
in
config
and
has_vfi_model
:
self
.
record_fps
=
config
[
"video_frame_interpolation"
][
"target_fps"
]
self
.
record_fps
=
config
.
get
(
"record_fps"
,
self
.
record_fps
)
self
.
tgt_h
=
input_info
.
target_shape
[
0
]
self
.
tgt_w
=
input_info
.
target_shape
[
1
]
self
.
record_h
,
self
.
record_w
=
self
.
tgt_h
,
self
.
tgt_w
if
"video_super_resolution"
in
config
and
has_vsr_model
:
_
,
_
,
self
.
record_w
,
self
.
record_h
=
compute_scaled_and_target_dims
(
self
.
record_w
,
self
.
record_h
,
scale
=
config
[
"video_super_resolution"
][
"scale"
],
multiple
=
128
,
)
# how many frames to publish stream as a batch
self
.
slice_frame
=
config
.
get
(
"slice_frame"
,
self
.
prev_frame_length
)
# estimate the max infer seconds, for immediate switch with local omni
slice_interval
=
self
.
slice_frame
/
self
.
record_fps
est_max_infer_secs
=
config
.
get
(
"est_max_infer_secs"
,
0.6
)
self
.
est_infer_end_idx
=
math
.
ceil
(
est_max_infer_secs
/
slice_interval
)
self
.
min_stay_queue_num
=
self
.
est_infer_end_idx
*
2
+
1
def
init_recorder
(
self
):
if
not
self
.
output_video_path
or
self
.
rank
!=
self
.
target_recorder_rank
:
return
logger
.
info
(
f
"Rank
{
self
.
rank
}
init recorder with:
{
self
.
output_video_path
}
"
)
whip_shared_path
=
os
.
getenv
(
"WHIP_SHARED_LIB"
,
None
)
if
whip_shared_path
and
self
.
output_video_path
.
startswith
(
"http"
):
from
lightx2v.deploy.common.va_recorder_x264
import
X264VARecorder
self
.
recorder
=
X264VARecorder
(
whip_shared_path
=
whip_shared_path
,
livestream_url
=
self
.
output_video_path
,
fps
=
self
.
record_fps
,
sample_rate
=
self
.
audio_sr
,
slice_frame
=
self
.
slice_frame
,
prev_frame
=
self
.
prev_frame_length
,
)
else
:
from
lightx2v.deploy.common.va_recorder
import
VARecorder
self
.
recorder
=
VARecorder
(
livestream_url
=
self
.
output_video_path
,
fps
=
self
.
record_fps
,
sample_rate
=
self
.
audio_sr
,
slice_frame
=
self
.
slice_frame
,
prev_frame
=
self
.
prev_frame_length
,
)
def
init_reader
(
self
,
model_runner
=
None
):
if
not
isinstance
(
self
.
audio_path
,
dict
):
return
assert
self
.
audio_path
[
"type"
]
==
"stream"
,
f
"unexcept audio_path:
{
self
.
audio_path
}
"
segment_duration
=
self
.
max_num_frames
/
self
.
target_fps
prev_duration
=
self
.
prev_frame_length
/
self
.
target_fps
omni_work_dir
=
os
.
getenv
(
"OMNI_WORK_DIR"
,
None
)
if
omni_work_dir
:
from
lightx2v.deploy.common.va_reader_omni
import
OmniVAReader
self
.
reader
=
OmniVAReader
(
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
stream_url
=
self
.
audio_path
[
"data"
],
sample_rate
=
self
.
audio_sr
,
segment_duration
=
segment_duration
,
prev_duration
=
prev_duration
,
target_rank
=
self
.
target_reader_rank
,
model_runner
=
model_runner
,
huoshan_tts_voice_type
=
self
.
audio_path
.
get
(
"huoshan_tts_voice_type"
,
None
),
)
else
:
from
lightx2v.deploy.common.va_reader
import
VAReader
self
.
reader
=
VAReader
(
rank
=
self
.
rank
,
world_size
=
self
.
world_size
,
stream_url
=
self
.
audio_path
[
"data"
],
sample_rate
=
self
.
audio_sr
,
segment_duration
=
segment_duration
,
prev_duration
=
prev_duration
,
target_rank
=
self
.
target_reader_rank
,
)
def
start
(
self
):
self
.
reader
.
start
()
if
self
.
rank
==
self
.
target_recorder_rank
:
assert
self
.
recorder
is
not
None
,
f
"recorder is required for stream audio input for rank
{
self
.
rank
}
"
self
.
recorder
.
start
(
self
.
record_w
,
self
.
record_h
)
if
self
.
world_size
>
1
:
dist
.
barrier
()
def
next_control
(
self
):
from
lightx2v.deploy.common.va_reader_omni
import
OmniVAReader
if
isinstance
(
self
.
reader
,
OmniVAReader
):
return
self
.
omni_reader_next_control
()
return
NextControl
(
action
=
"fetch"
)
def
before_control
(
self
):
from
lightx2v.deploy.common.va_reader_omni
import
OmniVAReader
if
isinstance
(
self
.
reader
,
OmniVAReader
):
self
.
len_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
self
.
flag_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
,
device
=
AI_DEVICE
)
self
.
prev_tensor
=
torch
.
zeros
((
1
,
3
,
self
.
prev_frame_length
,
self
.
tgt_h
,
self
.
tgt_w
),
dtype
=
torch
.
float
,
device
=
AI_DEVICE
)
def
omni_reader_next_control
(
self
):
immediate_switch
=
self
.
reader
.
get_immediate_switch
()
if
immediate_switch
==
1
:
# truncate the stream buffer to keep the max infer time length
# and broadcast the prev video tensor to all ranks
if
self
.
rank
==
self
.
target_recorder_rank
:
logger
.
warning
(
f
"runner recv immediate switch, truncate stream buffer"
)
video_tensor
=
self
.
recorder
.
truncate_stream_buffer
(
self
.
est_infer_end_idx
)
if
video_tensor
is
not
None
:
self
.
flag_tensor
.
fill_
(
1
)
self
.
prev_tensor
.
copy_
(
video_tensor
)
else
:
self
.
flag_tensor
.
fill_
(
0
)
dist
.
broadcast
(
self
.
flag_tensor
,
src
=
self
.
target_recorder_rank
)
if
self
.
flag_tensor
.
item
()
==
1
:
dist
.
broadcast
(
self
.
prev_tensor
,
src
=
self
.
target_recorder_rank
)
return
NextControl
(
action
=
"switch"
,
data
=
self
.
prev_tensor
)
else
:
# get the length of stream buffer, broadcast to all ranks
if
self
.
rank
==
self
.
target_recorder_rank
:
stream_buffer_length
=
self
.
recorder
.
get_buffer_stream_size
()
self
.
len_tensor
.
copy_
(
stream_buffer_length
)
dist
.
broadcast
(
self
.
len_tensor
,
src
=
self
.
target_recorder_rank
)
buffer_length
=
self
.
len_tensor
.
item
()
# stream buffer is enough, skip infer
if
buffer_length
>=
self
.
min_stay_queue_num
:
return
NextControl
(
action
=
"wait"
)
return
NextControl
(
action
=
"fetch"
)
def
pub_livestream
(
self
,
images
:
torch
.
Tensor
,
audios
:
torch
.
Tensor
,
gen_video
:
torch
.
Tensor
):
if
self
.
recorder
.
realtime
:
self
.
recorder
.
buffer_stream
(
images
,
audios
,
gen_video
)
else
:
self
.
recorder
.
pub_livestream
(
images
,
audios
)
def
clear
(
self
):
self
.
len_tensor
=
None
self
.
flag_tensor
=
None
self
.
prev_tensor
=
None
if
self
.
reader
is
not
None
:
self
.
reader
.
stop
()
self
.
reader
=
None
if
self
.
recorder
is
not
None
:
self
.
recorder
.
stop
()
self
.
recorder
=
None
def
__del__
(
self
):
self
.
clear
()
lightx2v/deploy/common/va_reader.py
0 → 100644
View file @
a1ebc651
import
os
import
queue
import
signal
import
subprocess
import
threading
import
time
import
traceback
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
class
VAReader
:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
stream_url
:
str
,
segment_duration
:
float
=
5.0
,
sample_rate
:
int
=
16000
,
audio_channels
:
int
=
1
,
buffer_size
:
int
=
1
,
prev_duration
:
float
=
0.3125
,
target_rank
:
int
=
0
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
stream_url
=
stream_url
self
.
segment_duration
=
segment_duration
self
.
sample_rate
=
sample_rate
self
.
audio_channels
=
audio_channels
self
.
prev_duration
=
prev_duration
# int16 = 2 bytes
self
.
chunk_size
=
int
(
self
.
segment_duration
*
self
.
sample_rate
)
*
2
self
.
prev_size
=
int
(
self
.
prev_duration
*
self
.
sample_rate
)
*
2
self
.
prev_chunk
=
None
self
.
buffer_size
=
buffer_size
self
.
audio_queue
=
queue
.
Queue
(
maxsize
=
self
.
buffer_size
)
self
.
audio_thread
=
None
self
.
ffmpeg_process
=
None
self
.
bytes_buffer
=
bytearray
()
self
.
target_rank
=
target_rank
%
self
.
world_size
self
.
flag_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
).
to
(
device
=
"cuda"
)
self
.
audio_tensor
=
torch
.
zeros
(
self
.
chunk_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
logger
.
info
(
f
"VAReader initialized for stream:
{
stream_url
}
target_rank:
{
self
.
target_rank
}
"
)
logger
.
info
(
f
"Audio duration per chunk:
{
segment_duration
}
s, sample rate:
{
sample_rate
}
Hz"
)
def
start
(
self
):
if
self
.
rank
==
self
.
target_rank
:
if
self
.
stream_url
.
startswith
(
"rtmp://"
):
self
.
start_ffmpeg_process_rtmp
()
elif
self
.
stream_url
.
startswith
(
"http"
):
self
.
start_ffmpeg_process_whep
()
else
:
raise
Exception
(
f
"Unsupported stream URL:
{
self
.
stream_url
}
"
)
self
.
audio_thread
=
threading
.
Thread
(
target
=
self
.
audio_worker
,
daemon
=
True
)
self
.
audio_thread
.
start
()
logger
.
info
(
f
"VAReader
{
self
.
rank
}
/
{
self
.
world_size
}
started successfully"
)
else
:
logger
.
info
(
f
"VAReader
{
self
.
rank
}
/
{
self
.
world_size
}
wait only"
)
if
self
.
world_size
>
1
:
logger
.
info
(
f
"VAReader
{
self
.
rank
}
/
{
self
.
world_size
}
wait barrier"
)
dist
.
barrier
()
logger
.
info
(
f
"VAReader
{
self
.
rank
}
/
{
self
.
world_size
}
end barrier"
)
def
start_ffmpeg_process_rtmp
(
self
):
"""Start ffmpeg process read audio from stream"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-i"
,
self
.
stream_url
,
"-vn"
,
# "-acodec",
# "pcm_s16le",
"-ar"
,
str
(
self
.
sample_rate
),
"-ac"
,
str
(
self
.
audio_channels
),
"-f"
,
"s16le"
,
"-"
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
,
bufsize
=
0
)
logger
.
info
(
f
"FFmpeg audio pull process started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg process:
{
e
}
"
)
raise
def
start_ffmpeg_process_whep
(
self
):
"""Start gstream process read audio from stream"""
ffmpeg_cmd
=
[
"gst-launch-1.0"
,
"-q"
,
"whepsrc"
,
f
"whep-endpoint=
{
self
.
stream_url
}
"
,
"video-caps=none"
,
"!rtpopusdepay"
,
"!opusdec"
,
"plc=false"
,
"!audioconvert"
,
"!audioresample"
,
f
"!audio/x-raw,format=S16LE,channels=
{
self
.
audio_channels
}
,rate=
{
self
.
sample_rate
}
"
,
"!fdsink"
,
"fd=1"
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
,
stdout
=
subprocess
.
PIPE
,
# stderr=subprocess.PIPE,
bufsize
=
0
,
)
logger
.
info
(
f
"FFmpeg audio pull process started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg process:
{
e
}
"
)
raise
def
audio_worker
(
self
):
logger
.
info
(
"Audio pull worker thread started"
)
try
:
while
True
:
if
not
self
.
ffmpeg_process
or
self
.
ffmpeg_process
.
poll
()
is
not
None
:
logger
.
warning
(
"FFmpeg process exited, audio worker thread stopped"
)
break
self
.
fetch_audio_data
()
time
.
sleep
(
0.01
)
except
:
# noqa
logger
.
error
(
f
"Audio pull worker error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
warning
(
"Audio pull worker thread stopped"
)
def
fetch_audio_data
(
self
):
"""Fetch audio data from ffmpeg process"""
try
:
audio_bytes
=
self
.
ffmpeg_process
.
stdout
.
read
(
self
.
chunk_size
)
if
not
audio_bytes
:
return
self
.
bytes_buffer
.
extend
(
audio_bytes
)
# logger.info(f"Fetch audio data: {len(audio_bytes)} bytes, bytes_buffer: {len(self.bytes_buffer)} bytes")
if
len
(
self
.
bytes_buffer
)
>=
self
.
chunk_size
:
audio_data
=
self
.
bytes_buffer
[:
self
.
chunk_size
]
self
.
bytes_buffer
=
self
.
bytes_buffer
[
self
.
chunk_size
:]
# first chunk, read original 81 frames
# for other chunks, read 81 - 5 = 76 frames, concat with previous 5 frames
if
self
.
prev_chunk
is
None
:
logger
.
info
(
f
"change chunk_size: from
{
self
.
chunk_size
}
to
{
self
.
chunk_size
-
self
.
prev_size
}
"
)
self
.
chunk_size
-=
self
.
prev_size
else
:
audio_data
=
self
.
prev_chunk
+
audio_data
self
.
prev_chunk
=
audio_data
[
-
self
.
prev_size
:]
try
:
self
.
audio_queue
.
put_nowait
(
audio_data
)
except
queue
.
Full
:
logger
.
warning
(
f
"Audio queue full:
{
self
.
audio_queue
.
qsize
()
}
, discarded oldest chunk"
)
self
.
audio_queue
.
get_nowait
()
self
.
audio_queue
.
put_nowait
(
audio_data
)
logger
.
info
(
f
"Put audio data:
{
len
(
audio_data
)
}
bytes, audio_queue:
{
self
.
audio_queue
.
qsize
()
}
, chunk_size:
{
self
.
chunk_size
}
"
)
except
:
# noqa
logger
.
error
(
f
"Fetch audio data error:
{
traceback
.
format_exc
()
}
"
)
def
braodcast_audio_data
(
self
,
audio_data
):
if
self
.
rank
==
self
.
target_rank
:
if
audio_data
is
None
:
self
.
flag_tensor
.
fill_
(
0
)
else
:
self
.
flag_tensor
.
fill_
(
1
)
self
.
audio_tensor
.
copy_
(
torch
.
frombuffer
(
bytearray
(
audio_data
),
dtype
=
torch
.
uint8
))
logger
.
info
(
f
"rank
{
self
.
rank
}
send audio_tensor:
{
self
.
audio_tensor
.
shape
}
"
)
dist
.
broadcast
(
self
.
flag_tensor
,
src
=
self
.
target_rank
)
if
self
.
flag_tensor
.
item
()
==
0
:
return
None
dist
.
broadcast
(
self
.
audio_tensor
,
src
=
self
.
target_rank
)
if
self
.
rank
!=
self
.
target_rank
:
logger
.
info
(
f
"rank
{
self
.
rank
}
recv audio_tensor:
{
self
.
audio_tensor
.
shape
}
"
)
audio_data
=
self
.
audio_tensor
.
cpu
().
numpy
().
tobytes
()
return
audio_data
def
bytes_to_ndarray
(
self
,
audio_data
):
if
audio_data
is
None
:
return
None
audio_data
=
np
.
frombuffer
(
audio_data
,
dtype
=
np
.
int16
)
audio_data
=
audio_data
.
astype
(
np
.
float32
)
/
32768.0
logger
.
info
(
f
"Got segment audio rank=
{
self
.
rank
}
:
{
audio_data
.
shape
}
{
audio_data
.
dtype
}
{
audio_data
.
min
()
}
{
audio_data
.
max
()
}
"
)
return
audio_data
def
get_audio_segment
(
self
,
timeout
:
float
=
1.0
):
audio_data
=
None
if
self
.
rank
==
self
.
target_rank
:
try
:
audio_data
=
self
.
audio_queue
.
get
(
timeout
=
timeout
)
except
:
# noqa
logger
.
warning
(
f
"Failed to get audio segment:
{
traceback
.
format_exc
()
}
"
)
if
self
.
world_size
>
1
:
audio_data
=
self
.
braodcast_audio_data
(
audio_data
)
audio_data
=
self
.
bytes_to_ndarray
(
audio_data
)
return
audio_data
def
stop
(
self
):
# Stop ffmpeg process
if
self
.
ffmpeg_process
:
self
.
ffmpeg_process
.
send_signal
(
signal
.
SIGINT
)
try
:
self
.
ffmpeg_process
.
wait
(
timeout
=
5
)
except
subprocess
.
TimeoutExpired
:
self
.
ffmpeg_process
.
kill
()
logger
.
warning
(
"FFmpeg reader process stopped"
)
# Wait for threads to finish
if
self
.
audio_thread
and
self
.
audio_thread
.
is_alive
():
self
.
audio_thread
.
join
(
timeout
=
5
)
if
self
.
audio_thread
.
is_alive
():
logger
.
error
(
"Audio pull thread did not stop gracefully"
)
while
self
.
audio_queue
and
self
.
audio_queue
.
qsize
()
>
0
:
self
.
audio_queue
.
get_nowait
()
self
.
audio_queue
=
None
logger
.
warning
(
"Audio pull queue cleaned"
)
def
__del__
(
self
):
self
.
stop
()
if
__name__
==
"__main__"
:
WORLD_SIZE
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
RANK
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
if
WORLD_SIZE
>
1
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
logger
.
info
(
f
"Distributed initialized: rank=
{
RANK
}
, world_size=
{
WORLD_SIZE
}
"
)
reader
=
VAReader
(
RANK
,
WORLD_SIZE
,
# "rtmp://localhost/live/test_audio",
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=live&stream=ll_test_audio&eip=10.120.114.76:8000"
,
segment_duration
=
1.0
,
sample_rate
=
16000
,
audio_channels
=
1
,
prev_duration
=
1
/
16
,
)
reader
.
start
()
fail_count
=
0
max_fail_count
=
2
try
:
while
True
:
audio_data
=
reader
.
get_audio_segment
(
timeout
=
2
)
if
audio_data
is
not
None
:
# logger.info(f"Got audio chunk, shape: {audio_data.shape}, range: [{audio_data.min()}, {audio_data.max()}]")
fail_count
=
0
else
:
fail_count
+=
1
if
fail_count
>
max_fail_count
:
logger
.
warning
(
"Failed to get audio chunk, stop reader"
)
reader
.
stop
()
break
time
.
sleep
(
0.95
)
finally
:
reader
.
stop
()
lightx2v/deploy/common/va_reader_omni.py
0 → 100644
View file @
a1ebc651
import
datetime
import
json
import
os
import
random
import
subprocess
import
threading
import
time
import
traceback
from
collections
import
deque
from
copy
import
deepcopy
import
jsonschema
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
import
zmq
from
loguru
import
logger
try
:
from
bson
import
BSON
except
ImportError
:
BSON
=
None
logger
.
warning
(
"BSON is not installed"
)
from
scipy.signal
import
resample
class
AudioInfo
:
def
__init__
(
self
,
info
:
dict
):
self
.
sample_count
=
info
[
"sample_count"
]
self
.
sample_rate
=
info
[
"sample_rate"
]
self
.
channel_count
=
info
[
"channel_count"
]
self
.
sample_fmt
=
info
[
"sample_fmt"
]
self
.
pts
=
info
[
"pts"
]
def
is_spec_equal
(
self
,
other
:
"AudioInfo"
)
->
bool
:
return
self
.
sample_fmt
==
other
.
sample_fmt
and
self
.
sample_rate
==
other
.
sample_rate
and
self
.
channel_count
==
other
.
channel_count
def
duration
(
self
)
->
datetime
.
timedelta
:
return
datetime
.
timedelta
(
seconds
=
self
.
sample_count
/
self
.
sample_rate
)
def
__str__
(
self
):
return
"AudioInfo(sample_count={}, sample_rate={}, channel_count={}, sample_fmt={}, pts={})"
.
format
(
self
.
sample_count
,
self
.
sample_rate
,
self
.
channel_count
,
self
.
sample_fmt
,
self
.
pts
)
class
ByteBuffer
:
def
__init__
(
self
):
self
.
buffer
=
deque
()
self
.
current_size
=
0
# is the audio belonging to current turn finished
self
.
audio_finished
=
False
def
add
(
self
,
byte_data
:
bytes
):
self
.
buffer
.
append
(
byte_data
)
self
.
current_size
+=
len
(
byte_data
)
def
get
(
self
,
size
=
1024
):
data
=
bytearray
()
while
size
>
0
and
len
(
self
.
buffer
)
>
0
:
chunk
=
self
.
buffer
.
popleft
()
if
len
(
chunk
)
<=
size
:
# 如果当前数据小于size,则将当前数据全部添加到data中
data
.
extend
(
chunk
)
self
.
current_size
-=
len
(
chunk
)
size
-=
len
(
chunk
)
else
:
# 如果当前数据大于size,则将当前数据的一部分添加到data中,剩余部分留在缓冲区
data
.
extend
(
chunk
[:
size
])
self
.
buffer
.
appendleft
(
chunk
[
size
:])
# 剩余部分留在缓冲区
self
.
current_size
-=
size
size
=
0
return
bytes
(
data
)
def
mark_finished
(
self
):
self
.
audio_finished
=
True
def
has_more_voice
(
self
):
return
not
self
.
audio_finished
def
__len__
(
self
):
return
self
.
current_size
class
ChatAdapter
:
def
__init__
(
self
,
omni_work_dir
:
str
,
whep_url
:
str
,
session_id
:
str
,
account
:
str
,
config_files
:
list
[
str
],
config_schema_path
:
str
,
seg_duration
:
float
,
model_runner
,
huoshan_tts_voice_type
,
):
assert
os
.
path
.
exists
(
omni_work_dir
),
f
"OMNI work directory
{
omni_work_dir
}
does not exist"
self
.
omni_work_dir
=
omni_work_dir
self
.
context
=
zmq
.
Context
()
self
.
w2f_socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
w2f_url
=
ChatAdapter
.
select_and_bind
(
self
.
w2f_socket
)
self
.
f2w_socket
=
self
.
context
.
socket
(
zmq
.
PUSH
)
self
.
f2w_url
=
ChatAdapter
.
select_and_bind
(
self
.
f2w_socket
)
self
.
recv_thread
=
None
self
.
audio_buffer
=
ByteBuffer
()
self
.
audio_info
=
None
self
.
chat_server_cmd
=
[
os
.
path
.
join
(
self
.
omni_work_dir
,
"bin"
,
"seko-chatter"
),
"--session-id"
,
session_id
,
"--account"
,
account
,
"--whep-server-url"
,
whep_url
,
"--w2f-endpoint"
,
self
.
w2f_url
,
"--f2w-endpoint"
,
self
.
f2w_url
,
"--config-files"
,
*
config_files
,
]
override_config
=
{}
if
huoshan_tts_voice_type
is
not
None
:
logger
.
info
(
f
"Use Huoshan TTS voice type:
{
huoshan_tts_voice_type
}
"
)
override_config
[
"TTS"
]
=
{
"default_voice_info"
:
{
"voice_type"
:
huoshan_tts_voice_type
,
"provider"
:
"huoshan_stream_tts"
,
}
}
with
open
(
config_schema_path
,
"r"
)
as
f
:
schema
=
json
.
load
(
f
)
jsonschema
.
validate
(
instance
=
override_config
,
schema
=
schema
)
if
override_config
is
not
None
:
self
.
chat_server_cmd
.
extend
([
"--override-config"
,
json
.
dumps
(
override_config
)])
self
.
chatter_proc
=
None
self
.
seg_duration
=
seg_duration
self
.
reset_prev
=
False
self
.
status
=
"blank"
self
.
immediate_switch
=
0
self
.
model_runner
=
model_runner
def
launch_chat_server
(
self
):
env
=
{
"RUST_LOG"
:
"info,duplex_server=debug,backend_5o=debug"
,
"LD_LIBRARY_PATH"
:
os
.
environ
.
get
(
"LD_LIBRARY_PATH"
,
""
)
+
":"
+
os
.
path
.
join
(
self
.
omni_work_dir
,
"lib/"
),
"PATH"
:
os
.
environ
[
"PATH"
]
+
":"
+
os
.
path
.
join
(
self
.
omni_work_dir
,
"bin/"
),
}
self
.
chatter_proc
=
subprocess
.
Popen
(
self
.
chat_server_cmd
,
env
=
env
,
cwd
=
self
.
omni_work_dir
)
@
staticmethod
def
select_and_bind
(
socket
:
zmq
.
Socket
)
->
str
:
# randomly select a port between 1024 and 6553
retry_count
=
20
err
=
None
while
retry_count
>
0
:
try
:
port
=
random
.
randint
(
1024
,
65535
)
# port = 5555
url
=
f
"tcp://localhost:
{
port
}
"
socket
.
bind
(
url
)
return
url
except
zmq
.
error
.
ZMQError
as
e
:
retry_count
-=
1
err
=
e
raise
err
# immediate switch to status, discard prev_bytes, set immediate_switch to 1
def
immediate_switch_to
(
self
,
status
):
logger
.
warning
(
f
"VA reader immediate switch to
{
status
}
"
)
self
.
reset_prev
=
True
self
.
status
=
status
self
.
immediate_switch
=
1
if
self
.
model_runner
is
not
None
:
self
.
model_runner
.
pause_signal
=
True
logger
.
warning
(
f
"Model runner pause signal set to True"
)
def
recv_loop
(
self
):
while
True
:
try
:
message
=
self
.
w2f_socket
.
recv
()
except
Exception
:
logger
.
error
(
f
"Error receiving message:
{
traceback
.
format_exc
()
}
"
)
break
try
:
message
=
BSON
.
decode
(
message
)
msg_type
=
message
[
"type"
]
logger
.
debug
(
"Received message type: {}"
.
format
(
msg_type
))
if
msg_type
==
"AgentAudio"
:
audio
=
message
[
"audio"
]
if
audio
[
"type"
]
!=
"Pcm"
:
logger
.
error
(
"Unsupported audio type: {}"
.
format
(
audio
[
"type"
]))
continue
pcm_data
=
audio
[
"data"
]
audio_info
=
AudioInfo
(
audio
[
"info"
])
logger
.
debug
(
"Received audio with duration: {}"
.
format
(
audio_info
.
duration
()))
if
self
.
audio_info
is
None
:
self
.
audio_info
=
audio_info
else
:
# check if the audio info is the same
if
not
self
.
audio_info
.
is_spec_equal
(
audio_info
):
raise
ValueError
(
"Audio info mismatch"
)
self
.
audio_buffer
.
add
(
pcm_data
)
# if status is blank and has voice, set immediate switch to 1
if
self
.
status
==
"blank"
and
self
.
has_voice
(
self
.
seg_duration
):
self
.
immediate_switch_to
(
"voice"
)
elif
msg_type
==
"AgentStartPlay"
:
logger
.
debug
(
"Received AgentStartPlay, create new audio buffer"
)
self
.
audio_buffer
=
ByteBuffer
()
elif
msg_type
==
"AgentEndPlay"
:
logger
.
debug
(
"Received AgentEndPlay, mark audio finished"
)
self
.
audio_buffer
.
mark_finished
()
elif
msg_type
==
"ClearAgentAudio"
:
logger
.
warning
(
"Received ClearAgentAudio, clear audio buffer"
)
self
.
audio_buffer
=
None
self
.
audio_info
=
None
if
self
.
status
==
"voice"
:
self
.
status
=
"blank"
# self.immediate_switch_to("blank")
except
Exception
as
e
:
logger
.
error
(
"Error decoding message: {}, continue"
.
format
(
e
))
continue
logger
.
warning
(
"recv loop interrupted"
)
def
start
(
self
):
self
.
launch_chat_server
()
self
.
recv_thread
=
threading
.
Thread
(
target
=
self
.
recv_loop
)
self
.
recv_thread
.
start
()
def
has_voice
(
self
,
duration
)
->
bool
:
if
self
.
audio_info
is
None
or
self
.
audio_buffer
.
current_size
==
0
:
return
False
bytes_count
=
round
(
duration
*
self
.
audio_info
.
sample_rate
)
*
self
.
audio_info
.
channel_count
*
2
# S16LE assumed
# if not has enough bytes and maybe has more voice, return False
if
self
.
audio_buffer
.
current_size
<
bytes_count
and
self
.
audio_buffer
.
has_more_voice
():
logger
.
warning
(
f
"Not enough bytes and maybe has more voice, content_size:
{
self
.
audio_buffer
.
current_size
}
, bytes_count:
{
bytes_count
}
"
)
return
False
return
bytes_count
def
get_audio
(
self
,
fetch_duration
)
->
(
bytes
,
AudioInfo
):
bytes_count
=
self
.
has_voice
(
fetch_duration
)
if
bytes_count
is
False
:
return
None
pcm_data
=
self
.
audio_buffer
.
get
(
bytes_count
)
# the actual sample count fetched
sample_count
=
len
(
pcm_data
)
//
(
self
.
audio_info
.
channel_count
*
2
)
logger
.
debug
(
"Fetched {} bytes audio"
.
format
(
sample_count
))
logger
.
debug
(
"After fetch, there are {} bytes left"
.
format
(
self
.
audio_buffer
.
current_size
))
audio_info
=
deepcopy
(
self
.
audio_info
)
audio_info
.
sample_count
=
sample_count
return
(
pcm_data
,
audio_info
)
def
stop
(
self
):
self
.
model_runner
=
None
if
self
.
chatter_proc
is
not
None
:
self
.
chatter_proc
.
terminate
()
self
.
chatter_proc
.
wait
()
self
.
chatter_proc
=
None
self
.
w2f_socket
.
close
()
self
.
f2w_socket
.
close
()
def
__del__
(
self
):
self
.
stop
()
class
OmniVAReader
:
def
__init__
(
self
,
rank
:
int
,
world_size
:
int
,
stream_url
:
str
,
segment_duration
:
float
=
5.0625
,
sample_rate
:
int
=
16000
,
audio_channels
:
int
=
1
,
buffer_size
:
int
=
1
,
prev_duration
:
float
=
0.3125
,
target_rank
:
int
=
0
,
model_runner
=
None
,
huoshan_tts_voice_type
=
None
,
):
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
stream_url
=
stream_url
self
.
segment_duration
=
segment_duration
self
.
sample_rate
=
sample_rate
self
.
audio_channels
=
audio_channels
self
.
prev_duration
=
prev_duration
self
.
all_seg_sample_count
=
int
(
self
.
segment_duration
*
self
.
sample_rate
)
self
.
prev_seg_sample_count
=
int
(
self
.
prev_duration
*
self
.
sample_rate
)
self
.
prev_seg_chunk
=
None
self
.
target_rank
=
target_rank
%
self
.
world_size
self
.
flag_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
).
to
(
device
=
"cuda"
)
self
.
immediate_switch_tensor
=
torch
.
tensor
([
0
],
dtype
=
torch
.
int32
).
to
(
device
=
"cuda"
)
chunk_size
=
int
(
self
.
segment_duration
*
self
.
sample_rate
)
*
2
self
.
audio_tensor
=
torch
.
zeros
(
chunk_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
chat_adapter
=
None
self
.
model_runner
=
model_runner
self
.
huoshan_tts_voice_type
=
huoshan_tts_voice_type
assert
self
.
audio_channels
==
1
,
"Only mono audio is supported for OmniVAReader"
logger
.
info
(
f
"VAReader initialized for stream:
{
stream_url
}
target_rank:
{
self
.
target_rank
}
"
)
logger
.
info
(
f
"Audio duration per chunk:
{
segment_duration
}
s, sample rate:
{
sample_rate
}
Hz"
)
def
init_omni_env
(
self
):
self
.
omni_work_dir
=
os
.
getenv
(
"OMNI_WORK_DIR"
,
"/path/of/seko_chatter/"
)
self
.
session_id
=
os
.
getenv
(
"OMNI_SESSION_ID"
,
""
)
self
.
account
=
os
.
getenv
(
"OMNI_ACCOUNT"
,
""
)
self
.
config_files
=
os
.
getenv
(
"OMNI_CONFIG_FILES"
,
""
).
split
(
","
)
self
.
config_schema_path
=
os
.
getenv
(
"OMNI_CONFIG_SCHEMA_PATH"
,
None
)
assert
os
.
path
.
exists
(
self
.
omni_work_dir
),
f
"OMNI work directory
{
self
.
omni_work_dir
}
does not exist"
assert
self
.
session_id
and
self
.
account
,
"OMNI_SESSION_ID and OMNI_ACCOUNT are required"
logger
.
info
(
f
"OMNI work directory:
{
self
.
omni_work_dir
}
, session_id:
{
self
.
session_id
}
, account:
{
self
.
account
}
, config_files:
{
self
.
config_files
}
, config_schema_path:
{
self
.
config_schema_path
}
"
)
def
start
(
self
):
if
self
.
rank
==
self
.
target_rank
:
self
.
init_omni_env
()
assert
self
.
stream_url
.
startswith
(
"http"
),
"Only HTTP stream is supported for OmniVAReader"
self
.
chat_adapter
=
ChatAdapter
(
omni_work_dir
=
self
.
omni_work_dir
,
whep_url
=
self
.
stream_url
,
session_id
=
self
.
session_id
,
account
=
self
.
account
,
config_files
=
self
.
config_files
,
config_schema_path
=
self
.
config_schema_path
,
seg_duration
=
self
.
segment_duration
,
model_runner
=
self
.
model_runner
,
huoshan_tts_voice_type
=
self
.
huoshan_tts_voice_type
,
)
self
.
chat_adapter
.
start
()
logger
.
info
(
f
"OmniVAReader
{
self
.
rank
}
/
{
self
.
world_size
}
started successfully"
)
else
:
logger
.
info
(
f
"OmniVAReader
{
self
.
rank
}
/
{
self
.
world_size
}
wait only"
)
if
self
.
world_size
>
1
:
logger
.
info
(
f
"OmniVAReader
{
self
.
rank
}
/
{
self
.
world_size
}
wait barrier"
)
dist
.
barrier
()
logger
.
info
(
f
"OmniVAReader
{
self
.
rank
}
/
{
self
.
world_size
}
end barrier"
)
def
braodcast_audio_data
(
self
,
audio_data
):
if
self
.
rank
==
self
.
target_rank
:
if
audio_data
is
None
:
self
.
flag_tensor
.
fill_
(
0
)
else
:
self
.
flag_tensor
.
fill_
(
1
)
self
.
audio_tensor
.
copy_
(
torch
.
frombuffer
(
bytearray
(
audio_data
),
dtype
=
torch
.
uint8
))
# logger.info(f"rank {self.rank} send audio_tensor: {self.audio_tensor.shape}")
dist
.
broadcast
(
self
.
flag_tensor
,
src
=
self
.
target_rank
)
if
self
.
flag_tensor
.
item
()
==
0
:
return
None
dist
.
broadcast
(
self
.
audio_tensor
,
src
=
self
.
target_rank
)
if
self
.
rank
!=
self
.
target_rank
:
# logger.info(f"rank {self.rank} recv audio_tensor: {self.audio_tensor.shape}")
audio_data
=
self
.
audio_tensor
.
cpu
().
numpy
().
tobytes
()
return
audio_data
def
bytes_to_ndarray
(
self
,
audio_data
):
if
audio_data
is
None
:
return
None
audio_data
=
np
.
frombuffer
(
audio_data
,
dtype
=
np
.
int16
)
audio_data
=
audio_data
.
astype
(
np
.
float32
)
/
32768.0
# logger.info(f"Got segment audio rank={self.rank}: {audio_data.shape} {audio_data.dtype} {audio_data.min()} {audio_data.max()}")
return
audio_data
def
convert_pcm_s16le_to_mono_resampled
(
self
,
audio_data
,
audio_info
):
audio
=
np
.
frombuffer
(
audio_data
,
dtype
=
np
.
int16
)
sample_count
=
audio_info
.
sample_count
assert
len
(
audio
)
==
sample_count
*
audio_info
.
channel_count
,
f
"audio length
{
len
(
audio
)
}
!= sample_count * channel_count
{
sample_count
*
audio_info
.
channel_count
}
"
# convert to mono
if
audio_info
.
channel_count
>
1
:
audio
=
audio
.
reshape
(
-
1
,
audio_info
.
channel_count
).
mean
(
axis
=
1
)
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()}")
if
audio_info
.
sample_rate
!=
self
.
sample_rate
:
sample_count
=
int
(
len
(
audio
)
*
self
.
sample_rate
/
audio_info
.
sample_rate
)
audio
=
resample
(
audio
,
sample_count
).
astype
(
np
.
int16
)
# logger.info(f"resampled audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}")
logger
.
warning
(
f
"valid audio:
{
audio
.
shape
}
{
audio
.
dtype
}
{
audio
.
min
()
}
{
audio
.
max
()
}
{
sample_count
}
"
)
return
audio
,
sample_count
def
prepare_audio_data
(
self
,
chat_audio_result
):
sample_count
=
0
audio
=
np
.
array
([],
dtype
=
np
.
int16
)
# convert chat audio result to mono and target sample rate
if
chat_audio_result
is
not
None
:
audio_data
,
audio_info
=
chat_audio_result
audio
,
sample_count
=
self
.
convert_pcm_s16le_to_mono_resampled
(
audio_data
,
audio_info
)
# if is not the first segment, concat with previous segment
if
self
.
prev_seg_chunk
is
not
None
:
audio
=
np
.
concatenate
([
self
.
prev_seg_chunk
,
audio
])
sample_count
=
len
(
audio
)
assert
sample_count
<=
self
.
all_seg_sample_count
,
f
"audio length
{
sample_count
}
> all_seg_sample_count
{
self
.
all_seg_sample_count
}
"
# pad 0 to the audio to make it the same length as all_seg_sample_count
if
sample_count
<
self
.
all_seg_sample_count
:
pad_count
=
self
.
all_seg_sample_count
-
sample_count
# logger.info(f"pad {pad_count} samples to audio")
audio
=
np
.
pad
(
audio
,
(
0
,
pad_count
),
mode
=
"constant"
,
constant_values
=
0
)
sample_count
=
len
(
audio
)
# update prev seg chunk
self
.
prev_seg_chunk
=
audio
[
-
self
.
prev_seg_sample_count
:]
# logger.info(f"audio: {audio.shape} {audio.dtype} {audio.min()} {audio.max()} {sample_count}, prev seg chunk: {self.prev_seg_chunk.shape}")
return
audio
.
tobytes
()
def
get_fetch_duration
(
self
):
fetch_duration
=
self
.
segment_duration
# after immediate switch, reset prev seg chunk
if
self
.
chat_adapter
.
reset_prev
:
self
.
prev_seg_chunk
=
None
self
.
chat_adapter
.
reset_prev
=
False
logger
.
warning
(
f
"Reset prev seg chunk"
)
# first segment, fetch segment_duration, else fetch segment_duration - prev_duration
if
self
.
prev_seg_chunk
is
not
None
:
fetch_duration
-=
self
.
prev_duration
return
fetch_duration
def
get_audio_segment
(
self
):
audio_data
=
None
if
self
.
rank
==
self
.
target_rank
:
try
:
fetch_duration
=
self
.
get_fetch_duration
()
# logger.info(f"Get segment, fetch_duration: {fetch_duration}")
if
self
.
chat_adapter
.
status
==
"voice"
:
audio_result
=
self
.
chat_adapter
.
get_audio
(
fetch_duration
)
audio_data
=
self
.
prepare_audio_data
(
audio_result
)
# think all voice segments inferred, naturally switch to blank
if
audio_result
is
None
:
logger
.
info
(
f
"Think all voice segments inferred, naturally switch to blank"
)
self
.
chat_adapter
.
status
=
"blank"
else
:
audio_data
=
self
.
prepare_audio_data
(
None
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to get voice segment:
{
e
}
"
)
return
None
if
self
.
world_size
>
1
:
audio_data
=
self
.
braodcast_audio_data
(
audio_data
)
audio_data
=
self
.
bytes_to_ndarray
(
audio_data
)
return
audio_data
def
get_immediate_switch
(
self
):
if
self
.
rank
==
self
.
target_rank
:
if
self
.
chat_adapter
.
immediate_switch
==
1
:
self
.
immediate_switch_tensor
.
fill_
(
1
)
# reset immediate switch
self
.
chat_adapter
.
immediate_switch
=
0
else
:
self
.
immediate_switch_tensor
.
fill_
(
0
)
dist
.
broadcast
(
self
.
immediate_switch_tensor
,
src
=
self
.
target_rank
)
immediate_switch
=
self
.
immediate_switch_tensor
.
item
()
return
immediate_switch
def
stop
(
self
):
self
.
model_runner
=
None
if
self
.
chat_adapter
is
not
None
:
self
.
chat_adapter
.
stop
()
self
.
chat_adapter
=
None
logger
.
warning
(
"OmniVAReader stopped"
)
def
__del__
(
self
):
self
.
stop
()
if
__name__
==
"__main__"
:
WORLD_SIZE
=
int
(
os
.
environ
.
get
(
"WORLD_SIZE"
,
1
))
RANK
=
int
(
os
.
environ
.
get
(
"RANK"
,
0
))
if
WORLD_SIZE
>
1
:
dist
.
init_process_group
(
backend
=
"nccl"
)
torch
.
cuda
.
set_device
(
dist
.
get_rank
())
logger
.
info
(
f
"Distributed initialized: rank=
{
RANK
}
, world_size=
{
WORLD_SIZE
}
"
)
reader
=
OmniVAReader
(
RANK
,
WORLD_SIZE
,
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whep/?app=publish&stream=test_stream_ll&eip=10.120.114.82:8000"
,
segment_duration
=
17
/
16
,
sample_rate
=
16000
,
audio_channels
=
1
,
prev_duration
=
1
/
16
,
)
reader
.
start
()
fail_count
=
0
max_fail_count
=
100000000
try
:
while
True
:
audio_data
=
reader
.
get_audio_segment
(
timeout
=
1
)
if
audio_data
is
not
None
:
logger
.
info
(
f
"Got audio chunk, shape:
{
audio_data
.
shape
}
, range: [
{
audio_data
.
min
()
}
,
{
audio_data
.
max
()
}
]"
)
fail_count
=
0
else
:
fail_count
+=
1
if
fail_count
>
max_fail_count
:
logger
.
warning
(
"Failed to get audio chunk, stop reader"
)
reader
.
stop
()
break
time
.
sleep
(
0.95
)
finally
:
reader
.
stop
()
lightx2v/deploy/common/va_recorder.py
0 → 100644
View file @
a1ebc651
import
os
import
queue
import
socket
import
subprocess
import
threading
import
time
import
traceback
import
numpy
as
np
import
torch
import
torchaudio
as
ta
from
loguru
import
logger
def
pseudo_random
(
a
,
b
):
x
=
str
(
time
.
time
()).
split
(
"."
)[
1
]
y
=
int
(
float
(
"0."
+
x
)
*
1000000
)
return
a
+
(
y
%
(
b
-
a
+
1
))
class
VARecorder
:
def
__init__
(
self
,
livestream_url
:
str
,
fps
:
float
=
16.0
,
sample_rate
:
int
=
16000
,
slice_frame
:
int
=
1
,
prev_frame
:
int
=
1
,
):
self
.
livestream_url
=
livestream_url
self
.
fps
=
fps
self
.
sample_rate
=
sample_rate
self
.
audio_port
=
pseudo_random
(
32000
,
40000
)
self
.
video_port
=
self
.
audio_port
+
1
self
.
ffmpeg_log_level
=
os
.
getenv
(
"FFMPEG_LOG_LEVEL"
,
"error"
)
logger
.
info
(
f
"VARecorder audio port:
{
self
.
audio_port
}
, video port:
{
self
.
video_port
}
, ffmpeg_log_level:
{
self
.
ffmpeg_log_level
}
"
)
self
.
width
=
None
self
.
height
=
None
self
.
stoppable_t
=
None
self
.
realtime
=
False
if
self
.
livestream_url
.
startswith
(
"rtmp://"
)
or
self
.
livestream_url
.
startswith
(
"http"
):
self
.
realtime
=
True
# ffmpeg process for mix video and audio data and push to livestream
self
.
ffmpeg_process
=
None
# TCP connection objects
self
.
audio_socket
=
None
self
.
video_socket
=
None
self
.
audio_conn
=
None
self
.
video_conn
=
None
self
.
audio_thread
=
None
self
.
video_thread
=
None
# queue for send data to ffmpeg process
self
.
audio_queue
=
queue
.
Queue
()
self
.
video_queue
=
queue
.
Queue
()
# buffer for stream data
self
.
audio_samples_per_frame
=
round
(
self
.
sample_rate
/
self
.
fps
)
self
.
stream_buffer
=
[]
self
.
stream_buffer_lock
=
threading
.
Lock
()
self
.
stop_schedule
=
False
self
.
schedule_thread
=
None
self
.
slice_frame
=
slice_frame
self
.
prev_frame
=
prev_frame
assert
self
.
slice_frame
>=
self
.
prev_frame
,
"Slice frame must be greater than previous frame"
def
init_sockets
(
self
):
# TCP socket for send and recv video and audio data
self
.
video_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
video_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
self
.
video_socket
.
setsockopt
(
socket
.
IPPROTO_TCP
,
socket
.
TCP_NODELAY
,
1
)
self
.
video_socket
.
bind
((
"127.0.0.1"
,
self
.
video_port
))
self
.
video_socket
.
listen
(
1
)
self
.
audio_socket
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
audio_socket
.
setsockopt
(
socket
.
SOL_SOCKET
,
socket
.
SO_REUSEADDR
,
1
)
self
.
audio_socket
.
setsockopt
(
socket
.
IPPROTO_TCP
,
socket
.
TCP_NODELAY
,
1
)
self
.
audio_socket
.
bind
((
"127.0.0.1"
,
self
.
audio_port
))
self
.
audio_socket
.
listen
(
1
)
def
audio_worker
(
self
):
try
:
logger
.
info
(
"Waiting for ffmpeg to connect to audio socket..."
)
self
.
audio_conn
,
_
=
self
.
audio_socket
.
accept
()
logger
.
info
(
f
"Audio connection established from
{
self
.
audio_conn
.
getpeername
()
}
"
)
fail_time
,
max_fail_time
=
0
,
10
while
True
:
try
:
if
self
.
audio_queue
is
None
:
break
data
=
self
.
audio_queue
.
get
()
if
data
is
None
:
logger
.
info
(
"Audio thread received stop signal"
)
break
# Convert audio data to 16-bit integer format
audios
=
torch
.
clamp
(
torch
.
round
(
data
*
32767
),
-
32768
,
32767
).
to
(
torch
.
int16
)
try
:
self
.
audio_conn
.
send
(
audios
[
None
].
cpu
().
numpy
().
tobytes
())
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
)
as
e
:
logger
.
info
(
f
"Audio connection closed, stopping worker:
{
type
(
e
).
__name__
}
"
)
return
fail_time
=
0
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
):
logger
.
info
(
"Audio connection closed during queue processing"
)
break
except
Exception
:
logger
.
error
(
f
"Send audio data error:
{
traceback
.
format_exc
()
}
"
)
fail_time
+=
1
if
fail_time
>
max_fail_time
:
logger
.
error
(
f
"Audio push worker thread failed
{
fail_time
}
times, stopping..."
)
break
except
Exception
:
logger
.
error
(
f
"Audio push worker thread error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
info
(
"Audio push worker thread stopped"
)
def
video_worker
(
self
):
try
:
logger
.
info
(
"Waiting for ffmpeg to connect to video socket..."
)
self
.
video_conn
,
_
=
self
.
video_socket
.
accept
()
logger
.
info
(
f
"Video connection established from
{
self
.
video_conn
.
getpeername
()
}
"
)
fail_time
,
max_fail_time
=
0
,
10
packet_secs
=
1.0
/
self
.
fps
while
True
:
try
:
if
self
.
video_queue
is
None
:
break
data
=
self
.
video_queue
.
get
()
if
data
is
None
:
logger
.
info
(
"Video thread received stop signal"
)
break
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
for
i
in
range
(
data
.
shape
[
0
]):
t0
=
time
.
time
()
frame
=
(
data
[
i
]
*
255
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
try
:
self
.
video_conn
.
send
(
frame
.
tobytes
())
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
)
as
e
:
logger
.
info
(
f
"Video connection closed, stopping worker:
{
type
(
e
).
__name__
}
"
)
return
if
self
.
realtime
and
i
<
data
.
shape
[
0
]
-
1
:
time
.
sleep
(
max
(
0
,
packet_secs
-
(
time
.
time
()
-
t0
)))
fail_time
=
0
except
(
BrokenPipeError
,
OSError
,
ConnectionResetError
):
logger
.
info
(
"Video connection closed during queue processing"
)
break
except
Exception
:
logger
.
error
(
f
"Send video data error:
{
traceback
.
format_exc
()
}
"
)
fail_time
+=
1
if
fail_time
>
max_fail_time
:
logger
.
error
(
f
"Video push worker thread failed
{
fail_time
}
times, stopping..."
)
break
except
Exception
:
logger
.
error
(
f
"Video push worker thread error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
info
(
"Video push worker thread stopped"
)
def
start_ffmpeg_process_local
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-fflags"
,
"nobuffer"
,
"-analyzeduration"
,
"0"
,
"-probesize"
,
"32"
,
"-flush_packets"
,
"1"
,
"-f"
,
"s16le"
,
"-ar"
,
str
(
self
.
sample_rate
),
"-ac"
,
"1"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
audio_port
}
"
,
"-f"
,
"rawvideo"
,
"-pix_fmt"
,
"rgb24"
,
"-color_range"
,
"pc"
,
"-colorspace"
,
"rgb"
,
"-color_primaries"
,
"bt709"
,
"-color_trc"
,
"iec61966-2-1"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-ar"
,
"44100"
,
"-b:v"
,
"4M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-f"
,
"mp4"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start_ffmpeg_process_rtmp
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-re"
,
"-f"
,
"s16le"
,
"-ar"
,
str
(
self
.
sample_rate
),
"-ac"
,
"1"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
audio_port
}
"
,
"-f"
,
"rawvideo"
,
"-re"
,
"-pix_fmt"
,
"rgb24"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-ar"
,
"44100"
,
"-b:v"
,
"2M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-f"
,
"flv"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start_ffmpeg_process_whip
(
self
):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd
=
[
"ffmpeg"
,
"-re"
,
"-fflags"
,
"nobuffer"
,
"-analyzeduration"
,
"0"
,
"-probesize"
,
"32"
,
"-flush_packets"
,
"1"
,
"-f"
,
"s16le"
,
"-ar"
,
str
(
self
.
sample_rate
),
"-ac"
,
"1"
,
"-ch_layout"
,
"mono"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
audio_port
}
"
,
"-f"
,
"rawvideo"
,
"-re"
,
"-pix_fmt"
,
"rgb24"
,
"-r"
,
str
(
self
.
fps
),
"-s"
,
f
"
{
self
.
width
}
x
{
self
.
height
}
"
,
"-i"
,
f
"tcp://127.0.0.1:
{
self
.
video_port
}
"
,
"-ar"
,
"48000"
,
"-c:a"
,
"libopus"
,
"-ac"
,
"2"
,
"-b:v"
,
"2M"
,
"-c:v"
,
"libx264"
,
"-preset"
,
"ultrafast"
,
"-tune"
,
"zerolatency"
,
"-g"
,
f
"
{
self
.
fps
}
"
,
"-pix_fmt"
,
"yuv420p"
,
"-threads"
,
"1"
,
"-bf"
,
"0"
,
"-f"
,
"whip"
,
self
.
livestream_url
,
"-y"
,
"-loglevel"
,
self
.
ffmpeg_log_level
,
]
try
:
self
.
ffmpeg_process
=
subprocess
.
Popen
(
ffmpeg_cmd
)
logger
.
info
(
f
"FFmpeg streaming started with PID:
{
self
.
ffmpeg_process
.
pid
}
"
)
logger
.
info
(
f
"FFmpeg command:
{
' '
.
join
(
ffmpeg_cmd
)
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to start FFmpeg:
{
e
}
"
)
def
start
(
self
,
width
:
int
,
height
:
int
):
self
.
set_video_size
(
width
,
height
)
duration
=
1.0
frames
=
int
(
self
.
fps
*
duration
)
samples
=
int
(
self
.
sample_rate
*
(
frames
/
self
.
fps
))
self
.
pub_livestream
(
torch
.
zeros
((
frames
,
height
,
width
,
3
),
dtype
=
torch
.
float16
),
torch
.
zeros
(
samples
,
dtype
=
torch
.
float16
))
time
.
sleep
(
duration
)
def
set_video_size
(
self
,
width
:
int
,
height
:
int
):
if
self
.
width
is
not
None
and
self
.
height
is
not
None
:
assert
self
.
width
==
width
and
self
.
height
==
height
,
"Video size already set"
return
self
.
width
=
width
self
.
height
=
height
self
.
init_sockets
()
if
self
.
livestream_url
.
startswith
(
"rtmp://"
):
self
.
start_ffmpeg_process_rtmp
()
elif
self
.
livestream_url
.
startswith
(
"http"
):
self
.
start_ffmpeg_process_whip
()
else
:
self
.
start_ffmpeg_process_local
()
self
.
audio_thread
=
threading
.
Thread
(
target
=
self
.
audio_worker
)
self
.
video_thread
=
threading
.
Thread
(
target
=
self
.
video_worker
)
self
.
audio_thread
.
start
()
self
.
video_thread
.
start
()
if
self
.
realtime
:
self
.
schedule_thread
=
threading
.
Thread
(
target
=
self
.
schedule_stream_buffer
)
self
.
schedule_thread
.
start
()
# Publish ComfyUI Image tensor and audio tensor to livestream
def
pub_livestream
(
self
,
images
:
torch
.
Tensor
,
audios
:
torch
.
Tensor
):
N
,
height
,
width
,
C
=
images
.
shape
M
=
audios
.
reshape
(
-
1
).
shape
[
0
]
assert
C
==
3
,
"Input must be [N, H, W, C] with C=3"
logger
.
info
(
f
"Publishing video [
{
N
}
x
{
width
}
x
{
height
}
], audio: [
{
M
}
]"
)
audio_frames
=
round
(
M
*
self
.
fps
/
self
.
sample_rate
)
if
audio_frames
!=
N
:
logger
.
warning
(
f
"Video and audio frames mismatch,
{
N
}
vs
{
audio_frames
}
"
)
self
.
set_video_size
(
width
,
height
)
self
.
audio_queue
.
put
(
audios
)
self
.
video_queue
.
put
(
images
)
logger
.
info
(
f
"Published
{
N
}
frames and
{
M
}
audio samples"
)
self
.
stoppable_t
=
time
.
time
()
+
M
/
self
.
sample_rate
+
3
def
buffer_stream
(
self
,
images
:
torch
.
Tensor
,
audios
:
torch
.
Tensor
,
gen_video
:
torch
.
Tensor
):
N
,
height
,
width
,
C
=
images
.
shape
M
=
audios
.
reshape
(
-
1
).
shape
[
0
]
assert
N
%
self
.
slice_frame
==
0
,
"Video frames must be divisible by slice_frame"
assert
C
==
3
,
"Input must be [N, H, W, C] with C=3"
audio_frames
=
round
(
M
*
self
.
fps
/
self
.
sample_rate
)
if
audio_frames
!=
N
:
logger
.
warning
(
f
"Video and audio frames mismatch,
{
N
}
vs
{
audio_frames
}
"
)
self
.
set_video_size
(
width
,
height
)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets
=
[]
for
i
in
range
(
0
,
N
,
self
.
slice_frame
):
end_frame
=
i
+
self
.
slice_frame
img
=
images
[
i
:
end_frame
]
aud
=
audios
[
i
*
self
.
audio_samples_per_frame
:
end_frame
*
self
.
audio_samples_per_frame
]
gen
=
gen_video
[:,
:,
(
end_frame
-
self
.
prev_frame
)
:
end_frame
]
rets
.
append
((
img
,
aud
,
gen
))
with
self
.
stream_buffer_lock
:
origin_size
=
len
(
self
.
stream_buffer
)
self
.
stream_buffer
.
extend
(
rets
)
logger
.
info
(
f
"Buffered
{
origin_size
}
+
{
len
(
rets
)
}
=
{
len
(
self
.
stream_buffer
)
}
stream segments"
)
def
get_buffer_stream_size
(
self
):
return
len
(
self
.
stream_buffer
)
def
truncate_stream_buffer
(
self
,
size
:
int
):
with
self
.
stream_buffer_lock
:
self
.
stream_buffer
=
self
.
stream_buffer
[:
size
]
logger
.
info
(
f
"Truncated stream buffer to
{
len
(
self
.
stream_buffer
)
}
segments"
)
if
len
(
self
.
stream_buffer
)
>
0
:
return
self
.
stream_buffer
[
-
1
][
2
]
# return the last video tensor
else
:
return
None
def
schedule_stream_buffer
(
self
):
schedule_interval
=
self
.
slice_frame
/
self
.
fps
logger
.
info
(
f
"Schedule stream buffer with interval:
{
schedule_interval
}
seconds"
)
t
=
None
while
True
:
try
:
if
self
.
stop_schedule
:
break
img
,
aud
,
gen
=
None
,
None
,
None
with
self
.
stream_buffer_lock
:
if
len
(
self
.
stream_buffer
)
>
0
:
img
,
aud
,
gen
=
self
.
stream_buffer
.
pop
(
0
)
if
t
is
not
None
:
wait_secs
=
schedule_interval
-
(
time
.
time
()
-
t
)
if
wait_secs
>
0
:
time
.
sleep
(
wait_secs
)
t
=
time
.
time
()
if
img
is
not
None
and
aud
is
not
None
:
self
.
audio_queue
.
put
(
aud
)
self
.
video_queue
.
put
(
img
)
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del
gen
self
.
stoppable_t
=
time
.
time
()
+
aud
.
shape
[
0
]
/
self
.
sample_rate
+
3
else
:
logger
.
warning
(
f
"No stream buffer to schedule"
)
except
Exception
:
logger
.
error
(
f
"Schedule stream buffer error:
{
traceback
.
format_exc
()
}
"
)
break
logger
.
info
(
"Schedule stream buffer thread stopped"
)
def
stop
(
self
,
wait
=
True
):
if
wait
and
self
.
stoppable_t
:
t
=
self
.
stoppable_t
-
time
.
time
()
if
t
>
0
:
logger
.
warning
(
f
"Waiting for
{
t
}
seconds to stop ..."
)
time
.
sleep
(
t
)
self
.
stoppable_t
=
None
if
self
.
schedule_thread
:
self
.
stop_schedule
=
True
self
.
schedule_thread
.
join
(
timeout
=
5
)
if
self
.
schedule_thread
and
self
.
schedule_thread
.
is_alive
():
logger
.
error
(
f
"Schedule thread did not stop after 5s"
)
# Send stop signals to queues
if
self
.
audio_queue
:
self
.
audio_queue
.
put
(
None
)
if
self
.
video_queue
:
self
.
video_queue
.
put
(
None
)
# Wait for threads to finish processing queued data (increased timeout)
queue_timeout
=
30
# Increased from 5s to 30s to allow sufficient time for large video frames
if
self
.
audio_thread
and
self
.
audio_thread
.
is_alive
():
self
.
audio_thread
.
join
(
timeout
=
queue_timeout
)
if
self
.
audio_thread
.
is_alive
():
logger
.
error
(
f
"Audio push thread did not stop after
{
queue_timeout
}
s"
)
if
self
.
video_thread
and
self
.
video_thread
.
is_alive
():
self
.
video_thread
.
join
(
timeout
=
queue_timeout
)
if
self
.
video_thread
.
is_alive
():
logger
.
error
(
f
"Video push thread did not stop after
{
queue_timeout
}
s"
)
# Shutdown connections to signal EOF to FFmpeg
# shutdown(SHUT_WR) will wait for send buffer to flush, no explicit sleep needed
if
self
.
audio_conn
:
try
:
self
.
audio_conn
.
getpeername
()
self
.
audio_conn
.
shutdown
(
socket
.
SHUT_WR
)
logger
.
info
(
"Audio connection shutdown initiated"
)
except
OSError
:
# Connection already closed, skip shutdown
pass
if
self
.
video_conn
:
try
:
self
.
video_conn
.
getpeername
()
self
.
video_conn
.
shutdown
(
socket
.
SHUT_WR
)
logger
.
info
(
"Video connection shutdown initiated"
)
except
OSError
:
# Connection already closed, skip shutdown
pass
if
self
.
ffmpeg_process
:
is_local_file
=
not
self
.
livestream_url
.
startswith
((
"rtmp://"
,
"http"
))
# Local MP4 files need time to write moov atom and finalize the container
timeout_seconds
=
30
if
is_local_file
else
10
logger
.
info
(
f
"Waiting for FFmpeg to finalize file (timeout=
{
timeout_seconds
}
s, local_file=
{
is_local_file
}
)"
)
logger
.
info
(
f
"FFmpeg output:
{
self
.
livestream_url
}
"
)
try
:
returncode
=
self
.
ffmpeg_process
.
wait
(
timeout
=
timeout_seconds
)
if
returncode
==
0
:
logger
.
info
(
f
"FFmpeg process exited successfully (exit code:
{
returncode
}
)"
)
else
:
logger
.
warning
(
f
"FFmpeg process exited with non-zero code:
{
returncode
}
"
)
except
subprocess
.
TimeoutExpired
:
logger
.
warning
(
f
"FFmpeg process did not exit within
{
timeout_seconds
}
s, sending SIGTERM..."
)
try
:
self
.
ffmpeg_process
.
terminate
()
# SIGTERM
returncode
=
self
.
ffmpeg_process
.
wait
(
timeout
=
5
)
logger
.
warning
(
f
"FFmpeg process terminated with SIGTERM (exit code:
{
returncode
}
)"
)
except
subprocess
.
TimeoutExpired
:
logger
.
error
(
"FFmpeg process still running after SIGTERM, killing with SIGKILL..."
)
self
.
ffmpeg_process
.
kill
()
self
.
ffmpeg_process
.
wait
()
# Wait for kill to complete
logger
.
error
(
"FFmpeg process killed with SIGKILL"
)
finally
:
self
.
ffmpeg_process
=
None
if
self
.
audio_conn
:
try
:
self
.
audio_conn
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing audio connection:
{
e
}
"
)
finally
:
self
.
audio_conn
=
None
if
self
.
video_conn
:
try
:
self
.
video_conn
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing video connection:
{
e
}
"
)
finally
:
self
.
video_conn
=
None
if
self
.
audio_socket
:
try
:
self
.
audio_socket
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing audio socket:
{
e
}
"
)
finally
:
self
.
audio_socket
=
None
if
self
.
video_socket
:
try
:
self
.
video_socket
.
close
()
except
Exception
as
e
:
logger
.
debug
(
f
"Error closing video socket:
{
e
}
"
)
finally
:
self
.
video_socket
=
None
if
self
.
audio_queue
:
while
self
.
audio_queue
.
qsize
()
>
0
:
try
:
self
.
audio_queue
.
get_nowait
()
except
:
# noqa
break
if
self
.
video_queue
:
while
self
.
video_queue
.
qsize
()
>
0
:
try
:
self
.
video_queue
.
get_nowait
()
except
:
# noqa
break
self
.
audio_queue
=
None
self
.
video_queue
=
None
logger
.
info
(
"VARecorder stopped and resources cleaned up"
)
def
__del__
(
self
):
self
.
stop
(
wait
=
False
)
def
create_simple_video
(
frames
=
10
,
height
=
480
,
width
=
640
):
video_data
=
[]
for
i
in
range
(
frames
):
frame
=
np
.
zeros
((
height
,
width
,
3
),
dtype
=
np
.
float32
)
stripe_height
=
height
//
8
colors
=
[
[
1.0
,
0.0
,
0.0
],
# 红色
[
0.0
,
1.0
,
0.0
],
# 绿色
[
0.0
,
0.0
,
1.0
],
# 蓝色
[
1.0
,
1.0
,
0.0
],
# 黄色
[
1.0
,
0.0
,
1.0
],
# 洋红
[
0.0
,
1.0
,
1.0
],
# 青色
[
1.0
,
1.0
,
1.0
],
# 白色
[
0.5
,
0.5
,
0.5
],
# 灰色
]
for
j
,
color
in
enumerate
(
colors
):
start_y
=
j
*
stripe_height
end_y
=
min
((
j
+
1
)
*
stripe_height
,
height
)
frame
[
start_y
:
end_y
,
:]
=
color
offset
=
int
((
i
/
frames
)
*
width
)
frame
=
np
.
roll
(
frame
,
offset
,
axis
=
1
)
frame
=
torch
.
tensor
(
frame
,
dtype
=
torch
.
float32
)
video_data
.
append
(
frame
)
return
torch
.
stack
(
video_data
,
dim
=
0
)
if
__name__
==
"__main__"
:
sample_rate
=
16000
fps
=
16
width
=
640
height
=
480
recorder
=
VARecorder
(
# livestream_url="rtmp://localhost/live/test",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url
=
"/path/to/output_video.mp4"
,
fps
=
fps
,
sample_rate
=
sample_rate
,
)
audio_path
=
"/path/to/test_b_2min.wav"
audio_array
,
ori_sr
=
ta
.
load
(
audio_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
16000
)
audio_array
=
audio_array
.
reshape
(
-
1
)
secs
=
audio_array
.
shape
[
0
]
//
sample_rate
interval
=
1
for
i
in
range
(
0
,
secs
,
interval
):
logger
.
info
(
f
"
{
i
}
/
{
secs
}
s"
)
start
=
i
*
sample_rate
end
=
(
i
+
interval
)
*
sample_rate
cur_audio_array
=
audio_array
[
start
:
end
]
logger
.
info
(
f
"audio:
{
cur_audio_array
.
shape
}
{
cur_audio_array
.
dtype
}
{
cur_audio_array
.
min
()
}
{
cur_audio_array
.
max
()
}
"
)
num_frames
=
int
(
interval
*
fps
)
images
=
create_simple_video
(
num_frames
,
height
,
width
)
logger
.
info
(
f
"images:
{
images
.
shape
}
{
images
.
dtype
}
{
images
.
min
()
}
{
images
.
max
()
}
"
)
recorder
.
pub_livestream
(
images
,
cur_audio_array
)
time
.
sleep
(
interval
)
recorder
.
stop
()
lightx2v/deploy/common/va_recorder_x264.py
0 → 100644
View file @
a1ebc651
import
ctypes
import
queue
import
threading
import
time
import
traceback
import
numpy
as
np
import
torch
import
torchaudio
as
ta
from
loguru
import
logger
from
scipy.signal
import
resample
class
X264VARecorder
:
def
__init__
(
self
,
whip_shared_path
:
str
,
livestream_url
:
str
,
fps
:
float
=
16.0
,
sample_rate
:
int
=
16000
,
slice_frame
:
int
=
1
,
prev_frame
:
int
=
1
,
):
assert
livestream_url
.
startswith
(
"http"
),
"X264VARecorder only support whip http livestream"
self
.
livestream_url
=
livestream_url
self
.
fps
=
fps
self
.
sample_rate
=
sample_rate
self
.
width
=
None
self
.
height
=
None
self
.
stoppable_t
=
None
# only enable whip shared api for whip http livestream
self
.
whip_shared_path
=
whip_shared_path
self
.
whip_shared_lib
=
None
self
.
whip_shared_handle
=
None
assert
livestream_url
.
startswith
(
"http"
),
"X264VARecorder only support whip http livestream"
self
.
realtime
=
True
# queue for send data to whip shared api
self
.
queue
=
queue
.
Queue
()
self
.
worker_thread
=
None
# buffer for stream data
self
.
target_sample_rate
=
48000
self
.
target_samples_per_frame
=
round
(
self
.
target_sample_rate
/
self
.
fps
)
self
.
target_chunks_per_frame
=
self
.
target_samples_per_frame
*
2
self
.
stream_buffer
=
[]
self
.
stream_buffer_lock
=
threading
.
Lock
()
self
.
stop_schedule
=
False
self
.
schedule_thread
=
None
self
.
slice_frame
=
slice_frame
self
.
prev_frame
=
prev_frame
assert
self
.
slice_frame
>=
self
.
prev_frame
,
"Slice frame must be greater than previous frame"
def
worker
(
self
):
try
:
fail_time
,
max_fail_time
=
0
,
10
packet_secs
=
1.0
/
self
.
fps
while
True
:
try
:
if
self
.
queue
is
None
:
break
data
=
self
.
queue
.
get
()
if
data
is
None
:
logger
.
info
(
"Worker thread received stop signal"
)
break
audios
,
images
=
data
for
i
in
range
(
images
.
shape
[
0
]):
t0
=
time
.
time
()
cur_audio
=
audios
[
i
*
self
.
target_chunks_per_frame
:
(
i
+
1
)
*
self
.
target_chunks_per_frame
].
flatten
()
audio_ptr
=
cur_audio
.
ctypes
.
data_as
(
ctypes
.
POINTER
(
ctypes
.
c_int16
))
self
.
whip_shared_lib
.
pushWhipRawAudioFrame
(
self
.
whip_shared_handle
,
audio_ptr
,
self
.
target_samples_per_frame
)
cur_video
=
images
[
i
].
flatten
()
video_ptr
=
cur_video
.
ctypes
.
data_as
(
ctypes
.
POINTER
(
ctypes
.
c_uint8
))
self
.
whip_shared_lib
.
pushWhipRawVideoFrame
(
self
.
whip_shared_handle
,
video_ptr
,
self
.
width
,
self
.
height
)
if
self
.
realtime
and
i
<
images
.
shape
[
0
]
-
1
:
time
.
sleep
(
max
(
0
,
packet_secs
-
(
time
.
time
()
-
t0
)))
fail_time
=
0
except
:
# noqa
logger
.
error
(
f
"Send audio data error:
{
traceback
.
format_exc
()
}
"
)
fail_time
+=
1
if
fail_time
>
max_fail_time
:
logger
.
error
(
f
"Audio push worker thread failed
{
fail_time
}
times, stopping..."
)
break
except
:
# noqa
logger
.
error
(
f
"Audio push worker thread error:
{
traceback
.
format_exc
()
}
"
)
finally
:
logger
.
info
(
"Audio push worker thread stopped"
)
def
start_libx264_whip_shared_api
(
self
,
width
:
int
,
height
:
int
):
self
.
whip_shared_lib
=
ctypes
.
CDLL
(
self
.
whip_shared_path
)
# define function argtypes and restype
self
.
whip_shared_lib
.
initWhipStream
.
argtypes
=
[
ctypes
.
c_char_p
,
ctypes
.
c_int
,
ctypes
.
c_int
,
ctypes
.
c_int
,
ctypes
.
c_int
,
ctypes
.
c_int
]
self
.
whip_shared_lib
.
initWhipStream
.
restype
=
ctypes
.
c_void_p
self
.
whip_shared_lib
.
pushWhipRawAudioFrame
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
POINTER
(
ctypes
.
c_int16
),
ctypes
.
c_int
]
self
.
whip_shared_lib
.
pushWhipRawVideoFrame
.
argtypes
=
[
ctypes
.
c_void_p
,
ctypes
.
POINTER
(
ctypes
.
c_uint8
),
ctypes
.
c_int
,
ctypes
.
c_int
]
self
.
whip_shared_lib
.
destroyWhipStream
.
argtypes
=
[
ctypes
.
c_void_p
]
whip_url
=
ctypes
.
c_char_p
(
self
.
livestream_url
.
encode
(
"utf-8"
))
self
.
whip_shared_handle
=
ctypes
.
c_void_p
(
self
.
whip_shared_lib
.
initWhipStream
(
whip_url
,
1
,
1
,
0
,
width
,
height
))
logger
.
info
(
f
"WHIP shared API initialized with handle:
{
self
.
whip_shared_handle
}
"
)
def
convert_data
(
self
,
audios
,
images
):
# Convert audio data to 16-bit integer format
audio_datas
=
torch
.
clamp
(
torch
.
round
(
audios
*
32767
),
-
32768
,
32767
).
to
(
torch
.
int16
).
cpu
().
numpy
().
reshape
(
-
1
)
# Convert to numpy and scale to [0, 255], convert RGB to BGR for OpenCV/FFmpeg
image_datas
=
(
images
*
255
).
clamp
(
0
,
255
).
to
(
torch
.
uint8
).
cpu
().
numpy
()
logger
.
info
(
f
"image_datas:
{
image_datas
.
shape
}
{
image_datas
.
dtype
}
{
image_datas
.
min
()
}
{
image_datas
.
max
()
}
"
)
reample_audios
=
resample
(
audio_datas
,
int
(
len
(
audio_datas
)
*
48000
/
self
.
sample_rate
))
stereo_audios
=
np
.
stack
([
reample_audios
,
reample_audios
],
axis
=-
1
).
astype
(
np
.
int16
).
reshape
(
-
1
)
return
stereo_audios
,
image_datas
def
start
(
self
,
width
:
int
,
height
:
int
):
self
.
set_video_size
(
width
,
height
)
def
set_video_size
(
self
,
width
:
int
,
height
:
int
):
if
self
.
width
is
not
None
and
self
.
height
is
not
None
:
assert
self
.
width
==
width
and
self
.
height
==
height
,
"Video size already set"
return
self
.
width
=
width
self
.
height
=
height
self
.
start_libx264_whip_shared_api
(
width
,
height
)
self
.
worker_thread
=
threading
.
Thread
(
target
=
self
.
worker
)
self
.
worker_thread
.
start
()
if
self
.
realtime
:
self
.
schedule_thread
=
threading
.
Thread
(
target
=
self
.
schedule_stream_buffer
)
self
.
schedule_thread
.
start
()
def
buffer_stream
(
self
,
images
:
torch
.
Tensor
,
audios
:
torch
.
Tensor
,
gen_video
:
torch
.
Tensor
):
N
,
height
,
width
,
C
=
images
.
shape
M
=
audios
.
reshape
(
-
1
).
shape
[
0
]
assert
N
%
self
.
slice_frame
==
0
,
"Video frames must be divisible by slice_frame"
assert
C
==
3
,
"Input must be [N, H, W, C] with C=3"
audio_frames
=
round
(
M
*
self
.
fps
/
self
.
sample_rate
)
if
audio_frames
!=
N
:
logger
.
warning
(
f
"Video and audio frames mismatch,
{
N
}
vs
{
audio_frames
}
"
)
self
.
set_video_size
(
width
,
height
)
audio_datas
,
image_datas
=
self
.
convert_data
(
audios
,
images
)
# logger.info(f"Buffer stream images {images.shape} {audios.shape} {gen_video.shape}")
rets
=
[]
for
i
in
range
(
0
,
N
,
self
.
slice_frame
):
end_frame
=
i
+
self
.
slice_frame
img
=
image_datas
[
i
:
end_frame
]
aud
=
audio_datas
[
i
*
self
.
target_chunks_per_frame
:
end_frame
*
self
.
target_chunks_per_frame
]
gen
=
gen_video
[:,
:,
(
end_frame
-
self
.
prev_frame
)
:
end_frame
]
rets
.
append
((
img
,
aud
,
gen
))
with
self
.
stream_buffer_lock
:
origin_size
=
len
(
self
.
stream_buffer
)
self
.
stream_buffer
.
extend
(
rets
)
logger
.
info
(
f
"Buffered
{
origin_size
}
+
{
len
(
rets
)
}
=
{
len
(
self
.
stream_buffer
)
}
stream segments"
)
def
get_buffer_stream_size
(
self
):
return
len
(
self
.
stream_buffer
)
def
truncate_stream_buffer
(
self
,
size
:
int
):
with
self
.
stream_buffer_lock
:
self
.
stream_buffer
=
self
.
stream_buffer
[:
size
]
logger
.
info
(
f
"Truncated stream buffer to
{
len
(
self
.
stream_buffer
)
}
segments"
)
if
len
(
self
.
stream_buffer
)
>
0
:
return
self
.
stream_buffer
[
-
1
][
2
]
# return the last video tensor
else
:
return
None
def
schedule_stream_buffer
(
self
):
schedule_interval
=
self
.
slice_frame
/
self
.
fps
logger
.
info
(
f
"Schedule stream buffer with interval:
{
schedule_interval
}
seconds"
)
t
=
None
while
True
:
try
:
if
self
.
stop_schedule
:
break
img
,
aud
,
gen
=
None
,
None
,
None
with
self
.
stream_buffer_lock
:
if
len
(
self
.
stream_buffer
)
>
0
:
img
,
aud
,
gen
=
self
.
stream_buffer
.
pop
(
0
)
if
t
is
not
None
:
wait_secs
=
schedule_interval
-
(
time
.
time
()
-
t
)
if
wait_secs
>
0
:
time
.
sleep
(
wait_secs
)
t
=
time
.
time
()
if
img
is
not
None
and
aud
is
not
None
:
self
.
queue
.
put
((
aud
,
img
))
# logger.info(f"Scheduled {img.shape[0]} frames and {aud.shape[0]} audio samples to publish")
del
gen
self
.
stoppable_t
=
time
.
time
()
+
img
.
shape
[
0
]
/
self
.
fps
+
3
else
:
logger
.
warning
(
f
"No stream buffer to schedule"
)
except
Exception
:
logger
.
error
(
f
"Schedule stream buffer error:
{
traceback
.
format_exc
()
}
"
)
break
logger
.
info
(
"Schedule stream buffer thread stopped"
)
def
stop
(
self
,
wait
=
True
):
if
wait
and
self
.
stoppable_t
:
t
=
self
.
stoppable_t
-
time
.
time
()
if
t
>
0
:
logger
.
warning
(
f
"Waiting for
{
t
}
seconds to stop ..."
)
time
.
sleep
(
t
)
self
.
stoppable_t
=
None
if
self
.
schedule_thread
:
self
.
stop_schedule
=
True
self
.
schedule_thread
.
join
(
timeout
=
5
)
if
self
.
schedule_thread
and
self
.
schedule_thread
.
is_alive
():
logger
.
error
(
f
"Schedule thread did not stop after 5s"
)
# Send stop signals to queues
if
self
.
queue
:
self
.
queue
.
put
(
None
)
# Wait for threads to finish
if
self
.
worker_thread
and
self
.
worker_thread
.
is_alive
():
self
.
worker_thread
.
join
(
timeout
=
5
)
if
self
.
worker_thread
.
is_alive
():
logger
.
warning
(
"Worker thread did not stop gracefully"
)
# Destroy WHIP shared API
if
self
.
whip_shared_lib
and
self
.
whip_shared_handle
:
self
.
whip_shared_lib
.
destroyWhipStream
(
self
.
whip_shared_handle
)
self
.
whip_shared_handle
=
None
self
.
whip_shared_lib
=
None
logger
.
warning
(
"WHIP shared API destroyed"
)
def
__del__
(
self
):
self
.
stop
()
def
create_simple_video
(
frames
=
10
,
height
=
480
,
width
=
640
):
video_data
=
[]
for
i
in
range
(
frames
):
frame
=
np
.
zeros
((
height
,
width
,
3
),
dtype
=
np
.
float32
)
stripe_height
=
height
//
8
colors
=
[
[
1.0
,
0.0
,
0.0
],
# 红色
[
0.0
,
1.0
,
0.0
],
# 绿色
[
0.0
,
0.0
,
1.0
],
# 蓝色
[
1.0
,
1.0
,
0.0
],
# 黄色
[
1.0
,
0.0
,
1.0
],
# 洋红
[
0.0
,
1.0
,
1.0
],
# 青色
[
1.0
,
1.0
,
1.0
],
# 白色
[
0.5
,
0.5
,
0.5
],
# 灰色
]
for
j
,
color
in
enumerate
(
colors
):
start_y
=
j
*
stripe_height
end_y
=
min
((
j
+
1
)
*
stripe_height
,
height
)
frame
[
start_y
:
end_y
,
:]
=
color
offset
=
int
((
i
/
frames
)
*
width
)
frame
=
np
.
roll
(
frame
,
offset
,
axis
=
1
)
frame
=
torch
.
tensor
(
frame
,
dtype
=
torch
.
float32
)
video_data
.
append
(
frame
)
return
torch
.
stack
(
video_data
,
dim
=
0
)
if
__name__
==
"__main__"
:
sample_rate
=
16000
fps
=
16
width
=
452
height
=
352
recorder
=
X264VARecorder
(
whip_shared_path
=
"/data/nvme0/liuliang1/lightx2v/test_deploy/test_whip_so/0.1.1/go_whxp.so"
,
livestream_url
=
"https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=subscribe&stream=ll2&eip=10.120.114.82:8000"
,
fps
=
fps
,
sample_rate
=
sample_rate
,
)
recorder
.
start
(
width
,
height
)
# time.sleep(5)
audio_path
=
"/data/nvme0/liuliang1/lightx2v/test_deploy/media_test/mangzhong.wav"
audio_array
,
ori_sr
=
ta
.
load
(
audio_path
)
audio_array
=
ta
.
functional
.
resample
(
audio_array
.
mean
(
0
),
orig_freq
=
ori_sr
,
new_freq
=
16000
)
audio_array
=
audio_array
.
numpy
().
reshape
(
-
1
)
secs
=
audio_array
.
shape
[
0
]
//
sample_rate
interval
=
1
space
=
10
i
=
0
while
i
<
space
:
t0
=
time
.
time
()
logger
.
info
(
f
"space
{
i
}
/
{
space
}
s"
)
cur_audio_array
=
np
.
zeros
(
int
(
interval
*
sample_rate
),
dtype
=
np
.
float32
)
num_frames
=
int
(
interval
*
fps
)
images
=
create_simple_video
(
num_frames
,
height
,
width
)
recorder
.
buffer_stream
(
images
,
torch
.
tensor
(
cur_audio_array
,
dtype
=
torch
.
float32
),
images
)
i
+=
interval
time
.
sleep
(
interval
-
(
time
.
time
()
-
t0
))
started
=
True
i
=
0
while
i
<
secs
:
t0
=
time
.
time
()
start
=
int
(
i
*
sample_rate
)
end
=
int
((
i
+
interval
)
*
sample_rate
)
cur_audio_array
=
torch
.
tensor
(
audio_array
[
start
:
end
],
dtype
=
torch
.
float32
)
num_frames
=
int
(
interval
*
fps
)
images
=
create_simple_video
(
num_frames
,
height
,
width
)
logger
.
info
(
f
"
{
i
}
/
{
secs
}
s"
)
if
started
:
logger
.
warning
(
f
"start pub_livestream !!!!!!!!!!!!!!!!!!!!!!!"
)
started
=
False
recorder
.
buffer_stream
(
images
,
cur_audio_array
,
images
)
i
+=
interval
time
.
sleep
(
interval
-
(
time
.
time
()
-
t0
))
recorder
.
stop
()
Prev
1
…
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment