Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b85216a3
Commit
b85216a3
authored
Jul 31, 2024
by
comfyanonymous
Browse files
Lower T5 memory usage by a few hundred MB.
parent
82cae45d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
16 deletions
+32
-16
comfy/ldm/hydit/models.py
comfy/ldm/hydit/models.py
+1
-1
comfy/ops.py
comfy/ops.py
+23
-8
comfy/text_encoders/t5.py
comfy/text_encoders/t5.py
+8
-7
No files found.
comfy/ldm/hydit/models.py
View file @
b85216a3
...
@@ -355,7 +355,7 @@ class HunYuanDiT(nn.Module):
...
@@ -355,7 +355,7 @@ class HunYuanDiT(nn.Module):
if
self
.
use_style_cond
:
if
self
.
use_style_cond
:
if
style
is
None
:
if
style
is
None
:
style
=
torch
.
zeros
((
extra_vec
.
shape
[
0
],),
device
=
x
.
device
,
dtype
=
torch
.
int
)
style
=
torch
.
zeros
((
extra_vec
.
shape
[
0
],),
device
=
x
.
device
,
dtype
=
torch
.
int
)
style_embedding
=
self
.
style_embedder
(
style
)
style_embedding
=
self
.
style_embedder
(
style
,
out_dtype
=
x
.
dtype
)
extra_vec
=
torch
.
cat
([
extra_vec
,
style_embedding
],
dim
=
1
)
extra_vec
=
torch
.
cat
([
extra_vec
,
style_embedding
],
dim
=
1
)
# Concatenate all extra vectors
# Concatenate all extra vectors
...
...
comfy/ops.py
View file @
b85216a3
...
@@ -19,17 +19,27 @@
...
@@ -19,17 +19,27 @@
import
torch
import
torch
import
comfy.model_management
import
comfy.model_management
def
cast_to
(
weight
,
dtype
=
None
,
device
=
None
,
non_blocking
=
False
):
return
weight
.
to
(
device
=
device
,
dtype
=
dtype
,
non_blocking
=
non_blocking
)
def
cast_to_input
(
weight
,
input
,
non_blocking
=
False
):
def
cast_to_input
(
weight
,
input
,
non_blocking
=
False
):
return
weight
.
to
(
device
=
input
.
device
,
dtype
=
input
.
dtype
,
non_blocking
=
non_blocking
)
return
cast_to
(
weight
,
input
.
dtype
,
input
.
device
,
non_blocking
=
non_blocking
)
def
cast_bias_weight
(
s
,
input
=
None
,
dtype
=
None
,
device
=
None
):
if
input
is
not
None
:
if
dtype
is
None
:
dtype
=
input
.
dtype
if
device
is
None
:
device
=
input
.
device
def
cast_bias_weight
(
s
,
input
):
bias
=
None
bias
=
None
non_blocking
=
comfy
.
model_management
.
device_should_use_non_blocking
(
input
.
device
)
non_blocking
=
comfy
.
model_management
.
device_should_use_non_blocking
(
device
)
if
s
.
bias
is
not
None
:
if
s
.
bias
is
not
None
:
bias
=
cast_to
_input
(
s
.
bias
,
input
,
non_blocking
=
non_blocking
)
bias
=
cast_to
(
s
.
bias
,
dtype
,
device
,
non_blocking
=
non_blocking
)
if
s
.
bias_function
is
not
None
:
if
s
.
bias_function
is
not
None
:
bias
=
s
.
bias_function
(
bias
)
bias
=
s
.
bias_function
(
bias
)
weight
=
cast_to
_input
(
s
.
weight
,
input
,
non_blocking
=
non_blocking
)
weight
=
cast_to
(
s
.
weight
,
dtype
,
device
,
non_blocking
=
non_blocking
)
if
s
.
weight_function
is
not
None
:
if
s
.
weight_function
is
not
None
:
weight
=
s
.
weight_function
(
weight
)
weight
=
s
.
weight_function
(
weight
)
return
weight
,
bias
return
weight
,
bias
...
@@ -176,14 +186,19 @@ class disable_weight_init:
...
@@ -176,14 +186,19 @@ class disable_weight_init:
self
.
bias
=
None
self
.
bias
=
None
return
None
return
None
def
forward_comfy_cast_weights
(
self
,
input
):
def
forward_comfy_cast_weights
(
self
,
input
,
out_dtype
=
None
):
weight
,
bias
=
cast_bias_weight
(
self
,
input
)
output_dtype
=
out_dtype
return
torch
.
nn
.
functional
.
embedding
(
input
,
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
)
if
self
.
weight
.
dtype
==
torch
.
float16
or
self
.
weight
.
dtype
==
torch
.
bfloat16
:
out_dtype
=
None
weight
,
bias
=
cast_bias_weight
(
self
,
device
=
input
.
device
,
dtype
=
out_dtype
)
return
torch
.
nn
.
functional
.
embedding
(
input
,
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
).
to
(
dtype
=
output_dtype
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
comfy_cast_weights
:
if
self
.
comfy_cast_weights
:
return
self
.
forward_comfy_cast_weights
(
*
args
,
**
kwargs
)
return
self
.
forward_comfy_cast_weights
(
*
args
,
**
kwargs
)
else
:
else
:
if
"out_dtype"
in
kwargs
:
kwargs
.
pop
(
"out_dtype"
)
return
super
().
forward
(
*
args
,
**
kwargs
)
return
super
().
forward
(
*
args
,
**
kwargs
)
@
classmethod
@
classmethod
...
...
comfy/text_encoders/t5.py
View file @
b85216a3
import
torch
import
torch
import
math
import
math
from
comfy.ldm.modules.attention
import
optimized_attention_for_device
from
comfy.ldm.modules.attention
import
optimized_attention_for_device
import
comfy.ops
class
T5LayerNorm
(
torch
.
nn
.
Module
):
class
T5LayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
,
dtype
=
None
,
device
=
None
,
operations
=
None
):
...
@@ -11,7 +12,7 @@ class T5LayerNorm(torch.nn.Module):
...
@@ -11,7 +12,7 @@ class T5LayerNorm(torch.nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
variance
=
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
variance
=
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
*
x
return
comfy
.
ops
.
cast_to_input
(
self
.
weight
,
x
)
*
x
activations
=
{
activations
=
{
"gelu_pytorch_tanh"
:
lambda
a
:
torch
.
nn
.
functional
.
gelu
(
a
,
approximate
=
"tanh"
),
"gelu_pytorch_tanh"
:
lambda
a
:
torch
.
nn
.
functional
.
gelu
(
a
,
approximate
=
"tanh"
),
...
@@ -82,7 +83,7 @@ class T5Attention(torch.nn.Module):
...
@@ -82,7 +83,7 @@ class T5Attention(torch.nn.Module):
if
relative_attention_bias
:
if
relative_attention_bias
:
self
.
relative_attention_num_buckets
=
32
self
.
relative_attention_num_buckets
=
32
self
.
relative_attention_max_distance
=
128
self
.
relative_attention_max_distance
=
128
self
.
relative_attention_bias
=
torch
.
nn
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
num_heads
,
device
=
device
)
self
.
relative_attention_bias
=
operations
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
num_heads
,
device
=
device
,
dtype
=
dtype
)
@
staticmethod
@
staticmethod
def
_relative_position_bucket
(
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
):
def
_relative_position_bucket
(
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
):
...
@@ -132,7 +133,7 @@ class T5Attention(torch.nn.Module):
...
@@ -132,7 +133,7 @@ class T5Attention(torch.nn.Module):
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
relative_buckets
+=
torch
.
where
(
is_small
,
relative_position
,
relative_position_if_large
)
return
relative_buckets
return
relative_buckets
def
compute_bias
(
self
,
query_length
,
key_length
,
device
):
def
compute_bias
(
self
,
query_length
,
key_length
,
device
,
dtype
):
"""Compute binned relative position bias"""
"""Compute binned relative position bias"""
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
device
)[:,
None
]
context_position
=
torch
.
arange
(
query_length
,
dtype
=
torch
.
long
,
device
=
device
)[:,
None
]
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
device
)[
None
,
:]
memory_position
=
torch
.
arange
(
key_length
,
dtype
=
torch
.
long
,
device
=
device
)[
None
,
:]
...
@@ -143,7 +144,7 @@ class T5Attention(torch.nn.Module):
...
@@ -143,7 +144,7 @@ class T5Attention(torch.nn.Module):
num_buckets
=
self
.
relative_attention_num_buckets
,
num_buckets
=
self
.
relative_attention_num_buckets
,
max_distance
=
self
.
relative_attention_max_distance
,
max_distance
=
self
.
relative_attention_max_distance
,
)
)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
)
# shape (query_length, key_length, num_heads)
values
=
self
.
relative_attention_bias
(
relative_position_bucket
,
out_dtype
=
dtype
)
# shape (query_length, key_length, num_heads)
values
=
values
.
permute
([
2
,
0
,
1
]).
unsqueeze
(
0
)
# shape (1, num_heads, query_length, key_length)
values
=
values
.
permute
([
2
,
0
,
1
]).
unsqueeze
(
0
)
# shape (1, num_heads, query_length, key_length)
return
values
return
values
...
@@ -152,7 +153,7 @@ class T5Attention(torch.nn.Module):
...
@@ -152,7 +153,7 @@ class T5Attention(torch.nn.Module):
k
=
self
.
k
(
x
)
k
=
self
.
k
(
x
)
v
=
self
.
v
(
x
)
v
=
self
.
v
(
x
)
if
self
.
relative_attention_bias
is
not
None
:
if
self
.
relative_attention_bias
is
not
None
:
past_bias
=
self
.
compute_bias
(
x
.
shape
[
1
],
x
.
shape
[
1
],
x
.
device
)
past_bias
=
self
.
compute_bias
(
x
.
shape
[
1
],
x
.
shape
[
1
],
x
.
device
,
x
.
dtype
)
if
past_bias
is
not
None
:
if
past_bias
is
not
None
:
if
mask
is
not
None
:
if
mask
is
not
None
:
...
@@ -225,7 +226,7 @@ class T5(torch.nn.Module):
...
@@ -225,7 +226,7 @@ class T5(torch.nn.Module):
self
.
encoder
=
T5Stack
(
self
.
num_layers
,
model_dim
,
model_dim
,
config_dict
[
"d_ff"
],
config_dict
[
"dense_act_fn"
],
config_dict
[
"is_gated_act"
],
config_dict
[
"num_heads"
],
config_dict
[
"model_type"
]
!=
"umt5"
,
dtype
,
device
,
operations
)
self
.
encoder
=
T5Stack
(
self
.
num_layers
,
model_dim
,
model_dim
,
config_dict
[
"d_ff"
],
config_dict
[
"dense_act_fn"
],
config_dict
[
"is_gated_act"
],
config_dict
[
"num_heads"
],
config_dict
[
"model_type"
]
!=
"umt5"
,
dtype
,
device
,
operations
)
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
shared
=
torch
.
nn
.
Embedding
(
config_dict
[
"vocab_size"
],
model_dim
,
device
=
device
)
self
.
shared
=
operations
.
Embedding
(
config_dict
[
"vocab_size"
],
model_dim
,
device
=
device
,
dtype
=
dtype
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
shared
return
self
.
shared
...
@@ -234,5 +235,5 @@ class T5(torch.nn.Module):
...
@@ -234,5 +235,5 @@ class T5(torch.nn.Module):
self
.
shared
=
embeddings
self
.
shared
=
embeddings
def
forward
(
self
,
input_ids
,
*
args
,
**
kwargs
):
def
forward
(
self
,
input_ids
,
*
args
,
**
kwargs
):
x
=
self
.
shared
(
input_ids
)
x
=
self
.
shared
(
input_ids
,
out_dtype
=
kwargs
.
get
(
"dtype"
,
torch
.
float32
)
)
return
self
.
encoder
(
x
,
*
args
,
**
kwargs
)
return
self
.
encoder
(
x
,
*
args
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment