Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
a9ac56fc
Commit
a9ac56fc
authored
Jul 26, 2024
by
comfyanonymous
Browse files
Own BertModel implementation that works with lowvram.
parent
25b51b1a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
44 deletions
+142
-44
comfy/text_encoders/bert.py
comfy/text_encoders/bert.py
+139
-0
comfy/text_encoders/hydit.py
comfy/text_encoders/hydit.py
+3
-44
No files found.
comfy/text_encoders/bert.py
0 → 100644
View file @
a9ac56fc
import
torch
from
comfy.ldm.modules.attention
import
optimized_attention_for_device
class
BertAttention
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
heads
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
heads
=
heads
self
.
query
=
operations
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
self
.
key
=
operations
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
self
.
value
=
operations
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
True
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
,
mask
=
None
,
optimized_attention
=
None
):
q
=
self
.
query
(
x
)
k
=
self
.
key
(
x
)
v
=
self
.
value
(
x
)
out
=
optimized_attention
(
q
,
k
,
v
,
self
.
heads
,
mask
)
return
out
class
BertOutput
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
input_dim
,
output_dim
,
layer_norm_eps
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
dense
=
operations
.
Linear
(
input_dim
,
output_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
LayerNorm
=
operations
.
LayerNorm
(
output_dim
,
eps
=
layer_norm_eps
,
dtype
=
dtype
,
device
=
device
)
# self.dropout = nn.Dropout(0.0)
def
forward
(
self
,
x
,
y
):
x
=
self
.
dense
(
x
)
# hidden_states = self.dropout(hidden_states)
x
=
self
.
LayerNorm
(
x
+
y
)
return
x
class
BertAttentionBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
heads
,
layer_norm_eps
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
self
=
BertAttention
(
embed_dim
,
heads
,
dtype
,
device
,
operations
)
self
.
output
=
BertOutput
(
embed_dim
,
embed_dim
,
layer_norm_eps
,
dtype
,
device
,
operations
)
def
forward
(
self
,
x
,
mask
,
optimized_attention
):
y
=
self
.
self
(
x
,
mask
,
optimized_attention
)
return
self
.
output
(
y
,
x
)
class
BertIntermediate
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
intermediate_dim
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
dense
=
operations
.
Linear
(
embed_dim
,
intermediate_dim
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
x
):
x
=
self
.
dense
(
x
)
return
torch
.
nn
.
functional
.
gelu
(
x
)
class
BertBlock
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
intermediate_dim
,
heads
,
layer_norm_eps
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
attention
=
BertAttentionBlock
(
embed_dim
,
heads
,
layer_norm_eps
,
dtype
,
device
,
operations
)
self
.
intermediate
=
BertIntermediate
(
embed_dim
,
intermediate_dim
,
dtype
,
device
,
operations
)
self
.
output
=
BertOutput
(
intermediate_dim
,
embed_dim
,
layer_norm_eps
,
dtype
,
device
,
operations
)
def
forward
(
self
,
x
,
mask
,
optimized_attention
):
x
=
self
.
attention
(
x
,
mask
,
optimized_attention
)
y
=
self
.
intermediate
(
x
)
return
self
.
output
(
y
,
x
)
class
BertEncoder
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
embed_dim
,
intermediate_dim
,
heads
,
layer_norm_eps
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
layer
=
torch
.
nn
.
ModuleList
([
BertBlock
(
embed_dim
,
intermediate_dim
,
heads
,
layer_norm_eps
,
dtype
,
device
,
operations
)
for
i
in
range
(
num_layers
)])
def
forward
(
self
,
x
,
mask
=
None
,
intermediate_output
=
None
):
optimized_attention
=
optimized_attention_for_device
(
x
.
device
,
mask
=
mask
is
not
None
,
small_input
=
True
)
if
intermediate_output
is
not
None
:
if
intermediate_output
<
0
:
intermediate_output
=
len
(
self
.
layer
)
+
intermediate_output
intermediate
=
None
for
i
,
l
in
enumerate
(
self
.
layer
):
x
=
l
(
x
,
mask
,
optimized_attention
)
if
i
==
intermediate_output
:
intermediate
=
x
.
clone
()
return
x
,
intermediate
class
BertEmbeddings
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
pad_token_id
,
embed_dim
,
layer_norm_eps
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
word_embeddings
=
torch
.
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
pad_token_id
,
dtype
=
dtype
,
device
=
device
)
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
token_type_embeddings
=
torch
.
nn
.
Embedding
(
type_vocab_size
,
embed_dim
,
dtype
=
dtype
,
device
=
device
)
self
.
LayerNorm
=
operations
.
LayerNorm
(
embed_dim
,
eps
=
layer_norm_eps
,
dtype
=
dtype
,
device
=
device
)
def
forward
(
self
,
input_tokens
,
token_type_ids
=
None
):
x
=
self
.
word_embeddings
(
input_tokens
)
x
+=
self
.
position_embeddings
.
weight
[:
x
.
shape
[
1
]]
if
token_type_ids
is
not
None
:
x
+=
self
.
token_type_embeddings
(
token_type_ids
)
else
:
x
+=
self
.
token_type_embeddings
.
weight
[
0
]
x
=
self
.
LayerNorm
(
x
)
return
x
class
BertModel_
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
,
operations
):
super
().
__init__
()
embed_dim
=
config_dict
[
"hidden_size"
]
layer_norm_eps
=
config_dict
[
"layer_norm_eps"
]
self
.
embeddings
=
BertEmbeddings
(
config_dict
[
"vocab_size"
],
config_dict
[
"max_position_embeddings"
],
config_dict
[
"type_vocab_size"
],
config_dict
[
"pad_token_id"
],
embed_dim
,
layer_norm_eps
,
dtype
,
device
,
operations
)
self
.
encoder
=
BertEncoder
(
config_dict
[
"num_hidden_layers"
],
embed_dim
,
config_dict
[
"intermediate_size"
],
config_dict
[
"num_attention_heads"
],
layer_norm_eps
,
dtype
,
device
,
operations
)
def
forward
(
self
,
input_tokens
,
attention_mask
=
None
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
):
x
=
self
.
embeddings
(
input_tokens
)
mask
=
None
if
attention_mask
is
not
None
:
mask
=
1.0
-
attention_mask
.
to
(
x
.
dtype
).
reshape
((
attention_mask
.
shape
[
0
],
1
,
-
1
,
attention_mask
.
shape
[
-
1
])).
expand
(
attention_mask
.
shape
[
0
],
1
,
attention_mask
.
shape
[
-
1
],
attention_mask
.
shape
[
-
1
])
mask
=
mask
.
masked_fill
(
mask
.
to
(
torch
.
bool
),
float
(
"-inf"
))
x
,
i
=
self
.
encoder
(
x
,
mask
,
intermediate_output
)
return
x
,
i
class
BertModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
bert
=
BertModel_
(
config_dict
,
dtype
,
device
,
operations
)
self
.
num_layers
=
config_dict
[
"num_hidden_layers"
]
def
get_input_embeddings
(
self
):
return
self
.
bert
.
embeddings
.
word_embeddings
def
set_input_embeddings
(
self
,
embeddings
):
self
.
bert
.
embeddings
.
word_embeddings
=
embeddings
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
bert
(
*
args
,
**
kwargs
)
comfy/text_encoders/hydit.py
View file @
a9ac56fc
from
comfy
import
sd1_clip
from
transformers
import
T5TokenizerFast
,
BertTokenizer
,
BertModel
,
modeling_utils
,
BertConfig
from
transformers
import
BertTokenizer
from
.spiece_tokenizer
import
SPieceTokenizer
from
.bert
import
BertModel
import
comfy.text_encoders.t5
import
os
import
torch
import
contextlib
@
contextlib
.
contextmanager
def
use_comfy_ops
(
ops
,
device
=
None
,
dtype
=
None
):
old_torch_nn_linear
=
torch
.
nn
.
Linear
force_device
=
device
force_dtype
=
dtype
def
linear_with_dtype
(
in_features
:
int
,
out_features
:
int
,
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
):
if
force_device
is
not
None
:
device
=
force_device
if
force_dtype
is
not
None
:
dtype
=
force_dtype
return
ops
.
Linear
(
in_features
,
out_features
,
bias
=
bias
,
device
=
device
,
dtype
=
dtype
)
torch
.
nn
.
Linear
=
linear_with_dtype
try
:
yield
finally
:
torch
.
nn
.
Linear
=
old_torch_nn_linear
class
RobertaWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config_dict
,
dtype
,
device
,
operations
):
super
().
__init__
()
config
=
BertConfig
(
**
config_dict
)
with
use_comfy_ops
(
operations
,
device
,
dtype
):
with
modeling_utils
.
no_init_weights
():
self
.
bert
=
BertModel
(
config
,
add_pooling_layer
=
False
)
self
.
num_layers
=
config
.
num_hidden_layers
def
get_input_embeddings
(
self
):
return
self
.
bert
.
get_input_embeddings
()
def
set_input_embeddings
(
self
,
value
):
return
self
.
bert
.
set_input_embeddings
(
value
)
def
forward
(
self
,
input_tokens
,
attention_mask
=
None
,
intermediate_output
=
None
,
final_layer_norm_intermediate
=
True
):
intermediate
=
None
out
=
self
.
bert
(
input_ids
=
input_tokens
,
output_hidden_states
=
intermediate_output
is
not
None
,
attention_mask
=
attention_mask
)
return
out
.
last_hidden_state
,
intermediate
,
out
.
pooler_output
class
HyditBertModel
(
sd1_clip
.
SDClipModel
):
def
__init__
(
self
,
device
=
"cpu"
,
layer
=
"last"
,
layer_idx
=
None
,
dtype
=
None
):
textmodel_json_config
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"hydit_clip.json"
)
super
().
__init__
(
device
=
device
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
dtype
=
dtype
,
special_tokens
=
{
"start"
:
101
,
"end"
:
102
,
"pad"
:
0
},
model_class
=
RobertaWrapper
,
enable_attention_masks
=
True
,
return_attention_masks
=
True
)
super
().
__init__
(
device
=
device
,
layer
=
layer
,
layer_idx
=
layer_idx
,
textmodel_json_config
=
textmodel_json_config
,
dtype
=
dtype
,
special_tokens
=
{
"start"
:
101
,
"end"
:
102
,
"pad"
:
0
},
model_class
=
BertModel
,
enable_attention_masks
=
True
,
return_attention_masks
=
True
)
class
HyditBertTokenizer
(
sd1_clip
.
SDTokenizer
):
def
__init__
(
self
,
embedding_directory
=
None
,
tokenizer_data
=
{}):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment