Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
ComfyUI
Commits
b8e58a93
Commit
b8e58a93
authored
Jul 06, 2024
by
comfyanonymous
Browse files
Cleanup T5 code a bit.
parent
80c45909
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
23 additions
and
14 deletions
+23
-14
comfy/t5.py
comfy/t5.py
+21
-14
comfy/t5_config_base.json
comfy/t5_config_base.json
+1
-0
comfy/t5_config_xxl.json
comfy/t5_config_xxl.json
+1
-0
No files found.
comfy/t5.py
View file @
b8e58a93
...
...
@@ -13,29 +13,36 @@ class T5LayerNorm(torch.nn.Module):
x
=
x
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
return
self
.
weight
.
to
(
device
=
x
.
device
,
dtype
=
x
.
dtype
)
*
x
activations
=
{
"gelu_pytorch_tanh"
:
lambda
a
:
torch
.
nn
.
functional
.
gelu
(
a
,
approximate
=
"tanh"
),
"relu"
:
torch
.
nn
.
functional
.
relu
,
}
class
T5DenseActDense
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
ff_dim
,
dtype
,
device
,
operations
):
def
__init__
(
self
,
model_dim
,
ff_dim
,
ff_activation
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
wi
=
operations
.
Linear
(
model_dim
,
ff_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
wo
=
operations
.
Linear
(
ff_dim
,
model_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
# self.dropout = nn.Dropout(config.dropout_rate)
self
.
act
=
activations
[
ff_activation
]
def
forward
(
self
,
x
):
x
=
torch
.
nn
.
functional
.
relu
(
self
.
wi
(
x
))
x
=
self
.
act
(
self
.
wi
(
x
))
# x = self.dropout(x)
x
=
self
.
wo
(
x
)
return
x
class
T5DenseGatedActDense
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
ff_dim
,
dtype
,
device
,
operations
):
def
__init__
(
self
,
model_dim
,
ff_dim
,
ff_activation
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
wi_0
=
operations
.
Linear
(
model_dim
,
ff_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
wi_1
=
operations
.
Linear
(
model_dim
,
ff_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
self
.
wo
=
operations
.
Linear
(
ff_dim
,
model_dim
,
bias
=
False
,
dtype
=
dtype
,
device
=
device
)
# self.dropout = nn.Dropout(config.dropout_rate)
self
.
act
=
activations
[
ff_activation
]
def
forward
(
self
,
x
):
hidden_gelu
=
torch
.
nn
.
functional
.
gelu
(
self
.
wi_0
(
x
),
approximate
=
"tanh"
)
hidden_gelu
=
self
.
act
(
self
.
wi_0
(
x
)
)
hidden_linear
=
self
.
wi_1
(
x
)
x
=
hidden_gelu
*
hidden_linear
# x = self.dropout(x)
...
...
@@ -43,12 +50,12 @@ class T5DenseGatedActDense(torch.nn.Module):
return
x
class
T5LayerFF
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
ff_dim
,
ff_activation
,
dtype
,
device
,
operations
):
def
__init__
(
self
,
model_dim
,
ff_dim
,
ff_activation
,
gated_act
,
dtype
,
device
,
operations
):
super
().
__init__
()
if
ff_activation
==
"gelu_pytorch_tanh"
:
self
.
DenseReluDense
=
T5DenseGatedActDense
(
model_dim
,
ff_dim
,
dtype
,
device
,
operations
)
el
if
ff_activation
==
"relu"
:
self
.
DenseReluDense
=
T5DenseActDense
(
model_dim
,
ff_dim
,
dtype
,
device
,
operations
)
if
gated_act
:
self
.
DenseReluDense
=
T5DenseGatedActDense
(
model_dim
,
ff_dim
,
ff_activation
,
dtype
,
device
,
operations
)
el
se
:
self
.
DenseReluDense
=
T5DenseActDense
(
model_dim
,
ff_dim
,
ff_activation
,
dtype
,
device
,
operations
)
self
.
layer_norm
=
T5LayerNorm
(
model_dim
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# self.dropout = nn.Dropout(config.dropout_rate)
...
...
@@ -171,11 +178,11 @@ class T5LayerSelfAttention(torch.nn.Module):
return
x
,
past_bias
class
T5Block
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model_dim
,
inner_dim
,
ff_dim
,
ff_activation
,
num_heads
,
relative_attention_bias
,
dtype
,
device
,
operations
):
def
__init__
(
self
,
model_dim
,
inner_dim
,
ff_dim
,
ff_activation
,
gated_act
,
num_heads
,
relative_attention_bias
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
layer
=
torch
.
nn
.
ModuleList
()
self
.
layer
.
append
(
T5LayerSelfAttention
(
model_dim
,
inner_dim
,
ff_dim
,
num_heads
,
relative_attention_bias
,
dtype
,
device
,
operations
))
self
.
layer
.
append
(
T5LayerFF
(
model_dim
,
ff_dim
,
ff_activation
,
dtype
,
device
,
operations
))
self
.
layer
.
append
(
T5LayerFF
(
model_dim
,
ff_dim
,
ff_activation
,
gated_act
,
dtype
,
device
,
operations
))
def
forward
(
self
,
x
,
mask
=
None
,
past_bias
=
None
,
optimized_attention
=
None
):
x
,
past_bias
=
self
.
layer
[
0
](
x
,
mask
,
past_bias
,
optimized_attention
)
...
...
@@ -183,11 +190,11 @@ class T5Block(torch.nn.Module):
return
x
,
past_bias
class
T5Stack
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
num_layers
,
model_dim
,
inner_dim
,
ff_dim
,
ff_activation
,
num_heads
,
dtype
,
device
,
operations
):
def
__init__
(
self
,
num_layers
,
model_dim
,
inner_dim
,
ff_dim
,
ff_activation
,
gated_act
,
num_heads
,
dtype
,
device
,
operations
):
super
().
__init__
()
self
.
block
=
torch
.
nn
.
ModuleList
(
[
T5Block
(
model_dim
,
inner_dim
,
ff_dim
,
ff_activation
,
num_heads
,
relative_attention_bias
=
(
i
==
0
),
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
for
i
in
range
(
num_layers
)]
[
T5Block
(
model_dim
,
inner_dim
,
ff_dim
,
ff_activation
,
gated_act
,
num_heads
,
relative_attention_bias
=
(
i
==
0
),
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
for
i
in
range
(
num_layers
)]
)
self
.
final_layer_norm
=
T5LayerNorm
(
model_dim
,
dtype
=
dtype
,
device
=
device
,
operations
=
operations
)
# self.dropout = nn.Dropout(config.dropout_rate)
...
...
@@ -216,7 +223,7 @@ class T5(torch.nn.Module):
self
.
num_layers
=
config_dict
[
"num_layers"
]
model_dim
=
config_dict
[
"d_model"
]
self
.
encoder
=
T5Stack
(
self
.
num_layers
,
model_dim
,
model_dim
,
config_dict
[
"d_ff"
],
config_dict
[
"dense_act_fn"
],
config_dict
[
"num_heads"
],
dtype
,
device
,
operations
)
self
.
encoder
=
T5Stack
(
self
.
num_layers
,
model_dim
,
model_dim
,
config_dict
[
"d_ff"
],
config_dict
[
"dense_act_fn"
],
config_dict
[
"is_gated_act"
],
config_dict
[
"num_heads"
],
dtype
,
device
,
operations
)
self
.
dtype
=
dtype
self
.
shared
=
torch
.
nn
.
Embedding
(
config_dict
[
"vocab_size"
],
model_dim
,
device
=
device
)
...
...
comfy/t5_config_base.json
View file @
b8e58a93
...
...
@@ -8,6 +8,7 @@
"dense_act_fn"
:
"relu"
,
"initializer_factor"
:
1.0
,
"is_encoder_decoder"
:
true
,
"is_gated_act"
:
false
,
"layer_norm_epsilon"
:
1e-06
,
"model_type"
:
"t5"
,
"num_decoder_layers"
:
12
,
...
...
comfy/t5_config_xxl.json
View file @
b8e58a93
...
...
@@ -8,6 +8,7 @@
"dense_act_fn"
:
"gelu_pytorch_tanh"
,
"initializer_factor"
:
1.0
,
"is_encoder_decoder"
:
true
,
"is_gated_act"
:
true
,
"layer_norm_epsilon"
:
1e-06
,
"model_type"
:
"t5"
,
"num_decoder_layers"
:
24
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment