Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
820ea4d0
Commit
820ea4d0
authored
Nov 02, 2022
by
oahzxl
Browse files
align evoformer
parent
86f2a314
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
67 additions
and
192 deletions
+67
-192
chunk_codegen.py
chunk_codegen.py
+21
-122
chunk_codegen_run.py
chunk_codegen_run.py
+34
-63
evoformer/evoformer.py
evoformer/evoformer.py
+6
-1
evoformer/kernel.py
evoformer/kernel.py
+1
-1
evoformer/msa.py
evoformer/msa.py
+1
-1
evoformer/triangle.py
evoformer/triangle.py
+4
-4
No files found.
chunk_codegen.py
View file @
820ea4d0
import
colossalai
import
colossalai
import
torch
import
torch
import
copy
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
,
Iterable
from
typing
import
List
,
Callable
,
Any
,
Tuple
,
Dict
,
Iterable
try
:
try
:
...
@@ -17,75 +18,19 @@ else:
...
@@ -17,75 +18,19 @@ else:
__all__
=
[
'python_code_with_activation_checkpoint'
]
__all__
=
[
'python_code_with_activation_checkpoint'
]
def
_gen_saved_tensors_hooks
():
def
_gen_loop_start
(
to_keep
,
chunk_size
=
2
):
"""
context
=
"chunk_result = []; chunk_size = %d
\n
for gen_loop_idx in range(0, %s.shape[0], chunk_size):
\n
"
%
(
chunk_size
,
to_keep
[
0
])
Generate saved tensors hooks
context
+=
" chunk_tensor = "
+
to_keep
+
"[gen_loop_idx:gen_loop_idx + chunk_size, :]
\n
"
"""
pack_hook
=
"""def pack_hook_input(self, x):
if getattr(x, "offload", False):
return (x.device, x.cpu())
else:
return x
def pack_hook_no_input(self, x):
if getattr(x, "offload", True):
return (x.device, x.cpu())
else:
return x
"""
unpack_hook
=
"""def unpack_hook(self, packed):
if isinstance(packed, tuple):
device, tensor = packed
return tensor.to(device)
else:
return packed
"""
return
pack_hook
,
unpack_hook
def
_gen_loop_5
(
to_keep
):
context
=
"chunk_result = []
\n
for gen_loop_idx in range(4):
\n
"
context
+=
" chunk_tensor = "
+
to_keep
+
"[gen_loop_idx, :]
\n
"
return
context
return
context
def
_gen_loop_
5_final
(
final_name
,
to_keep
):
def
_gen_loop_
end
(
final_name
,
to_keep
):
context
=
" chunk_result.append("
+
final_name
+
")
\n
"
context
=
" chunk_result.append("
+
final_name
+
")
\n
"
context
+=
"chunk_result = torch.cat(chunk_result, dim=0); "
+
to_keep
[
0
]
+
" = None
\n
"
context
+=
"chunk_result = torch.cat(chunk_result, dim=0); "
+
to_keep
[
0
]
+
" = None
\n
"
context
+=
final_name
+
" = chunk_result; chunk_result = None
\n
"
context
+=
final_name
+
" = chunk_result; chunk_result = None
\n
"
return
context
return
context
def
_gen_save_tensors_hooks_context
(
offload_input
=
True
)
->
str
:
"""Generate customized saved_tensors_hooks
Args:
offload_input (bool, optional): whether we need offload input, if offload_input=False,
we will use self.pack_hook_no_input instead. Defaults to True.
Returns:
str: generated context
"""
if
offload_input
:
context
=
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_input, self.unpack_hook):
\n
"
else
:
context
=
"with torch.autograd.graph.saved_tensors_hooks(self.pack_hook_no_input, self.unpack_hook):
\n
"
return
context
def
_gen_save_on_cpu_context
():
"""
Generate save on cpu context
"""
context
=
"with torch.autograd.graph.save_on_cpu(pin_memory=True):
\n
"
return
context
def
_find_input_and_output_nodes
(
nodes
:
List
[
Node
]):
def
_find_input_and_output_nodes
(
nodes
:
List
[
Node
]):
"""
"""
Find the input and output node names which are not found in the given list of nodes.
Find the input and output node names which are not found in the given list of nodes.
...
@@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
...
@@ -112,49 +57,6 @@ def _find_input_and_output_nodes(nodes: List[Node]):
return
input_nodes
,
output_nodes
return
input_nodes
,
output_nodes
def
_find_ckpt_regions
(
nodes
:
List
[
Node
]):
"""
Find the checkpoint regions given a list of consecutive nodes. The outputs will be list
of tuples, each tuple is in the form of (start_index, end_index).
"""
ckpt_nodes
=
[]
ckpt_regions
=
[]
start
=
-
1
end
=
-
1
current_region
=
None
for
idx
,
node
in
enumerate
(
nodes
):
if
hasattr
(
node
,
'activation_checkpoint'
):
act_ckpt_label
=
node
.
activation_checkpoint
# this activation checkpoint label is not set yet
# meaning this is the first node of the activation ckpt region
if
current_region
is
None
:
current_region
=
act_ckpt_label
start
=
idx
# if activation checkpoint has changed
# we restart the tracking
# e.g. node ckpt states = [ckpt1, ckpt2, ckpt2, ckpt2]
if
act_ckpt_label
!=
current_region
:
assert
start
!=
-
1
ckpt_regions
.
append
((
start
,
idx
-
1
))
current_region
=
act_ckpt_label
start
=
idx
end
=
-
1
elif
current_region
is
not
None
and
not
hasattr
(
node
,
'activation_checkpoint'
):
# used to check the case below
# node ckpt states = [ckpt, ckpt, non-ckpt]
end
=
idx
-
1
assert
start
!=
-
1
and
end
!=
-
1
ckpt_regions
.
append
((
start
,
end
))
start
=
end
=
-
1
current_region
=
None
else
:
pass
return
ckpt_regions
def
_find_offload_regions
(
nodes
:
List
[
Node
]):
def
_find_offload_regions
(
nodes
:
List
[
Node
]):
"""This function is to find the offload regions
"""This function is to find the offload regions
In pofo algorithm, during annotation, we will annotate the offload region with the
In pofo algorithm, during annotation, we will annotate the offload region with the
...
@@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -400,12 +302,9 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
emit_node_func: function to emit node
emit_node_func: function to emit node
delete_unused_value_func: function to remove the unused value
delete_unused_value_func: function to remove the unused value
"""
"""
ckpt_regions
=
_find_nested_ckpt_regions
(
nodes
,
0
)
start_idx
=
[
item
[
0
]
for
item
in
ckpt_regions
]
end_idx
=
[
item
[
1
]
for
item
in
ckpt_regions
]
# find the offload regions
# find the offload regions
chunk_regions
,
chunk_labels
=
_find_offload_regions
(
nodes
)
chunk_regions
=
[(
1
,
4
)]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_starts
=
[
item
[
0
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_ends
=
[
item
[
1
]
for
item
in
chunk_regions
]
chunk_inputs
=
[]
chunk_inputs
=
[]
...
@@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -424,7 +323,7 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
# this flag is to prevent repeated insert of save tensors
# this flag is to prevent repeated insert of save tensors
# hooks definition in ckpt_func
# hooks definition in ckpt_func
node_idx
=
0
node_idx
=
0
to_keep
=
[]
chunk_var
=
[]
while
node_idx
<
len
(
node_list
):
while
node_idx
<
len
(
node_list
):
# break if we finish the processing all the nodes
# break if we finish the processing all the nodes
if
node_idx
>=
len
(
node_list
):
if
node_idx
>=
len
(
node_list
):
...
@@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
...
@@ -435,28 +334,30 @@ def emit_code_with_chunk(body, ckpt_func, nodes, emit_node_func, delete_unused_v
node
=
node_list
[
node_idx
]
node
=
node_list
[
node_idx
]
if
node_idx
in
chunk_starts
:
if
node_idx
in
chunk_starts
:
# save chunk input var, dont delete it
to_keep
.
extend
(
node
.
args
[
0
].
name
)
within_chunk_region
=
True
within_chunk_region
=
True
# save chunk input var, dont delete it
chunk_var
.
append
(
node
.
args
[
0
].
name
)
# add for loop
# add for loop
body
.
append
(
_gen_loop_5
(
to_keep
[
0
]))
body
.
append
(
_gen_loop_start
(
chunk_var
[
0
]))
# change first node's input to new chunked var
node_args
=
list
(
node
.
args
)
node_args
[
0
]
=
'chunk_tensor'
if
within_chunk_region
:
if
within_chunk_region
:
emit_node_func
(
node
,
body
)
emit_node_func
(
node
,
body
)
# replace input var with chunk var
if
node_idx
in
chunk_starts
:
body
[
-
1
]
=
body
[
-
1
].
replace
(
"("
+
chunk_var
[
0
]
+
")"
,
'(chunk_tensor)'
)
body
[
-
1
]
=
' '
+
body
[
-
1
]
body
[
-
1
]
=
' '
+
body
[
-
1
]
delete_unused_value_func
(
node
,
body
,
to_keep
)
delete_unused_value_func
(
node
,
body
,
chunk_var
)
else
:
else
:
emit_node_func
(
node
,
body
)
emit_node_func
(
node
,
body
)
if
node_idx
not
in
chunk_inputs
:
if
node_idx
not
in
chunk_inputs
:
delete_unused_value_func
(
node
,
body
,
to_keep
)
delete_unused_value_func
(
node
,
body
,
chunk_var
)
if
node_idx
in
chunk_ends
:
if
node_idx
in
chunk_ends
:
body
.
append
(
_gen_loop_
5_final
(
node
.
name
,
to_keep
))
body
.
append
(
_gen_loop_
end
(
node
.
name
,
chunk_var
))
to_keep
=
[]
chunk_var
=
[]
within_chunk_region
=
False
within_chunk_region
=
False
node_idx
+=
1
node_idx
+=
1
...
@@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
...
@@ -580,9 +481,7 @@ if CODEGEN_AVAILABLE:
body
.
append
(
'
\n
'
)
body
.
append
(
'
\n
'
)
return
return
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
nodes_to_delete
=
user_to_last_uses
.
get
(
user
,
[])
for
n
in
nodes_to_delete
:
nodes_to_delete
=
[
i
for
i
in
nodes_to_delete
if
i
.
name
not
in
to_keep
]
if
n
.
name
in
to_keep
:
nodes_to_delete
.
remove
(
n
)
if
len
(
nodes_to_delete
):
if
len
(
nodes_to_delete
):
to_delete_str
=
' = '
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
'None'
])
to_delete_str
=
' = '
.
join
([
repr
(
n
)
for
n
in
nodes_to_delete
]
+
[
'None'
])
body
.
append
(
f
';
{
to_delete_str
}
\n
'
)
body
.
append
(
f
';
{
to_delete_str
}
\n
'
)
...
...
chunk_codegen_run.py
View file @
820ea4d0
...
@@ -9,60 +9,39 @@ import colossalai
...
@@ -9,60 +9,39 @@ import colossalai
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.core
import
global_context
as
gpc
from
colossalai.core
import
global_context
as
gpc
from
colossalai.fx.graph_module
import
ColoGraphModule
from
colossalai.fx.graph_module
import
ColoGraphModule
from
evoformer.evoformer
import
evoformer_base
try
:
from
chunk_codegen
import
ChunkCodeGen
from
chunk_codegen
import
ChunkCodeGen
with_codegen
=
True
with_codegen
=
True
except
:
# fall back to older pytorch version
from
chunk_codegen
import
python_code_with_activation_checkpoint
with_codegen
=
False
class
MyNet
(
torch
.
nn
.
Module
):
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
linear0
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear1
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear2
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear3
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear4
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear5
=
torch
.
nn
.
Linear
(
4
,
4
)
self
.
linear6
=
torch
.
nn
.
Linear
(
4
,
4
)
def
forward
(
self
,
x
):
x
=
self
.
linear0
(
x
)
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
x
=
self
.
linear3
(
x
)
x
=
self
.
linear4
(
x
)
x
=
self
.
linear5
(
x
)
x
=
self
.
linear6
(
x
)
return
x
def
_is_all_gradient_close
(
m
:
torch
.
nn
.
Module
,
gm
:
GraphModule
)
->
bool
:
def
_is_all_gradient_close
(
m
:
torch
.
nn
.
Module
,
gm
:
GraphModule
)
->
bool
:
for
m_p
,
gm_p
in
zip
(
m
.
parameters
(),
gm
.
parameters
()):
for
m_p
,
gm_p
in
zip
(
m
.
parameters
(),
gm
.
parameters
()):
if
not
torch
.
allclose
(
m_p
.
grad
,
gm_p
.
grad
):
if
m_p
.
grad
is
not
None
and
not
torch
.
allclose
(
m_p
.
grad
,
gm_p
.
grad
):
return
False
return
False
return
True
return
True
def
_test_fwd_and_bwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
data
:
torch
.
Tensor
):
def
_is_all_param_close
(
m
:
torch
.
nn
.
Module
,
gm
:
GraphModule
)
->
bool
:
for
m_p
,
gm_p
in
zip
(
m
.
parameters
(),
gm
.
parameters
()):
if
m_p
.
grad
is
not
None
and
not
torch
.
allclose
(
m_p
.
data
,
gm_p
.
data
):
return
False
return
True
def
_test_fwd_and_bwd
(
model
:
torch
.
nn
.
Module
,
gm
:
ColoGraphModule
,
node
,
pair
):
# test forward
# test forward
non_fx_out
=
model
(
data
)
non_fx_out
=
model
(
node
.
clone
(),
pair
.
clone
()
)
fx_out
=
gm
(
data
)
fx_out
=
gm
(
node
.
clone
(),
pair
.
clone
()
)
print
(
non_fx_out
.
shape
,
fx_out
.
shape
)
assert
torch
.
equal
(
non_fx_out
[
0
],
fx_out
[
0
])
,
"
fx_out
doesn't comply with original output"
assert
torch
.
equal
(
non_fx_out
,
fx_out
),
"fx_out doesn't comply with original output"
assert
torch
.
equal
(
non_fx_out
[
1
]
,
fx_out
[
1
]
),
"fx_out doesn't comply with original output"
# test barckward
# test barckward
loss0
=
non_fx_out
.
sum
()
# loss0 = non_fx_out[0].sum() + non_fx_out[1].sum()
loss0
.
backward
()
# loss0.backward()
loss1
=
fx_out
.
sum
()
# loss1 = fx_out[0].sum() + fx_out[1].sum()
loss1
.
backward
()
# loss1.backward()
assert
_is_all_gradient_close
(
model
,
gm
),
"gm doesn't have the same gradient as original one"
# assert _is_all_param_close(model, gm)
# assert _is_all_gradient_close(model, gm), "gm doesn't have the same gradient as original one"
def
_run_offload_codegen
(
rank
):
def
_run_offload_codegen
(
rank
):
...
@@ -70,30 +49,22 @@ def _run_offload_codegen(rank):
...
@@ -70,30 +49,22 @@ def _run_offload_codegen(rank):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
1
,
host
=
'localhost'
,
port
=
free_port
(),
backend
=
'nccl'
)
# build model and input
# build model and input
model
=
MyNet
().
cuda
()
model
=
evoformer_base
().
cuda
()
data
=
torch
.
rand
(
4
,
4
).
cuda
()
node
=
torch
.
randn
(
1
,
16
,
32
,
256
).
cuda
()
pair
=
torch
.
randn
(
1
,
32
,
32
,
128
).
cuda
()
# trace the module and replace codegen
# trace the module and replace codegen
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
tracer
=
ColoTracer
(
trace_act_ckpt
=
True
)
graph
=
tracer
.
trace
(
model
)
graph
=
tracer
.
trace
(
model
)
codegen
=
ChunkCodeGen
()
# codegen = ChunkCodeGen()
graph
.
set_codegen
(
codegen
)
# graph.set_codegen(codegen)
# annotate the activation offload part
# annotate the chunk part
# also annotate the activation_checkpoint so we could test both types
# for node in graph.nodes:
# of input offload
# if node.name == "linear0":
for
node
in
graph
.
nodes
:
# setattr(node, "activation_offload", [0, True, False])
if
node
.
name
==
"linear0"
:
# if node.name == "linear1":
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
])
# setattr(node, "activation_offload", [0, True, False])
if
node
.
name
==
"linear1"
:
setattr
(
node
,
"activation_offload"
,
[
0
,
True
,
False
])
# if node.name == "linear2":
# setattr(node, "activation_offload", [1, True, True])
# if node.name == "linear4":
# setattr(node, "activation_offload", [2, False, True])
# if node.name == "linear5":
# setattr(node, "activation_checkpoint", [0])
# setattr(node, "activation_offload", True)
gm
=
ColoGraphModule
(
copy
.
deepcopy
(
model
),
graph
)
gm
=
ColoGraphModule
(
copy
.
deepcopy
(
model
),
graph
)
gm
.
recompile
()
gm
.
recompile
()
...
@@ -102,7 +73,7 @@ def _run_offload_codegen(rank):
...
@@ -102,7 +73,7 @@ def _run_offload_codegen(rank):
code
=
graph
.
python_code
(
"self"
).
src
code
=
graph
.
python_code
(
"self"
).
src
print
(
code
)
print
(
code
)
_test_fwd_and_bwd
(
model
,
gm
,
data
)
_test_fwd_and_bwd
(
model
,
gm
,
node
,
pair
)
gpc
.
destroy
()
gpc
.
destroy
()
...
...
evoformer/evoformer.py
View file @
820ea4d0
...
@@ -28,7 +28,7 @@ class Evoformer(nn.Module):
...
@@ -28,7 +28,7 @@ class Evoformer(nn.Module):
super
(
Evoformer
,
self
).
__init__
()
super
(
Evoformer
,
self
).
__init__
()
self
.
blocks
=
nn
.
ModuleList
()
self
.
blocks
=
nn
.
ModuleList
()
for
_
in
range
(
3
):
for
_
in
range
(
1
):
self
.
blocks
.
append
(
EvoformerBlock
(
d_node
,
d_pair
))
self
.
blocks
.
append
(
EvoformerBlock
(
d_node
,
d_pair
))
def
forward
(
self
,
node
,
pair
):
def
forward
(
self
,
node
,
pair
):
...
@@ -36,6 +36,11 @@ class Evoformer(nn.Module):
...
@@ -36,6 +36,11 @@ class Evoformer(nn.Module):
node
,
pair
=
b
(
node
,
pair
)
node
,
pair
=
b
(
node
,
pair
)
return
node
,
pair
return
node
,
pair
def
evoformer_tiny
():
return
Evoformer
(
d_node
=
64
,
d_pair
=
32
)
def
evoformer_base
():
def
evoformer_base
():
return
Evoformer
(
d_node
=
256
,
d_pair
=
128
)
return
Evoformer
(
d_node
=
256
,
d_pair
=
128
)
...
...
evoformer/kernel.py
View file @
820ea4d0
...
@@ -8,7 +8,7 @@ def bias_sigmod_ele(y, bias, z):
...
@@ -8,7 +8,7 @@ def bias_sigmod_ele(y, bias, z):
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
def
bias_dropout_add
(
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
dropmask
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
Tru
e
)
out
=
(
x
+
bias
)
*
F
.
dropout
(
dropmask
,
p
=
prob
,
training
=
Fals
e
)
out
=
residual
+
out
out
=
residual
+
out
return
out
return
out
...
...
evoformer/msa.py
View file @
820ea4d0
...
@@ -45,7 +45,7 @@ class MSARowAttentionWithPairBias(nn.Module):
...
@@ -45,7 +45,7 @@ class MSARowAttentionWithPairBias(nn.Module):
# b = rearrange(b, 'b q k h -> b h q k')
# b = rearrange(b, 'b q k h -> b h q k')
M
=
self
.
attention
(
M
,
b
)
M
=
self
.
attention
(
M
,
b
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]
,
device
=
M
.
device
,
dtype
=
M
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
M
[:,
0
:
1
,
:,
:]
).
to
(
M
.
device
).
to
(
M
.
dtype
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
return
bias_dropout_add
(
M
,
self
.
out_bias
,
dropout_mask
,
M_raw
,
prob
=
self
.
p_drop
)
...
...
evoformer/triangle.py
View file @
820ea4d0
...
@@ -51,7 +51,7 @@ class TriangleMultiplicationOutgoing(nn.Module):
...
@@ -51,7 +51,7 @@ class TriangleMultiplicationOutgoing(nn.Module):
ab
=
torch
.
einsum
(
'bikd,bjkd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
torch
.
einsum
(
'bikd,bjkd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
,
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
self
.
output_bias
,
g
,
g
,
...
@@ -97,7 +97,7 @@ class TriangleMultiplicationIncoming(nn.Module):
...
@@ -97,7 +97,7 @@ class TriangleMultiplicationIncoming(nn.Module):
ab
=
torch
.
einsum
(
'bkid,bkjd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
torch
.
einsum
(
'bkid,bkjd->bijd'
,
left_proj_act
,
right_proj_act
)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
,
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
self
.
output_bias
,
g
,
g
,
...
@@ -134,7 +134,7 @@ class TriangleAttentionStartingNode(nn.Module):
...
@@ -134,7 +134,7 @@ class TriangleAttentionStartingNode(nn.Module):
Z
=
self
.
attention
(
Z
,
b
)
Z
=
self
.
attention
(
Z
,
b
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
,
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:]
).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
...
@@ -168,7 +168,7 @@ class TriangleAttentionEndingNode(nn.Module):
...
@@ -168,7 +168,7 @@ class TriangleAttentionEndingNode(nn.Module):
Z
=
self
.
attention
(
Z
,
b
)
Z
=
self
.
attention
(
Z
,
b
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:]
,
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:]
).
to
(
Z
.
device
).
to
(
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment