Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
TS-MODELS-OPT
training
Video-Generation-Model
Commits
c07946d8
Commit
c07946d8
authored
Apr 09, 2026
by
hepj
Browse files
dit & video
parents
Changes
270
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5607 additions
and
0 deletions
+5607
-0
FastVideo-main/fastvideo/v1/logging_utils/formatter.py
FastVideo-main/fastvideo/v1/logging_utils/formatter.py
+18
-0
FastVideo-main/fastvideo/v1/models/__init__.py
FastVideo-main/fastvideo/v1/models/__init__.py
+0
-0
FastVideo-main/fastvideo/v1/models/dits/base.py
FastVideo-main/fastvideo/v1/models/dits/base.py
+65
-0
FastVideo-main/fastvideo/v1/models/dits/hunyuanvideo.py
FastVideo-main/fastvideo/v1/models/dits/hunyuanvideo.py
+836
-0
FastVideo-main/fastvideo/v1/models/dits/wanvideo.py
FastVideo-main/fastvideo/v1/models/dits/wanvideo.py
+489
-0
FastVideo-main/fastvideo/v1/models/encoders/base.py
FastVideo-main/fastvideo/v1/models/encoders/base.py
+59
-0
FastVideo-main/fastvideo/v1/models/encoders/clip.py
FastVideo-main/fastvideo/v1/models/encoders/clip.py
+639
-0
FastVideo-main/fastvideo/v1/models/encoders/llama.py
FastVideo-main/fastvideo/v1/models/encoders/llama.py
+437
-0
FastVideo-main/fastvideo/v1/models/encoders/t5.py
FastVideo-main/fastvideo/v1/models/encoders/t5.py
+676
-0
FastVideo-main/fastvideo/v1/models/encoders/vision.py
FastVideo-main/fastvideo/v1/models/encoders/vision.py
+92
-0
FastVideo-main/fastvideo/v1/models/hf_transformer_utils.py
FastVideo-main/fastvideo/v1/models/hf_transformer_utils.py
+152
-0
FastVideo-main/fastvideo/v1/models/loader/__init__.py
FastVideo-main/fastvideo/v1/models/loader/__init__.py
+0
-0
FastVideo-main/fastvideo/v1/models/loader/component_loader.py
...Video-main/fastvideo/v1/models/loader/component_loader.py
+506
-0
FastVideo-main/fastvideo/v1/models/loader/fsdp_load.py
FastVideo-main/fastvideo/v1/models/loader/fsdp_load.py
+254
-0
FastVideo-main/fastvideo/v1/models/loader/utils.py
FastVideo-main/fastvideo/v1/models/loader/utils.py
+18
-0
FastVideo-main/fastvideo/v1/models/loader/weight_utils.py
FastVideo-main/fastvideo/v1/models/loader/weight_utils.py
+341
-0
FastVideo-main/fastvideo/v1/models/parameter.py
FastVideo-main/fastvideo/v1/models/parameter.py
+410
-0
FastVideo-main/fastvideo/v1/models/registry.py
FastVideo-main/fastvideo/v1/models/registry.py
+307
-0
FastVideo-main/fastvideo/v1/models/schedulers/base.py
FastVideo-main/fastvideo/v1/models/schedulers/base.py
+46
-0
FastVideo-main/fastvideo/v1/models/schedulers/scheduling_flow_match_euler_discrete.py
...models/schedulers/scheduling_flow_match_euler_discrete.py
+262
-0
No files found.
Too many changes to show.
To preserve performance only
270 of 270+
files are displayed.
Plain diff
Email patch
FastVideo-main/fastvideo/v1/logging_utils/formatter.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/logging_utils/formatter.py
import
logging
class
NewLineFormatter
(
logging
.
Formatter
):
"""Adds logging prefix to newlines to align multi-line messages."""
def
__init__
(
self
,
fmt
,
datefmt
=
None
,
style
=
"%"
):
logging
.
Formatter
.
__init__
(
self
,
fmt
,
datefmt
,
style
)
def
format
(
self
,
record
):
msg
=
logging
.
Formatter
.
format
(
self
,
record
)
if
record
.
message
!=
""
:
parts
=
msg
.
split
(
record
.
message
)
msg
=
msg
.
replace
(
"
\n
"
,
"
\r\n
"
+
parts
[
0
])
return
msg
FastVideo-main/fastvideo/v1/models/__init__.py
0 → 100644
View file @
c07946d8
FastVideo-main/fastvideo/v1/models/dits/base.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
fastvideo.v1.configs.models
import
DiTConfig
from
fastvideo.v1.platforms
import
_Backend
# TODO
class
BaseDiT
(
nn
.
Module
,
ABC
):
_fsdp_shard_conditions
:
list
=
[]
_param_names_mapping
:
dict
hidden_size
:
int
num_attention_heads
:
int
num_channels_latents
:
int
# always supports torch_sdpa
_supported_attention_backends
:
Tuple
[
_Backend
,
...]
=
DiTConfig
().
_supported_attention_backends
def
__init_subclass__
(
cls
)
->
None
:
required_class_attrs
=
[
"_fsdp_shard_conditions"
,
"_param_names_mapping"
]
super
().
__init_subclass__
()
for
attr
in
required_class_attrs
:
if
not
hasattr
(
cls
,
attr
):
raise
AttributeError
(
f
"Subclasses of BaseDiT must define '
{
attr
}
' class variable"
)
def
__init__
(
self
,
config
:
DiTConfig
,
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
config
=
config
if
not
self
.
supported_attention_backends
:
raise
ValueError
(
f
"Subclass
{
self
.
__class__
.
__name__
}
must define _supported_attention_backends"
)
@
abstractmethod
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
timestep
:
torch
.
LongTensor
,
encoder_hidden_states_image
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
,
guidance
=
None
,
**
kwargs
)
->
torch
.
Tensor
:
pass
def
__post_init__
(
self
)
->
None
:
required_attrs
=
[
"hidden_size"
,
"num_attention_heads"
,
"num_channels_latents"
]
for
attr
in
required_attrs
:
if
not
hasattr
(
self
,
attr
):
raise
AttributeError
(
f
"Subclasses of BaseDiT must define '
{
attr
}
' instance variable"
)
@
property
def
supported_attention_backends
(
self
)
->
Tuple
[
_Backend
,
...]:
return
self
.
_supported_attention_backends
FastVideo-main/fastvideo/v1/models/dits/hunyuanvideo.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
fastvideo.v1.attention
import
DistributedAttention
,
LocalAttention
from
fastvideo.v1.configs.models.dits
import
HunyuanVideoConfig
from
fastvideo.v1.distributed.parallel_state
import
(
get_sequence_model_parallel_world_size
)
from
fastvideo.v1.layers.layernorm
import
(
LayerNormScaleShift
,
ScaleResidual
,
ScaleResidualLayerNormScaleShift
)
from
fastvideo.v1.layers.linear
import
ReplicatedLinear
# TODO(will-PY-refactor): RMSNorm ....
from
fastvideo.v1.layers.mlp
import
MLP
from
fastvideo.v1.layers.rotary_embedding
import
(
_apply_rotary_emb
,
get_rotary_pos_embed
)
from
fastvideo.v1.layers.visual_embedding
import
(
ModulateProjection
,
PatchEmbed
,
TimestepEmbedder
,
unpatchify
)
from
fastvideo.v1.models.dits.base
import
BaseDiT
from
fastvideo.v1.platforms
import
_Backend
class
HunyuanRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
elementwise_affine
=
True
,
eps
:
float
=
1e-6
,
device
=
None
,
dtype
=
None
,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs
=
{
"device"
:
device
,
"dtype"
:
dtype
}
super
().
__init__
()
self
.
eps
=
eps
if
elementwise_affine
:
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
,
**
factory_kwargs
))
def
_norm
(
self
,
x
)
->
torch
.
Tensor
:
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
if
hasattr
(
self
,
"weight"
):
output
=
output
*
self
.
weight
return
output
class
MMDoubleStreamBlock
(
nn
.
Module
):
"""
A multimodal DiT block with separate modulation for text and image/video,
using distributed attention and linear layers.
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
mlp_ratio
:
float
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
supported_attention_backends
:
Optional
[
Tuple
[
_Backend
,
...]]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
deterministic
=
False
self
.
num_attention_heads
=
num_attention_heads
head_dim
=
hidden_size
//
num_attention_heads
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_ratio
)
# Image modulation components
self
.
img_mod
=
ModulateProjection
(
hidden_size
,
factor
=
6
,
act_layer
=
"silu"
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.img_mod"
,
)
# Fused operations for image stream
self
.
img_attn_norm
=
LayerNormScaleShift
(
hidden_size
,
norm_type
=
"layer"
,
elementwise_affine
=
False
,
dtype
=
dtype
)
self
.
img_attn_residual_mlp_norm
=
ScaleResidualLayerNormScaleShift
(
hidden_size
,
norm_type
=
"layer"
,
elementwise_affine
=
False
,
dtype
=
dtype
)
self
.
img_mlp_residual
=
ScaleResidual
()
# Image attention components
self
.
img_attn_qkv
=
ReplicatedLinear
(
hidden_size
,
hidden_size
*
3
,
bias
=
True
,
params_dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.img_attn_qkv"
)
self
.
img_attn_q_norm
=
HunyuanRMSNorm
(
head_dim
,
eps
=
1e-6
,
dtype
=
dtype
)
self
.
img_attn_k_norm
=
HunyuanRMSNorm
(
head_dim
,
eps
=
1e-6
,
dtype
=
dtype
)
self
.
img_attn_proj
=
ReplicatedLinear
(
hidden_size
,
hidden_size
,
bias
=
True
,
params_dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.img_attn_proj"
)
self
.
img_mlp
=
MLP
(
hidden_size
,
mlp_hidden_dim
,
bias
=
True
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.img_mlp"
)
# Text modulation components
self
.
txt_mod
=
ModulateProjection
(
hidden_size
,
factor
=
6
,
act_layer
=
"silu"
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.txt_mod"
,
)
# Fused operations for text stream
self
.
txt_attn_norm
=
LayerNormScaleShift
(
hidden_size
,
norm_type
=
"layer"
,
elementwise_affine
=
False
,
dtype
=
dtype
)
self
.
txt_attn_residual_mlp_norm
=
ScaleResidualLayerNormScaleShift
(
hidden_size
,
norm_type
=
"layer"
,
elementwise_affine
=
False
,
dtype
=
dtype
)
self
.
txt_mlp_residual
=
ScaleResidual
()
# Text attention components
self
.
txt_attn_qkv
=
ReplicatedLinear
(
hidden_size
,
hidden_size
*
3
,
bias
=
True
,
params_dtype
=
dtype
)
# QK norm layers for text
self
.
txt_attn_q_norm
=
HunyuanRMSNorm
(
head_dim
,
eps
=
1e-6
,
dtype
=
dtype
)
self
.
txt_attn_k_norm
=
HunyuanRMSNorm
(
head_dim
,
eps
=
1e-6
,
dtype
=
dtype
)
self
.
txt_attn_proj
=
ReplicatedLinear
(
hidden_size
,
hidden_size
,
bias
=
True
,
params_dtype
=
dtype
)
self
.
txt_mlp
=
MLP
(
hidden_size
,
mlp_hidden_dim
,
bias
=
True
,
dtype
=
dtype
)
# Distributed attention
self
.
attn
=
DistributedAttention
(
num_heads
=
num_attention_heads
,
head_size
=
head_dim
,
causal
=
False
,
supported_attention_backends
=
supported_attention_backends
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
img
:
torch
.
Tensor
,
txt
:
torch
.
Tensor
,
vec
:
torch
.
Tensor
,
freqs_cis
:
tuple
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Process modulation vectors
img_mod_outputs
=
self
.
img_mod
(
vec
)
(
img_attn_shift
,
img_attn_scale
,
img_attn_gate
,
img_mlp_shift
,
img_mlp_scale
,
img_mlp_gate
,
)
=
torch
.
chunk
(
img_mod_outputs
,
6
,
dim
=-
1
)
txt_mod_outputs
=
self
.
txt_mod
(
vec
)
(
txt_attn_shift
,
txt_attn_scale
,
txt_attn_gate
,
txt_mlp_shift
,
txt_mlp_scale
,
txt_mlp_gate
,
)
=
torch
.
chunk
(
txt_mod_outputs
,
6
,
dim
=-
1
)
# Prepare image for attention using fused operation
img_attn_input
=
self
.
img_attn_norm
(
img
,
img_attn_shift
,
img_attn_scale
)
# Get QKV for image
img_qkv
,
_
=
self
.
img_attn_qkv
(
img_attn_input
)
batch_size
,
image_seq_len
=
img_qkv
.
shape
[
0
],
img_qkv
.
shape
[
1
]
# Split QKV
img_qkv
=
img_qkv
.
view
(
batch_size
,
image_seq_len
,
3
,
self
.
num_attention_heads
,
-
1
)
img_q
,
img_k
,
img_v
=
img_qkv
[:,
:,
0
],
img_qkv
[:,
:,
1
],
img_qkv
[:,
:,
2
]
# Apply QK-Norm if needed
img_q
=
self
.
img_attn_q_norm
(
img_q
).
to
(
img_v
)
img_k
=
self
.
img_attn_k_norm
(
img_k
).
to
(
img_v
)
# Apply rotary embeddings
cos
,
sin
=
freqs_cis
img_q
,
img_k
=
_apply_rotary_emb
(
img_q
,
cos
,
sin
,
is_neox_style
=
False
),
_apply_rotary_emb
(
img_k
,
cos
,
sin
,
is_neox_style
=
False
)
# Prepare text for attention using fused operation
txt_attn_input
=
self
.
txt_attn_norm
(
txt
,
txt_attn_shift
,
txt_attn_scale
)
# Get QKV for text
txt_qkv
,
_
=
self
.
txt_attn_qkv
(
txt_attn_input
)
batch_size
,
text_seq_len
=
txt_qkv
.
shape
[
0
],
txt_qkv
.
shape
[
1
]
# Split QKV
txt_qkv
=
txt_qkv
.
view
(
batch_size
,
text_seq_len
,
3
,
self
.
num_attention_heads
,
-
1
)
txt_q
,
txt_k
,
txt_v
=
txt_qkv
[:,
:,
0
],
txt_qkv
[:,
:,
1
],
txt_qkv
[:,
:,
2
]
# Apply QK-Norm if needed
txt_q
=
self
.
txt_attn_q_norm
(
txt_q
).
to
(
txt_q
.
dtype
)
txt_k
=
self
.
txt_attn_k_norm
(
txt_k
).
to
(
txt_k
.
dtype
)
# Run distributed attention
img_attn
,
txt_attn
=
self
.
attn
(
img_q
,
img_k
,
img_v
,
txt_q
,
txt_k
,
txt_v
)
img_attn_out
,
_
=
self
.
img_attn_proj
(
img_attn
.
view
(
batch_size
,
image_seq_len
,
-
1
))
# Use fused operation for residual connection, normalization, and modulation
img_mlp_input
,
img_residual
=
self
.
img_attn_residual_mlp_norm
(
img
,
img_attn_out
,
img_attn_gate
,
img_mlp_shift
,
img_mlp_scale
)
# Process image MLP
img_mlp_out
=
self
.
img_mlp
(
img_mlp_input
)
img
=
self
.
img_mlp_residual
(
img_residual
,
img_mlp_out
,
img_mlp_gate
)
# Process text attention output
txt_attn_out
,
_
=
self
.
txt_attn_proj
(
txt_attn
.
reshape
(
batch_size
,
text_seq_len
,
-
1
))
# Use fused operation for residual connection, normalization, and modulation
txt_mlp_input
,
txt_residual
=
self
.
txt_attn_residual_mlp_norm
(
txt
,
txt_attn_out
,
txt_attn_gate
,
txt_mlp_shift
,
txt_mlp_scale
)
# Process text MLP
txt_mlp_out
=
self
.
txt_mlp
(
txt_mlp_input
)
txt
=
self
.
txt_mlp_residual
(
txt_residual
,
txt_mlp_out
,
txt_mlp_gate
)
return
img
,
txt
class
MMSingleStreamBlock
(
nn
.
Module
):
"""
A DiT block with parallel linear layers using distributed attention
and tensor parallelism.
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
mlp_ratio
:
float
=
4.0
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
supported_attention_backends
:
Optional
[
Tuple
[
_Backend
,
...]]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
deterministic
=
False
self
.
hidden_size
=
hidden_size
self
.
num_attention_heads
=
num_attention_heads
head_dim
=
hidden_size
//
num_attention_heads
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_ratio
)
self
.
mlp_hidden_dim
=
mlp_hidden_dim
# Combined QKV and MLP input projection
self
.
linear1
=
ReplicatedLinear
(
hidden_size
,
hidden_size
*
3
+
mlp_hidden_dim
,
bias
=
True
,
params_dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.linear1"
)
# Combined projection and MLP output
self
.
linear2
=
ReplicatedLinear
(
hidden_size
+
mlp_hidden_dim
,
hidden_size
,
bias
=
True
,
params_dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.linear2"
)
# QK norm layers
self
.
q_norm
=
HunyuanRMSNorm
(
head_dim
,
eps
=
1e-6
,
dtype
=
dtype
)
self
.
k_norm
=
HunyuanRMSNorm
(
head_dim
,
eps
=
1e-6
,
dtype
=
dtype
)
# Fused operations with better naming
self
.
input_norm_scale_shift
=
LayerNormScaleShift
(
hidden_size
,
norm_type
=
"layer"
,
eps
=
1e-6
,
elementwise_affine
=
False
,
dtype
=
dtype
)
self
.
output_residual
=
ScaleResidual
()
# Activation function
self
.
mlp_act
=
nn
.
GELU
(
approximate
=
"tanh"
)
# Modulation
self
.
modulation
=
ModulateProjection
(
hidden_size
,
factor
=
3
,
act_layer
=
"silu"
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.modulation"
)
# Distributed attention
self
.
attn
=
DistributedAttention
(
num_heads
=
num_attention_heads
,
head_size
=
head_dim
,
causal
=
False
,
supported_attention_backends
=
supported_attention_backends
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
vec
:
torch
.
Tensor
,
txt_len
:
int
,
freqs_cis
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Process modulation
mod_shift
,
mod_scale
,
mod_gate
=
self
.
modulation
(
vec
).
chunk
(
3
,
dim
=-
1
)
# Apply pre-norm and modulation using fused operation
x_mod
=
self
.
input_norm_scale_shift
(
x
,
mod_shift
,
mod_scale
)
# Get combined projections
linear1_out
,
_
=
self
.
linear1
(
x_mod
)
# Split into QKV and MLP parts
qkv
,
mlp
=
torch
.
split
(
linear1_out
,
[
3
*
self
.
hidden_size
,
self
.
mlp_hidden_dim
],
dim
=-
1
)
# Process QKV
batch_size
,
seq_len
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
qkv
=
qkv
.
view
(
batch_size
,
seq_len
,
3
,
self
.
num_attention_heads
,
-
1
)
q
,
k
,
v
=
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
]
# Apply QK-Norm
q
=
self
.
q_norm
(
q
).
to
(
v
.
dtype
)
k
=
self
.
k_norm
(
k
).
to
(
v
.
dtype
)
# Split into image and text parts
img_q
,
txt_q
=
q
[:,
:
-
txt_len
],
q
[:,
-
txt_len
:]
img_k
,
txt_k
=
k
[:,
:
-
txt_len
],
k
[:,
-
txt_len
:]
img_v
,
txt_v
=
v
[:,
:
-
txt_len
],
v
[:,
-
txt_len
:]
# Apply rotary embeddings to image parts
cos
,
sin
=
freqs_cis
img_q
,
img_k
=
_apply_rotary_emb
(
img_q
,
cos
,
sin
,
is_neox_style
=
False
),
_apply_rotary_emb
(
img_k
,
cos
,
sin
,
is_neox_style
=
False
)
# Run distributed attention
img_attn_output
,
txt_attn_output
=
self
.
attn
(
img_q
,
img_k
,
img_v
,
txt_q
,
txt_k
,
txt_v
)
attn_output
=
torch
.
cat
((
img_attn_output
,
txt_attn_output
),
dim
=
1
).
view
(
batch_size
,
seq_len
,
-
1
)
# Process MLP activation
mlp_output
=
self
.
mlp_act
(
mlp
)
# Combine attention and MLP outputs
combined
=
torch
.
cat
((
attn_output
,
mlp_output
),
dim
=-
1
)
# Final projection
output
,
_
=
self
.
linear2
(
combined
)
# Apply residual connection with gating using fused operation
return
self
.
output_residual
(
x
,
output
,
mod_gate
)
class
HunyuanVideoTransformer3DModel
(
BaseDiT
):
"""
HunyuanVideo Transformer backbone adapted for distributed training.
This implementation uses distributed attention and linear layers for efficient
parallel processing across multiple GPUs.
Based on the architecture from:
- Flux.1: https://github.com/black-forest-labs/flux
- MMDiT: http://arxiv.org/abs/2403.03206
"""
# PY: we make the input args the same as HF config
# shard single stream, double stream blocks, and refiner_blocks
_fsdp_shard_conditions
=
HunyuanVideoConfig
().
_fsdp_shard_conditions
_supported_attention_backends
=
HunyuanVideoConfig
(
).
_supported_attention_backends
_param_names_mapping
=
HunyuanVideoConfig
().
_param_names_mapping
def
__init__
(
self
,
config
:
HunyuanVideoConfig
):
super
().
__init__
(
config
=
config
)
self
.
patch_size
=
[
config
.
patch_size_t
,
config
.
patch_size
,
config
.
patch_size
]
self
.
in_channels
=
config
.
in_channels
self
.
num_channels_latents
=
config
.
num_channels_latents
self
.
out_channels
=
config
.
in_channels
if
config
.
out_channels
is
None
else
config
.
out_channels
self
.
unpatchify_channels
=
self
.
out_channels
self
.
guidance_embeds
=
config
.
guidance_embeds
self
.
rope_dim_list
=
list
(
config
.
rope_axes_dim
)
self
.
rope_theta
=
config
.
rope_theta
self
.
text_states_dim
=
config
.
text_embed_dim
self
.
text_states_dim_2
=
config
.
pooled_projection_dim
# TODO(will): hack?
self
.
dtype
=
config
.
dtype
pe_dim
=
config
.
hidden_size
//
config
.
num_attention_heads
if
sum
(
config
.
rope_axes_dim
)
!=
pe_dim
:
raise
ValueError
(
f
"Got
{
config
.
rope_axes_dim
}
but expected positional dim
{
pe_dim
}
"
)
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
num_channels_latents
=
config
.
num_channels_latents
# Image projection
self
.
img_in
=
PatchEmbed
(
self
.
patch_size
,
self
.
in_channels
,
self
.
hidden_size
,
dtype
=
config
.
dtype
,
prefix
=
f
"
{
config
.
prefix
}
.img_in"
)
self
.
txt_in
=
SingleTokenRefiner
(
self
.
text_states_dim
,
config
.
hidden_size
,
config
.
num_attention_heads
,
depth
=
config
.
num_refiner_layers
,
dtype
=
config
.
dtype
,
prefix
=
f
"
{
config
.
prefix
}
.txt_in"
)
# Time modulation
self
.
time_in
=
TimestepEmbedder
(
self
.
hidden_size
,
act_layer
=
"silu"
,
dtype
=
config
.
dtype
,
prefix
=
f
"
{
config
.
prefix
}
.time_in"
)
# Text modulation
self
.
vector_in
=
MLP
(
self
.
text_states_dim_2
,
self
.
hidden_size
,
self
.
hidden_size
,
act_type
=
"silu"
,
dtype
=
config
.
dtype
,
prefix
=
f
"
{
config
.
prefix
}
.vector_in"
)
# Guidance modulation
self
.
guidance_in
=
(
TimestepEmbedder
(
self
.
hidden_size
,
act_layer
=
"silu"
,
dtype
=
config
.
dtype
,
prefix
=
f
"
{
config
.
prefix
}
.guidance_in"
)
if
self
.
guidance_embeds
else
None
)
# Double blocks
self
.
double_blocks
=
nn
.
ModuleList
([
MMDoubleStreamBlock
(
config
.
hidden_size
,
config
.
num_attention_heads
,
mlp_ratio
=
config
.
mlp_ratio
,
dtype
=
config
.
dtype
,
supported_attention_backends
=
self
.
_supported_attention_backends
,
prefix
=
f
"
{
config
.
prefix
}
.double_blocks.
{
i
}
"
)
for
i
in
range
(
config
.
num_layers
)
])
# Single blocks
self
.
single_blocks
=
nn
.
ModuleList
([
MMSingleStreamBlock
(
config
.
hidden_size
,
config
.
num_attention_heads
,
mlp_ratio
=
config
.
mlp_ratio
,
dtype
=
config
.
dtype
,
supported_attention_backends
=
self
.
_supported_attention_backends
,
prefix
=
f
"
{
config
.
prefix
}
.single_blocks.
{
i
+
config
.
num_layers
}
"
)
for
i
in
range
(
config
.
num_single_layers
)
])
self
.
final_layer
=
FinalLayer
(
config
.
hidden_size
,
self
.
patch_size
,
self
.
out_channels
,
dtype
=
config
.
dtype
,
prefix
=
f
"
{
config
.
prefix
}
.final_layer"
)
self
.
__post_init__
()
# TODO: change the input the FORWAD_BACTCH Dict
# TODO: change output to a dict
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
timestep
:
torch
.
LongTensor
,
encoder_hidden_states_image
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
,
guidance
=
None
,
**
kwargs
):
"""
Forward pass of the HunyuanDiT model.
Args:
hidden_states: Input image/video latents [B, C, T, H, W]
encoder_hidden_states: Text embeddings [B, L, D]
timestep: Diffusion timestep
guidance: Guidance scale for CFG
Returns:
Tuple of (output)
"""
if
guidance
is
None
:
guidance
=
torch
.
tensor
([
6016.0
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
img
=
x
=
hidden_states
t
=
timestep
# Split text embeddings - first token is global, rest are per-token
if
isinstance
(
encoder_hidden_states
,
torch
.
Tensor
):
txt
=
encoder_hidden_states
[:,
1
:]
text_states_2
=
encoder_hidden_states
[:,
0
,
:
self
.
text_states_dim_2
]
else
:
txt
=
encoder_hidden_states
[
0
]
text_states_2
=
encoder_hidden_states
[
1
]
# Get spatial dimensions
_
,
_
,
ot
,
oh
,
ow
=
x
.
shape
# codespell:ignore
tt
,
th
,
tw
=
(
ot
//
self
.
patch_size
[
0
],
# codespell:ignore
oh
//
self
.
patch_size
[
1
],
ow
//
self
.
patch_size
[
2
],
)
# Get rotary embeddings
freqs_cos
,
freqs_sin
=
get_rotary_pos_embed
(
(
tt
*
get_sequence_model_parallel_world_size
(),
th
,
tw
),
self
.
hidden_size
,
self
.
num_attention_heads
,
self
.
rope_dim_list
,
self
.
rope_theta
)
freqs_cos
=
freqs_cos
.
to
(
x
.
device
)
freqs_sin
=
freqs_sin
.
to
(
x
.
device
)
# Prepare modulation vectors
vec
=
self
.
time_in
(
t
)
# Add text modulation
vec
=
vec
+
self
.
vector_in
(
text_states_2
)
# Add guidance modulation if needed
if
self
.
guidance_in
and
guidance
is
not
None
:
vec
=
vec
+
self
.
guidance_in
(
guidance
)
# Embed image and text
img
=
self
.
img_in
(
img
)
txt
=
self
.
txt_in
(
txt
,
t
)
txt_seq_len
=
txt
.
shape
[
1
]
img_seq_len
=
img
.
shape
[
1
]
freqs_cis
=
(
freqs_cos
,
freqs_sin
)
if
freqs_cos
is
not
None
else
None
# Process through double stream blocks
for
index
,
block
in
enumerate
(
self
.
double_blocks
):
double_block_args
=
[
img
,
txt
,
vec
,
freqs_cis
]
img
,
txt
=
block
(
*
double_block_args
)
# Merge txt and img to pass through single stream blocks
x
=
torch
.
cat
((
img
,
txt
),
1
)
# Process through single stream blocks
if
len
(
self
.
single_blocks
)
>
0
:
for
index
,
block
in
enumerate
(
self
.
single_blocks
):
single_block_args
=
[
x
,
vec
,
txt_seq_len
,
freqs_cis
,
]
x
=
block
(
*
single_block_args
)
# Extract image features
img
=
x
[:,
:
img_seq_len
,
...]
# Final layer processing
img
=
self
.
final_layer
(
img
,
vec
)
# Unpatchify to get original shape
img
=
unpatchify
(
img
,
tt
,
th
,
tw
,
self
.
patch_size
,
self
.
out_channels
)
return
img
class
SingleTokenRefiner
(
nn
.
Module
):
"""
A token refiner that processes text embeddings with attention to improve
their representation for cross-attention with image features.
"""
def
__init__
(
self
,
in_channels
,
hidden_size
,
num_attention_heads
,
depth
=
2
,
qkv_bias
=
True
,
dtype
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
# Input projection
self
.
input_embedder
=
ReplicatedLinear
(
in_channels
,
hidden_size
,
bias
=
True
,
params_dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.input_embedder"
)
# Timestep embedding
self
.
t_embedder
=
TimestepEmbedder
(
hidden_size
,
act_layer
=
"silu"
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.t_embedder"
)
# Context embedding
self
.
c_embedder
=
MLP
(
in_channels
,
hidden_size
,
hidden_size
,
act_type
=
"silu"
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.c_embedder"
)
# Refiner blocks
self
.
refiner_blocks
=
nn
.
ModuleList
([
IndividualTokenRefinerBlock
(
hidden_size
,
num_attention_heads
,
qkv_bias
=
qkv_bias
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.refiner_blocks.
{
i
}
"
,
)
for
i
in
range
(
depth
)
])
def
forward
(
self
,
x
,
t
):
# Get timestep embeddings
timestep_aware_representations
=
self
.
t_embedder
(
t
)
# Get context-aware representations
context_aware_representations
=
torch
.
mean
(
x
,
dim
=
1
)
context_aware_representations
=
self
.
c_embedder
(
context_aware_representations
)
c
=
timestep_aware_representations
+
context_aware_representations
# Project input
x
,
_
=
self
.
input_embedder
(
x
)
# Process through refiner blocks
for
block
in
self
.
refiner_blocks
:
x
=
block
(
x
,
c
)
return
x
class
IndividualTokenRefinerBlock
(
nn
.
Module
):
"""
A transformer block for refining individual tokens with self-attention.
"""
def
__init__
(
self
,
hidden_size
,
num_attention_heads
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
dtype
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
num_attention_heads
=
num_attention_heads
mlp_hidden_dim
=
int
(
hidden_size
*
mlp_ratio
)
# Normalization and attention
self
.
norm1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
,
elementwise_affine
=
True
,
dtype
=
dtype
)
self
.
self_attn_qkv
=
ReplicatedLinear
(
hidden_size
,
hidden_size
*
3
,
bias
=
qkv_bias
,
params_dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.self_attn_qkv"
)
self
.
self_attn_proj
=
ReplicatedLinear
(
hidden_size
,
hidden_size
,
bias
=
qkv_bias
,
params_dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.self_attn_proj"
)
# MLP
self
.
norm2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
,
elementwise_affine
=
True
,
dtype
=
dtype
)
self
.
mlp
=
MLP
(
hidden_size
,
mlp_hidden_dim
,
bias
=
True
,
act_type
=
"silu"
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.mlp"
)
# Modulation
self
.
adaLN_modulation
=
ModulateProjection
(
hidden_size
,
factor
=
2
,
act_layer
=
"silu"
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.adaLN_modulation"
)
# Scaled dot product attention
self
.
attn
=
LocalAttention
(
num_heads
=
num_attention_heads
,
head_size
=
hidden_size
//
num_attention_heads
,
# TODO: remove hardcode; remove STA
supported_attention_backends
=
(
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
),
)
def
forward
(
self
,
x
,
c
):
# Get modulation parameters
gate_msa
,
gate_mlp
=
self
.
adaLN_modulation
(
c
).
chunk
(
2
,
dim
=-
1
)
# Self-attention
norm_x
=
self
.
norm1
(
x
)
qkv
,
_
=
self
.
self_attn_qkv
(
norm_x
)
batch_size
,
seq_len
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
qkv
=
qkv
.
view
(
batch_size
,
seq_len
,
3
,
self
.
num_attention_heads
,
-
1
)
q
,
k
,
v
=
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
]
# Run scaled dot product attention
attn_output
=
self
.
attn
(
q
,
k
,
v
)
# [B, L, H, D]
attn_output
=
attn_output
.
reshape
(
batch_size
,
seq_len
,
-
1
)
# [B, L, H*D]
# Project and apply residual connection with gating
attn_out
,
_
=
self
.
self_attn_proj
(
attn_output
)
x
=
x
+
attn_out
*
gate_msa
.
unsqueeze
(
1
)
# MLP
mlp_out
=
self
.
mlp
(
self
.
norm2
(
x
))
x
=
x
+
mlp_out
*
gate_mlp
.
unsqueeze
(
1
)
return
x
class
FinalLayer
(
nn
.
Module
):
"""
The final layer of DiT that projects features to pixel space.
"""
def
__init__
(
self
,
hidden_size
,
patch_size
,
out_channels
,
dtype
=
None
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
# Normalization
self
.
norm_final
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
,
elementwise_affine
=
False
,
dtype
=
dtype
)
output_dim
=
patch_size
[
0
]
*
patch_size
[
1
]
*
patch_size
[
2
]
*
out_channels
self
.
linear
=
ReplicatedLinear
(
hidden_size
,
output_dim
,
bias
=
True
,
params_dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.linear"
)
# Modulation
self
.
adaLN_modulation
=
ModulateProjection
(
hidden_size
,
factor
=
2
,
act_layer
=
"silu"
,
dtype
=
dtype
,
prefix
=
f
"
{
prefix
}
.adaLN_modulation"
)
def
forward
(
self
,
x
,
c
):
# What the heck HF? Why you change the scale and shift order here???
scale
,
shift
=
self
.
adaLN_modulation
(
c
).
chunk
(
2
,
dim
=-
1
)
x
=
self
.
norm_final
(
x
)
*
(
1.0
+
scale
.
unsqueeze
(
1
))
+
shift
.
unsqueeze
(
1
)
x
,
_
=
self
.
linear
(
x
)
return
x
FastVideo-main/fastvideo/v1/models/dits/wanvideo.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
from
fastvideo.v1.attention
import
DistributedAttention
,
LocalAttention
from
fastvideo.v1.configs.models.dits
import
WanVideoConfig
from
fastvideo.v1.distributed.parallel_state
import
(
get_sequence_model_parallel_world_size
)
from
fastvideo.v1.layers.layernorm
import
(
LayerNormScaleShift
,
RMSNorm
,
ScaleResidual
,
ScaleResidualLayerNormScaleShift
)
from
fastvideo.v1.layers.linear
import
ReplicatedLinear
# from torch.nn import RMSNorm
# TODO: RMSNorm ....
from
fastvideo.v1.layers.mlp
import
MLP
from
fastvideo.v1.layers.rotary_embedding
import
(
_apply_rotary_emb
,
get_rotary_pos_embed
)
from
fastvideo.v1.layers.visual_embedding
import
(
ModulateProjection
,
PatchEmbed
,
TimestepEmbedder
)
from
fastvideo.v1.models.dits.base
import
BaseDiT
from
fastvideo.v1.platforms
import
_Backend
class
WanImageEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_features
:
int
,
out_features
:
int
):
super
().
__init__
()
self
.
norm1
=
nn
.
LayerNorm
(
in_features
)
self
.
ff
=
MLP
(
in_features
,
in_features
,
out_features
,
act_type
=
"gelu"
)
self
.
norm2
=
nn
.
LayerNorm
(
out_features
)
def
forward
(
self
,
encoder_hidden_states_image
:
torch
.
Tensor
)
->
torch
.
Tensor
:
dtype
=
encoder_hidden_states_image
.
dtype
hidden_states
=
self
.
norm1
(
encoder_hidden_states_image
)
hidden_states
=
self
.
ff
(
hidden_states
)
hidden_states
=
self
.
norm2
(
hidden_states
).
to
(
dtype
)
return
hidden_states
class
WanTimeTextImageEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
time_freq_dim
:
int
,
text_embed_dim
:
int
,
image_embed_dim
:
Optional
[
int
]
=
None
,
):
super
().
__init__
()
self
.
time_embedder
=
TimestepEmbedder
(
dim
,
frequency_embedding_size
=
time_freq_dim
,
act_layer
=
"silu"
)
self
.
time_modulation
=
ModulateProjection
(
dim
,
factor
=
6
,
act_layer
=
"silu"
)
self
.
text_embedder
=
MLP
(
text_embed_dim
,
dim
,
dim
,
bias
=
True
,
act_type
=
"gelu_pytorch_tanh"
)
self
.
image_embedder
=
None
if
image_embed_dim
is
not
None
:
self
.
image_embedder
=
WanImageEmbedding
(
image_embed_dim
,
dim
)
def
forward
(
self
,
timestep
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
encoder_hidden_states_image
:
Optional
[
torch
.
Tensor
]
=
None
,
):
temb
=
self
.
time_embedder
(
timestep
)
timestep_proj
=
self
.
time_modulation
(
temb
)
encoder_hidden_states
=
self
.
text_embedder
(
encoder_hidden_states
)
if
encoder_hidden_states_image
is
not
None
:
assert
self
.
image_embedder
is
not
None
encoder_hidden_states_image
=
self
.
image_embedder
(
encoder_hidden_states_image
)
return
temb
,
timestep_proj
,
encoder_hidden_states
,
encoder_hidden_states_image
class
WanSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
,
window_size
=
(
-
1
,
-
1
),
qk_norm
=
True
,
eps
=
1e-6
,
parallel_attention
=
False
)
->
None
:
assert
dim
%
num_heads
==
0
super
().
__init__
()
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
window_size
=
window_size
self
.
qk_norm
=
qk_norm
self
.
eps
=
eps
self
.
parallel_attention
=
parallel_attention
# layers
self
.
to_q
=
ReplicatedLinear
(
dim
,
dim
)
self
.
to_k
=
ReplicatedLinear
(
dim
,
dim
)
self
.
to_v
=
ReplicatedLinear
(
dim
,
dim
)
self
.
to_out
=
ReplicatedLinear
(
dim
,
dim
)
self
.
norm_q
=
RMSNorm
(
dim
,
eps
=
eps
)
if
qk_norm
else
nn
.
Identity
()
self
.
norm_k
=
RMSNorm
(
dim
,
eps
=
eps
)
if
qk_norm
else
nn
.
Identity
()
# Scaled dot product attention
self
.
attn
=
LocalAttention
(
num_heads
=
num_heads
,
head_size
=
self
.
head_dim
,
dropout_rate
=
0
,
softmax_scale
=
None
,
causal
=
False
,
supported_attention_backends
=
(
_Backend
.
FLASH_ATTN
,
_Backend
.
TORCH_SDPA
))
def
forward
(
self
,
x
:
torch
.
Tensor
,
context
:
torch
.
Tensor
,
context_lens
:
int
):
r
"""
Args:
x(Tensor): Shape [B, L, num_heads, C / num_heads]
seq_lens(Tensor): Shape [B]
grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
"""
pass
class
WanT2VCrossAttention
(
WanSelfAttention
):
def
forward
(
self
,
x
,
context
,
context_lens
):
r
"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
b
,
n
,
d
=
x
.
size
(
0
),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
q
=
self
.
norm_q
(
self
.
to_q
(
x
)[
0
]).
view
(
b
,
-
1
,
n
,
d
)
k
=
self
.
norm_k
(
self
.
to_k
(
context
)[
0
]).
view
(
b
,
-
1
,
n
,
d
)
v
=
self
.
to_v
(
context
)[
0
].
view
(
b
,
-
1
,
n
,
d
)
# compute attention
x
=
self
.
attn
(
q
,
k
,
v
)
# output
x
=
x
.
flatten
(
2
)
x
,
_
=
self
.
to_out
(
x
)
return
x
class
WanI2VCrossAttention
(
WanSelfAttention
):
def
__init__
(
self
,
dim
:
int
,
num_heads
:
int
,
window_size
=
(
-
1
,
-
1
),
qk_norm
=
True
,
eps
=
1e-6
,
supported_attention_backends
:
Optional
[
Tuple
[
_Backend
,
...]]
=
None
)
->
None
:
super
().
__init__
(
dim
,
num_heads
,
window_size
,
qk_norm
,
eps
,
supported_attention_backends
)
self
.
add_k_proj
=
ReplicatedLinear
(
dim
,
dim
)
self
.
add_v_proj
=
ReplicatedLinear
(
dim
,
dim
)
self
.
norm_added_k
=
RMSNorm
(
dim
,
eps
=
eps
)
if
qk_norm
else
nn
.
Identity
()
self
.
norm_added_q
=
RMSNorm
(
dim
,
eps
=
eps
)
if
qk_norm
else
nn
.
Identity
()
def
forward
(
self
,
x
,
context
,
context_lens
):
r
"""
Args:
x(Tensor): Shape [B, L1, C]
context(Tensor): Shape [B, L2, C]
context_lens(Tensor): Shape [B]
"""
context_img
=
context
[:,
:
257
]
context
=
context
[:,
257
:]
b
,
n
,
d
=
x
.
size
(
0
),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
q
=
self
.
norm_q
(
self
.
to_q
(
x
)[
0
]).
view
(
b
,
-
1
,
n
,
d
)
k
=
self
.
norm_k
(
self
.
to_k
(
context
)[
0
]).
view
(
b
,
-
1
,
n
,
d
)
v
=
self
.
to_v
(
context
)[
0
].
view
(
b
,
-
1
,
n
,
d
)
k_img
=
self
.
norm_added_k
(
self
.
add_k_proj
(
context_img
)[
0
]).
view
(
b
,
-
1
,
n
,
d
)
v_img
=
self
.
add_v_proj
(
context_img
)[
0
].
view
(
b
,
-
1
,
n
,
d
)
img_x
=
self
.
attn
(
q
,
k_img
,
v_img
)
# compute attention
x
=
self
.
attn
(
q
,
k
,
v
)
# output
x
=
x
.
flatten
(
2
)
img_x
=
img_x
.
flatten
(
2
)
x
=
x
+
img_x
x
,
_
=
self
.
to_out
(
x
)
return
x
class
WanTransformerBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
ffn_dim
:
int
,
num_heads
:
int
,
qk_norm
:
str
=
"rms_norm_across_heads"
,
cross_attn_norm
:
bool
=
False
,
eps
:
float
=
1e-6
,
added_kv_proj_dim
:
Optional
[
int
]
=
None
,
supported_attention_backends
:
Optional
[
Tuple
[
_Backend
,
...]]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
# 1. Self-attention
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
eps
,
elementwise_affine
=
False
)
self
.
to_q
=
ReplicatedLinear
(
dim
,
dim
,
bias
=
True
)
self
.
to_k
=
ReplicatedLinear
(
dim
,
dim
,
bias
=
True
)
self
.
to_v
=
ReplicatedLinear
(
dim
,
dim
,
bias
=
True
)
self
.
to_out
=
ReplicatedLinear
(
dim
,
dim
,
bias
=
True
)
self
.
attn1
=
DistributedAttention
(
num_heads
=
num_heads
,
head_size
=
dim
//
num_heads
,
causal
=
False
,
supported_attention_backends
=
supported_attention_backends
,
prefix
=
f
"
{
prefix
}
.attn1"
)
self
.
hidden_dim
=
dim
self
.
num_attention_heads
=
num_heads
dim_head
=
dim
//
num_heads
if
qk_norm
==
"rms_norm"
:
self
.
norm_q
=
RMSNorm
(
dim_head
,
eps
=
eps
)
self
.
norm_k
=
RMSNorm
(
dim_head
,
eps
=
eps
)
elif
qk_norm
==
"rms_norm_across_heads"
:
# LTX applies qk norm across all heads
self
.
norm_q
=
RMSNorm
(
dim
,
eps
=
eps
)
self
.
norm_k
=
RMSNorm
(
dim
,
eps
=
eps
)
else
:
print
(
"QK Norm type not supported"
)
raise
Exception
assert
cross_attn_norm
is
True
self
.
self_attn_residual_norm
=
ScaleResidualLayerNormScaleShift
(
dim
,
norm_type
=
"layer"
,
eps
=
eps
,
elementwise_affine
=
True
,
dtype
=
torch
.
float32
)
# 2. Cross-attention
if
added_kv_proj_dim
is
not
None
:
# I2V
self
.
attn2
=
WanI2VCrossAttention
(
dim
,
num_heads
,
qk_norm
=
qk_norm
,
eps
=
eps
)
else
:
# T2V
self
.
attn2
=
WanT2VCrossAttention
(
dim
,
num_heads
,
qk_norm
=
qk_norm
,
eps
=
eps
)
self
.
cross_attn_residual_norm
=
ScaleResidualLayerNormScaleShift
(
dim
,
norm_type
=
"layer"
,
eps
=
eps
,
elementwise_affine
=
False
,
dtype
=
torch
.
float32
)
# 3. Feed-forward
self
.
ffn
=
MLP
(
dim
,
ffn_dim
,
act_type
=
"gelu_pytorch_tanh"
)
self
.
mlp_residual
=
ScaleResidual
()
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
1
,
6
,
dim
)
/
dim
**
0.5
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
torch
.
Tensor
,
temb
:
torch
.
Tensor
,
freqs_cis
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
)
->
torch
.
Tensor
:
if
hidden_states
.
dim
()
==
4
:
hidden_states
=
hidden_states
.
squeeze
(
1
)
bs
,
seq_length
,
_
=
hidden_states
.
shape
orig_dtype
=
hidden_states
.
dtype
assert
orig_dtype
!=
torch
.
float32
e
=
self
.
scale_shift_table
+
temb
.
float
()
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
e
.
chunk
(
6
,
dim
=
1
)
assert
shift_msa
.
dtype
==
torch
.
float32
# 1. Self-attention
norm_hidden_states
=
(
self
.
norm1
(
hidden_states
.
float
())
*
(
1
+
scale_msa
)
+
shift_msa
).
to
(
orig_dtype
)
query
,
_
=
self
.
to_q
(
norm_hidden_states
)
key
,
_
=
self
.
to_k
(
norm_hidden_states
)
value
,
_
=
self
.
to_v
(
norm_hidden_states
)
if
self
.
norm_q
is
not
None
:
query
=
self
.
norm_q
.
forward_native
(
query
)
if
self
.
norm_k
is
not
None
:
key
=
self
.
norm_k
.
forward_native
(
key
)
query
=
query
.
squeeze
(
1
).
unflatten
(
2
,
(
self
.
num_attention_heads
,
-
1
))
key
=
key
.
squeeze
(
1
).
unflatten
(
2
,
(
self
.
num_attention_heads
,
-
1
))
value
=
value
.
squeeze
(
1
).
unflatten
(
2
,
(
self
.
num_attention_heads
,
-
1
))
# Apply rotary embeddings
cos
,
sin
=
freqs_cis
query
,
key
=
_apply_rotary_emb
(
query
,
cos
,
sin
,
is_neox_style
=
False
),
_apply_rotary_emb
(
key
,
cos
,
sin
,
is_neox_style
=
False
)
attn_output
,
_
=
self
.
attn1
(
query
,
key
,
value
)
attn_output
=
attn_output
.
flatten
(
2
)
attn_output
,
_
=
self
.
to_out
(
attn_output
)
attn_output
=
attn_output
.
squeeze
(
1
)
null_shift
=
null_scale
=
torch
.
tensor
([
0
],
device
=
hidden_states
.
device
)
norm_hidden_states
,
hidden_states
=
self
.
self_attn_residual_norm
(
hidden_states
,
attn_output
,
gate_msa
,
null_shift
,
null_scale
)
norm_hidden_states
,
hidden_states
=
norm_hidden_states
.
to
(
orig_dtype
),
hidden_states
.
to
(
orig_dtype
)
# 2. Cross-attention
attn_output
=
self
.
attn2
(
norm_hidden_states
,
context
=
encoder_hidden_states
,
context_lens
=
None
)
norm_hidden_states
,
hidden_states
=
self
.
cross_attn_residual_norm
(
hidden_states
,
attn_output
,
1
,
c_shift_msa
,
c_scale_msa
)
norm_hidden_states
,
hidden_states
=
norm_hidden_states
.
to
(
orig_dtype
),
hidden_states
.
to
(
orig_dtype
)
# 3. Feed-forward
ff_output
=
self
.
ffn
(
norm_hidden_states
)
hidden_states
=
self
.
mlp_residual
(
hidden_states
,
ff_output
,
c_gate_msa
)
hidden_states
=
hidden_states
.
to
(
orig_dtype
)
return
hidden_states
class
WanTransformer3DModel
(
BaseDiT
):
_fsdp_shard_conditions
=
WanVideoConfig
().
_fsdp_shard_conditions
_supported_attention_backends
=
WanVideoConfig
(
).
_supported_attention_backends
_param_names_mapping
=
WanVideoConfig
().
_param_names_mapping
def
__init__
(
self
,
config
:
WanVideoConfig
)
->
None
:
super
().
__init__
(
config
=
config
)
inner_dim
=
config
.
num_attention_heads
*
config
.
attention_head_dim
self
.
hidden_size
=
config
.
hidden_size
self
.
num_attention_heads
=
config
.
num_attention_heads
self
.
in_channels
=
config
.
in_channels
self
.
out_channels
=
config
.
out_channels
self
.
num_channels_latents
=
config
.
num_channels_latents
self
.
patch_size
=
config
.
patch_size
self
.
text_len
=
config
.
text_len
# 1. Patch & position embedding
self
.
patch_embedding
=
PatchEmbed
(
in_chans
=
config
.
in_channels
,
embed_dim
=
inner_dim
,
patch_size
=
config
.
patch_size
,
flatten
=
False
)
# 2. Condition embeddings
self
.
condition_embedder
=
WanTimeTextImageEmbedding
(
dim
=
inner_dim
,
time_freq_dim
=
config
.
freq_dim
,
text_embed_dim
=
config
.
text_dim
,
image_embed_dim
=
config
.
image_dim
,
)
# 3. Transformer blocks
self
.
blocks
=
nn
.
ModuleList
([
WanTransformerBlock
(
inner_dim
,
config
.
ffn_dim
,
config
.
num_attention_heads
,
config
.
qk_norm
,
config
.
cross_attn_norm
,
config
.
eps
,
config
.
added_kv_proj_dim
,
self
.
_supported_attention_backends
,
prefix
=
f
"
{
config
.
prefix
}
.blocks.
{
i
}
"
)
for
i
in
range
(
config
.
num_layers
)
])
# 4. Output norm & projection
self
.
norm_out
=
LayerNormScaleShift
(
inner_dim
,
norm_type
=
"layer"
,
eps
=
config
.
eps
,
elementwise_affine
=
False
,
dtype
=
torch
.
float32
)
self
.
proj_out
=
nn
.
Linear
(
inner_dim
,
config
.
out_channels
*
math
.
prod
(
config
.
patch_size
))
self
.
scale_shift_table
=
nn
.
Parameter
(
torch
.
randn
(
1
,
2
,
inner_dim
)
/
inner_dim
**
0.5
)
self
.
gradient_checkpointing
=
False
self
.
__post_init__
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
encoder_hidden_states
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
timestep
:
torch
.
LongTensor
,
encoder_hidden_states_image
:
Optional
[
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]]
=
None
,
guidance
=
None
,
**
kwargs
)
->
torch
.
Tensor
:
orig_dtype
=
hidden_states
.
dtype
if
not
isinstance
(
encoder_hidden_states
,
torch
.
Tensor
):
encoder_hidden_states
=
encoder_hidden_states
[
0
]
if
isinstance
(
encoder_hidden_states_image
,
list
)
and
len
(
encoder_hidden_states_image
)
>
0
:
encoder_hidden_states_image
=
encoder_hidden_states_image
[
0
]
else
:
encoder_hidden_states_image
=
None
batch_size
,
num_channels
,
num_frames
,
height
,
width
=
hidden_states
.
shape
p_t
,
p_h
,
p_w
=
self
.
patch_size
post_patch_num_frames
=
num_frames
//
p_t
post_patch_height
=
height
//
p_h
post_patch_width
=
width
//
p_w
# Get rotary embeddings
d
=
self
.
hidden_size
//
self
.
num_attention_heads
rope_dim_list
=
[
d
-
4
*
(
d
//
6
),
2
*
(
d
//
6
),
2
*
(
d
//
6
)]
freqs_cos
,
freqs_sin
=
get_rotary_pos_embed
(
(
post_patch_num_frames
*
get_sequence_model_parallel_world_size
(),
post_patch_height
,
post_patch_width
),
self
.
hidden_size
,
self
.
num_attention_heads
,
rope_dim_list
,
dtype
=
torch
.
float64
,
rope_theta
=
10000
)
freqs_cos
=
freqs_cos
.
to
(
hidden_states
.
device
)
freqs_sin
=
freqs_sin
.
to
(
hidden_states
.
device
)
freqs_cis
=
(
freqs_cos
.
float
(),
freqs_sin
.
float
())
if
freqs_cos
is
not
None
else
None
hidden_states
=
self
.
patch_embedding
(
hidden_states
)
hidden_states
=
hidden_states
.
flatten
(
2
).
transpose
(
1
,
2
)
temb
,
timestep_proj
,
encoder_hidden_states
,
encoder_hidden_states_image
=
self
.
condition_embedder
(
timestep
,
encoder_hidden_states
,
encoder_hidden_states_image
)
timestep_proj
=
timestep_proj
.
unflatten
(
1
,
(
6
,
-
1
))
if
encoder_hidden_states_image
is
not
None
:
encoder_hidden_states
=
torch
.
concat
(
[
encoder_hidden_states_image
,
encoder_hidden_states
],
dim
=
1
)
assert
encoder_hidden_states
.
dtype
==
orig_dtype
# 4. Transformer blocks
if
torch
.
is_grad_enabled
()
and
self
.
gradient_checkpointing
:
for
block
in
self
.
blocks
:
hidden_states
=
self
.
_gradient_checkpointing_func
(
block
,
hidden_states
,
encoder_hidden_states
,
timestep_proj
,
freqs_cis
)
else
:
for
block
in
self
.
blocks
:
hidden_states
=
block
(
hidden_states
,
encoder_hidden_states
,
timestep_proj
,
freqs_cis
)
# 5. Output norm, projection & unpatchify
shift
,
scale
=
(
self
.
scale_shift_table
+
temb
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
hidden_states
=
self
.
norm_out
(
hidden_states
.
float
(),
shift
,
scale
)
hidden_states
=
self
.
proj_out
(
hidden_states
)
hidden_states
=
hidden_states
.
reshape
(
batch_size
,
post_patch_num_frames
,
post_patch_height
,
post_patch_width
,
p_t
,
p_h
,
p_w
,
-
1
)
hidden_states
=
hidden_states
.
permute
(
0
,
7
,
1
,
4
,
2
,
5
,
3
,
6
)
output
=
hidden_states
.
flatten
(
6
,
7
).
flatten
(
4
,
5
).
flatten
(
2
,
3
)
return
output
FastVideo-main/fastvideo/v1/models/encoders/base.py
0 → 100644
View file @
c07946d8
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
Tuple
import
torch
from
torch
import
nn
from
fastvideo.v1.configs.models.encoders
import
(
BaseEncoderOutput
,
ImageEncoderConfig
,
TextEncoderConfig
)
from
fastvideo.v1.platforms
import
_Backend
class
TextEncoder
(
nn
.
Module
,
ABC
):
_supported_attention_backends
:
Tuple
[
_Backend
,
...]
=
TextEncoderConfig
().
_supported_attention_backends
def
__init__
(
self
,
config
:
TextEncoderConfig
)
->
None
:
super
().
__init__
()
self
.
config
=
config
if
not
self
.
supported_attention_backends
:
raise
ValueError
(
f
"Subclass
{
self
.
__class__
.
__name__
}
must define _supported_attention_backends"
)
@
abstractmethod
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
**
kwargs
)
->
BaseEncoderOutput
:
pass
@
property
def
supported_attention_backends
(
self
)
->
Tuple
[
_Backend
,
...]:
return
self
.
_supported_attention_backends
class
ImageEncoder
(
nn
.
Module
,
ABC
):
_supported_attention_backends
:
Tuple
[
_Backend
,
...]
=
ImageEncoderConfig
().
_supported_attention_backends
def
__init__
(
self
,
config
:
ImageEncoderConfig
)
->
None
:
super
().
__init__
()
self
.
config
=
config
if
not
self
.
supported_attention_backends
:
raise
ValueError
(
f
"Subclass
{
self
.
__class__
.
__name__
}
must define _supported_attention_backends"
)
@
abstractmethod
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
**
kwargs
)
->
BaseEncoderOutput
:
pass
@
property
def
supported_attention_backends
(
self
)
->
Tuple
[
_Backend
,
...]:
return
self
.
_supported_attention_backends
FastVideo-main/fastvideo/v1/models/encoders/clip.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/clip.py
# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py
"""Minimal implementation of CLIPVisionModel intended to be only used
within a vision language model."""
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
# from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask, _prepare_4d_attention_mask
from
fastvideo.v1.attention
import
LocalAttention
from
fastvideo.v1.configs.models.encoders
import
(
BaseEncoderOutput
,
CLIPTextConfig
,
CLIPVisionConfig
)
from
fastvideo.v1.configs.quantization
import
QuantizationConfig
from
fastvideo.v1.distributed
import
(
divide
,
get_tensor_model_parallel_world_size
)
from
fastvideo.v1.layers.activation
import
get_act_fn
from
fastvideo.v1.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.models.encoders.base
import
ImageEncoder
,
TextEncoder
from
fastvideo.v1.models.encoders.vision
import
resolve_visual_encoder_outputs
# TODO: support quantization
# from vllm.model_executor.layers.quantization import QuantizationConfig
from
fastvideo.v1.models.loader.weight_utils
import
default_weight_loader
logger
=
init_logger
(
__name__
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
class
CLIPVisionEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
image_size
=
config
.
image_size
self
.
patch_size
=
config
.
patch_size
assert
self
.
image_size
%
self
.
patch_size
==
0
self
.
class_embedding
=
nn
.
Parameter
(
torch
.
randn
(
self
.
embed_dim
))
self
.
patch_embedding
=
nn
.
Conv2d
(
in_channels
=
config
.
num_channels
,
out_channels
=
self
.
embed_dim
,
kernel_size
=
self
.
patch_size
,
stride
=
self
.
patch_size
,
bias
=
False
,
)
self
.
num_patches
=
(
self
.
image_size
//
self
.
patch_size
)
**
2
self
.
num_positions
=
self
.
num_patches
+
1
self
.
position_embedding
=
nn
.
Embedding
(
self
.
num_positions
,
self
.
embed_dim
)
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
self
.
num_positions
).
expand
((
1
,
-
1
)),
persistent
=
False
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
=
pixel_values
.
shape
[
0
]
target_dtype
=
self
.
patch_embedding
.
weight
.
dtype
patch_embeds
=
self
.
patch_embedding
(
pixel_values
.
to
(
dtype
=
target_dtype
))
# shape = [*, width, grid, grid]
patch_embeds
=
patch_embeds
.
flatten
(
2
).
transpose
(
1
,
2
)
class_embeds
=
self
.
class_embedding
.
expand
(
batch_size
,
1
,
-
1
)
embeddings
=
torch
.
cat
([
class_embeds
,
patch_embeds
],
dim
=
1
)
embeddings
=
embeddings
+
self
.
position_embedding
(
self
.
position_ids
)
return
embeddings
class
CLIPTextEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
token_embedding
=
nn
.
Embedding
(
config
.
vocab_size
,
embed_dim
)
self
.
position_embedding
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
embed_dim
)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self
.
register_buffer
(
"position_ids"
,
torch
.
arange
(
config
.
max_position_embeddings
).
expand
((
1
,
-
1
)),
persistent
=
False
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
)
->
torch
.
Tensor
:
if
input_ids
is
not
None
:
seq_length
=
input_ids
.
shape
[
-
1
]
elif
inputs_embeds
is
not
None
:
seq_length
=
inputs_embeds
.
shape
[
-
2
]
else
:
raise
ValueError
(
"Either input_ids or inputs_embeds must be provided."
)
max_position_embedding
=
self
.
position_embedding
.
weight
.
shape
[
0
]
if
seq_length
>
max_position_embedding
:
raise
ValueError
(
f
"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f
"
{
seq_length
}
and max_position_embeddings:
{
max_position_embedding
}
"
)
if
position_ids
is
None
:
position_ids
=
self
.
position_ids
[:,
:
seq_length
]
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
token_embedding
(
input_ids
)
position_embeddings
=
self
.
position_embedding
(
position_ids
)
embeddings
=
inputs_embeds
+
position_embeddings
return
embeddings
class
CLIPAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
Union
[
CLIPVisionConfig
,
CLIPTextConfig
],
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
embed_dim
//
self
.
num_heads
if
self
.
head_dim
*
self
.
num_heads
!=
self
.
embed_dim
:
raise
ValueError
(
"embed_dim must be divisible by num_heads "
f
"(got `embed_dim`:
{
self
.
embed_dim
}
and `num_heads`:"
f
"
{
self
.
num_heads
}
)."
)
self
.
scale
=
self
.
head_dim
**-
0.5
self
.
dropout
=
config
.
attention_dropout
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
self
.
embed_dim
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
num_heads
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
out_proj
=
RowParallelLinear
(
input_size
=
self
.
embed_dim
,
output_size
=
self
.
embed_dim
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.out_proj"
,
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_heads_per_partition
=
divide
(
self
.
num_heads
,
self
.
tp_size
)
self
.
attn
=
LocalAttention
(
self
.
num_heads_per_partition
,
self
.
head_dim
,
self
.
num_heads_per_partition
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
supported_attention_backends
=
config
.
_supported_attention_backends
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
):
"""Input shape: Batch x Time x Channel"""
qkv_states
,
_
=
self
.
qkv_proj
(
hidden_states
)
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
3
,
dim
=-
1
)
# use flash_attn_func
query_states
=
query_states
.
reshape
(
query_states
.
shape
[
0
],
query_states
.
shape
[
1
],
self
.
num_heads_per_partition
,
self
.
head_dim
)
key_states
=
key_states
.
reshape
(
key_states
.
shape
[
0
],
key_states
.
shape
[
1
],
self
.
num_heads_per_partition
,
self
.
head_dim
)
value_states
=
value_states
.
reshape
(
value_states
.
shape
[
0
],
value_states
.
shape
[
1
],
self
.
num_heads_per_partition
,
self
.
head_dim
)
attn_output
=
self
.
attn
(
query_states
,
key_states
,
value_states
)
attn_output
=
attn_output
.
reshape
(
attn_output
.
shape
[
0
],
attn_output
.
shape
[
1
],
self
.
num_heads_per_partition
*
self
.
head_dim
)
attn_output
,
_
=
self
.
out_proj
(
attn_output
)
return
attn_output
,
None
class
CLIPMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Union
[
CLIPVisionConfig
,
CLIPTextConfig
],
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
activation_fn
=
get_act_fn
(
config
.
hidden_act
)
self
.
fc1
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
intermediate_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc1"
)
self
.
fc2
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.fc2"
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
hidden_states
=
self
.
activation_fn
(
hidden_states
)
hidden_states
,
_
=
self
.
fc2
(
hidden_states
)
return
hidden_states
class
CLIPEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Union
[
CLIPTextConfig
,
CLIPVisionConfig
],
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
self_attn
=
CLIPAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
layer_norm1
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
mlp
=
CLIPMLP
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.mlp"
)
self
.
layer_norm2
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
layer_norm1
(
hidden_states
)
hidden_states
,
_
=
self
.
self_attn
(
hidden_states
=
hidden_states
)
hidden_states
=
residual
+
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
layer_norm2
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
class
CLIPEncoder
(
nn
.
Module
):
"""
Transformer encoder consisting of `config.num_hidden_layers` self
attention layers. Each layer is a [`CLIPEncoderLayer`].
Args:
config: CLIPConfig
"""
def
__init__
(
self
,
config
:
Union
[
CLIPVisionConfig
,
CLIPTextConfig
],
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
if
num_hidden_layers_override
is
None
:
num_hidden_layers
=
config
.
num_hidden_layers
else
:
num_hidden_layers
=
num_hidden_layers_override
self
.
layers
=
nn
.
ModuleList
([
CLIPEncoderLayer
(
config
=
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.layers.
{
layer_idx
}
"
)
for
layer_idx
in
range
(
num_hidden_layers
)
])
def
forward
(
self
,
inputs_embeds
:
torch
.
Tensor
,
return_all_hidden_states
:
bool
)
->
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]]:
hidden_states_pool
=
[
inputs_embeds
]
hidden_states
=
inputs_embeds
for
encoder_layer
in
self
.
layers
:
hidden_states
=
encoder_layer
(
hidden_states
)
if
return_all_hidden_states
:
hidden_states_pool
.
append
(
hidden_states
)
# If we have multiple feature sample layers, we return all hidden
# states in order and grab the ones we need by index.
if
return_all_hidden_states
:
return
hidden_states_pool
return
[
hidden_states
]
class
CLIPTextTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPTextConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
CLIPTextEmbeddings
(
config
)
self
.
encoder
=
CLIPEncoder
(
config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
prefix
=
prefix
)
self
.
final_layer_norm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
# For `pooled_output` computation
self
.
eos_token_id
=
config
.
eos_token_id
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
)
->
BaseEncoderOutput
:
r
"""
Returns:
"""
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
if
input_ids
is
None
:
raise
ValueError
(
"You have to specify input_ids"
)
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
hidden_states
=
self
.
embeddings
(
input_ids
=
input_ids
,
position_ids
=
position_ids
)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
# causal_attention_mask = _create_4d_causal_attention_mask(
# input_shape, hidden_states.dtype, device=hidden_states.device
# )
# # expand attention_mask
# if attention_mask is not None and not self._use_flash_attention_2:
# raise NotImplementedError("attention_mask is not supported for CLIPTextTransformer")
# # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
# attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
# attention_mask=attention_mask,
# causal_attention_mask=causal_attention_mask,
# output_attentions=output_attentions,
return_all_hidden_states
=
output_hidden_states
,
# return_dict=return_dict,
)
last_hidden_state
=
encoder_outputs
[
-
1
]
last_hidden_state
=
self
.
final_layer_norm
(
last_hidden_state
)
if
self
.
eos_token_id
==
2
:
# The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
# A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
# ------------------------------------------------------------
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output
=
last_hidden_state
[
torch
.
arange
(
last_hidden_state
.
shape
[
0
],
device
=
last_hidden_state
.
device
),
input_ids
.
to
(
dtype
=
torch
.
int
,
device
=
last_hidden_state
.
device
).
argmax
(
dim
=-
1
),
]
else
:
# The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
pooled_output
=
last_hidden_state
[
torch
.
arange
(
last_hidden_state
.
shape
[
0
],
device
=
last_hidden_state
.
device
),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
# Note: we assume each sequence (along batch dim.) contains an `eos_token_id` (e.g. prepared by the tokenizer)
(
input_ids
.
to
(
dtype
=
torch
.
int
,
device
=
last_hidden_state
.
device
)
==
self
.
eos_token_id
).
int
().
argmax
(
dim
=-
1
),
]
return
BaseEncoderOutput
(
last_hidden_state
=
last_hidden_state
,
pooler_output
=
pooled_output
,
hidden_states
=
encoder_outputs
,
# attentions=encoder_outputs.attentions,
)
class
CLIPTextModel
(
TextEncoder
):
def
__init__
(
self
,
config
:
CLIPTextConfig
,
)
->
None
:
super
().
__init__
(
config
)
self
.
text_model
=
CLIPTextTransformer
(
config
=
config
,
quant_config
=
config
.
quant_config
,
prefix
=
config
.
prefix
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
**
kwargs
,
)
->
BaseEncoderOutput
:
outputs
:
BaseEncoderOutput
=
self
.
text_model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
output_hidden_states
=
output_hidden_states
,
)
return
outputs
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
# Define mapping for stacked parameters
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
# Handle q_proj, k_proj, v_proj -> qkv_proj mapping
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
in
name
:
# Replace the weight name with the parameter name
model_param_name
=
name
.
replace
(
weight_name
,
param_name
)
if
model_param_name
in
params_dict
:
param
=
params_dict
[
model_param_name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
model_param_name
)
break
else
:
# Use default weight loader for all other parameters
if
name
in
params_dict
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
CLIPVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
CLIPVisionConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
num_hidden_layers_override
:
Optional
[
int
]
=
None
,
require_post_norm
:
Optional
[
bool
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
embed_dim
=
config
.
hidden_size
self
.
embeddings
=
CLIPVisionEmbeddings
(
config
)
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
# the original transformers code and name of the model weights.
self
.
pre_layrnorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
self
.
encoder
=
CLIPEncoder
(
config
=
config
,
quant_config
=
quant_config
,
num_hidden_layers_override
=
num_hidden_layers_override
,
prefix
=
f
"
{
prefix
}
.encoder"
,
)
num_hidden_layers
=
config
.
num_hidden_layers
if
len
(
self
.
encoder
.
layers
)
>
config
.
num_hidden_layers
:
raise
ValueError
(
f
"The original encoder only has
{
num_hidden_layers
}
"
f
"layers, but you requested
{
len
(
self
.
encoder
.
layers
)
}
layers."
)
# If possible, skip post_layernorm to conserve memory
if
require_post_norm
is
None
:
require_post_norm
=
len
(
self
.
encoder
.
layers
)
==
num_hidden_layers
if
require_post_norm
:
self
.
post_layernorm
=
nn
.
LayerNorm
(
embed_dim
,
eps
=
config
.
layer_norm_eps
)
else
:
self
.
post_layernorm
=
None
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
feature_sample_layers
:
Optional
[
list
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
)
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
return_all_hidden_states
=
feature_sample_layers
is
not
None
# Produces either the last layer output or all of the hidden states,
# depending on if we have feature_sample_layers or not
encoder_outputs
=
self
.
encoder
(
inputs_embeds
=
hidden_states
,
return_all_hidden_states
=
return_all_hidden_states
)
if
not
return_all_hidden_states
:
encoder_outputs
=
encoder_outputs
[
0
]
# Handle post-norm (if applicable) and stacks feature layers if needed
encoder_outputs
=
resolve_visual_encoder_outputs
(
encoder_outputs
,
feature_sample_layers
,
self
.
post_layernorm
,
self
.
config
.
num_hidden_layers
)
return
encoder_outputs
class
CLIPVisionModel
(
ImageEncoder
):
config_class
=
CLIPVisionConfig
main_input_name
=
"pixel_values"
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
]}
def
__init__
(
self
,
config
:
CLIPVisionConfig
)
->
None
:
super
().
__init__
(
config
)
self
.
vision_model
=
CLIPVisionTransformer
(
config
=
config
,
quant_config
=
config
.
quant_config
,
num_hidden_layers_override
=
config
.
num_hidden_layers_override
,
require_post_norm
=
config
.
require_post_norm
,
prefix
=
f
"
{
config
.
prefix
}
.vision_model"
)
def
forward
(
self
,
pixel_values
:
torch
.
Tensor
,
feature_sample_layers
:
Optional
[
list
[
int
]]
=
None
,
**
kwargs
,
)
->
BaseEncoderOutput
:
last_hidden_state
=
self
.
vision_model
(
pixel_values
,
feature_sample_layers
)
return
BaseEncoderOutput
(
last_hidden_state
=
last_hidden_state
)
@
property
def
device
(
self
):
return
next
(
self
.
parameters
()).
device
# (TODO) Add prefix argument for filtering out weights to be loaded
# ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
layer_count
=
len
(
self
.
vision_model
.
encoder
.
layers
)
for
name
,
loaded_weight
in
weights
:
if
name
.
startswith
(
"visual_projection"
):
continue
# post_layernorm is not needed in CLIPVisionModel
if
(
name
.
startswith
(
"vision_model.post_layernorm"
)
and
self
.
vision_model
.
post_layernorm
is
None
):
continue
# omit layers when num_hidden_layers_override is set
if
name
.
startswith
(
"vision_model.encoder.layers"
):
layer_idx
=
int
(
name
.
split
(
"."
)[
3
])
if
layer_idx
>=
layer_count
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
FastVideo-main/fastvideo/v1/models/encoders/llama.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/llama.py
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Set
,
Tuple
import
torch
from
torch
import
nn
# from vllm.model_executor.layers.quantization import QuantizationConfig
from
fastvideo.v1.attention
import
LocalAttention
# from ..utils import (extract_layer_index)
from
fastvideo.v1.configs.models.encoders
import
BaseEncoderOutput
,
LlamaConfig
from
fastvideo.v1.configs.quantization
import
QuantizationConfig
from
fastvideo.v1.distributed
import
get_tensor_model_parallel_world_size
from
fastvideo.v1.layers.activation
import
SiluAndMul
from
fastvideo.v1.layers.layernorm
import
RMSNorm
from
fastvideo.v1.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
fastvideo.v1.layers.rotary_embedding
import
get_rope
from
fastvideo.v1.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
fastvideo.v1.models.encoders.base
import
TextEncoder
from
fastvideo.v1.models.loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
class
LlamaMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
input_size
=
hidden_size
,
output_sizes
=
[
intermediate_size
]
*
2
,
# output_size=intermediate_size,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.gate_up_proj"
,
)
self
.
down_proj
=
RowParallelLinear
(
input_size
=
intermediate_size
,
output_size
=
hidden_size
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.down_proj"
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
x
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
x
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
LlamaAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias_o_proj
:
bool
=
False
,
prefix
:
str
=
""
)
->
None
:
super
().
__init__
()
# layer_idx = extract_layer_index(prefix)
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
self
.
head_dim
=
getattr
(
config
,
"head_dim"
,
self
.
hidden_size
//
self
.
total_num_heads
)
# Phi models introduced a partial_rotary_factor parameter in the config
partial_rotary_factor
=
getattr
(
config
,
"partial_rotary_factor"
,
1
)
self
.
rotary_dim
=
int
(
partial_rotary_factor
*
self
.
head_dim
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
=
hidden_size
,
head_size
=
self
.
head_dim
,
total_num_heads
=
self
.
total_num_heads
,
total_num_kv_heads
=
self
.
total_num_kv_heads
,
bias
=
bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
o_proj
=
RowParallelLinear
(
input_size
=
self
.
total_num_heads
*
self
.
head_dim
,
output_size
=
hidden_size
,
bias
=
bias_o_proj
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
is_neox_style
=
True
is_gguf
=
quant_config
and
hasattr
(
quant_config
,
"get_name"
)
and
quant_config
.
get_name
()
==
"gguf"
if
is_gguf
and
config
.
model_type
==
"llama"
:
is_neox_style
=
False
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_dim
,
max_position
=
max_position_embeddings
,
base
=
int
(
rope_theta
),
rope_scaling
=
rope_scaling
,
is_neox_style
=
is_neox_style
,
)
self
.
attn
=
LocalAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
num_kv_heads
,
softmax_scale
=
self
.
scaling
,
causal
=
True
,
supported_attention_backends
=
config
.
_supported_attention_backends
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
# attn_output = self.attn(q, k, v)
# use flash_attn_func
# TODO (Attn abstraction and backend)
# from flash_attn import flash_attn_func
# reshape q, k, v to (batch_size, seq_len, num_heads, head_dim)
batch_size
=
q
.
shape
[
0
]
seq_len
=
q
.
shape
[
1
]
q
=
q
.
reshape
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
)
k
=
k
.
reshape
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_dim
)
v
=
v
.
reshape
(
batch_size
,
seq_len
,
self
.
num_kv_heads
,
self
.
head_dim
)
# import pdb; pdb.set_trace()
# attn_output = flash_attn_func(q, k, v, softmax_scale=self.scaling, causal=True)
attn_output
=
self
.
attn
(
q
,
k
,
v
)
attn_output
=
attn_output
.
reshape
(
batch_size
,
seq_len
,
self
.
num_heads
*
self
.
head_dim
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
LlamaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
if
rope_scaling
is
not
None
and
getattr
(
config
,
"original_max_position_embeddings"
,
None
):
rope_scaling
[
"original_max_position_embeddings"
]
=
(
config
.
original_max_position_embeddings
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
# Support abacusai/Smaug-72B-v0.1 with attention_bias
# Support internlm/internlm-7b with bias
attention_bias
=
getattr
(
config
,
"attention_bias"
,
False
)
or
getattr
(
config
,
"bias"
,
False
)
bias_o_proj
=
attention_bias
# support internlm/internlm3-8b with qkv_bias
if
hasattr
(
config
,
'qkv_bias'
):
attention_bias
=
config
.
qkv_bias
self
.
self_attn
=
LlamaAttention
(
config
=
config
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
getattr
(
config
,
"num_key_value_heads"
,
config
.
num_attention_heads
),
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
quant_config
=
quant_config
,
bias
=
attention_bias
,
bias_o_proj
=
bias_o_proj
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
self
.
mlp
=
LlamaMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"mlp_bias"
,
False
),
prefix
=
f
"
{
prefix
}
.mlp"
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
LlamaModel
(
TextEncoder
):
def
__init__
(
self
,
config
:
LlamaConfig
,
):
super
().
__init__
(
config
)
self
.
config
=
config
self
.
quant_config
=
self
.
config
.
quant_config
if
config
.
lora_config
is
not
None
:
max_loras
=
1
lora_vocab_size
=
1
if
hasattr
(
config
.
lora_config
,
"max_loras"
):
max_loras
=
config
.
lora_config
.
max_loras
if
hasattr
(
config
.
lora_config
,
"lora_extra_vocab_size"
):
lora_vocab_size
=
config
.
lora_config
.
lora_extra_vocab_size
lora_vocab
=
lora_vocab_size
*
max_loras
else
:
lora_vocab
=
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
quant_config
=
config
.
quant_config
,
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
=
config
,
quant_config
=
config
.
quant_config
,
prefix
=
f
"
{
config
.
prefix
}
.layers.
{
i
}
"
)
for
i
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
get_input_embeddings
(
self
,
input_ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
(
input_ids
)
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
**
kwargs
,
)
->
BaseEncoderOutput
:
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
if
inputs_embeds
is
not
None
:
hidden_states
=
inputs_embeds
else
:
hidden_states
=
self
.
get_input_embeddings
(
input_ids
)
residual
=
None
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
0
,
hidden_states
.
shape
[
1
],
device
=
hidden_states
.
device
).
unsqueeze
(
0
)
all_hidden_states
:
Optional
[
Tuple
[
Any
,
...]]
=
(
)
if
output_hidden_states
else
None
for
layer
in
self
.
layers
:
if
all_hidden_states
is
not
None
:
# TODO
all_hidden_states
+=
(
hidden_states
,
)
if
residual
is
None
else
(
hidden_states
+
residual
,
)
hidden_states
,
residual
=
layer
(
position_ids
,
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
# add hidden states from the last decoder layer
if
all_hidden_states
is
not
None
:
all_hidden_states
+=
(
hidden_states
,
)
# TODO(will): maybe unify the output format with other models and use
# our own class
output
=
BaseEncoderOutput
(
last_hidden_state
=
hidden_states
,
# past_key_values=past_key_values if use_cache else None,
hidden_states
=
all_hidden_states
,
# attentions=all_self_attns,
)
return
output
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
(
".qkv_proj"
,
".k_proj"
,
"k"
),
(
".qkv_proj"
,
".v_proj"
,
"v"
),
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
# if (self.quant_config is not None and
# (scale_name := self.quant_config.get_cache_scale(name))):
# # Loading kv cache quantization scales
# param = params_dict[scale_name]
# weight_loader = getattr(param, "weight_loader",
# default_weight_loader)
# loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
# loaded_weight[0])
# weight_loader(param, loaded_weight)
# loaded_params.add(scale_name)
# continue
if
"scale"
in
name
:
# Remapping the name of FP8 kv-scale.
kv_scale_name
:
Optional
[
str
]
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
kv_scale_name
is
None
:
continue
else
:
name
=
kv_scale_name
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
FastVideo-main/fastvideo/v1/models/encoders/t5.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/t5/modeling_t5.py
# Derived from T5 implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch T5 & UMT5 model."""
import
math
from
dataclasses
import
dataclass
from
typing
import
Iterable
,
Optional
,
Set
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
fastvideo.v1.configs.models.encoders
import
BaseEncoderOutput
,
T5Config
from
fastvideo.v1.configs.quantization
import
QuantizationConfig
from
fastvideo.v1.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
fastvideo.v1.layers.activation
import
get_act_fn
from
fastvideo.v1.layers.layernorm
import
RMSNorm
from
fastvideo.v1.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
fastvideo.v1.layers.vocab_parallel_embedding
import
VocabParallelEmbedding
from
fastvideo.v1.models.encoders.base
import
TextEncoder
from
fastvideo.v1.models.loader.weight_utils
import
default_weight_loader
class
AttentionType
:
"""
Attention type.
Use string to be compatible with `torch.compile`.
"""
# Decoder attention between previous layer Q/K/V
DECODER
=
"decoder"
# Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER
=
"encoder"
# Encoder attention between previous layer Q/K/V
ENCODER_ONLY
=
"encoder_only"
# Attention between dec. Q and enc. K/V for encoder-decoder
ENCODER_DECODER
=
"encoder_decoder"
@
dataclass
class
AttentionMetadata
:
attn_bias
:
torch
.
Tensor
class
T5DenseActDense
(
nn
.
Module
):
def
__init__
(
self
,
config
:
T5Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
wi
=
MergedColumnParallelLinear
(
config
.
d_model
,
[
config
.
d_ff
],
bias
=
False
)
self
.
wo
=
RowParallelLinear
(
config
.
d_ff
,
config
.
d_model
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
act
=
get_act_fn
(
config
.
dense_act_fn
)
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
hidden_states
,
_
=
self
.
wi
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
,
_
=
self
.
wo
(
hidden_states
)
return
hidden_states
class
T5DenseGatedActDense
(
nn
.
Module
):
def
__init__
(
self
,
config
:
T5Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
self
.
wi_0
=
MergedColumnParallelLinear
(
config
.
d_model
,
[
config
.
d_ff
],
bias
=
False
,
quant_config
=
quant_config
)
self
.
wi_1
=
MergedColumnParallelLinear
(
config
.
d_model
,
[
config
.
d_ff
],
bias
=
False
,
quant_config
=
quant_config
)
# Should not run in fp16 unless mixed-precision is used,
# see https://github.com/huggingface/transformers/issues/20287.
self
.
wo
=
RowParallelLinear
(
config
.
d_ff
,
config
.
d_model
,
bias
=
False
,
quant_config
=
quant_config
)
self
.
act
=
get_act_fn
(
config
.
dense_act_fn
)
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
hidden_gelu
=
self
.
act
(
self
.
wi_0
(
hidden_states
)[
0
])
hidden_linear
,
_
=
self
.
wi_1
(
hidden_states
)
hidden_states
=
hidden_gelu
*
hidden_linear
hidden_states
,
_
=
self
.
wo
(
hidden_states
)
return
hidden_states
class
T5LayerFF
(
nn
.
Module
):
def
__init__
(
self
,
config
:
T5Config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
super
().
__init__
()
if
config
.
is_gated_act
:
self
.
DenseReluDense
=
T5DenseGatedActDense
(
config
,
quant_config
=
quant_config
)
else
:
self
.
DenseReluDense
=
T5DenseActDense
(
config
,
quant_config
=
quant_config
)
self
.
layer_norm
=
RMSNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
hidden_states
)
->
torch
.
Tensor
:
forwarded_states
=
self
.
layer_norm
.
forward_native
(
hidden_states
)
forwarded_states
=
self
.
DenseReluDense
(
forwarded_states
)
hidden_states
=
hidden_states
+
forwarded_states
return
hidden_states
# T5 has attn_bias and does not use softmax scaling
class
T5MultiHeadAttention
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
q
,
k
,
v
,
attn_bias
=
None
):
b
,
_
,
n
,
c
=
q
.
shape
attn
=
torch
.
einsum
(
'binc,bjnc->bnij'
,
q
,
k
)
if
attn_bias
is
not
None
:
attn
+=
attn_bias
attn
=
F
.
softmax
(
attn
.
float
(),
dim
=-
1
).
type_as
(
attn
)
x
=
torch
.
einsum
(
'bnij,bjnc->binc'
,
attn
,
v
)
x
=
x
.
reshape
(
b
,
-
1
,
n
*
c
)
return
x
class
T5Attention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
T5Config
,
attn_type
:
str
,
has_relative_attention_bias
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
attn_type
=
attn_type
# Cross-attention has no relative pos encoding anyway
self
.
is_decoder
=
attn_type
==
AttentionType
.
DECODER
self
.
has_relative_attention_bias
=
has_relative_attention_bias
self
.
relative_attention_num_buckets
=
\
config
.
relative_attention_num_buckets
self
.
relative_attention_max_distance
=
\
config
.
relative_attention_max_distance
self
.
d_model
=
config
.
d_model
self
.
key_value_proj_dim
=
config
.
d_kv
self
.
total_num_heads
=
self
.
total_num_kv_heads
=
config
.
num_heads
# Partition heads across multiple tensor parallel GPUs.
tp_world_size
=
get_tensor_model_parallel_world_size
()
assert
config
.
num_heads
%
tp_world_size
==
0
self
.
n_heads
=
config
.
num_heads
//
tp_world_size
self
.
inner_dim
=
self
.
n_heads
*
self
.
key_value_proj_dim
# No GQA in t5.
# self.n_kv_heads = self.n_heads
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
d_model
,
self
.
d_model
//
self
.
total_num_heads
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.qkv_proj"
,
)
self
.
attn
=
T5MultiHeadAttention
()
if
self
.
has_relative_attention_bias
:
self
.
relative_attention_bias
=
\
VocabParallelEmbedding
(
self
.
relative_attention_num_buckets
,
self
.
total_num_heads
,
org_num_embeddings
=
self
.
relative_attention_num_buckets
,
padding_size
=
self
.
relative_attention_num_buckets
,
quant_config
=
quant_config
)
self
.
o
=
RowParallelLinear
(
self
.
d_model
,
self
.
d_model
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.o_proj"
,
)
@
staticmethod
def
_relative_position_bucket
(
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
)
->
torch
.
Tensor
:
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position,
i.e. the distance in tokens from the attending position to the
attended-to position. If bidirectional=False, then positive relative
positions are invalid. We use smaller buckets for small absolute
relative_position and larger buckets for larger absolute
relative_positions. All relative positions >=max_distance map to the
same bucket. All relative positions <=-max_distance map to the same
bucket. This should allow for more graceful generalization to longer
sequences than the model has been trained on
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
# noqa: E501
relative_buckets
=
0
if
bidirectional
:
num_buckets
//=
2
relative_buckets
+=
(
relative_position
>
0
).
to
(
torch
.
long
)
*
num_buckets
relative_position
=
torch
.
abs
(
relative_position
)
else
:
relative_position
=
-
torch
.
min
(
relative_position
,
torch
.
zeros_like
(
relative_position
))
# now relative_position is in the range [0, inf)
# half of the buckets are for exact increments in positions
max_exact
=
num_buckets
//
2
is_small
=
relative_position
<
max_exact
# The other half of the buckets are for logarithmically bigger bins
# in positions up to max_distance
relative_position_if_large
=
max_exact
+
(
torch
.
log
(
relative_position
.
float
()
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
)).
to
(
torch
.
long
)
relative_position_if_large
=
torch
.
min
(
relative_position_if_large
,
torch
.
full_like
(
relative_position_if_large
,
num_buckets
-
1
))
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
return
relative_buckets
def
compute_bias
(
self
,
query_length
,
key_length
,
device
=
None
)
->
torch
.
Tensor
:
"""Compute binned relative position bias"""
if
device
is
None
:
device
=
self
.
relative_attention_bias
.
weight
.
device
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
device
)[:,
None
]
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
device
)[
None
,
:]
# max_seq_len, nh
relative_position
=
memory_position
-
context_position
relative_position_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
# shape (query_length, key_length)
bidirectional
=
(
not
self
.
is_decoder
),
num_buckets
=
self
.
relative_attention_num_buckets
,
max_distance
=
self
.
relative_attention_max_distance
,
)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
x
=
values
.
permute
([
2
,
0
,
1
]).
unsqueeze
(
0
)
# shape (1, num_heads, query_length, key_length)
return
x
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
# (num_tokens, d_model)
attention_mask
:
torch
.
Tensor
,
attn_metadata
:
Optional
[
AttentionMetadata
]
=
None
,
)
->
torch
.
Tensor
:
bs
,
seq_len
,
_
=
hidden_states
.
shape
num_seqs
=
bs
n
,
c
=
self
.
n_heads
,
self
.
d_model
//
self
.
total_num_heads
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
# Projection of 'own' hidden state (self-attention). No GQA here.
q
,
k
,
v
=
qkv
.
split
(
self
.
inner_dim
,
dim
=-
1
)
q
=
q
.
reshape
(
bs
,
seq_len
,
n
,
c
)
k
=
k
.
reshape
(
bs
,
seq_len
,
n
,
c
)
v
=
v
.
reshape
(
bs
,
seq_len
,
n
,
c
)
assert
attn_metadata
is
not
None
attn_bias
=
attn_metadata
.
attn_bias
# Not compatible with CP here (as all encoder-decoder models),
# as it assumes homogeneous batch (prefills or decodes).
if
self
.
has_relative_attention_bias
:
# Self-attention. Compute T5 relative positional encoding.
# The bias term is computed on longest sequence in batch. Biases
# for shorter sequences are slices of the longest.
assert
self
.
attn_type
==
AttentionType
.
ENCODER
attn_bias
=
self
.
compute_bias
(
seq_len
,
seq_len
).
repeat
(
num_seqs
,
1
,
1
,
1
)
attn_metadata
.
attn_bias
=
attn_bias
else
:
# Encoder/Decoder Self-Attention Layer, attn bias already cached.
assert
attn_bias
is
not
None
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
view
(
bs
,
1
,
1
,
-
1
)
if
attention_mask
.
ndim
==
2
else
attention_mask
.
unsqueeze
(
1
)
attn_bias
.
masked_fill_
(
attention_mask
==
0
,
torch
.
finfo
(
q
.
dtype
).
min
)
if
get_tensor_model_parallel_world_size
()
>
1
:
rank
=
get_tensor_model_parallel_rank
()
attn_bias
=
attn_bias
[:,
rank
*
self
.
n_heads
:(
rank
+
1
)
*
self
.
n_heads
,
:,
:]
attn_output
=
self
.
attn
(
q
,
k
,
v
,
attn_bias
)
output
,
_
=
self
.
o
(
attn_output
)
return
output
class
T5LayerSelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
SelfAttention
=
T5Attention
(
config
,
AttentionType
.
DECODER
if
"decoder"
in
prefix
else
AttentionType
.
ENCODER
,
has_relative_attention_bias
=
has_relative_attention_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.SelfAttention"
)
self
.
layer_norm
=
RMSNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
attn_metadata
:
Optional
[
AttentionMetadata
]
=
None
,
)
->
torch
.
Tensor
:
normed_hidden_states
=
self
.
layer_norm
.
forward_native
(
hidden_states
)
attention_output
=
self
.
SelfAttention
(
hidden_states
=
normed_hidden_states
,
attention_mask
=
attention_mask
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
hidden_states
+
attention_output
return
hidden_states
class
T5LayerCrossAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
EncDecAttention
=
T5Attention
(
config
,
AttentionType
.
ENCODER_DECODER
,
has_relative_attention_bias
=
False
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.EncDecAttention"
)
self
.
layer_norm
=
RMSNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attn_metadata
:
Optional
[
AttentionMetadata
]
=
None
,
)
->
torch
.
Tensor
:
normed_hidden_states
=
self
.
layer_norm
.
forward_native
(
hidden_states
)
attention_output
=
self
.
EncDecAttention
(
hidden_states
=
normed_hidden_states
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
hidden_states
+
attention_output
return
hidden_states
class
T5Block
(
nn
.
Module
):
def
__init__
(
self
,
config
:
T5Config
,
is_decoder
:
bool
,
has_relative_attention_bias
=
False
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
self
.
is_decoder
=
is_decoder
self
.
layer
=
nn
.
ModuleList
()
self
.
layer
.
append
(
T5LayerSelfAttention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
))
if
self
.
is_decoder
:
self
.
layer
.
append
(
T5LayerCrossAttention
(
config
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.cross_attn"
))
self
.
layer
.
append
(
T5LayerFF
(
config
,
quant_config
=
quant_config
))
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
attn_metadata
:
Optional
[
AttentionMetadata
]
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
layer
[
0
](
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
attn_metadata
=
attn_metadata
)
if
self
.
is_decoder
:
hidden_states
=
self
.
layer
[
1
](
hidden_states
=
hidden_states
,
attn_metadata
=
attn_metadata
)
# Apply Feed Forward layer
hidden_states
=
self
.
layer
[
2
](
hidden_states
)
else
:
hidden_states
=
self
.
layer
[
1
](
hidden_states
)
return
hidden_states
class
T5Stack
(
nn
.
Module
):
def
__init__
(
self
,
config
:
T5Config
,
is_decoder
:
bool
,
n_layers
:
int
,
embed_tokens
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
is_umt5
:
bool
=
False
):
super
().
__init__
()
self
.
embed_tokens
=
embed_tokens
self
.
is_umt5
=
is_umt5
if
is_umt5
:
self
.
block
=
nn
.
ModuleList
([
T5Block
(
config
,
is_decoder
=
is_decoder
,
has_relative_attention_bias
=
True
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
i
}
"
)
for
i
in
range
(
n_layers
)
])
else
:
# Only the first block has relative positional encoding.
self
.
block
=
nn
.
ModuleList
([
T5Block
(
config
,
is_decoder
=
is_decoder
,
has_relative_attention_bias
=
i
==
0
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.blocks.
{
i
}
"
)
for
i
in
range
(
n_layers
)
])
self
.
final_layer_norm
=
RMSNorm
(
config
.
d_model
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
attention_mask
:
torch
.
Tensor
,
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
idx
,
block
in
enumerate
(
self
.
block
):
hidden_states
=
block
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
attn_metadata
=
attn_metadata
,
)
hidden_states
=
self
.
final_layer_norm
.
forward_native
(
hidden_states
)
return
hidden_states
class
T5EncoderModel
(
TextEncoder
):
def
__init__
(
self
,
config
:
T5Config
,
prefix
:
str
=
""
):
super
().
__init__
(
config
)
quant_config
=
None
self
.
shared
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
org_num_embeddings
=
config
.
vocab_size
)
self
.
encoder
=
T5Stack
(
config
,
False
,
config
.
num_layers
,
self
.
shared
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
,
is_umt5
=
False
)
def
get_input_embeddings
(
self
):
return
self
.
shared
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
**
kwargs
,
)
->
BaseEncoderOutput
:
attn_metadata
=
AttentionMetadata
(
None
)
hidden_states
=
self
.
encoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attn_metadata
=
attn_metadata
,
)
return
BaseEncoderOutput
(
last_hidden_state
=
hidden_states
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q"
,
"q"
),
(
".qkv_proj"
,
".k"
,
"k"
),
(
".qkv_proj"
,
".v"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
loaded
=
False
if
"decoder"
in
name
or
"lm_head"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded
=
True
break
if
not
loaded
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
class
UMT5EncoderModel
(
TextEncoder
):
def
__init__
(
self
,
config
:
T5Config
,
prefix
:
str
=
""
):
super
().
__init__
(
config
)
quant_config
=
None
self
.
shared
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
d_model
,
org_num_embeddings
=
config
.
vocab_size
)
self
.
encoder
=
T5Stack
(
config
,
False
,
config
.
num_layers
,
self
.
shared
,
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.encoder"
,
is_umt5
=
True
)
def
get_input_embeddings
(
self
):
return
self
.
shared
def
forward
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
**
kwargs
,
)
->
BaseEncoderOutput
:
attn_metadata
=
AttentionMetadata
(
None
)
hidden_states
=
self
.
encoder
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
attn_metadata
=
attn_metadata
,
)
return
BaseEncoderOutput
(
last_hidden_state
=
hidden_states
,
attention_mask
=
attention_mask
,
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q"
,
"q"
),
(
".qkv_proj"
,
".k"
,
"k"
),
(
".qkv_proj"
,
".v"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
loaded_params
:
Set
[
str
]
=
set
()
for
name
,
loaded_weight
in
weights
:
loaded
=
False
if
"decoder"
in
name
or
"lm_head"
in
name
:
continue
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded
=
True
break
if
not
loaded
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
return
loaded_params
FastVideo-main/fastvideo/v1/models/encoders/vision.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/vision.py
from
abc
import
ABC
,
abstractmethod
from
typing
import
Generic
,
Optional
,
TypeVar
,
Union
import
torch
from
transformers
import
PretrainedConfig
from
fastvideo.v1.logger
import
init_logger
logger
=
init_logger
(
__name__
)
_C
=
TypeVar
(
"_C"
,
bound
=
PretrainedConfig
)
class
VisionEncoderInfo
(
ABC
,
Generic
[
_C
]):
def
__init__
(
self
,
vision_config
:
_C
)
->
None
:
super
().
__init__
()
self
.
vision_config
=
vision_config
@
abstractmethod
def
get_num_image_tokens
(
self
,
*
,
image_width
:
int
,
image_height
:
int
,
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_max_image_tokens
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_image_size
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_patch_size
(
self
)
->
int
:
raise
NotImplementedError
@
abstractmethod
def
get_patch_grid_length
(
self
)
->
int
:
raise
NotImplementedError
def
resolve_visual_encoder_outputs
(
encoder_outputs
:
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
feature_sample_layers
:
Optional
[
list
[
int
]],
post_layer_norm
:
Optional
[
torch
.
nn
.
LayerNorm
],
max_possible_layers
:
int
,
)
->
torch
.
Tensor
:
"""Given the outputs a visual encoder module that may correspond to the
output of the last layer, or a list of hidden states to be stacked,
handle post normalization and resolve it into a single output tensor.
Args:
encoder_outputs: Output of encoder's last layer or all hidden states.
feature_sample_layers: Optional layer indices to grab from the encoder
outputs; if provided, encoder outputs must be a list.
post_layer_norm: Post norm to apply to the output of the encoder.
max_possible_layers: Total layers in the fully loaded visual encoder.
"""
if
feature_sample_layers
is
None
:
if
post_layer_norm
is
not
None
:
return
post_layer_norm
(
encoder_outputs
)
return
encoder_outputs
# Get the hidden states corresponding to the layer indices.
# Negative values are relative to the full visual encoder,
# so offset them depending on how many layers were loaded.
# NOTE: this assumes that encoder_outputs is a list containing
# the inputs to the visual encoder, followed by the hidden states
# of each layer.
num_loaded_layers
=
len
(
encoder_outputs
)
-
1
offset
=
max_possible_layers
-
num_loaded_layers
hs_pool
=
[
encoder_outputs
[
layer_idx
]
if
layer_idx
>=
0
else
encoder_outputs
[
layer_idx
+
offset
]
for
layer_idx
in
feature_sample_layers
]
# Apply post-norm on the final hidden state if we are using it
uses_last_layer
=
feature_sample_layers
[
-
1
]
in
(
len
(
hs_pool
)
-
1
,
-
1
)
if
post_layer_norm
is
not
None
and
uses_last_layer
:
hs_pool
[
-
1
]
=
post_layer_norm
(
encoder_outputs
)
return
torch
.
cat
(
hs_pool
,
dim
=-
1
)
\ No newline at end of file
FastVideo-main/fastvideo/v1/models/hf_transformer_utils.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from SGLang: https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/hf_transformers_utils.py
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for Huggingface Transformers."""
import
contextlib
import
json
import
os
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
Union
from
huggingface_hub
import
snapshot_download
from
transformers
import
AutoConfig
,
PretrainedConfig
from
transformers.models.auto.modeling_auto
import
(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
_CONFIG_REGISTRY
:
Dict
[
str
,
Type
[
PretrainedConfig
]]
=
{
# ChatGLMConfig.model_type: ChatGLMConfig,
# DbrxConfig.model_type: DbrxConfig,
# ExaoneConfig.model_type: ExaoneConfig,
# Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
}
for
name
,
cls
in
_CONFIG_REGISTRY
.
items
():
with
contextlib
.
suppress
(
ValueError
):
AutoConfig
.
register
(
name
,
cls
)
def
download_from_hf
(
model_path
:
str
):
if
os
.
path
.
exists
(
model_path
):
return
model_path
return
snapshot_download
(
model_path
,
allow_patterns
=
[
"*.json"
,
"*.bin"
,
"*.model"
])
def
get_hf_config
(
model
:
str
,
trust_remote_code
:
bool
,
revision
:
Optional
[
str
]
=
None
,
model_override_args
:
Optional
[
dict
]
=
None
,
**
kwargs
,
):
is_gguf
=
check_gguf_file
(
model
)
if
is_gguf
:
raise
NotImplementedError
(
"GGUF models are not supported."
)
config
=
AutoConfig
.
from_pretrained
(
model
,
trust_remote_code
=
trust_remote_code
,
revision
=
revision
,
**
kwargs
)
if
config
.
model_type
in
_CONFIG_REGISTRY
:
config_class
=
_CONFIG_REGISTRY
[
config
.
model_type
]
config
=
config_class
.
from_pretrained
(
model
,
revision
=
revision
)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
config
.
_name_or_path
=
model
if
model_override_args
:
config
.
update
(
model_override_args
)
# Special architecture mapping check for GGUF models
if
is_gguf
:
if
config
.
model_type
not
in
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
:
raise
RuntimeError
(
f
"Can't get gguf config for
{
config
.
model_type
}
."
)
model_type
=
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[
config
.
model_type
]
config
.
update
({
"architectures"
:
[
model_type
]})
return
config
def
get_diffusers_config
(
model
:
str
,
fastvideo_args
:
Optional
[
dict
]
=
None
,
)
->
Dict
[
str
,
Any
]:
"""Gets a configuration for the given diffusers model.
Args:
model: The model name or path.
fastvideo_args: Optional inference arguments to override in the config.
Returns:
The loaded configuration.
"""
config_name
=
"config.json"
if
"scheduler"
in
model
:
config_name
=
"scheduler_config.json"
# Check if the model path exists
if
os
.
path
.
exists
(
model
):
config_file
=
os
.
path
.
join
(
model
,
config_name
)
if
os
.
path
.
exists
(
config_file
):
try
:
# Load the config directly from the file
with
open
(
config_file
)
as
f
:
config_dict
:
Dict
[
str
,
Any
]
=
json
.
load
(
f
)
# TODO(will): apply any overrides from inference args
return
config_dict
except
Exception
as
e
:
raise
RuntimeError
(
f
"Failed to load diffusers config from
{
config_file
}
:
{
e
}
"
)
from
e
raise
RuntimeError
(
f
"Config file not found at
{
config_file
}
"
)
else
:
raise
RuntimeError
(
f
"Diffusers config file not found at
{
model
}
"
)
# Models don't use the same configuration key for determining the maximum
# context length. Store them here so we can sanely check them.
# NOTE: The ordering here is important. Some models have two of these and we
# have a preference for which value gets used.
CONTEXT_LENGTH_KEYS
=
[
"max_sequence_length"
,
"seq_length"
,
"max_seq_len"
,
"model_max_length"
,
"max_position_embeddings"
,
]
def
attach_additional_stop_token_ids
(
tokenizer
):
# Special handling for stop token <|eom_id|> generated by llama 3 tool use.
if
"<|eom_id|>"
in
tokenizer
.
get_added_vocab
():
tokenizer
.
additional_stop_token_ids
=
set
(
[
tokenizer
.
get_added_vocab
()[
"<|eom_id|>"
]])
else
:
tokenizer
.
additional_stop_token_ids
=
None
def
check_gguf_file
(
model
:
Union
[
str
,
os
.
PathLike
])
->
bool
:
"""Check if the file is a GGUF model."""
model
=
Path
(
model
)
if
not
model
.
is_file
():
return
False
elif
model
.
suffix
==
".gguf"
:
return
True
with
open
(
model
,
"rb"
)
as
f
:
header
=
f
.
read
(
4
)
return
header
==
b
"GGUF"
FastVideo-main/fastvideo/v1/models/loader/__init__.py
0 → 100644
View file @
c07946d8
FastVideo-main/fastvideo/v1/models/loader/component_loader.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
import
dataclasses
import
glob
import
json
import
os
import
time
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
import
torch
import
torch.nn
as
nn
from
safetensors.torch
import
load_file
as
safetensors_load_file
from
transformers
import
AutoImageProcessor
,
AutoTokenizer
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
fastvideo.v1.fastvideo_args
import
FastVideoArgs
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.models.hf_transformer_utils
import
get_diffusers_config
from
fastvideo.v1.models.loader.fsdp_load
import
load_fsdp_model
from
fastvideo.v1.models.loader.utils
import
set_default_torch_dtype
from
fastvideo.v1.models.loader.weight_utils
import
(
filter_duplicate_safetensors_files
,
filter_files_not_needed_for_inference
,
pt_weights_iterator
,
safetensors_weights_iterator
)
from
fastvideo.v1.models.registry
import
ModelRegistry
from
fastvideo.v1.utils
import
PRECISION_TO_TYPE
logger
=
init_logger
(
__name__
)
class
ComponentLoader
(
ABC
):
"""Base class for loading a specific type of model component."""
def
__init__
(
self
,
device
=
None
)
->
None
:
self
.
device
=
device
@
abstractmethod
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""
Load the component based on the model path, architecture, and inference args.
Args:
model_path: Path to the component model
architecture: Architecture of the component model
fastvideo_args: Inference arguments
Returns:
The loaded component
"""
raise
NotImplementedError
@
classmethod
def
for_module_type
(
cls
,
module_type
:
str
,
transformers_or_diffusers
:
str
)
->
'ComponentLoader'
:
"""
Factory method to create a component loader for a specific module type.
Args:
module_type: Type of module (e.g., "vae", "text_encoder", "transformer", "scheduler")
transformers_or_diffusers: Whether the module is from transformers or diffusers
Returns:
A component loader for the specified module type
"""
# Map of module types to their loader classes and expected library
module_loaders
=
{
"scheduler"
:
(
SchedulerLoader
,
"diffusers"
),
"transformer"
:
(
TransformerLoader
,
"diffusers"
),
"vae"
:
(
VAELoader
,
"diffusers"
),
"text_encoder"
:
(
TextEncoderLoader
,
"transformers"
),
"text_encoder_2"
:
(
TextEncoderLoader
,
"transformers"
),
"tokenizer"
:
(
TokenizerLoader
,
"transformers"
),
"tokenizer_2"
:
(
TokenizerLoader
,
"transformers"
),
"image_processor"
:
(
ImageProcessorLoader
,
"transformers"
),
"image_encoder"
:
(
ImageEncoderLoader
,
"transformers"
),
}
if
module_type
in
module_loaders
:
loader_cls
,
expected_library
=
module_loaders
[
module_type
]
# Assert that the library matches what's expected for this module type
assert
transformers_or_diffusers
==
expected_library
,
f
"
{
module_type
}
must be loaded from
{
expected_library
}
, got
{
transformers_or_diffusers
}
"
return
loader_cls
()
# For unknown module types, use a generic loader
logger
.
warning
(
"No specific loader found for module type: %s. Using generic loader."
,
module_type
)
return
GenericComponentLoader
(
transformers_or_diffusers
)
class
TextEncoderLoader
(
ComponentLoader
):
"""Loader for text encoders."""
@
dataclasses
.
dataclass
class
Source
:
"""A source for weights."""
model_or_path
:
str
"""The model ID or path."""
prefix
:
str
=
""
"""A prefix to prepend to all weights."""
fall_back_to_pt
:
bool
=
True
"""Whether .pt weights can be used."""
allow_patterns_overrides
:
Optional
[
list
[
str
]]
=
None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights
:
float
=
0.0
counter_after_loading_weights
:
float
=
0.0
def
_prepare_weights
(
self
,
model_name_or_path
:
str
,
fall_back_to_pt
:
bool
,
allow_patterns_overrides
:
Optional
[
list
[
str
]],
)
->
Tuple
[
str
,
List
[
str
],
bool
]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
# model_name_or_path = (self._maybe_download_from_modelscope(
# model_name_or_path, revision) or model_name_or_path)
is_local
=
os
.
path
.
isdir
(
model_name_or_path
)
assert
is_local
,
"Model path must be a local directory"
use_safetensors
=
False
index_file
=
SAFE_WEIGHTS_INDEX_NAME
allow_patterns
=
[
"*.safetensors"
,
"*.bin"
]
if
fall_back_to_pt
:
allow_patterns
+=
[
"*.pt"
]
if
allow_patterns_overrides
is
not
None
:
allow_patterns
=
allow_patterns_overrides
hf_folder
=
model_name_or_path
hf_weights_files
:
List
[
str
]
=
[]
for
pattern
in
allow_patterns
:
hf_weights_files
+=
glob
.
glob
(
os
.
path
.
join
(
hf_folder
,
pattern
))
if
len
(
hf_weights_files
)
>
0
:
if
pattern
==
"*.safetensors"
:
use_safetensors
=
True
break
if
use_safetensors
:
hf_weights_files
=
filter_duplicate_safetensors_files
(
hf_weights_files
,
hf_folder
,
index_file
)
else
:
hf_weights_files
=
filter_files_not_needed_for_inference
(
hf_weights_files
)
if
len
(
hf_weights_files
)
==
0
:
raise
RuntimeError
(
f
"Cannot find any model weights with `
{
model_name_or_path
}
`"
)
return
hf_folder
,
hf_weights_files
,
use_safetensors
def
_get_weights_iterator
(
self
,
source
:
"Source"
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
fall_back_to_pt
,
source
.
allow_patterns_overrides
)
if
use_safetensors
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
if
self
.
counter_before_loading_weights
==
0.0
:
self
.
counter_before_loading_weights
=
time
.
perf_counter
()
# Apply the prefix.
return
((
source
.
prefix
+
name
,
tensor
)
for
(
name
,
tensor
)
in
weights_iterator
)
def
_get_all_weights
(
self
,
model_config
:
Any
,
model
:
nn
.
Module
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
primary_weights
=
TextEncoderLoader
.
Source
(
model_config
.
model
,
prefix
=
""
,
fall_back_to_pt
=
getattr
(
model
,
"fall_back_to_pt_during_load"
,
True
),
allow_patterns_overrides
=
getattr
(
model
,
"allow_patterns_overrides"
,
None
),
)
yield
from
self
.
_get_weights_iterator
(
primary_weights
)
secondary_weights
=
cast
(
Iterable
[
TextEncoderLoader
.
Source
],
getattr
(
model
,
"secondary_weights"
,
()),
)
for
source
in
secondary_weights
:
yield
from
self
.
_get_weights_iterator
(
source
)
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""Load the text encoders based on the model path, architecture, and inference args."""
# model_config: PretrainedConfig = get_hf_config(
# model=model_path,
# trust_remote_code=fastvideo_args.trust_remote_code,
# revision=fastvideo_args.revision,
# model_override_args=None,
# )
with
open
(
os
.
path
.
join
(
model_path
,
"config.json"
))
as
f
:
model_config
=
json
.
load
(
f
)
model_config
.
pop
(
"_name_or_path"
,
None
)
model_config
.
pop
(
"transformers_version"
,
None
)
model_config
.
pop
(
"model_type"
,
None
)
model_config
.
pop
(
"tokenizer_class"
,
None
)
model_config
.
pop
(
"torch_dtype"
,
None
)
logger
.
info
(
"HF Model config: %s"
,
model_config
)
# @TODO(Wei): Better way to handle this?
try
:
encoder_config
=
fastvideo_args
.
text_encoder_configs
[
0
]
encoder_config
.
update_model_arch
(
model_config
)
encoder_precision
=
fastvideo_args
.
text_encoder_precisions
[
0
]
except
Exception
:
encoder_config
=
fastvideo_args
.
text_encoder_configs
[
1
]
encoder_config
.
update_model_arch
(
model_config
)
encoder_precision
=
fastvideo_args
.
text_encoder_precisions
[
1
]
target_device
=
torch
.
device
(
fastvideo_args
.
device_str
)
# TODO(will): add support for other dtypes
return
self
.
load_model
(
model_path
,
encoder_config
,
target_device
,
encoder_precision
)
def
load_model
(
self
,
model_path
:
str
,
model_config
,
target_device
:
torch
.
device
,
dtype
:
str
=
"fp16"
):
with
set_default_torch_dtype
(
PRECISION_TO_TYPE
[
dtype
]):
with
target_device
:
architectures
=
getattr
(
model_config
,
"architectures"
,
[])
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
architectures
)
model
=
model_cls
(
model_config
)
weights_to_load
=
{
name
for
name
,
_
in
model
.
named_parameters
()}
model_config
.
model
=
model_path
loaded_weights
=
model
.
load_weights
(
self
.
_get_all_weights
(
model_config
,
model
))
self
.
counter_after_loading_weights
=
time
.
perf_counter
()
logger
.
info
(
"Loading weights took %.2f seconds"
,
self
.
counter_after_loading_weights
-
self
.
counter_before_loading_weights
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
# if loaded_weights is not None:
weights_not_loaded
=
weights_to_load
-
loaded_weights
if
weights_not_loaded
:
raise
ValueError
(
"Following weights were not initialized from "
f
"checkpoint:
{
weights_not_loaded
}
"
)
# TODO(will): add support for training/finetune
return
model
.
eval
()
class
ImageEncoderLoader
(
TextEncoderLoader
):
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""Load the text encoders based on the model path, architecture, and inference args."""
# model_config: PretrainedConfig = get_hf_config(
# model=model_path,
# trust_remote_code=fastvideo_args.trust_remote_code,
# revision=fastvideo_args.revision,
# model_override_args=None,
# )
with
open
(
os
.
path
.
join
(
model_path
,
"config.json"
))
as
f
:
model_config
=
json
.
load
(
f
)
model_config
.
pop
(
"_name_or_path"
,
None
)
model_config
.
pop
(
"transformers_version"
,
None
)
model_config
.
pop
(
"torch_dtype"
,
None
)
model_config
.
pop
(
"model_type"
,
None
)
logger
.
info
(
"HF Model config: %s"
,
model_config
)
encoder_config
=
fastvideo_args
.
image_encoder_config
encoder_config
.
update_model_arch
(
model_config
)
target_device
=
torch
.
device
(
fastvideo_args
.
device_str
)
# TODO(will): add support for other dtypes
return
self
.
load_model
(
model_path
,
encoder_config
,
target_device
,
fastvideo_args
.
image_encoder_precision
)
class
ImageProcessorLoader
(
ComponentLoader
):
"""Loader for image processor."""
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""Load the image processor based on the model path, architecture, and inference args."""
logger
.
info
(
"Loading image processor from %s"
,
model_path
)
image_processor
=
AutoImageProcessor
.
from_pretrained
(
model_path
,
)
logger
.
info
(
"Loaded image processor: %s"
,
image_processor
.
__class__
.
__name__
)
return
image_processor
class
TokenizerLoader
(
ComponentLoader
):
"""Loader for tokenizers."""
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""Load the tokenizer based on the model path, architecture, and inference args."""
logger
.
info
(
"Loading tokenizer from %s"
,
model_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
# "<path to model>/tokenizer"
# in v0, this was same string as encoder_name "ClipTextModel"
# TODO(will): pass these tokenizer kwargs from inference args? Maybe
# other method of config?
padding_size
=
'right'
,
)
logger
.
info
(
"Loaded tokenizer: %s"
,
tokenizer
.
__class__
.
__name__
)
return
tokenizer
class
VAELoader
(
ComponentLoader
):
"""Loader for VAE."""
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""Load the VAE based on the model path, architecture, and inference args."""
# TODO(will): move this to a constants file
config
=
get_diffusers_config
(
model
=
model_path
)
class_name
=
config
.
pop
(
"_class_name"
)
assert
class_name
is
not
None
,
"Model config does not contain a _class_name attribute. Only diffusers format is supported."
config
.
pop
(
"_diffusers_version"
)
vae_config
=
fastvideo_args
.
vae_config
vae_config
.
update_model_arch
(
config
)
vae_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
class_name
)
vae
=
vae_cls
(
vae_config
).
to
(
fastvideo_args
.
device
)
# Find all safetensors files
safetensors_list
=
glob
.
glob
(
os
.
path
.
join
(
str
(
model_path
),
"*.safetensors"
))
# TODO(PY)
assert
len
(
safetensors_list
)
==
1
,
f
"Found
{
len
(
safetensors_list
)
}
safetensors files in
{
model_path
}
"
loaded
=
safetensors_load_file
(
safetensors_list
[
0
])
vae
.
load_state_dict
(
loaded
,
strict
=
False
)
# We might only load encoder or decoder
dtype
=
PRECISION_TO_TYPE
[
fastvideo_args
.
vae_precision
]
vae
=
vae
.
eval
().
to
(
dtype
)
return
vae
class
TransformerLoader
(
ComponentLoader
):
"""Loader for transformer."""
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""Load the transformer based on the model path, architecture, and inference args."""
config
=
get_diffusers_config
(
model
=
model_path
)
cls_name
=
config
.
pop
(
"_class_name"
)
if
cls_name
is
None
:
raise
ValueError
(
"Model config does not contain a _class_name attribute. "
"Only diffusers format is supported."
)
config
.
pop
(
"_diffusers_version"
)
# Config from Diffusers supersedes fastvideo's model config
dit_config
=
fastvideo_args
.
dit_config
dit_config
.
update_model_arch
(
config
)
model_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
cls_name
)
# Find all safetensors files
safetensors_list
=
glob
.
glob
(
os
.
path
.
join
(
str
(
model_path
),
"*.safetensors"
))
if
not
safetensors_list
:
raise
ValueError
(
f
"No safetensors files found in
{
model_path
}
"
)
logger
.
info
(
"Loading model from %s safetensors files in %s"
,
len
(
safetensors_list
),
model_path
)
# initialize_sequence_parallel_group(fastvideo_args.sp_size)
default_dtype
=
PRECISION_TO_TYPE
[
fastvideo_args
.
precision
]
# Load the model using FSDP loader
logger
.
info
(
"Loading model from %s"
,
cls_name
)
model
=
load_fsdp_model
(
model_cls
=
model_cls
,
init_params
=
{
"config"
:
dit_config
},
weight_dir_list
=
safetensors_list
,
device
=
fastvideo_args
.
device
,
cpu_offload
=
fastvideo_args
.
use_cpu_offload
,
default_dtype
=
default_dtype
)
total_params
=
sum
(
p
.
numel
()
for
p
in
model
.
parameters
())
logger
.
info
(
"Loaded model with %.2fB parameters"
,
total_params
/
1e9
)
dtypes
=
set
(
param
.
dtype
for
param
in
model
.
parameters
())
if
len
(
dtypes
)
>
1
:
model
=
model
.
to
(
default_dtype
)
model
=
model
.
eval
()
return
model
class
SchedulerLoader
(
ComponentLoader
):
"""Loader for scheduler."""
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""Load the scheduler based on the model path, architecture, and inference args."""
config
=
get_diffusers_config
(
model
=
model_path
)
class_name
=
config
.
pop
(
"_class_name"
)
assert
class_name
is
not
None
,
"Model config does not contain a _class_name attribute. Only diffusers format is supported."
config
.
pop
(
"_diffusers_version"
)
scheduler_cls
,
_
=
ModelRegistry
.
resolve_model_cls
(
class_name
)
scheduler
=
scheduler_cls
(
**
config
)
if
fastvideo_args
.
flow_shift
is
not
None
:
scheduler
.
set_shift
(
fastvideo_args
.
flow_shift
)
return
scheduler
class
GenericComponentLoader
(
ComponentLoader
):
"""Generic loader for components that don't have a specific loader."""
def
__init__
(
self
,
library
=
"transformers"
)
->
None
:
super
().
__init__
()
self
.
library
=
library
def
load
(
self
,
model_path
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""Load a generic component based on the model path, architecture, and inference args."""
logger
.
warning
(
"Using generic loader for %s with library %s"
,
model_path
,
self
.
library
)
if
self
.
library
==
"transformers"
:
from
transformers
import
AutoModel
model
=
AutoModel
.
from_pretrained
(
model_path
,
trust_remote_code
=
fastvideo_args
.
trust_remote_code
,
revision
=
fastvideo_args
.
revision
,
)
logger
.
info
(
"Loaded generic transformers model: %s"
,
model
.
__class__
.
__name__
)
return
model
elif
self
.
library
==
"diffusers"
:
logger
.
warning
(
"Generic loading for diffusers components is not fully implemented"
)
model_config
=
get_diffusers_config
(
model
=
model_path
)
logger
.
info
(
"Diffusers Model config: %s"
,
model_config
)
# This is a placeholder - in a real implementation, you'd need to handle this properly
return
None
else
:
raise
ValueError
(
f
"Unsupported library:
{
self
.
library
}
"
)
class
PipelineComponentLoader
:
"""
Utility class for loading pipeline components.
This replaces the chain of if-else statements in load_pipeline_module.
"""
@
staticmethod
def
load_module
(
module_name
:
str
,
component_model_path
:
str
,
transformers_or_diffusers
:
str
,
architecture
:
str
,
fastvideo_args
:
FastVideoArgs
):
"""
Load a pipeline module.
Args:
module_name: Name of the module (e.g., "vae", "text_encoder", "transformer", "scheduler")
component_model_path: Path to the component model
transformers_or_diffusers: Whether the module is from transformers or diffusers
architecture: Architecture of the component model
fastvideo_args: Inference arguments
Returns:
The loaded module
"""
logger
.
info
(
"Loading %s using %s from %s"
,
module_name
,
transformers_or_diffusers
,
component_model_path
,
)
# Get the appropriate loader for this module type
loader
=
ComponentLoader
.
for_module_type
(
module_name
,
transformers_or_diffusers
)
# Load the module
return
loader
.
load
(
component_model_path
,
architecture
,
fastvideo_args
)
FastVideo-main/fastvideo/v1/models/loader/fsdp_load.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from torchtune
# Copyright 2024 The TorchTune Authors.
# Copyright 2025 The FastVideo Authors.
import
contextlib
import
re
from
collections
import
defaultdict
from
itertools
import
chain
from
typing
import
(
Any
,
Callable
,
DefaultDict
,
Dict
,
Generator
,
Hashable
,
List
,
Optional
,
Tuple
,
Type
)
import
torch
from
torch
import
nn
from
torch.distributed
import
DeviceMesh
,
init_device_mesh
from
torch.distributed._composable.fsdp
import
CPUOffloadPolicy
,
fully_shard
from
torch.distributed._tensor
import
distribute_tensor
from
torch.nn.modules.module
import
_IncompatibleKeys
from
fastvideo.v1.distributed.parallel_state
import
(
get_sequence_model_parallel_world_size
)
from
fastvideo.v1.models.loader.weight_utils
import
safetensors_weights_iterator
# TODO(PY): move this to utils elsewhere
@
contextlib
.
contextmanager
def
set_default_dtype
(
dtype
:
torch
.
dtype
)
->
Generator
[
None
,
None
,
None
]:
"""
Context manager to set torch's default dtype.
Args:
dtype (torch.dtype): The desired default dtype inside the context manager.
Returns:
ContextManager: context manager for setting default dtype.
Example:
>>> with set_default_dtype(torch.bfloat16):
>>> x = torch.tensor([1, 2, 3])
>>> x.dtype
torch.bfloat16
"""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
try
:
yield
finally
:
torch
.
set_default_dtype
(
old_dtype
)
def
get_param_names_mapping
(
mapping_dict
:
Dict
[
str
,
str
])
->
Callable
[[
str
],
tuple
[
str
,
Any
,
Any
]]:
"""
Creates a mapping function that transforms parameter names using regex patterns.
Args:
mapping_dict (Dict[str, str]): Dictionary mapping regex patterns to replacement patterns
param_name (str): The parameter name to be transformed
Returns:
Callable[[str], str]: A function that maps parameter names from source to target format
"""
def
mapping_fn
(
name
:
str
)
->
tuple
[
str
,
Any
,
Any
]:
# Try to match and transform the name using the regex patterns in mapping_dict
for
pattern
,
replacement
in
mapping_dict
.
items
():
match
=
re
.
match
(
pattern
,
name
)
if
match
:
merge_index
=
None
total_splitted_params
=
None
if
isinstance
(
replacement
,
tuple
):
merge_index
=
replacement
[
1
]
total_splitted_params
=
replacement
[
2
]
replacement
=
replacement
[
0
]
name
=
re
.
sub
(
pattern
,
replacement
,
name
)
return
name
,
merge_index
,
total_splitted_params
# If no pattern matches, return the original name
return
name
,
None
,
None
return
mapping_fn
# TODO(PY): add compile option
def
load_fsdp_model
(
model_cls
:
Type
[
nn
.
Module
],
init_params
:
Dict
[
str
,
Any
],
weight_dir_list
:
List
[
str
],
device
:
torch
.
device
,
cpu_offload
:
bool
=
False
,
default_dtype
:
Optional
[
torch
.
dtype
]
=
torch
.
bfloat16
,
)
->
torch
.
nn
.
Module
:
with
set_default_dtype
(
default_dtype
),
torch
.
device
(
"meta"
):
model
=
model_cls
(
**
init_params
)
device_mesh
=
init_device_mesh
(
"cuda"
,
mesh_shape
=
(
get_sequence_model_parallel_world_size
(),
),
mesh_dim_names
=
(
"dp"
,
),
)
shard_model
(
model
,
cpu_offload
=
cpu_offload
,
reshard_after_forward
=
True
,
dp_mesh
=
device_mesh
[
"dp"
])
weight_iterator
=
safetensors_weights_iterator
(
weight_dir_list
)
param_names_mapping_fn
=
get_param_names_mapping
(
model
.
_param_names_mapping
)
load_fsdp_model_from_full_model_state_dict
(
model
,
weight_iterator
,
device
,
strict
=
True
,
cpu_offload
=
cpu_offload
,
param_names_mapping
=
param_names_mapping_fn
,
)
for
n
,
p
in
chain
(
model
.
named_parameters
(),
model
.
named_buffers
()):
if
p
.
is_meta
:
raise
RuntimeError
(
f
"Unexpected param or buffer
{
n
}
on meta device."
)
for
p
in
model
.
parameters
():
p
.
requires_grad
=
False
return
model
def
shard_model
(
model
,
*
,
cpu_offload
:
bool
,
reshard_after_forward
:
bool
=
True
,
dp_mesh
:
Optional
[
DeviceMesh
]
=
None
,
)
->
None
:
"""
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
This method will over the model's named modules from the bottom-up and apply shard modules
based on whether they meet any of the criteria from shard_conditions.
Args:
model (TransformerDecoder): Model to shard with FSDP.
shard_conditions (List[Callable[[str, nn.Module], bool]]): A list of functions to determine
which modules to shard with FSDP. Each function should take module name (relative to root)
and the module itself, returning True if FSDP should shard the module and False otherwise.
If any of shard_conditions return True for a given module, it will be sharded by FSDP.
cpu_offload (bool): If set to True, FSDP will offload parameters, gradients, and optimizer
states to CPU.
reshard_after_forward (bool): Whether to reshard parameters and buffers after
the forward pass. Setting this to True corresponds to the FULL_SHARD sharding strategy
from FSDP1, while setting it to False corresponds to the SHARD_GRAD_OP sharding strategy.
dp_mesh (Optional[DeviceMesh]): Device mesh to use for FSDP sharding under multiple parallelism.
Default to None.
Raises:
ValueError: If no layer modules were sharded, indicating that no shard_condition was triggered.
"""
fsdp_kwargs
=
{
"reshard_after_forward"
:
reshard_after_forward
,
"mesh"
:
dp_mesh
}
if
cpu_offload
:
fsdp_kwargs
[
"offload_policy"
]
=
CPUOffloadPolicy
()
# Shard the model with FSDP, iterating in reverse to start with
# lowest-level modules first
num_layers_sharded
=
0
for
n
,
m
in
reversed
(
list
(
model
.
named_modules
())):
if
any
([
shard_condition
(
n
,
m
)
for
shard_condition
in
model
.
_fsdp_shard_conditions
]):
fully_shard
(
m
,
**
fsdp_kwargs
)
num_layers_sharded
+=
1
if
num_layers_sharded
==
0
:
raise
ValueError
(
"No layer modules were sharded. Please check if shard conditions are working as expected."
)
# Finally shard the entire model to account for any stragglers
fully_shard
(
model
,
**
fsdp_kwargs
)
# TODO(PY): device mesh for cfg parallel
def
load_fsdp_model_from_full_model_state_dict
(
model
:
torch
.
nn
.
Module
,
full_sd_iterator
:
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
],
device
:
torch
.
device
,
strict
:
bool
=
False
,
cpu_offload
:
bool
=
False
,
param_names_mapping
:
Optional
[
Callable
[[
str
],
tuple
[
str
,
Any
,
Any
]]]
=
None
,
)
->
_IncompatibleKeys
:
"""
Converting full state dict into a sharded state dict
and loading it into FSDP model
Args:
model (FSDPModule): Model to generate fully qualified names for cpu_state_dict
full_sd_iterator (Generator): an iterator yielding (param_name, tensor) pairs
device (torch.device): device used to move full state dict tensors
strict (bool): flag to check if to load the model in strict mode
cpu_offload (bool): flag to check if offload to CPU is enabled
param_names_mapping (Optional[Callable[[str], str]]): a function that maps full param name to sharded param name
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Raises:
NotImplementedError: If got FSDP with more than 1D.
"""
meta_sharded_sd
=
model
.
state_dict
()
sharded_sd
=
{}
to_merge_params
:
DefaultDict
[
Hashable
,
Dict
[
Any
,
Any
]]
=
defaultdict
(
dict
)
for
source_param_name
,
full_tensor
in
full_sd_iterator
:
assert
param_names_mapping
is
not
None
target_param_name
,
merge_index
,
num_params_to_merge
=
param_names_mapping
(
source_param_name
)
if
merge_index
is
not
None
:
to_merge_params
[
target_param_name
][
merge_index
]
=
full_tensor
if
len
(
to_merge_params
[
target_param_name
])
==
num_params_to_merge
:
# cat at dim=1 according to the merge_index order
sorted_tensors
=
[
to_merge_params
[
target_param_name
][
i
]
for
i
in
range
(
num_params_to_merge
)
]
full_tensor
=
torch
.
cat
(
sorted_tensors
,
dim
=
0
)
del
to_merge_params
[
target_param_name
]
else
:
continue
sharded_meta_param
=
meta_sharded_sd
.
get
(
target_param_name
)
if
sharded_meta_param
is
None
:
raise
ValueError
(
f
"Parameter
{
source_param_name
}
-->
{
target_param_name
}
not found in meta sharded state dict"
)
full_tensor
=
full_tensor
.
to
(
sharded_meta_param
.
dtype
).
to
(
device
)
if
not
hasattr
(
sharded_meta_param
,
"device_mesh"
):
# In cases where parts of the model aren't sharded, some parameters will be plain tensors
sharded_tensor
=
full_tensor
else
:
sharded_tensor
=
distribute_tensor
(
full_tensor
,
sharded_meta_param
.
device_mesh
,
sharded_meta_param
.
placements
,
)
if
cpu_offload
:
sharded_tensor
=
sharded_tensor
.
cpu
()
sharded_sd
[
target_param_name
]
=
nn
.
Parameter
(
sharded_tensor
)
# choose `assign=True` since we cannot call `copy_` on meta tensor
return
model
.
load_state_dict
(
sharded_sd
,
strict
=
strict
,
assign
=
True
)
FastVideo-main/fastvideo/v1/models/loader/utils.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
"""Utilities for selecting and loading models."""
import
contextlib
import
torch
from
fastvideo.v1.logger
import
init_logger
logger
=
init_logger
(
__name__
)
@
contextlib
.
contextmanager
def
set_default_torch_dtype
(
dtype
:
torch
.
dtype
):
"""Sets the default torch dtype to the given dtype."""
old_dtype
=
torch
.
get_default_dtype
()
torch
.
set_default_dtype
(
dtype
)
yield
torch
.
set_default_dtype
(
old_dtype
)
FastVideo-main/fastvideo/v1/models/loader/weight_utils.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
import
fnmatch
import
hashlib
import
json
import
os
import
tempfile
import
time
from
collections
import
defaultdict
from
pathlib
import
Path
from
typing
import
Generator
,
List
,
Optional
,
Tuple
,
Union
import
filelock
import
huggingface_hub.constants
import
torch
from
huggingface_hub
import
HfFileSystem
,
hf_hub_download
,
snapshot_download
from
safetensors.torch
import
safe_open
from
tqdm.auto
import
tqdm
from
fastvideo.v1.logger
import
init_logger
logger
=
init_logger
(
__name__
)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir
=
tempfile
.
gettempdir
()
def
enable_hf_transfer
()
->
None
:
"""automatically activates hf_transfer
"""
if
"HF_HUB_ENABLE_HF_TRANSFER"
not
in
os
.
environ
:
try
:
# enable hf hub transfer if available
import
hf_transfer
# type: ignore # noqa
huggingface_hub
.
constants
.
HF_HUB_ENABLE_HF_TRANSFER
=
True
except
ImportError
:
pass
enable_hf_transfer
()
class
DisabledTqdm
(
tqdm
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
,
disable
=
True
)
def
get_lock
(
model_name_or_path
:
Union
[
str
,
Path
],
cache_dir
:
Optional
[
str
]
=
None
):
lock_dir
=
cache_dir
or
temp_dir
model_name_or_path
=
str
(
model_name_or_path
)
os
.
makedirs
(
os
.
path
.
dirname
(
lock_dir
),
exist_ok
=
True
)
model_name
=
model_name_or_path
.
replace
(
"/"
,
"-"
)
hash_name
=
hashlib
.
sha256
(
model_name
.
encode
()).
hexdigest
()
# add hash to avoid conflict with old users' lock files
lock_file_name
=
hash_name
+
model_name
+
".lock"
# mode 0o666 is required for the filelock to be shared across users
lock
=
filelock
.
FileLock
(
os
.
path
.
join
(
lock_dir
,
lock_file_name
),
mode
=
0o666
)
return
lock
def
_shared_pointers
(
tensors
):
ptrs
=
defaultdict
(
list
)
for
k
,
v
in
tensors
.
items
():
ptrs
[
v
.
data_ptr
()].
append
(
k
)
failing
=
[]
for
_
,
names
in
ptrs
.
items
():
if
len
(
names
)
>
1
:
failing
.
append
(
names
)
return
failing
def
download_weights_from_hf
(
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
],
allow_patterns
:
List
[
str
],
revision
:
Optional
[
str
]
=
None
,
ignore_patterns
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
)
->
str
:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (List[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, List[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
"""
local_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
if
not
local_only
:
# Before we download we look at that is available:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
logger
.
info
(
"Using model weights format %s"
,
allow_patterns
)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
start_time
=
time
.
perf_counter
()
hf_folder
:
str
=
snapshot_download
(
model_name_or_path
,
allow_patterns
=
allow_patterns
,
ignore_patterns
=
ignore_patterns
,
cache_dir
=
cache_dir
,
tqdm_class
=
DisabledTqdm
,
revision
=
revision
,
local_files_only
=
local_only
,
)
time_taken
=
time
.
perf_counter
()
-
start_time
if
time_taken
>
0.5
:
logger
.
info
(
"Time spent downloading weights for %s: %.6f seconds"
,
model_name_or_path
,
time_taken
)
return
hf_folder
def
download_safetensors_index_file_from_hf
(
model_name_or_path
:
str
,
index_file
:
str
,
cache_dir
:
Optional
[
str
],
revision
:
Optional
[
str
]
=
None
,
)
->
None
:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
try
:
# Download the safetensors index file.
hf_hub_download
(
repo_id
=
model_name_or_path
,
filename
=
index_file
,
cache_dir
=
cache_dir
,
revision
=
revision
,
local_files_only
=
huggingface_hub
.
constants
.
HF_HUB_OFFLINE
,
)
# If file not found on remote or locally, we should not fail since
# only some models will have index_file.
except
huggingface_hub
.
utils
.
EntryNotFoundError
:
logger
.
info
(
"No %s found in remote."
,
index_file
)
except
huggingface_hub
.
utils
.
LocalEntryNotFoundError
:
logger
.
info
(
"No %s found in local cache."
,
index_file
)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the index_file to
# look up which safetensors files should be used.
def
filter_duplicate_safetensors_files
(
hf_weights_files
:
List
[
str
],
hf_folder
:
str
,
index_file
:
str
)
->
List
[
str
]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name
=
os
.
path
.
join
(
hf_folder
,
index_file
)
if
not
os
.
path
.
isfile
(
index_file_name
):
return
hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with
open
(
index_file_name
)
as
f
:
weight_map
=
json
.
load
(
f
)[
"weight_map"
]
weight_files_in_index
=
set
()
for
weight_name
in
weight_map
:
weight_files_in_index
.
add
(
os
.
path
.
join
(
hf_folder
,
weight_map
[
weight_name
]))
# Filter out any fields that are not found in the index file.
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
f
in
weight_files_in_index
]
return
hf_weights_files
def
filter_files_not_needed_for_inference
(
hf_weights_files
:
List
[
str
])
->
List
[
str
]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist
=
[
"training_args.bin"
,
"optimizer.bin"
,
"optimizer.pt"
,
"scheduler.pt"
,
"scaler.pt"
,
]
hf_weights_files
=
[
f
for
f
in
hf_weights_files
if
not
any
(
f
.
endswith
(
x
)
for
x
in
blacklist
)
]
return
hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT
=
"{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]
\n
"
# noqa: E501
def
safetensors_weights_iterator
(
hf_weights_files
:
List
[
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
)
or
torch
.
distributed
.
get_rank
()
==
0
for
st_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading safetensors checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
with
safe_open
(
st_file
,
framework
=
"pt"
)
as
f
:
for
name
in
f
.
keys
():
# noqa: SIM118
param
=
f
.
get_tensor
(
name
)
yield
name
,
param
def
pt_weights_iterator
(
hf_weights_files
:
List
[
str
]
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm
=
not
torch
.
distributed
.
is_initialized
(
)
or
torch
.
distributed
.
get_rank
()
==
0
for
bin_file
in
tqdm
(
hf_weights_files
,
desc
=
"Loading pt checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
):
state
=
torch
.
load
(
bin_file
,
map_location
=
"cpu"
,
weights_only
=
True
)
yield
from
state
.
items
()
del
state
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
try
:
if
param
.
numel
()
==
1
and
loaded_weight
.
numel
()
==
1
:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param
.
data
.
fill_
(
loaded_weight
.
item
())
else
:
assert
param
.
size
()
==
loaded_weight
.
size
(),
(
f
"Attempted to load weight (
{
loaded_weight
.
size
()
}
) "
f
"into parameter (
{
param
.
size
()
}
)"
)
param
.
data
.
copy_
(
loaded_weight
)
except
Exception
:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def
maybe_remap_kv_scale_name
(
name
:
str
,
params_dict
:
dict
)
->
Optional
[
str
]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if
name
.
endswith
(
".kv_scale"
):
logger
.
warning_once
(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale"
)
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name
=
name
.
replace
(
".kv_scale"
,
".attn.k_scale"
)
if
remapped_name
not
in
params_dict
:
logger
.
warning_once
(
f
"Found kv_scale in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_name
}
). kv_scale is "
"not loaded."
)
return
None
return
remapped_name
possible_scale_names
=
[
".k_scale"
,
".v_scale"
]
modelopt_scale_names
=
[
".self_attn.k_proj.k_scale"
,
".self_attn.v_proj.v_scale"
]
for
scale_name
in
possible_scale_names
:
if
name
.
endswith
(
scale_name
):
if
any
(
mo_scale_name
in
name
for
mo_scale_name
in
modelopt_scale_names
):
remapped_name
=
name
.
replace
(
f
".self_attn.
{
scale_name
[
1
]
}
_proj
{
scale_name
}
"
,
f
".self_attn.attn
{
scale_name
}
"
)
else
:
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
if
remapped_name
not
in
params_dict
:
logger
.
warning_once
(
f
"Found
{
scale_name
}
in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_name
}
).
{
scale_name
}
is "
"not loaded."
)
return
None
return
remapped_name
# If there were no matches, return the untouched param name
return
name
FastVideo-main/fastvideo/v1/models/parameter.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/parameter.py
from
fractions
import
Fraction
from
typing
import
Any
,
Callable
,
Tuple
,
Union
import
torch
from
torch.nn
import
Parameter
from
fastvideo.v1.distributed
import
get_tensor_model_parallel_rank
from
fastvideo.v1.logger
import
init_logger
from
fastvideo.v1.models.utils
import
_make_synced_weight_loader
logger
=
init_logger
(
__name__
)
class
BasevLLMParameter
(
Parameter
):
"""
Base parameter for vLLM linear layers. Extends the torch.nn.parameter
by taking in a linear weight loader. Will copy the loaded weight
into the parameter when the provided weight loader is called.
"""
def
__new__
(
cls
,
data
:
torch
.
Tensor
,
**
kwargs
):
return
super
().
__new__
(
cls
,
data
=
data
,
requires_grad
=
False
)
def
__init__
(
self
,
data
:
torch
.
Tensor
,
weight_loader
:
Callable
):
"""
Initialize the BasevLLMParameter
:param data: torch tensor with the parameter data
:param weight_loader: weight loader callable
:returns: a torch.nn.parameter
"""
# During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
# narrowed_tensor.copy_(real_weight)
# expecting narrowed_tensor and param.data to share the same storage.
# However, on TPUs, narrowed_tensor will lazily propagate to the base
# tensor, which is param.data, leading to the redundant memory usage.
# This sometimes causes OOM errors during model loading. To avoid this,
# we sync the param tensor after its weight loader is called.
from
fastvideo.v1.platforms
import
current_platform
if
current_platform
.
is_tpu
():
weight_loader
=
_make_synced_weight_loader
(
weight_loader
)
self
.
_weight_loader
=
weight_loader
@
property
def
weight_loader
(
self
):
return
self
.
_weight_loader
def
_is_1d_and_scalar
(
self
,
loaded_weight
:
torch
.
Tensor
):
cond1
=
self
.
data
.
ndim
==
1
and
self
.
data
.
numel
()
==
1
cond2
=
loaded_weight
.
ndim
==
0
and
loaded_weight
.
numel
()
==
1
return
(
cond1
and
cond2
)
def
_assert_and_load
(
self
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
assert
(
self
.
data
.
shape
==
loaded_weight
.
shape
or
self
.
_is_1d_and_scalar
(
loaded_weight
))
self
.
data
.
copy_
(
loaded_weight
)
def
load_column_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
self
.
_assert_and_load
(
loaded_weight
)
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
self
.
_assert_and_load
(
loaded_weight
)
def
load_merged_column_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
)
->
None
:
self
.
_assert_and_load
(
loaded_weight
)
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
)
->
None
:
self
.
_assert_and_load
(
loaded_weight
)
class
_ColumnvLLMParameter
(
BasevLLMParameter
):
"""
Private class defining weight loading functionality
(load_merged_column_weight, load_qkv_weight)
for parameters being loaded into linear layers with column
parallelism. This includes QKV and MLP layers which are
not already fused on disk. Requires an output dimension
to be defined. Called within the weight loader of
each of the column parallel linear layers.
"""
def
__init__
(
self
,
output_dim
:
int
,
**
kwargs
):
self
.
_output_dim
=
output_dim
super
().
__init__
(
**
kwargs
)
@
property
def
output_dim
(
self
):
return
self
.
_output_dim
def
load_column_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
data
.
shape
[
self
.
output_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
tp_rank
*
shard_size
,
shard_size
)
assert
self
.
data
.
shape
==
loaded_weight
.
shape
self
.
data
.
copy_
(
loaded_weight
)
def
load_merged_column_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
)
->
None
:
shard_offset
=
kwargs
.
get
(
"shard_offset"
)
shard_size
=
kwargs
.
get
(
"shard_size"
)
if
shard_offset
is
None
or
shard_size
is
None
:
raise
ValueError
(
"shard_offset and shard_size must be provided"
)
if
isinstance
(
self
,
(
PackedColumnParameter
,
PackedvLLMParameter
))
and
self
.
packed_dim
==
self
.
output_dim
:
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
param_data
=
self
.
data
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
tp_rank
*
shard_size
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
def
load_qkv_weight
(
self
,
loaded_weight
:
torch
.
Tensor
,
**
kwargs
)
->
None
:
shard_offset
=
kwargs
.
get
(
"shard_offset"
)
shard_size
=
kwargs
.
get
(
"shard_size"
)
shard_id
=
kwargs
.
get
(
"shard_id"
)
num_heads
=
kwargs
.
get
(
"num_heads"
)
assert
shard_offset
is
not
None
assert
shard_size
is
not
None
assert
shard_id
is
not
None
assert
num_heads
is
not
None
if
isinstance
(
self
,
(
PackedColumnParameter
,
PackedvLLMParameter
))
and
self
.
output_dim
==
self
.
packed_dim
:
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
param_data
=
self
.
data
tp_rank
=
get_tensor_model_parallel_rank
()
shard_id
=
tp_rank
if
shard_id
==
"q"
else
tp_rank
//
num_heads
param_data
=
param_data
.
narrow
(
self
.
output_dim
,
shard_offset
,
shard_size
)
loaded_weight
=
loaded_weight
.
narrow
(
self
.
output_dim
,
shard_id
*
shard_size
,
shard_size
)
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
RowvLLMParameter
(
BasevLLMParameter
):
"""
Parameter class defining weight_loading functionality
(load_row_parallel_weight) for parameters being loaded
into linear layers with row parallel functionality.
Requires an input_dim to be defined.
"""
def
__init__
(
self
,
input_dim
:
int
,
**
kwargs
):
self
.
_input_dim
=
input_dim
super
().
__init__
(
**
kwargs
)
@
property
def
input_dim
(
self
):
return
self
.
_input_dim
def
load_row_parallel_weight
(
self
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
data
.
shape
[
self
.
input_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
self
.
input_dim
,
tp_rank
*
shard_size
,
shard_size
)
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
self
.
data
.
shape
==
loaded_weight
.
shape
self
.
data
.
copy_
(
loaded_weight
)
class
ModelWeightParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class
GroupQuantScaleParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for weight scales loaded for weights with
grouped quantization. Uses both column and row parallelism.
"""
pass
class
ChannelQuantScaleParameter
(
_ColumnvLLMParameter
):
"""
Parameter class for weight scales loaded for weights with
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
"""
pass
class
PerTensorScaleParameter
(
BasevLLMParameter
):
"""
Parameter class for scales where the number of scales is
equivalent to the number of logical matrices in fused linear
layers (e.g. for QKV, there are 3 scales loaded from disk).
This is relevant to weights with per-tensor quantization.
Adds functionality to map the scalers to a shard during
weight loading.
Note: additional parameter manipulation may be handled
for each quantization config specifically, within
process_weights_after_loading
"""
def
__init__
(
self
,
**
kwargs
):
self
.
qkv_idxs
=
{
"q"
:
0
,
"k"
:
1
,
"v"
:
2
}
super
().
__init__
(
**
kwargs
)
def
_shard_id_as_int
(
self
,
shard_id
:
Union
[
str
,
int
])
->
int
:
if
isinstance
(
shard_id
,
int
):
return
shard_id
# if not int, assume shard_id for qkv
# map to int and return
assert
isinstance
(
shard_id
,
str
)
assert
shard_id
in
self
.
qkv_idxs
return
self
.
qkv_idxs
[
shard_id
]
# For row parallel layers, no sharding needed
# load weight into parameter as is
def
load_row_parallel_weight
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
load_row_parallel_weight
(
*
args
,
**
kwargs
)
def
load_merged_column_weight
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
_load_into_shard_id
(
*
args
,
**
kwargs
)
def
load_qkv_weight
(
self
,
*
args
,
**
kwargs
)
->
None
:
self
.
_load_into_shard_id
(
*
args
,
**
kwargs
)
def
load_column_parallel_weight
(
self
,
*
args
,
**
kwargs
)
->
None
:
super
().
load_row_parallel_weight
(
*
args
,
**
kwargs
)
def
_load_into_shard_id
(
self
,
loaded_weight
:
torch
.
Tensor
,
shard_id
:
Union
[
str
,
int
],
**
kwargs
):
"""
Slice the parameter data based on the shard id for
loading.
"""
param_data
=
self
.
data
shard_id
=
self
.
_shard_id_as_int
(
shard_id
)
# AutoFP8 scales do not have a shape
# compressed-tensors scales do have a shape
if
len
(
loaded_weight
.
shape
)
!=
0
:
assert
loaded_weight
.
shape
[
0
]
==
1
loaded_weight
=
loaded_weight
[
0
]
param_data
=
param_data
[
shard_id
]
assert
param_data
.
shape
==
loaded_weight
.
shape
param_data
.
copy_
(
loaded_weight
)
class
PackedColumnParameter
(
_ColumnvLLMParameter
):
"""
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def
__init__
(
self
,
packed_factor
:
Union
[
int
,
Fraction
],
packed_dim
:
int
,
**
kwargs
):
self
.
_packed_factor
=
packed_factor
self
.
_packed_dim
=
packed_dim
super
().
__init__
(
**
kwargs
)
@
property
def
packed_dim
(
self
):
return
self
.
_packed_dim
@
property
def
packed_factor
(
self
):
return
self
.
_packed_factor
def
adjust_shard_indexes_for_packing
(
self
,
shard_size
,
shard_offset
)
->
Tuple
[
Any
,
Any
]:
return
_adjust_shard_indexes_for_packing
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
,
packed_factor
=
self
.
packed_factor
)
class
PackedvLLMParameter
(
ModelWeightParameter
):
"""
Parameter for model weights which are packed on disk.
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
"""
def
__init__
(
self
,
packed_factor
:
Union
[
int
,
Fraction
],
packed_dim
:
int
,
**
kwargs
):
self
.
_packed_factor
=
packed_factor
self
.
_packed_dim
=
packed_dim
super
().
__init__
(
**
kwargs
)
@
property
def
packed_dim
(
self
):
return
self
.
_packed_dim
@
property
def
packed_factor
(
self
):
return
self
.
_packed_factor
def
adjust_shard_indexes_for_packing
(
self
,
shard_size
,
shard_offset
):
return
_adjust_shard_indexes_for_packing
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
,
packed_factor
=
self
.
packed_factor
)
class
BlockQuantScaleParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for weight scales loaded for weights with
block-wise quantization. Uses both column and row parallelism.
"""
pass
def
permute_param_layout_
(
param
:
BasevLLMParameter
,
input_dim
:
int
,
output_dim
:
int
,
**
kwargs
)
->
BasevLLMParameter
:
"""
Permute a parameter's layout to the specified input and output dimensions,
useful for forcing the parameter into a known layout, for example, if I need
a packed (quantized) weight matrix to be in the layout
{input_dim = 0, output_dim = 1, packed_dim = 0}
then I can call:
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
to ensure x is in the correct layout (permuting it to the correct layout if
required, asserting if it cannot get it to the correct layout)
"""
curr_input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
curr_output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
if
curr_input_dim
is
None
or
curr_output_dim
is
None
:
assert
param
.
data
.
dim
()
==
2
,
\
"permute_param_layout_ only supports 2D parameters when either "
\
"input_dim or output_dim is not set"
# if one of the dimensions is not set, set it to the opposite of the other
# we can only do this since we asserted the parameter is 2D above
if
curr_input_dim
is
None
:
assert
curr_output_dim
is
not
None
,
\
"either input or output dim must be set"
curr_input_dim
=
(
curr_output_dim
+
1
)
%
2
if
curr_output_dim
is
None
:
assert
curr_input_dim
is
not
None
,
\
"either input or output dim must be set"
curr_output_dim
=
(
curr_input_dim
+
1
)
%
2
# create permutation from the current layout to the layout with
# self.input_dim at input_dim and self.output_dim at output_dim preserving
# other dimensions
perm
=
[
i
for
i
in
range
(
param
.
data
.
dim
())
if
i
not
in
[
curr_input_dim
,
curr_output_dim
]
]
perm
.
insert
(
input_dim
,
curr_input_dim
)
perm
.
insert
(
output_dim
,
curr_output_dim
)
if
"packed_dim"
in
kwargs
:
assert
hasattr
(
param
,
"packed_dim"
)
and
\
param
.
packed_dim
==
perm
[
kwargs
[
"packed_dim"
]],
\
"permute_param_layout_ currently doesn't support repacking"
param
.
data
=
param
.
data
.
permute
(
*
perm
)
if
hasattr
(
param
,
"_input_dim"
):
param
.
_input_dim
=
input_dim
if
hasattr
(
param
,
"_output_dim"
):
param
.
_output_dim
=
output_dim
if
"packed_dim"
in
kwargs
and
hasattr
(
param
,
"_packed_dim"
):
param
.
_packed_dim
=
kwargs
[
"packed_dim"
]
return
param
def
_adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
,
packed_factor
)
->
Tuple
[
Any
,
Any
]:
shard_size
=
shard_size
//
packed_factor
shard_offset
=
shard_offset
//
packed_factor
return
shard_size
,
shard_offset
FastVideo-main/fastvideo/v1/models/registry.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/registry.py
import
importlib
import
os
import
pickle
import
subprocess
import
sys
import
tempfile
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
from
typing
import
(
AbstractSet
,
Callable
,
Dict
,
List
,
NoReturn
,
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
cast
)
import
cloudpickle
from
torch
import
nn
from
fastvideo.v1.logger
import
logger
# huggingface class name: (component_name, fastvideo module name, fastvideo class name)
_TEXT_TO_VIDEO_DIT_MODELS
=
{
"HunyuanVideoTransformer3DModel"
:
(
"dits"
,
"hunyuanvideo"
,
"HunyuanVideoTransformer3DModel"
),
"WanTransformer3DModel"
:
(
"dits"
,
"wanvideo"
,
"WanTransformer3DModel"
),
}
_IMAGE_TO_VIDEO_DIT_MODELS
=
{
# "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoDiT"),
"WanTransformer3DModel"
:
(
"dits"
,
"wanvideo"
,
"WanTransformer3DModel"
),
}
_TEXT_ENCODER_MODELS
=
{
"CLIPTextModel"
:
(
"encoders"
,
"clip"
,
"CLIPTextModel"
),
"LlamaModel"
:
(
"encoders"
,
"llama"
,
"LlamaModel"
),
"UMT5EncoderModel"
:
(
"encoders"
,
"t5"
,
"UMT5EncoderModel"
),
}
_IMAGE_ENCODER_MODELS
:
dict
[
str
,
tuple
]
=
{
# "HunyuanVideoTransformer3DModel": ("image_encoder", "hunyuanvideo", "HunyuanVideoImageEncoder"),
"CLIPVisionModelWithProjection"
:
(
"encoders"
,
"clip"
,
"CLIPVisionModel"
),
}
_VAE_MODELS
=
{
"AutoencoderKLHunyuanVideo"
:
(
"vaes"
,
"hunyuanvae"
,
"AutoencoderKLHunyuanVideo"
),
"AutoencoderKLWan"
:
(
"vaes"
,
"wanvae"
,
"AutoencoderKLWan"
),
}
_SCHEDULERS
=
{
"FlowMatchEulerDiscreteScheduler"
:
(
"schedulers"
,
"scheduling_flow_match_euler_discrete"
,
"FlowMatchDiscreteScheduler"
),
"UniPCMultistepScheduler"
:
(
"schedulers"
,
"scheduling_unipc_multistep"
,
"UniPCMultistepScheduler"
),
}
_FAST_VIDEO_MODELS
=
{
**
_TEXT_TO_VIDEO_DIT_MODELS
,
**
_IMAGE_TO_VIDEO_DIT_MODELS
,
**
_TEXT_ENCODER_MODELS
,
**
_IMAGE_ENCODER_MODELS
,
**
_VAE_MODELS
,
**
_SCHEDULERS
,
}
_SUBPROCESS_COMMAND
=
[
sys
.
executable
,
"-m"
,
"fastvideo.v1.models.dits.registry"
]
_T
=
TypeVar
(
"_T"
)
@
dataclass
(
frozen
=
True
)
class
_ModelInfo
:
architecture
:
str
@
staticmethod
def
from_model_cls
(
model
:
Type
[
nn
.
Module
])
->
"_ModelInfo"
:
return
_ModelInfo
(
architecture
=
model
.
__name__
,
)
class
_BaseRegisteredModel
(
ABC
):
@
abstractmethod
def
inspect_model_cls
(
self
)
->
_ModelInfo
:
raise
NotImplementedError
@
abstractmethod
def
load_model_cls
(
self
)
->
Type
[
nn
.
Module
]:
raise
NotImplementedError
@
dataclass
(
frozen
=
True
)
class
_RegisteredModel
(
_BaseRegisteredModel
):
"""
Represents a model that has already been imported in the main process.
"""
interfaces
:
_ModelInfo
model_cls
:
Type
[
nn
.
Module
]
@
staticmethod
def
from_model_cls
(
model_cls
:
Type
[
nn
.
Module
]):
return
_RegisteredModel
(
interfaces
=
_ModelInfo
.
from_model_cls
(
model_cls
),
model_cls
=
model_cls
,
)
def
inspect_model_cls
(
self
)
->
_ModelInfo
:
return
self
.
interfaces
def
load_model_cls
(
self
)
->
Type
[
nn
.
Module
]:
return
self
.
model_cls
def
_run_in_subprocess
(
fn
:
Callable
[[],
_T
])
->
_T
:
# NOTE: We use a temporary directory instead of a temporary file to avoid
# issues like https://stackoverflow.com/questions/23212435/permission-denied-to-write-to-my-temporary-file
with
tempfile
.
TemporaryDirectory
()
as
tempdir
:
output_filepath
=
os
.
path
.
join
(
tempdir
,
"registry_output.tmp"
)
# `cloudpickle` allows pickling lambda functions directly
input_bytes
=
cloudpickle
.
dumps
((
fn
,
output_filepath
))
# cannot use `sys.executable __file__` here because the script
# contains relative imports
returned
=
subprocess
.
run
(
_SUBPROCESS_COMMAND
,
input
=
input_bytes
,
capture_output
=
True
)
# check if the subprocess is successful
try
:
returned
.
check_returncode
()
except
Exception
as
e
:
# wrap raised exception to provide more information
raise
RuntimeError
(
f
"Error raised in subprocess:
\n
"
f
"
{
returned
.
stderr
.
decode
()
}
"
)
from
e
with
open
(
output_filepath
,
"rb"
)
as
f
:
return
cast
(
_T
,
pickle
.
load
(
f
))
@
dataclass
(
frozen
=
True
)
class
_LazyRegisteredModel
(
_BaseRegisteredModel
):
"""
Represents a model that has not been imported in the main process.
"""
module_name
:
str
component_name
:
str
class_name
:
str
# Performed in another process to avoid initializing CUDA
def
inspect_model_cls
(
self
)
->
_ModelInfo
:
return
_run_in_subprocess
(
lambda
:
_ModelInfo
.
from_model_cls
(
self
.
load_model_cls
()))
def
load_model_cls
(
self
)
->
Type
[
nn
.
Module
]:
mod
=
importlib
.
import_module
(
self
.
module_name
)
return
cast
(
Type
[
nn
.
Module
],
getattr
(
mod
,
self
.
class_name
))
@
lru_cache
(
maxsize
=
128
)
def
_try_load_model_cls
(
model_arch
:
str
,
model
:
_BaseRegisteredModel
,
)
->
Optional
[
Type
[
nn
.
Module
]]:
from
fastvideo.v1.platforms
import
current_platform
current_platform
.
verify_model_arch
(
model_arch
)
try
:
return
model
.
load_model_cls
()
except
Exception
:
logger
.
exception
(
"Error in loading model architecture '%s'"
,
model_arch
)
return
None
@
lru_cache
(
maxsize
=
128
)
def
_try_inspect_model_cls
(
model_arch
:
str
,
model
:
_BaseRegisteredModel
,
)
->
Optional
[
_ModelInfo
]:
try
:
return
model
.
inspect_model_cls
()
except
Exception
:
logger
.
exception
(
"Error in inspecting model architecture '%s'"
,
model_arch
)
return
None
@
dataclass
class
_ModelRegistry
:
# Keyed by model_arch
models
:
Dict
[
str
,
_BaseRegisteredModel
]
=
field
(
default_factory
=
dict
)
def
get_supported_archs
(
self
)
->
AbstractSet
[
str
]:
return
self
.
models
.
keys
()
def
register_model
(
self
,
model_arch
:
str
,
model_cls
:
Union
[
Type
[
nn
.
Module
],
str
],
)
->
None
:
"""
Register an external model to be used in vLLM.
:code:`model_cls` can be either:
- A :class:`torch.nn.Module` class directly referencing the model.
- A string in the format :code:`<module>:<class>` which can be used to
lazily import the model. This is useful to avoid initializing CUDA
when importing the model and thus the related error
:code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
"""
if
model_arch
in
self
.
models
:
logger
.
warning
(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s."
,
model_arch
,
model_cls
)
if
isinstance
(
model_cls
,
str
):
split_str
=
model_cls
.
split
(
":"
)
if
len
(
split_str
)
!=
2
:
msg
=
"Expected a string in the format `<module>:<class>`"
raise
ValueError
(
msg
)
model
=
_LazyRegisteredModel
(
*
split_str
)
else
:
model
=
_RegisteredModel
.
from_model_cls
(
model_cls
)
self
.
models
[
model_arch
]
=
model
def
_raise_for_unsupported
(
self
,
architectures
:
List
[
str
])
->
NoReturn
:
all_supported_archs
=
self
.
get_supported_archs
()
if
any
(
arch
in
all_supported_archs
for
arch
in
architectures
):
raise
ValueError
(
f
"Model architectures
{
architectures
}
failed "
"to be inspected. Please check the logs for more details."
)
raise
ValueError
(
f
"Model architectures
{
architectures
}
are not supported for now. "
f
"Supported architectures:
{
all_supported_archs
}
"
)
def
_try_load_model_cls
(
self
,
model_arch
:
str
)
->
Optional
[
Type
[
nn
.
Module
]]:
if
model_arch
not
in
self
.
models
:
return
None
return
_try_load_model_cls
(
model_arch
,
self
.
models
[
model_arch
])
def
_try_inspect_model_cls
(
self
,
model_arch
:
str
)
->
Optional
[
_ModelInfo
]:
if
model_arch
not
in
self
.
models
:
return
None
return
_try_inspect_model_cls
(
model_arch
,
self
.
models
[
model_arch
])
def
_normalize_archs
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
List
[
str
]:
if
isinstance
(
architectures
,
str
):
architectures
=
[
architectures
]
if
not
architectures
:
logger
.
warning
(
"No model architectures are specified"
)
normalized_arch
=
[]
for
model
in
architectures
:
if
model
not
in
self
.
models
:
model
=
"TransformersModel"
normalized_arch
.
append
(
model
)
return
normalized_arch
def
inspect_model_cls
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
Tuple
[
_ModelInfo
,
str
]:
architectures
=
self
.
_normalize_archs
(
architectures
)
for
arch
in
architectures
:
model_info
=
self
.
_try_inspect_model_cls
(
arch
)
if
model_info
is
not
None
:
return
(
model_info
,
arch
)
return
self
.
_raise_for_unsupported
(
architectures
)
def
resolve_model_cls
(
self
,
architectures
:
Union
[
str
,
List
[
str
]],
)
->
Tuple
[
Type
[
nn
.
Module
],
str
]:
architectures
=
self
.
_normalize_archs
(
architectures
)
for
arch
in
architectures
:
model_cls
=
self
.
_try_load_model_cls
(
arch
)
if
model_cls
is
not
None
:
return
(
model_cls
,
arch
)
return
self
.
_raise_for_unsupported
(
architectures
)
ModelRegistry
=
_ModelRegistry
({
model_arch
:
_LazyRegisteredModel
(
module_name
=
f
"fastvideo.v1.models.
{
component_name
}
.
{
mod_relname
}
"
,
component_name
=
component_name
,
class_name
=
cls_name
,
)
for
model_arch
,
(
component_name
,
mod_relname
,
cls_name
)
in
_FAST_VIDEO_MODELS
.
items
()
})
FastVideo-main/fastvideo/v1/models/schedulers/base.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
,
Tuple
,
Union
import
torch
from
diffusers.utils
import
BaseOutput
class
BaseScheduler
(
ABC
):
timesteps
:
torch
.
Tensor
order
:
int
def
__init__
(
self
,
*
args
,
**
kwargs
)
->
None
:
# Check if subclass has defined all required properties
required_attributes
=
[
'timesteps'
,
'order'
]
for
attr
in
required_attributes
:
if
not
hasattr
(
self
,
attr
):
raise
AttributeError
(
f
"Subclasses of BaseScheduler must define '
{
attr
}
' property"
)
@
abstractmethod
def
set_shift
(
self
,
shift
:
float
)
->
None
:
pass
@
abstractmethod
def
set_timesteps
(
self
,
*
args
,
**
kwargs
)
->
None
:
pass
@
abstractmethod
def
scale_model_input
(
self
,
sample
:
torch
.
Tensor
,
timestep
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
pass
@
abstractmethod
def
step
(
self
,
model_output
:
torch
.
Tensor
,
timestep
:
Union
[
int
,
torch
.
Tensor
],
sample
:
torch
.
Tensor
,
return_dict
:
bool
=
True
,
)
->
Union
[
BaseOutput
,
Tuple
]:
pass
FastVideo-main/fastvideo/v1/models/schedulers/scheduling_flow_match_euler_discrete.py
0 → 100644
View file @
c07946d8
# SPDX-License-Identifier: Apache-2.0
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch
from
diffusers.configuration_utils
import
ConfigMixin
,
register_to_config
from
diffusers.schedulers.scheduling_utils
import
SchedulerMixin
from
diffusers.utils
import
BaseOutput
,
logging
from
fastvideo.v1.models.schedulers.base
import
BaseScheduler
logger
=
logging
.
get_logger
(
__name__
)
# pylint: disable=invalid-name
@
dataclass
class
FlowMatchDiscreteSchedulerOutput
(
BaseOutput
):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample
:
torch
.
FloatTensor
class
FlowMatchDiscreteScheduler
(
SchedulerMixin
,
ConfigMixin
,
BaseScheduler
):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
reverse (`bool`, defaults to `True`):
Whether to reverse the timestep schedule.
"""
_compatibles
:
list
[
Any
]
=
[]
order
=
1
@
register_to_config
def
__init__
(
self
,
num_train_timesteps
:
int
=
1000
,
shift
:
float
=
1.0
,
reverse
:
bool
=
True
,
solver
:
str
=
"euler"
,
n_tokens
:
Optional
[
int
]
=
None
,
**
kwargs
,
):
sigmas
=
torch
.
linspace
(
1
,
0
,
num_train_timesteps
+
1
)
if
not
reverse
:
sigmas
=
sigmas
.
flip
(
0
)
self
.
sigmas
=
sigmas
# the value fed to model
self
.
timesteps
=
(
sigmas
[:
-
1
]
*
num_train_timesteps
).
to
(
dtype
=
torch
.
float32
)
self
.
_step_index
:
int
|
None
=
None
self
.
_begin_index
=
0
self
.
supported_solver
=
[
"euler"
]
if
solver
not
in
self
.
supported_solver
:
raise
ValueError
(
f
"Solver
{
solver
}
not supported. Supported solvers:
{
self
.
supported_solver
}
"
)
BaseScheduler
.
__init__
(
self
)
@
property
def
step_index
(
self
):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return
self
.
_step_index
@
property
def
begin_index
(
self
):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return
self
.
_begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def
set_begin_index
(
self
,
begin_index
:
int
=
0
):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self
.
_begin_index
=
begin_index
def
_sigma_to_t
(
self
,
sigma
):
return
sigma
*
self
.
config
.
num_train_timesteps
def
set_timesteps
(
self
,
num_inference_steps
:
int
,
device
:
Union
[
str
,
torch
.
device
]
=
None
,
n_tokens
:
int
=
0
,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
"""
self
.
num_inference_steps
=
num_inference_steps
sigmas
=
torch
.
linspace
(
1
,
0
,
num_inference_steps
+
1
)
sigmas
=
self
.
sd3_time_shift
(
sigmas
)
if
not
self
.
config
.
reverse
:
sigmas
=
1
-
sigmas
self
.
sigmas
=
sigmas
self
.
timesteps
=
(
sigmas
[:
-
1
]
*
self
.
config
.
num_train_timesteps
).
to
(
dtype
=
torch
.
float32
,
device
=
device
)
# Reset step index
self
.
_step_index
=
None
def
index_for_timestep
(
self
,
timestep
,
schedule_timesteps
=
None
)
->
int
:
if
schedule_timesteps
is
None
:
schedule_timesteps
=
self
.
timesteps
indices
=
(
schedule_timesteps
==
timestep
).
nonzero
()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos
=
1
if
len
(
indices
)
>
1
else
0
idx
:
int
=
indices
[
pos
].
item
()
return
idx
def
set_shift
(
self
,
shift
:
float
)
->
None
:
self
.
config
.
shift
=
shift
def
_init_step_index
(
self
,
timestep
)
->
None
:
if
self
.
begin_index
is
None
:
if
isinstance
(
timestep
,
torch
.
Tensor
):
timestep
=
timestep
.
to
(
self
.
timesteps
.
device
)
self
.
_step_index
=
self
.
index_for_timestep
(
timestep
)
else
:
self
.
_step_index
=
self
.
_begin_index
def
scale_model_input
(
self
,
sample
:
torch
.
Tensor
,
timestep
:
Optional
[
int
]
=
None
)
->
torch
.
Tensor
:
return
sample
def
sd3_time_shift
(
self
,
t
:
torch
.
Tensor
):
return
(
self
.
config
.
shift
*
t
)
/
(
1
+
(
self
.
config
.
shift
-
1
)
*
t
)
def
step
(
self
,
model_output
:
torch
.
FloatTensor
,
timestep
:
Union
[
float
,
torch
.
FloatTensor
],
sample
:
torch
.
FloatTensor
,
return_dict
:
bool
=
True
,
**
kwargs
,
)
->
Union
[
FlowMatchDiscreteSchedulerOutput
,
Tuple
]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if
isinstance
(
timestep
,
(
int
,
torch
.
IntTensor
,
torch
.
LongTensor
)):
raise
ValueError
((
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if
self
.
step_index
is
None
:
self
.
_init_step_index
(
timestep
)
# Upcast to avoid precision issues when computing prev_sample
sample
=
sample
.
to
(
torch
.
float32
)
assert
self
.
step_index
is
not
None
dt
=
self
.
sigmas
[
self
.
step_index
+
1
]
-
self
.
sigmas
[
self
.
step_index
]
if
self
.
config
.
solver
==
"euler"
:
prev_sample
=
sample
+
model_output
.
to
(
torch
.
float32
)
*
dt
else
:
raise
ValueError
(
f
"Solver
{
self
.
config
.
solver
}
not supported. Supported solvers:
{
self
.
supported_solver
}
"
)
# upon completion increase step index by one
assert
self
.
_step_index
is
not
None
self
.
_step_index
+=
1
if
not
return_dict
:
return
(
prev_sample
,
)
return
FlowMatchDiscreteSchedulerOutput
(
prev_sample
=
prev_sample
)
def
__len__
(
self
):
return
self
.
config
.
num_train_timesteps
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment