Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
daf4c74e
Commit
daf4c74e
authored
Mar 24, 2025
by
helloyongyang
Committed by
Yang Yong(雍洋)
Apr 08, 2025
Browse files
first commit
parent
6c79160f
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1175 additions
and
0 deletions
+1175
-0
lightx2v/attentions/distributed/partial_heads_attn/tests/test_acc.py
...tentions/distributed/partial_heads_attn/tests/test_acc.py
+80
-0
lightx2v/attentions/distributed/partial_heads_attn/wrap.py
lightx2v/attentions/distributed/partial_heads_attn/wrap.py
+5
-0
lightx2v/attentions/distributed/ring/__init__.py
lightx2v/attentions/distributed/ring/__init__.py
+0
-0
lightx2v/attentions/distributed/ring/attn.py
lightx2v/attentions/distributed/ring/attn.py
+80
-0
lightx2v/attentions/distributed/ring/wrap.py
lightx2v/attentions/distributed/ring/wrap.py
+53
-0
lightx2v/attentions/distributed/ulysses/__init__.py
lightx2v/attentions/distributed/ulysses/__init__.py
+0
-0
lightx2v/attentions/distributed/ulysses/attn.py
lightx2v/attentions/distributed/ulysses/attn.py
+96
-0
lightx2v/attentions/distributed/ulysses/wrap.py
lightx2v/attentions/distributed/ulysses/wrap.py
+60
-0
lightx2v/attentions/distributed/utils/__init__.py
lightx2v/attentions/distributed/utils/__init__.py
+0
-0
lightx2v/attentions/distributed/utils/process.py
lightx2v/attentions/distributed/utils/process.py
+72
-0
lightx2v/common/__init__.py
lightx2v/common/__init__.py
+0
-0
lightx2v/common/backend_infer/trt/common.py
lightx2v/common/backend_infer/trt/common.py
+146
-0
lightx2v/common/backend_infer/trt/common_runtime.py
lightx2v/common/backend_infer/trt/common_runtime.py
+170
-0
lightx2v/common/offload/me_block.py
lightx2v/common/offload/me_block.py
+91
-0
lightx2v/common/ops/__init__.py
lightx2v/common/ops/__init__.py
+3
-0
lightx2v/common/ops/conv/__init__.py
lightx2v/common/ops/conv/__init__.py
+2
-0
lightx2v/common/ops/conv/conv2d.py
lightx2v/common/ops/conv/conv2d.py
+58
-0
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+58
-0
lightx2v/common/ops/mm/__init__.py
lightx2v/common/ops/mm/__init__.py
+2
-0
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+199
-0
No files found.
lightx2v/attentions/distributed/partial_heads_attn/tests/test_acc.py
0 → 100644
View file @
daf4c74e
import
torch
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
from
lightx2v.utils.utils
import
seed_all
seed_all
(
42
)
def
prepare_tensors
():
cur_rank
=
dist
.
get_rank
()
# 获取当前进程的 rank
torch
.
cuda
.
set_device
(
cur_rank
)
# 设置当前进程的 CUDA 设备
q
=
torch
.
randn
(
32656
,
24
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
k
=
torch
.
randn
(
32656
,
24
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
v
=
torch
.
randn
(
32656
,
24
,
128
,
dtype
=
torch
.
bfloat16
).
cuda
()
cu_seqlens_qkv
=
torch
.
tensor
(
[
0
,
32411
,
32656
],
dtype
=
torch
.
int32
).
cuda
()
max_seqlen_qkv
=
32656
return
q
,
k
,
v
,
cu_seqlens_qkv
,
max_seqlen_qkv
def
test_part_head
():
q
,
k
,
v
,
cu_seqlens_qkv
,
max_seqlen_qkv
=
prepare_tensors
()
# 先计算完整的结果作为参考
single_gpu_output
=
attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
num_heads
=
q
.
shape
[
-
2
]
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
num_chunk_heads
=
int
(
num_heads
/
dist
.
get_world_size
())
if
cur_rank
==
world_size
-
1
:
q
=
q
[:,
num_chunk_heads
*
cur_rank
:,
:]
k
=
k
[:,
num_chunk_heads
*
cur_rank
:,
:]
v
=
v
[:,
num_chunk_heads
*
cur_rank
:,
:]
else
:
q
=
q
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
k
=
k
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
v
=
v
[:,
num_chunk_heads
*
cur_rank
:
num_chunk_heads
*
(
cur_rank
+
1
),
:]
output
=
attention
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
,
)
gathered_outputs
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_outputs
,
output
)
combined_output
=
torch
.
cat
(
gathered_outputs
,
dim
=
1
)
# 验证结果一致性
if
cur_rank
==
0
:
# import pdb; pdb.set_trace()
print
(
"Outputs match:"
,
torch
.
allclose
(
single_gpu_output
,
combined_output
,
rtol
=
1e-3
,
atol
=
1e-3
))
# # 验证结果一致性
# print("Outputs match:", torch.allclose(single_gpu_output, combined_output, rtol=1e-3, atol=1e-3))
if
__name__
==
"__main__"
:
# 初始化分布式环境
dist
.
init_process_group
(
backend
=
'nccl'
)
test_part_head
()
\ No newline at end of file
lightx2v/attentions/distributed/partial_heads_attn/wrap.py
0 → 100644
View file @
daf4c74e
from
lightx2v.attentions.distributed.partial_heads_attn.attn
import
partial_heads_attn
def
parallelize_hunyuan
(
hunyuan_model
):
hunyuan_model
.
transformer_infer
.
parallel_attention
=
partial_heads_attn
\ No newline at end of file
lightx2v/attentions/distributed/ring/__init__.py
0 → 100644
View file @
daf4c74e
lightx2v/attentions/distributed/ring/attn.py
0 → 100644
View file @
daf4c74e
import
torch
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
def
ring_attn
(
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_type
=
"flash_attn2"
):
'''
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
'''
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
# 获取序列长度和文本相关的长度
seq_len
=
q
.
shape
[
0
]
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
# 获取查询张量的头数和隐藏维度
_
,
heads
,
hidden_dims
=
q
.
shape
shard_heads
=
heads
//
world_size
# 每个进程处理的头数
shard_seqlen
=
img_qkv_len
# 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q
,
img_k
,
img_v
=
q
[:
img_qkv_len
,:,:].
contiguous
(),
k
[:
img_qkv_len
,:,:].
contiguous
(),
v
[:
img_qkv_len
,:,:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
q
[
img_qkv_len
:,:,:].
contiguous
(),
k
[
img_qkv_len
:,:,:].
contiguous
(),
v
[
img_qkv_len
:,:,:].
contiguous
()
gathered_img_k
=
[
torch
.
empty_like
(
img_k
)
for
_
in
range
(
world_size
)]
gathered_img_v
=
[
torch
.
empty_like
(
img_v
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_img_k
,
img_k
)
dist
.
all_gather
(
gathered_img_v
,
img_v
)
torch
.
cuda
.
synchronize
()
q
=
q
k
=
torch
.
cat
(
gathered_img_k
+
[
txt_k
],
dim
=
0
)
v
=
torch
.
cat
(
gathered_img_v
+
[
txt_v
],
dim
=
0
)
# 初始化累积序列长度张量
cu_seqlens_q
=
torch
.
zeros
([
3
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
s2
=
txt_mask_len
+
img_q
.
shape
[
0
]
# 文本掩码的结束位置
cu_seqlens_q
[
1
]
=
s1
# 设置累积序列长度
cu_seqlens_q
[
2
]
=
s2
# 设置累积序列长度
max_seqlen_q
=
img_q
.
shape
[
0
]
+
txt_q
.
shape
[
0
]
# 最大序列长度
# 初始化累积序列长度张量
cu_seqlens_kv
=
torch
.
zeros
([
3
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
s
=
txt_qkv_len
+
img_k
.
shape
[
0
]
*
world_size
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
s2
=
txt_mask_len
+
img_k
.
shape
[
0
]
*
world_size
# 文本掩码的结束位置
cu_seqlens_kv
[
1
]
=
s1
# 设置累积序列长度
cu_seqlens_kv
[
2
]
=
s2
# 设置累积序列长度
max_seqlen_kv
=
img_k
.
shape
[
0
]
*
world_size
+
txt_q
.
shape
[
0
]
# 最大序列长度
attn
=
attention
(
attention_type
=
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
)
return
attn
lightx2v/attentions/distributed/ring/wrap.py
0 → 100644
View file @
daf4c74e
import
functools
from
lightx2v.attentions.distributed.ring.attn
import
ring_attn
from
lightx2v.attentions.distributed.utils.process
import
pre_process
,
post_process
def
parallelize_hunyuan
(
hunyuan_model
):
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
"""
# 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
hunyuan_model
.
transformer_infer
.
parallel_attention
=
ring_attn
# 保存原始的推理方法,以便后续调用
original_infer
=
hunyuan_model
.
infer
@
functools
.
wraps
(
hunyuan_model
.
__class__
.
infer
)
# 保留原始推理方法的元信息
def
new_infer
(
self
,
latent_model_input
,
t_expand
,
text_states
,
text_mask
,
text_states_2
,
freqs_cos
,
freqs_sin
,
guidance
):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
latent_model_input: 潜在模型输入
t_expand: 时间扩展参数
text_states: 文本状态
text_mask: 文本掩码
text_states_2: 第二组文本状态
freqs_cos: 余弦频率
freqs_sin: 正弦频率
guidance: 指导参数
返回:
combined_output: 经过后处理的输出结果
"""
# 预处理输入数据
latent_model_input
,
freqs_cos
,
freqs_sin
,
split_dim
=
pre_process
(
latent_model_input
,
freqs_cos
,
freqs_sin
)
# 调用原始推理方法,获取输出
output
=
original_infer
(
latent_model_input
,
t_expand
,
text_states
,
text_mask
,
text_states_2
,
freqs_cos
,
freqs_sin
,
guidance
)
# 对输出进行后处理
combined_output
=
post_process
(
output
,
split_dim
)
return
combined_output
# 返回处理后的输出
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer
=
new_infer
.
__get__
(
hunyuan_model
)
hunyuan_model
.
infer
=
new_infer
# 替换原始推理方法
\ No newline at end of file
lightx2v/attentions/distributed/ulysses/__init__.py
0 → 100644
View file @
daf4c74e
lightx2v/attentions/distributed/ulysses/attn.py
0 → 100644
View file @
daf4c74e
import
torch
import
torch.distributed
as
dist
from
lightx2v.attentions
import
attention
from
lightx2v.attentions.distributed.comm.all2all
import
all2all_seq2head
,
all2all_head2seq
def
ulysses_attn
(
q
,
k
,
v
,
img_qkv_len
,
cu_seqlens_qkv
,
attention_type
=
"flash_attn2"
):
'''
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
参数:
q (torch.Tensor): 查询张量,形状为 [shard_seqlen, heads, hidden_dims]
k (torch.Tensor): 键张量,形状为 [shard_seqlen, heads, hidden_dims]
v (torch.Tensor): 值张量,形状为 [shard_seqlen, heads, hidden_dims]
img_qkv_len (int): 图像查询、键和值的长度
cu_seqlens_qkv (torch.Tensor): 累积序列长度,包含文本和图像的长度信息
attention_type (str): 注意力类型,默认为 "flash_attn2"
返回:
torch.Tensor: 计算得到的注意力结果
'''
# 获取当前进程的排名和全局进程数
cur_rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
# 获取序列长度和文本相关的长度
seq_len
=
q
.
shape
[
0
]
txt_qkv_len
=
cu_seqlens_qkv
[
1
]
-
img_qkv_len
# 文本查询、键和值的长度
txt_mask_len
=
cu_seqlens_qkv
[
2
]
-
img_qkv_len
# 文本掩码长度
# 获取查询张量的头数和隐藏维度
_
,
heads
,
hidden_dims
=
q
.
shape
shard_heads
=
heads
//
world_size
# 每个进程处理的头数
shard_seqlen
=
img_qkv_len
# 每个进程处理的序列长度
# 分割图像和文本的查询、键和值
img_q
,
img_k
,
img_v
=
q
[:
img_qkv_len
,:,:].
contiguous
(),
k
[:
img_qkv_len
,:,:].
contiguous
(),
v
[:
img_qkv_len
,:,:].
contiguous
()
txt_q
,
txt_k
,
txt_v
=
q
[
img_qkv_len
:,:,:].
contiguous
(),
k
[
img_qkv_len
:,:,:].
contiguous
(),
v
[
img_qkv_len
:,:,:].
contiguous
()
# 将图像的查询、键和值转换为头的格式
img_q
=
all2all_seq2head
(
img_q
)
img_k
=
all2all_seq2head
(
img_k
)
img_v
=
all2all_seq2head
(
img_v
)
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
# 处理文本的查询、键和值,选择当前进程的头
txt_q
=
txt_q
[:,
cur_rank
*
shard_heads
:(
cur_rank
+
1
)
*
shard_heads
,:]
txt_k
=
txt_k
[:,
cur_rank
*
shard_heads
:(
cur_rank
+
1
)
*
shard_heads
,:]
txt_v
=
txt_v
[:,
cur_rank
*
shard_heads
:(
cur_rank
+
1
)
*
shard_heads
,:]
# 合并图像和文本的查询、键和值
q
=
torch
.
cat
((
img_q
,
txt_q
),
dim
=
0
)
k
=
torch
.
cat
((
img_k
,
txt_k
),
dim
=
0
)
v
=
torch
.
cat
((
img_v
,
txt_v
),
dim
=
0
)
# 初始化累积序列长度张量
cu_seqlens_qkv
=
torch
.
zeros
([
3
],
dtype
=
torch
.
int32
,
device
=
"cuda"
)
s
=
txt_qkv_len
+
img_q
.
shape
[
0
]
# 计算文本和图像的总长度
s1
=
s
# 当前样本的结束位置
s2
=
txt_mask_len
+
img_q
.
shape
[
0
]
# 文本掩码的结束位置
cu_seqlens_qkv
[
1
]
=
s1
# 设置累积序列长度
cu_seqlens_qkv
[
2
]
=
s2
# 设置累积序列长度
max_seqlen_qkv
=
img_q
.
shape
[
0
]
+
txt_q
.
shape
[
0
]
# 最大序列长度
# 调用注意力函数计算注意力结果
attn
=
attention
(
attention_type
=
attention_type
,
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
max_seqlen_qkv
,
max_seqlen_kv
=
max_seqlen_qkv
)
# 分割图像和文本的注意力结果
img_attn
,
txt_attn
=
attn
[:
img_q
.
shape
[
0
],:],
attn
[
img_q
.
shape
[
0
]:,]
# 收集所有进程的文本注意力结果
gathered_txt_attn
=
[
torch
.
empty_like
(
txt_attn
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
)
# 处理图像注意力结果
img_attn
=
img_attn
.
reshape
(
world_size
*
shard_seqlen
,
shard_heads
,
hidden_dims
)
# 重塑图像注意力结果
img_attn
=
all2all_head2seq
(
img_attn
)
# 将头的格式转换回序列格式
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
txt_attn
=
torch
.
cat
(
gathered_txt_attn
,
dim
=
1
)
# 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
attn
=
torch
.
cat
([
img_attn
,
txt_attn
],
dim
=
0
)
return
attn
# 返回最终的注意力结果
\ No newline at end of file
lightx2v/attentions/distributed/ulysses/wrap.py
0 → 100644
View file @
daf4c74e
import
functools
from
lightx2v.attentions.distributed.ulysses.attn
import
ulysses_attn
from
lightx2v.attentions.distributed.utils.process
import
pre_process
,
post_process
def
parallelize_hunyuan
(
hunyuan_model
):
"""将 Hunyuan 模型的推理过程并行化,使用 Ulysses 注意力机制。
参数:
hunyuan_model: Hunyuan 模型实例,包含推理方法和其他属性。
"""
# 将 Hunyuan 模型的并行注意力机制替换为 Ulysses 注意力
hunyuan_model
.
transformer_infer
.
parallel_attention
=
ulysses_attn
# 保存原始的推理方法,以便后续调用
original_infer
=
hunyuan_model
.
infer
@
functools
.
wraps
(
hunyuan_model
.
__class__
.
infer
)
# 保留原始推理方法的元信息
def
new_infer
(
self
,
text_encoders_output
,
args
):
"""新的推理方法,处理输入并调用原始推理方法。
参数:
self: Hunyuan 模型实例
text_encoders_output: 文本编码器的输出
args: 其他参数
返回:
None
"""
# 保存原始的潜在模型输入和频率数据
self
.
scheduler
.
ori_latents
,
self
.
scheduler
.
ori_freqs_cos
,
self
.
scheduler
.
ori_freqs_sin
=
(
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
)
# 预处理输入数据以适应并行计算
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
,
split_dim
=
pre_process
(
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
)
# 调用原始推理方法,获取输出
output
=
original_infer
(
text_encoders_output
,
args
)
# 对输出进行后处理
self
.
scheduler
.
noise_pred
=
post_process
(
self
.
scheduler
.
noise_pred
,
split_dim
)
# 恢复原始的潜在模型输入和频率数据
self
.
scheduler
.
latents
,
self
.
scheduler
.
freqs_cos
,
self
.
scheduler
.
freqs_sin
=
(
self
.
scheduler
.
ori_latents
,
self
.
scheduler
.
ori_freqs_cos
,
self
.
scheduler
.
ori_freqs_sin
)
# return combined_output # 返回处理后的输出(当前被注释掉)
# 将新的推理方法绑定到 Hunyuan 模型实例
new_infer
=
new_infer
.
__get__
(
hunyuan_model
)
hunyuan_model
.
infer
=
new_infer
# 替换原始推理方法
\ No newline at end of file
lightx2v/attentions/distributed/utils/__init__.py
0 → 100644
View file @
daf4c74e
lightx2v/attentions/distributed/utils/process.py
0 → 100644
View file @
daf4c74e
import
torch
import
torch.distributed
as
dist
def
pre_process
(
latent_model_input
,
freqs_cos
,
freqs_sin
):
'''
对输入的潜在模型数据和频率数据进行预处理,进行切分以适应分布式计算。
参数:
latent_model_input (torch.Tensor): 输入的潜在模型数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_cos (torch.Tensor): 余弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
freqs_sin (torch.Tensor): 正弦频率数据,形状为 [batch_size, channels, temporal_size, height, width]
返回:
tuple: 处理后的 latent_model_input, freqs_cos, freqs_sin 和切分维度 split_dim
'''
# 获取当前进程的世界大小和当前进程的排名
world_size
=
dist
.
get_world_size
()
cur_rank
=
dist
.
get_rank
()
# 根据输入的形状确定切分维度
if
latent_model_input
.
shape
[
-
2
]
//
2
%
world_size
==
0
:
split_dim
=
-
2
# 按高度切分
elif
latent_model_input
.
shape
[
-
1
]
//
2
%
world_size
==
0
:
split_dim
=
-
1
# 按宽度切分
else
:
raise
ValueError
(
f
"Cannot split video sequence into world size (
{
world_size
}
) parts evenly"
)
# 获取时间维度、处理后的高度和宽度
temporal_size
,
h
,
w
=
latent_model_input
.
shape
[
2
],
latent_model_input
.
shape
[
3
]
//
2
,
latent_model_input
.
shape
[
4
]
//
2
# 按照确定的维度切分潜在模型输入
latent_model_input
=
torch
.
chunk
(
latent_model_input
,
world_size
,
dim
=
split_dim
)[
cur_rank
]
# 处理余弦频率数据
dim_thw
=
freqs_cos
.
shape
[
-
1
]
# 获取频率数据的最后一个维度
freqs_cos
=
freqs_cos
.
reshape
(
temporal_size
,
h
,
w
,
dim_thw
)
# 重塑为 [temporal_size, height, width, dim_thw]
freqs_cos
=
torch
.
chunk
(
freqs_cos
,
world_size
,
dim
=
split_dim
-
1
)[
cur_rank
]
# 切分频率数据
freqs_cos
=
freqs_cos
.
reshape
(
-
1
,
dim_thw
)
# 重塑为 [batch_size, dim_thw]
# 处理正弦频率数据
dim_thw
=
freqs_sin
.
shape
[
-
1
]
# 获取频率数据的最后一个维度
freqs_sin
=
freqs_sin
.
reshape
(
temporal_size
,
h
,
w
,
dim_thw
)
# 重塑为 [temporal_size, height, width, dim_thw]
freqs_sin
=
torch
.
chunk
(
freqs_sin
,
world_size
,
dim
=
split_dim
-
1
)[
cur_rank
]
# 切分频率数据
freqs_sin
=
freqs_sin
.
reshape
(
-
1
,
dim_thw
)
# 重塑为 [batch_size, dim_thw]
return
latent_model_input
,
freqs_cos
,
freqs_sin
,
split_dim
# 返回处理后的数据
def
post_process
(
output
,
split_dim
):
"""对输出进行后处理,收集所有进程的输出并合并。
参数:
output (torch.Tensor): 当前进程的输出,形状为 [batch_size, ...]
split_dim (int): 切分维度,用于合并输出
返回:
torch.Tensor: 合并后的输出,形状为 [world_size * batch_size, ...]
"""
# 获取当前进程的世界大小
world_size
=
dist
.
get_world_size
()
# 创建一个列表,用于存储所有进程的输出
gathered_outputs
=
[
torch
.
empty_like
(
output
)
for
_
in
range
(
world_size
)]
# 收集所有进程的输出
dist
.
all_gather
(
gathered_outputs
,
output
)
# 在指定的维度上合并所有进程的输出
combined_output
=
torch
.
cat
(
gathered_outputs
,
dim
=
split_dim
)
return
combined_output
# 返回合并后的输出
lightx2v/common/__init__.py
0 → 100755
View file @
daf4c74e
lightx2v/common/backend_infer/trt/common.py
0 → 100644
View file @
daf4c74e
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import
argparse
import
os
import
tensorrt
as
trt
from
.common_runtime
import
*
try
:
# Sometimes python does not understand FileNotFoundError
FileNotFoundError
except
NameError
:
FileNotFoundError
=
IOError
def
GiB
(
val
):
return
val
*
1
<<
30
def
add_help
(
description
):
parser
=
argparse
.
ArgumentParser
(
description
=
description
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
args
,
_
=
parser
.
parse_known_args
()
def
find_sample_data
(
description
=
"Runs a TensorRT Python sample"
,
subfolder
=
""
,
find_files
=
[],
err_msg
=
""
):
"""
Parses sample arguments.
Args:
description (str): Description of the sample.
subfolder (str): The subfolder containing data relevant to this sample
find_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.
Returns:
str: Path of data directory.
"""
# Standard command-line arguments for all samples.
kDEFAULT_DATA_ROOT
=
os
.
path
.
join
(
os
.
sep
,
"usr"
,
"src"
,
"tensorrt"
,
"data"
)
parser
=
argparse
.
ArgumentParser
(
description
=
description
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
)
parser
.
add_argument
(
"-d"
,
"--datadir"
,
help
=
"Location of the TensorRT sample data directory, and any additional data directories."
,
action
=
"append"
,
default
=
[
kDEFAULT_DATA_ROOT
],
)
args
,
_
=
parser
.
parse_known_args
()
def
get_data_path
(
data_dir
):
# If the subfolder exists, append it to the path, otherwise use the provided path as-is.
data_path
=
os
.
path
.
join
(
data_dir
,
subfolder
)
if
not
os
.
path
.
exists
(
data_path
):
if
data_dir
!=
kDEFAULT_DATA_ROOT
:
print
(
"WARNING: "
+
data_path
+
" does not exist. Trying "
+
data_dir
+
" instead."
)
data_path
=
data_dir
# Make sure data directory exists.
if
not
(
os
.
path
.
exists
(
data_path
))
and
data_dir
!=
kDEFAULT_DATA_ROOT
:
print
(
"WARNING: {:} does not exist. Please provide the correct data path with the -d option."
.
format
(
data_path
)
)
return
data_path
data_paths
=
[
get_data_path
(
data_dir
)
for
data_dir
in
args
.
datadir
]
return
data_paths
,
locate_files
(
data_paths
,
find_files
,
err_msg
)
def
locate_files
(
data_paths
,
filenames
,
err_msg
=
""
):
"""
Locates the specified files in the specified data directories.
If a file exists in multiple data directories, the first directory is used.
Args:
data_paths (List[str]): The data directories.
filename (List[str]): The names of the files to find.
Returns:
List[str]: The absolute paths of the files.
Raises:
FileNotFoundError if a file could not be located.
"""
found_files
=
[
None
]
*
len
(
filenames
)
for
data_path
in
data_paths
:
# Find all requested files.
for
index
,
(
found
,
filename
)
in
enumerate
(
zip
(
found_files
,
filenames
)):
if
not
found
:
file_path
=
os
.
path
.
abspath
(
os
.
path
.
join
(
data_path
,
filename
))
if
os
.
path
.
exists
(
file_path
):
found_files
[
index
]
=
file_path
# Check that all files were found
for
f
,
filename
in
zip
(
found_files
,
filenames
):
if
not
f
or
not
os
.
path
.
exists
(
f
):
raise
FileNotFoundError
(
"Could not find {:}. Searched in data paths: {:}
\n
{:}"
.
format
(
filename
,
data_paths
,
err_msg
)
)
return
found_files
# Sets up the builder to use the timing cache file, and creates it if it does not already exist
def
setup_timing_cache
(
config
:
trt
.
IBuilderConfig
,
timing_cache_path
:
os
.
PathLike
):
buffer
=
b
""
if
os
.
path
.
exists
(
timing_cache_path
):
with
open
(
timing_cache_path
,
mode
=
"rb"
)
as
timing_cache_file
:
buffer
=
timing_cache_file
.
read
()
timing_cache
:
trt
.
ITimingCache
=
config
.
create_timing_cache
(
buffer
)
config
.
set_timing_cache
(
timing_cache
,
True
)
# Saves the config's timing cache to file
def
save_timing_cache
(
config
:
trt
.
IBuilderConfig
,
timing_cache_path
:
os
.
PathLike
):
timing_cache
:
trt
.
ITimingCache
=
config
.
get_timing_cache
()
with
open
(
timing_cache_path
,
"wb"
)
as
timing_cache_file
:
timing_cache_file
.
write
(
memoryview
(
timing_cache
.
serialize
()))
\ No newline at end of file
lightx2v/common/backend_infer/trt/common_runtime.py
0 → 100644
View file @
daf4c74e
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import
ctypes
from
typing
import
Optional
,
List
,
Union
import
numpy
as
np
import
tensorrt
as
trt
from
cuda
import
cuda
,
cudart
def
check_cuda_err
(
err
):
if
isinstance
(
err
,
cuda
.
CUresult
):
if
err
!=
cuda
.
CUresult
.
CUDA_SUCCESS
:
raise
RuntimeError
(
"Cuda Error: {}"
.
format
(
err
))
if
isinstance
(
err
,
cudart
.
cudaError_t
):
if
err
!=
cudart
.
cudaError_t
.
cudaSuccess
:
raise
RuntimeError
(
"Cuda Runtime Error: {}"
.
format
(
err
))
else
:
raise
RuntimeError
(
"Unknown error type: {}"
.
format
(
err
))
def
cuda_call
(
call
):
err
,
res
=
call
[
0
],
call
[
1
:]
check_cuda_err
(
err
)
if
len
(
res
)
==
1
:
res
=
res
[
0
]
return
res
class
HostDeviceMem
:
"""Pair of host and device memory, where the host memory is wrapped in a numpy array"""
def
__init__
(
self
,
size
:
int
,
dtype
:
Optional
[
np
.
dtype
]
=
None
):
dtype
=
dtype
or
np
.
dtype
(
np
.
uint8
)
nbytes
=
size
*
dtype
.
itemsize
host_mem
=
cuda_call
(
cudart
.
cudaMallocHost
(
nbytes
))
pointer_type
=
ctypes
.
POINTER
(
np
.
ctypeslib
.
as_ctypes_type
(
dtype
))
self
.
_host
=
np
.
ctypeslib
.
as_array
(
ctypes
.
cast
(
host_mem
,
pointer_type
),
(
size
,))
self
.
_device
=
cuda_call
(
cudart
.
cudaMalloc
(
nbytes
))
self
.
_nbytes
=
nbytes
@
property
def
host
(
self
)
->
np
.
ndarray
:
return
self
.
_host
@
host
.
setter
def
host
(
self
,
data
:
Union
[
np
.
ndarray
,
bytes
]):
if
isinstance
(
data
,
np
.
ndarray
):
if
data
.
size
>
self
.
host
.
size
:
raise
ValueError
(
f
"Tried to fit an array of size
{
data
.
size
}
into host memory of size
{
self
.
host
.
size
}
"
)
np
.
copyto
(
self
.
host
[:
data
.
size
],
data
.
flat
,
casting
=
'safe'
)
else
:
assert
self
.
host
.
dtype
==
np
.
uint8
self
.
host
[:
self
.
nbytes
]
=
np
.
frombuffer
(
data
,
dtype
=
np
.
uint8
)
@
property
def
device
(
self
)
->
int
:
return
self
.
_device
@
property
def
nbytes
(
self
)
->
int
:
return
self
.
_nbytes
def
__str__
(
self
):
return
f
"Host:
\n
{
self
.
host
}
\n
Device:
\n
{
self
.
device
}
\n
Size:
\n
{
self
.
nbytes
}
\n
"
def
__repr__
(
self
):
return
self
.
__str__
()
def
free
(
self
):
cuda_call
(
cudart
.
cudaFree
(
self
.
device
))
cuda_call
(
cudart
.
cudaFreeHost
(
self
.
host
.
ctypes
.
data
))
# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
# If engine uses dynamic shapes, specify a profile to find the maximum input & output size.
def
allocate_buffers
(
engine
:
trt
.
ICudaEngine
,
profile_idx
:
Optional
[
int
]
=
None
):
inputs
=
[]
outputs
=
[]
bindings
=
[]
stream
=
cuda_call
(
cudart
.
cudaStreamCreate
())
tensor_names
=
[
engine
.
get_tensor_name
(
i
)
for
i
in
range
(
engine
.
num_io_tensors
)]
for
binding
in
tensor_names
:
# get_tensor_profile_shape returns (min_shape, optimal_shape, max_shape)
# Pick out the max shape to allocate enough memory for the binding.
shape
=
engine
.
get_tensor_shape
(
binding
)
if
profile_idx
is
None
else
engine
.
get_tensor_profile_shape
(
binding
,
profile_idx
)[
-
1
]
shape_valid
=
np
.
all
([
s
>=
0
for
s
in
shape
])
if
not
shape_valid
and
profile_idx
is
None
:
raise
ValueError
(
f
"Binding
{
binding
}
has dynamic shape, "
+
\
"but no profile was specified."
)
size
=
trt
.
volume
(
shape
)
trt_type
=
engine
.
get_tensor_dtype
(
binding
)
# Allocate host and device buffers
try
:
dtype
=
np
.
dtype
(
trt
.
nptype
(
trt_type
))
bindingMemory
=
HostDeviceMem
(
size
,
dtype
)
except
TypeError
:
# no numpy support: create a byte array instead (BF16, FP8, INT4)
size
=
int
(
size
*
trt_type
.
itemsize
)
bindingMemory
=
HostDeviceMem
(
size
)
# Append the device buffer to device bindings.
bindings
.
append
(
int
(
bindingMemory
.
device
))
# Append to the appropriate list.
if
engine
.
get_tensor_mode
(
binding
)
==
trt
.
TensorIOMode
.
INPUT
:
inputs
.
append
(
bindingMemory
)
else
:
outputs
.
append
(
bindingMemory
)
return
inputs
,
outputs
,
bindings
,
stream
# Frees the resources allocated in allocate_buffers
def
free_buffers
(
inputs
:
List
[
HostDeviceMem
],
outputs
:
List
[
HostDeviceMem
],
stream
:
cudart
.
cudaStream_t
):
for
mem
in
inputs
+
outputs
:
mem
.
free
()
cuda_call
(
cudart
.
cudaStreamDestroy
(
stream
))
# Wrapper for cudaMemcpy which infers copy size and does error checking
def
memcpy_host_to_device
(
device_ptr
:
int
,
host_arr
:
np
.
ndarray
):
nbytes
=
host_arr
.
size
*
host_arr
.
itemsize
cuda_call
(
cudart
.
cudaMemcpy
(
device_ptr
,
host_arr
,
nbytes
,
cudart
.
cudaMemcpyKind
.
cudaMemcpyHostToDevice
))
# Wrapper for cudaMemcpy which infers copy size and does error checking
def
memcpy_device_to_host
(
host_arr
:
np
.
ndarray
,
device_ptr
:
int
):
nbytes
=
host_arr
.
size
*
host_arr
.
itemsize
cuda_call
(
cudart
.
cudaMemcpy
(
host_arr
,
device_ptr
,
nbytes
,
cudart
.
cudaMemcpyKind
.
cudaMemcpyDeviceToHost
))
def
_do_inference_base
(
inputs
,
outputs
,
stream
,
execute_async_func
):
# Transfer input data to the GPU.
kind
=
cudart
.
cudaMemcpyKind
.
cudaMemcpyHostToDevice
[
cuda_call
(
cudart
.
cudaMemcpyAsync
(
inp
.
device
,
inp
.
host
,
inp
.
nbytes
,
kind
,
stream
))
for
inp
in
inputs
]
# Run inference.
execute_async_func
()
# Transfer predictions back from the GPU.
kind
=
cudart
.
cudaMemcpyKind
.
cudaMemcpyDeviceToHost
[
cuda_call
(
cudart
.
cudaMemcpyAsync
(
out
.
host
,
out
.
device
,
out
.
nbytes
,
kind
,
stream
))
for
out
in
outputs
]
# Synchronize the stream
cuda_call
(
cudart
.
cudaStreamSynchronize
(
stream
))
# Return only the host outputs.
return
[
out
.
host
for
out
in
outputs
]
# This function is generalized for multiple inputs/outputs.
# inputs and outputs are expected to be lists of HostDeviceMem objects.
def
do_inference
(
context
,
engine
,
bindings
,
inputs
,
outputs
,
stream
):
def
execute_async_func
():
context
.
execute_async_v3
(
stream_handle
=
stream
)
# Setup context tensor address.
num_io
=
engine
.
num_io_tensors
for
i
in
range
(
num_io
):
context
.
set_tensor_address
(
engine
.
get_tensor_name
(
i
),
bindings
[
i
])
return
_do_inference_base
(
inputs
,
outputs
,
stream
,
execute_async_func
)
\ No newline at end of file
lightx2v/common/offload/me_block.py
0 → 100644
View file @
daf4c74e
import
torch
import
torch.nn
as
nn
class
MemoryEfficientBlocks
(
nn
.
Module
):
def
__init__
(
self
,
block_class
,
num_blocks
,
**
block_params
):
super
().
__init__
()
self
.
block_class
=
block_class
self
.
num_blocks
=
num_blocks
self
.
block_params
=
block_params
# 初始化两个block
self
.
active_blocks
=
nn
.
ModuleList
([
block_class
(
**
block_params
)
for
_
in
range
(
2
)
])
# 为权重加载创建独立的CUDA流,并设置优先级
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=-
1
)
# 高优先级
self
.
load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
# 普通优先级
# 预分配固定内存用于异步传输
self
.
pinned_memory
=
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
memory
.
set_per_process_memory_fraction
(
0.8
)
# 限制GPU内存使用
# 用于存储预加载的权重
# self.next_weights = None
self
.
weight_buffer
=
[]
# self.current_block_idx = 0
def
initialize_weights
(
self
,
checkpoint
,
key
):
"""加载所有权重到CPU内存"""
# checkpoint = torch.load(checkpoint_path, map_location='cpu')
for
i
in
range
(
self
.
num_blocks
):
block_weights
=
{
k
.
replace
(
f
'
{
key
}
.
{
i
}
.'
,
''
):
v
for
k
,
v
in
checkpoint
.
items
()
if
f
'
{
key
}
.
{
i
}
.'
in
k
}
self
.
weight_buffer
.
append
(
block_weights
)
def
prefetch_weights
(
self
,
block_idx
):
"""在独立CUDA流中预加载下一个block的权重"""
with
torch
.
cuda
.
stream
(
self
.
load_stream
):
next_weights
=
self
.
weight_buffer
[
block_idx
]
next_weights
=
{
k
:
v
.
cuda
(
non_blocking
=
True
)
for
k
,
v
in
next_weights
.
items
()
}
self
.
active_blocks
[
1
].
load_state_dict
(
next_weights
)
def
swap_blocks
(
self
):
"""交换两个block并更新权重"""
# 等待计算完成
self
.
compute_stream
.
synchronize
()
# 等待加载完成
self
.
load_stream
.
synchronize
()
# 交换blocks
self
.
active_blocks
[
0
],
self
.
active_blocks
[
1
]
=
\
self
.
active_blocks
[
1
],
self
.
active_blocks
[
0
]
def
forward
(
self
,
*
args
,
**
kwargs
):
"""前向传播,同时进行计算和权重加载"""
# import pdb; pdb.set_trace()
for
i
in
range
(
self
.
num_blocks
):
if
i
==
0
:
self
.
active_blocks
[
0
].
load_state_dict
(
self
.
weight_buffer
[
0
])
# 在主计算流中进行当前block的计算
with
torch
.
cuda
.
stream
(
self
.
compute_stream
):
current_block
=
self
.
active_blocks
[
0
]
outputs
=
current_block
(
*
args
,
**
kwargs
)
# 解包参数传入
# import pdb; pdb.set_trace()
# 在独立流中预加载下一个block的权重
if
i
<
self
.
num_blocks
-
1
:
self
.
prefetch_weights
(
i
+
1
)
# 交换blocks并更新权重
self
.
swap_blocks
()
# 更新args中的输入为当前输出
args
=
list
(
args
)
if
len
(
outputs
)
==
1
:
args
[
0
]
=
outputs
else
:
for
i
in
range
(
len
(
outputs
)):
args
[
i
]
=
outputs
[
i
]
args
=
tuple
(
args
)
return
outputs
\ No newline at end of file
lightx2v/common/ops/__init__.py
0 → 100755
View file @
daf4c74e
from
.mm
import
*
from
.norm
import
*
from
.conv
import
*
lightx2v/common/ops/conv/__init__.py
0 → 100755
View file @
daf4c74e
from
.conv2d
import
*
from
.conv3d
import
*
lightx2v/common/ops/conv/conv2d.py
0 → 100644
View file @
daf4c74e
import
torch
from
abc
import
ABCMeta
,
abstractmethod
from
lightx2v.utils.registry_factory
import
CONV2D_WEIGHT_REGISTER
class
Conv2dWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
groups
=
groups
self
.
config
=
{}
@
abstractmethod
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
@
CONV2D_WEIGHT_REGISTER
(
'Default'
)
class
Conv2dWeight
(
Conv2dWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
):
super
().
__init__
(
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
)
def
load
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv2d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
)
return
input_tensor
def
to_cpu
(
self
):
self
.
weight
=
self
.
weight
.
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
()
def
to_cuda
(
self
):
self
.
weight
=
self
.
weight
.
cuda
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
lightx2v/common/ops/conv/conv3d.py
0 → 100644
View file @
daf4c74e
import
torch
from
abc
import
ABCMeta
,
abstractmethod
from
lightx2v.utils.registry_factory
import
CONV3D_WEIGHT_REGISTER
class
Conv3dWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
stride
=
stride
self
.
padding
=
padding
self
.
dilation
=
dilation
self
.
groups
=
groups
self
.
config
=
{}
@
abstractmethod
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
@
CONV3D_WEIGHT_REGISTER
(
'Default'
)
class
Conv3dWeight
(
Conv3dWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
):
super
().
__init__
(
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
)
def
load
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv3d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
)
return
input_tensor
def
to_cpu
(
self
):
self
.
weight
=
self
.
weight
.
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
()
def
to_cuda
(
self
):
self
.
weight
=
self
.
weight
.
cuda
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
lightx2v/common/ops/mm/__init__.py
0 → 100755
View file @
daf4c74e
from
.mm_weight
import
*
from
.mm_weight_calib
import
*
lightx2v/common/ops/mm/mm_weight.py
0 → 100755
View file @
daf4c74e
import
torch
from
abc
import
ABCMeta
,
abstractmethod
from
vllm
import
_custom_ops
as
ops
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
config
=
{}
@
abstractmethod
def
load
(
self
,
weight_dict
):
pass
@
abstractmethod
def
apply
(
self
,
input_tensor
):
pass
def
set_config
(
self
,
config
=
None
):
if
config
is
not
None
:
self
.
config
=
config
@
MM_WEIGHT_REGISTER
(
'Default'
)
class
MMWeight
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
def
load
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
if
self
.
bias
is
None
:
return
torch
.
mm
(
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
return
torch
.
addmm
(
self
.
bias
,
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
def
to_cpu
(
self
):
self
.
weight
=
self
.
weight
.
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
()
def
to_cuda
(
self
):
self
.
weight
=
self
.
weight
.
cuda
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
@
MM_WEIGHT_REGISTER
(
'Default-Force-FP32'
)
class
MMWeight
(
MMWeight
):
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
def
load
(
self
,
weight_dict
):
super
().
load
(
weight_dict
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float32
)
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
torch
.
float32
)
@
MM_WEIGHT_REGISTER
(
'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm'
)
class
MMWeightWfp8channelAfp8channeldynamicVllm
(
MMWeightTemplate
):
'''
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: vllm
'''
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
def
load
(
self
,
weight_dict
):
if
self
.
config
.
get
(
'weight_auto_quant'
,
True
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
torch
.
float32
).
cuda
()
w_quantizer
=
FloatQuantizer
(
'e4m3'
,
True
,
'channel'
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
).
t
().
cuda
()
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
).
cuda
()
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
'.weight_scale'
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_tensor
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
qinput
,
self
.
weight
,
x_scale
,
self
.
weight_scale
,
self
.
bias
)
return
output_tensor
def
to_cpu
(
self
):
self
.
weight
=
self
.
weight
.
cpu
()
self
.
weight_scale
=
self
.
weight_scale
.
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
()
def
to_cuda
(
self
):
self
.
weight
=
self
.
weight
.
cuda
()
self
.
weight_scale
=
self
.
weight_scale
.
cuda
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
@
MM_WEIGHT_REGISTER
(
'W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm'
)
class
MMWeightWfp8channelAfp8channeldynamicVllm
(
MMWeightTemplate
):
'''
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: vllm
'''
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
def
load
(
self
,
weight_dict
):
if
self
.
config
.
get
(
'weight_auto_quant'
,
True
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
torch
.
float32
).
cuda
()
w_quantizer
=
IntegerQuantizer
(
8
,
True
,
'channel'
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
int8
).
t
().
cuda
()
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
).
cuda
()
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
'.weight_scale'
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
qinput
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input_tensor
,
scale
=
None
,
azp
=
None
,
symmetric
=
True
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
qinput
,
self
.
weight
,
x_scale
,
self
.
weight_scale
,
self
.
bias
)
return
output_tensor
def
to_cpu
(
self
):
self
.
weight
=
self
.
weight
.
cpu
()
self
.
weight_scale
=
self
.
weight_scale
.
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cpu
()
def
to_cuda
(
self
):
self
.
weight
=
self
.
weight
.
cuda
()
self
.
weight_scale
=
self
.
weight_scale
.
cuda
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
()
if
__name__
==
'__main__'
:
weight_dict
=
{
'xx.weight'
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
'xx.bias'
:
torch
.
randn
(
8192
).
to
(
torch
.
bfloat16
),
'xx.weight_scale'
:
torch
.
randn
(
8192
,
1
).
to
(
torch
.
float32
),
}
mm_weight
=
MM_WEIGHT_REGISTER
[
'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm'
](
'xx.weight'
,
'xx.bias'
)
mm_weight
.
set_config
({
'weight_auto_quant'
:
False
})
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
weight_dict
=
{
'xx.weight'
:
torch
.
randn
(
8192
,
4096
),
'xx.bias'
:
torch
.
randn
(
8192
).
to
(
torch
.
bfloat16
),
}
mm_weight
=
MM_WEIGHT_REGISTER
[
'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm'
](
'xx.weight'
,
'xx.bias'
)
mm_weight
.
set_config
({
'weight_auto_quant'
:
True
})
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
weight_dict
=
{
'xx.weight'
:
torch
.
randn
(
8192
,
4096
),
'xx.bias'
:
torch
.
randn
(
8192
).
to
(
torch
.
bfloat16
),
}
mm_weight
=
MM_WEIGHT_REGISTER
[
'W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm'
](
'xx.weight'
,
'xx.bias'
)
mm_weight
.
set_config
({
'weight_auto_quant'
:
True
})
mm_weight
.
load
(
weight_dict
)
input_tensor
=
torch
.
randn
(
1024
,
4096
).
to
(
torch
.
bfloat16
).
cuda
()
output_tensor
=
mm_weight
.
apply
(
input_tensor
)
print
(
output_tensor
.
shape
)
\ No newline at end of file
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment