Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastFold
Commits
b14e47f4
Commit
b14e47f4
authored
Apr 26, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/FastFold
parents
490cb6f5
05681304
Pipeline
#234
failed with stages
in 0 seconds
Changes
188
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3944 additions
and
0 deletions
+3944
-0
fastfold/habana/fastnn/ops.py
fastfold/habana/fastnn/ops.py
+313
-0
fastfold/habana/fastnn/triangle.py
fastfold/habana/fastnn/triangle.py
+228
-0
fastfold/habana/inject_habana.py
fastfold/habana/inject_habana.py
+403
-0
fastfold/model/__init__.py
fastfold/model/__init__.py
+0
-0
fastfold/model/fastnn/__init__.py
fastfold/model/fastnn/__init__.py
+13
-0
fastfold/model/fastnn/embedders.py
fastfold/model/fastnn/embedders.py
+381
-0
fastfold/model/fastnn/embedders_multimer.py
fastfold/model/fastnn/embedders_multimer.py
+387
-0
fastfold/model/fastnn/evoformer.py
fastfold/model/fastnn/evoformer.py
+332
-0
fastfold/model/fastnn/initializer.py
fastfold/model/fastnn/initializer.py
+29
-0
fastfold/model/fastnn/kernel/__init__.py
fastfold/model/fastnn/kernel/__init__.py
+13
-0
fastfold/model/fastnn/kernel/attention_core.py
fastfold/model/fastnn/kernel/attention_core.py
+53
-0
fastfold/model/fastnn/kernel/cuda_native/__init__.py
fastfold/model/fastnn/kernel/cuda_native/__init__.py
+0
-0
fastfold/model/fastnn/kernel/cuda_native/csrc/compat.h
fastfold/model/fastnn/kernel/cuda_native/csrc/compat.h
+11
-0
fastfold/model/fastnn/kernel/cuda_native/csrc/layer_norm_cuda.cpp
.../model/fastnn/kernel/cuda_native/csrc/layer_norm_cuda.cpp
+125
-0
fastfold/model/fastnn/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
.../fastnn/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
+500
-0
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
...old/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
+27
-0
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
...del/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
+839
-0
fastfold/model/fastnn/kernel/cuda_native/csrc/type_shim.h
fastfold/model/fastnn/kernel/cuda_native/csrc/type_shim.h
+233
-0
fastfold/model/fastnn/kernel/cuda_native/layer_norm.py
fastfold/model/fastnn/kernel/cuda_native/layer_norm.py
+35
-0
fastfold/model/fastnn/kernel/cuda_native/softmax.py
fastfold/model/fastnn/kernel/cuda_native/softmax.py
+22
-0
No files found.
fastfold/habana/fastnn/ops.py
0 → 100755
View file @
b14e47f4
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
fastfold.habana.distributed
import
gather
,
scatter
from
.initializer
import
glorot_uniform_af
from
.kernel
import
bias_sigmod_ele
from
fastfold.habana.distributed
import
gather
,
scatter
from
fastfold.habana.fastnn.custom_op
import
fused_softmax
,
fused_softmax_bias
CHUNK_SIZE
=
None
DEBUG
=
False
def
set_chunk_size
(
chunk_size
):
global
CHUNK_SIZE
CHUNK_SIZE
=
chunk_size
def
get_chunk_size
():
global
CHUNK_SIZE
return
CHUNK_SIZE
class
DropoutRowwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutRowwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
0
:
1
,
:,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
DropoutColumnwise
(
nn
.
Module
):
def
__init__
(
self
,
p
):
super
(
DropoutColumnwise
,
self
).
__init__
()
self
.
p
=
p
self
.
dropout
=
nn
.
Dropout
(
p
=
p
)
def
forward
(
self
,
x
):
dropout_mask
=
torch
.
ones_like
(
x
[:,
:,
0
:
1
,
:])
dropout_mask
=
self
.
dropout
(
dropout_mask
)
return
dropout_mask
*
x
class
Transition
(
nn
.
Module
):
def
__init__
(
self
,
d
,
n
=
4
):
super
(
Transition
,
self
).
__init__
()
self
.
norm
=
LayerNorm
(
d
)
self
.
linear1
=
Linear
(
d
,
n
*
d
,
initializer
=
'relu'
)
self
.
linear2
=
Linear
(
n
*
d
,
d
,
initializer
=
'zeros'
)
def
forward
(
self
,
src
):
x
=
self
.
norm
(
src
)
x
=
self
.
linear2
(
F
.
relu
(
self
.
linear1
(
x
)))
return
src
+
x
class
OutProductMean
(
nn
.
Module
):
def
__init__
(
self
,
n_feat
=
64
,
n_feat_out
=
128
,
n_feat_proj
=
32
):
super
(
OutProductMean
,
self
).
__init__
()
self
.
layernormM
=
LayerNorm
(
n_feat
)
self
.
linear_a
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
linear_b
=
Linear
(
n_feat
,
n_feat_proj
)
self
.
o_linear
=
Linear
(
n_feat_proj
*
n_feat_proj
,
n_feat_out
,
initializer
=
'zero'
,
use_bias
=
True
)
def
forward
(
self
,
M
,
M_mask
,
Z_raw
):
Z
=
torch
.
empty_like
(
Z_raw
)
M
=
self
.
layernormM
(
M
)
left_act
=
self
.
linear_a
(
M
)
right_act
=
self
.
linear_b
(
M
)
right_act_all
=
gather
(
right_act
,
dim
=
2
)
M_mask
=
M_mask
.
unsqueeze
(
-
1
)
M_mask_col
=
scatter
(
M_mask
,
dim
=
2
)
left_act
=
M_mask_col
*
left_act
right_act_all
=
M_mask
*
right_act_all
norm
=
torch
.
einsum
(
'...ab,...ad->...bd'
,
M_mask_col
.
squeeze
(
-
1
).
squeeze
(
0
),
M_mask
.
squeeze
(
-
1
).
squeeze
(
0
)).
unsqueeze
(
-
1
).
unsqueeze
(
0
)
para_dim
=
left_act
.
shape
[
2
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
out
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
left_act_part
=
left_act
[:,
:,
ax
:
ax
+
chunk_size
,
:]
# O = torch.einsum('sid,sje->ijde', left_act_part.squeeze(0), right_act_all.squeeze(0))
# O = rearrange(O, 'i j d e -> i j (d e)')
left_shape
=
left_act_part
.
shape
right_shape
=
right_act_all
.
shape
left_act_part
=
left_act_part
.
reshape
(
left_shape
[
0
],
left_shape
[
1
],
left_shape
[
2
]
*
left_shape
[
3
])
right_act_all
=
right_act_all
.
reshape
(
right_shape
[
0
],
right_shape
[
1
],
right_shape
[
2
]
*
right_shape
[
3
])
# O = torch.einsum('...ab,...ad->...bd', left_act_part.squeeze(0), right_act_all.squeeze(0))
O
=
torch
.
matmul
(
left_act_part
.
squeeze
(
0
).
transpose
(
1
,
0
),
right_act_all
.
squeeze
(
0
))
O
=
O
.
reshape
(
left_shape
[
2
],
left_shape
[
3
],
right_shape
[
2
],
right_shape
[
3
]).
transpose
(
-
2
,
-
3
)
O
=
O
.
reshape
(
O
.
shape
[
0
],
O
.
shape
[
1
],
O
.
shape
[
2
]
*
O
.
shape
[
3
])
O
=
O
.
unsqueeze
(
0
)
out
.
append
(
self
.
o_linear
(
O
))
Z
=
torch
.
cat
(
out
,
dim
=
1
)
Z
/=
(
1e-3
+
norm
)
return
Z
+
Z_raw
class
Linear
(
nn
.
Linear
):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def
__init__
(
self
,
feature_in
:
int
,
feature_out
:
int
,
initializer
:
str
=
'linear'
,
use_bias
:
bool
=
True
,
bias_init
:
float
=
0.
,
):
super
(
Linear
,
self
).
__init__
(
feature_in
,
feature_out
,
bias
=
use_bias
)
self
.
use_bias
=
use_bias
if
initializer
==
'linear'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
1.0
)
elif
initializer
==
'relu'
:
glorot_uniform_af
(
self
.
weight
,
gain
=
2.0
)
elif
initializer
==
'zeros'
:
nn
.
init
.
zeros_
(
self
.
weight
)
if
self
.
use_bias
:
with
torch
.
no_grad
():
self
.
bias
.
fill_
(
bias_init
)
class
SelfAttention
(
nn
.
Module
):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def
__init__
(
self
,
qkv_dim
,
c
,
n_head
,
out_dim
,
gating
=
True
,
last_bias_fuse
=
False
):
super
(
SelfAttention
,
self
).
__init__
()
self
.
qkv_dim
=
qkv_dim
self
.
c
=
c
self
.
n_head
=
n_head
self
.
out_dim
=
out_dim
self
.
gating
=
gating
self
.
last_bias_fuse
=
last_bias_fuse
self
.
scaling
=
self
.
c
**
(
-
0.5
)
self
.
to_qkv
=
Linear
(
qkv_dim
,
3
*
n_head
*
c
,
initializer
=
'linear'
,
use_bias
=
False
)
# self.to_q = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_k = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
# self.to_v = Linear(qkv_dim, n_head * c, initializer='linear', use_bias=False)
if
gating
:
self
.
gating_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
ones
((
n_head
*
c
,)))
self
.
gating_linear
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
'zero'
,
use_bias
=
False
)
self
.
o_linear
=
Linear
(
n_head
*
c
,
out_dim
,
initializer
=
'zero'
,
use_bias
=
(
not
last_bias_fuse
))
def
forward
(
self
,
in_data
,
mask
,
nonbatched_bias
=
None
):
"""
:param in_data: [batch_size1, batch_size2, len_qkv, qkv_dim]
:param bias: None or [batch_size1, batch_size2, n_head, len_q, len_kv]
:param nonbatched_bias: None or [batch_size1, n_head, len_q, len_kv]
"""
para_dim
=
in_data
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
in_data_part
=
in_data
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
qkv
=
self
.
to_qkv
(
in_data_part
).
chunk
(
3
,
dim
=-
1
)
q
,
k
,
v
=
map
(
lambda
t
:
rearrange
(
t
,
'b1 b2 n (h d) -> b1 b2 h n d'
,
h
=
self
.
n_head
),
qkv
)
# q = self.to_q(in_data_part)
# k = self.to_k(in_data_part)
# v = self.to_v(in_data_part)
# q, k, v = map(lambda t: rearrange(t, 'b1 b2 n (h d) -> b1 b2 h n d', h=self.n_head),
# [q, k, v])
q
=
q
*
self
.
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
# if nonbatched_bias is not None:
# logits += nonbatched_bias.unsqueeze(1)
# weights = torch.softmax(logits, dim=-1)
mask00
=
(
1e9
*
(
mask_part
-
1
))[...,
:,
None
,
None
,
:]
if
nonbatched_bias
is
not
None
:
weights
=
fused_softmax_bias
(
logits
,
mask00
,
nonbatched_bias
.
unsqueeze
(
1
),
-
1
)
else
:
weights
=
fused_softmax
(
logits
,
mask00
,
-
1
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
if
self
.
gating
:
gate_values
=
self
.
gating_linear
(
in_data_part
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
)
output
.
append
(
self
.
o_linear
(
weighted_avg
))
output
=
torch
.
cat
(
output
,
dim
=
1
)
return
output
class
GlobalAttention
(
nn
.
Module
):
"""
Multi-Head SelfAttention dealing with [batch_size1, batch_size2, len, dim] tensors
"""
def
__init__
(
self
,
qkv_dim
,
c
,
n_head
,
out_dim
):
super
(
GlobalAttention
,
self
).
__init__
()
self
.
qkv_dim
=
qkv_dim
self
.
c
=
c
self
.
n_head
=
n_head
self
.
out_dim
=
out_dim
self
.
scaling
=
self
.
c
**
(
-
0.5
)
self
.
eps
=
1e-10
self
.
inf
=
1e9
self
.
to_q
=
Linear
(
qkv_dim
,
c
*
self
.
n_head
,
use_bias
=
False
)
self
.
to_kv
=
Linear
(
qkv_dim
,
2
*
c
,
initializer
=
"linear"
,
use_bias
=
False
)
self
.
gating_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
ones
((
n_head
*
c
,)))
self
.
gating_linear
=
Linear
(
qkv_dim
,
n_head
*
c
,
initializer
=
"zero"
,
use_bias
=
False
)
self
.
o_linear
=
Linear
(
n_head
*
c
,
out_dim
,
initializer
=
"zero"
)
def
forward
(
self
,
m
,
mask
):
para_dim
=
m
.
shape
[
1
]
chunk_size
=
CHUNK_SIZE
if
CHUNK_SIZE
==
None
:
chunk_size
=
para_dim
output
=
[]
for
ax
in
range
(
0
,
para_dim
,
chunk_size
):
m_part
=
m
[:,
ax
:
ax
+
chunk_size
,
:,
:]
mask_part
=
mask
[:,
ax
:
ax
+
chunk_size
,
:]
q
=
torch
.
sum
(
m_part
*
mask_part
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask_part
,
dim
=-
1
)[...,
None
]
+
self
.
eps
)
q
=
self
.
to_q
(
q
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
n_head
,
-
1
))
k
,
v
=
self
.
to_kv
(
m_part
).
chunk
(
2
,
dim
=-
1
)
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
# logits += (1e9 * (mask_part - 1))[..., :, None, None, :]
weights
=
torch
.
softmax
(
logits
,
dim
=-
1
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
"b1 b2 h d -> b1 b2 (h d)"
)
gate_values
=
self
.
gating_linear
(
m_part
)
weighted_avg
=
bias_sigmod_ele
(
gate_values
,
self
.
gating_bias
,
weighted_avg
.
unsqueeze
(
-
2
))
output
.
append
(
self
.
o_linear
(
weighted_avg
))
m
=
torch
.
cat
(
output
,
dim
=
1
)
return
m
fastfold/habana/fastnn/triangle.py
0 → 100644
View file @
b14e47f4
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
torch.nn
import
LayerNorm
from
fastfold.habana.distributed
import
col_to_row
,
gather
,
row_to_col
,
scatter
from
.kernel
import
bias_dropout_add
,
bias_ele_dropout_residual
from
.ops
import
Linear
,
SelfAttention
,
Transition
def
permute_final_dims
(
tensor
,
inds
):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
])))
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
class
TriangleMultiplicationOutgoing
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationOutgoing
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
,
Z_mask
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
Z_mask
.
unsqueeze
(
-
1
)
*
left_proj_act
right_proj_act
=
Z_mask
.
unsqueeze
(
-
1
)
*
right_proj_act
left_proj_act
*=
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
*=
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
right_proj_act
=
gather
(
right_proj_act
.
contiguous
(),
dim
=
1
)
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
p
=
torch
.
matmul
(
permute_final_dims
(
left_proj_act
,
(
2
,
0
,
1
)),
permute_final_dims
(
right_proj_act
,
(
2
,
1
,
0
)),
)
ab
=
permute_final_dims
(
p
,
(
1
,
2
,
0
))
# ab = torch.einsum('bikd,bjkd->bijd', left_proj_act, right_proj_act)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
TriangleMultiplicationIncoming
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
128
):
super
(
TriangleMultiplicationIncoming
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
layernorm1
=
LayerNorm
(
d_pair
)
self
.
left_projection
=
Linear
(
d_pair
,
c
)
self
.
right_projection
=
Linear
(
d_pair
,
c
)
self
.
left_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
right_gate
=
Linear
(
d_pair
,
c
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
output_gate
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
bias_init
=
1.
)
self
.
layernorm2
=
LayerNorm
(
c
)
self
.
output_projection
=
Linear
(
d_pair
,
d_pair
,
initializer
=
'zeros'
,
use_bias
=
False
)
self
.
output_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
self
.
p_drop
=
p_drop
def
forward
(
self
,
Z_raw
,
Z_mask
):
Z
=
self
.
layernorm1
(
Z_raw
)
left_proj_act
=
self
.
left_projection
(
Z
)
right_proj_act
=
self
.
right_projection
(
Z
)
left_proj_act
=
Z_mask
.
unsqueeze
(
-
1
)
*
left_proj_act
right_proj_act
=
Z_mask
.
unsqueeze
(
-
1
)
*
right_proj_act
left_proj_act
*=
torch
.
sigmoid
(
self
.
left_gate
(
Z
))
right_proj_act
*=
torch
.
sigmoid
(
self
.
right_gate
(
Z
))
left_proj_act
=
gather
(
left_proj_act
.
contiguous
(),
dim
=
2
)
g
=
torch
.
sigmoid
(
self
.
output_gate
(
Z
))
p
=
torch
.
matmul
(
permute_final_dims
(
left_proj_act
,
(
2
,
1
,
0
)),
permute_final_dims
(
right_proj_act
,
(
2
,
0
,
1
)),
)
ab
=
permute_final_dims
(
p
,
(
1
,
2
,
0
))
# ab = torch.einsum('bkid,bkjd->bijd', left_proj_act, right_proj_act)
ab
=
self
.
output_projection
(
self
.
layernorm2
(
ab
))
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_ele_dropout_residual
(
ab
,
self
.
output_bias
,
g
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
TriangleAttentionStartingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionStartingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
n_head
,
d_pair
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
,
Z_mask
):
Z
=
self
.
layernorm1
(
Z_raw
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
=
gather
(
b
,
dim
=
1
)
b
=
rearrange
(
b
,
'b q k h -> b h q k'
)
Z
=
self
.
attention
(
Z
,
Z_mask
,
b
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
0
:
1
,
:,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
TriangleAttentionEndingNode
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
,
c
=
32
,
n_head
=
4
):
super
(
TriangleAttentionEndingNode
,
self
).
__init__
()
self
.
d_pair
=
d_pair
self
.
c
=
c
self
.
n_head
=
n_head
self
.
p_drop
=
p_drop
self
.
layernorm1
=
LayerNorm
(
d_pair
)
_init_weights
=
torch
.
nn
.
init
.
normal_
(
torch
.
zeros
([
n_head
,
d_pair
]),
std
=
1.0
/
math
.
sqrt
(
d_pair
))
self
.
linear_b_weights
=
nn
.
parameter
.
Parameter
(
data
=
_init_weights
)
self
.
attention
=
SelfAttention
(
qkv_dim
=
d_pair
,
c
=
c
,
n_head
=
n_head
,
out_dim
=
d_pair
,
gating
=
True
,
last_bias_fuse
=
True
)
self
.
out_bias
=
nn
.
parameter
.
Parameter
(
data
=
torch
.
zeros
((
d_pair
,)),
requires_grad
=
True
)
def
forward
(
self
,
Z_raw
,
Z_mask
):
Z
=
Z_raw
.
transpose
(
-
2
,
-
3
)
Z_mask
=
Z_mask
.
transpose
(
-
1
,
-
2
)
Z
=
self
.
layernorm1
(
Z
)
b
=
F
.
linear
(
Z
,
self
.
linear_b_weights
)
b
=
gather
(
b
,
dim
=
1
)
b
=
rearrange
(
b
,
'b q k h -> b h q k'
)
Z
=
self
.
attention
(
Z
,
Z_mask
,
b
)
Z
=
Z
.
transpose
(
-
2
,
-
3
)
dropout_mask
=
torch
.
ones_like
(
Z
[:,
:,
0
:
1
,
:],
device
=
Z
.
device
,
dtype
=
Z
.
dtype
)
return
bias_dropout_add
(
Z
,
self
.
out_bias
,
dropout_mask
,
Z_raw
,
prob
=
self
.
p_drop
,
training
=
self
.
training
)
class
PairStack
(
nn
.
Module
):
def
__init__
(
self
,
d_pair
,
p_drop
=
0.25
):
super
(
PairStack
,
self
).
__init__
()
self
.
TriangleMultiplicationOutgoing
=
TriangleMultiplicationOutgoing
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleMultiplicationIncoming
=
TriangleMultiplicationIncoming
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionStartingNode
=
TriangleAttentionStartingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
TriangleAttentionEndingNode
=
TriangleAttentionEndingNode
(
d_pair
,
p_drop
=
p_drop
)
self
.
PairTransition
=
Transition
(
d
=
d_pair
)
def
forward
(
self
,
pair
,
pair_mask
):
pair_mask_row
=
scatter
(
pair_mask
,
dim
=
1
)
pair_mask_col
=
scatter
(
pair_mask
,
dim
=
2
)
pair
=
self
.
TriangleMultiplicationOutgoing
(
pair
,
pair_mask_row
)
pair
=
row_to_col
(
pair
)
pair
=
self
.
TriangleMultiplicationIncoming
(
pair
,
pair_mask_col
)
pair
=
col_to_row
(
pair
)
pair
=
self
.
TriangleAttentionStartingNode
(
pair
,
pair_mask_row
)
pair
=
row_to_col
(
pair
)
pair
=
self
.
TriangleAttentionEndingNode
(
pair
,
pair_mask_col
)
pair
=
self
.
PairTransition
(
pair
)
pair
=
col_to_row
(
pair
)
return
pair
fastfold/habana/inject_habana.py
0 → 100644
View file @
b14e47f4
# Copyright 2022 BioMap (Beijing) Intelligence Technology Limited
# Copyright 2022 HPC-AI Technology Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
fastfold.habana.fastnn
import
EvoformerStack
,
ExtraMSAStack
#from fastfold.model.fastnn.embedders import TemplateEmbedder
#from fastfold.model.fastnn.embedders_multimer import TemplateEmbedderMultimer
#from fastfold.model.fastnn.ops import RecyclingEmbedder, InputEmbedder
def
copy_layernorm
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
def
copy_linear
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
if
model_fast
.
use_bias
:
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
def
copy_native_linear
(
model_fast
,
model_ori
):
model_fast
.
weight
.
copy_
(
model_ori
.
weight
)
try
:
model_fast
.
bias
.
copy_
(
model_ori
.
bias
)
except
:
pass
def
copy_kv_linear
(
model_fast
,
ori_k
,
ori_v
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_k
.
weight
,
ori_v
.
weight
),
dim
=
0
))
def
copy_qkv_linear
(
model_fast
,
ori_q
,
ori_k
,
ori_v
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_q
.
weight
,
ori_k
.
weight
,
ori_v
.
weight
),
dim
=
0
))
def
copy_attention
(
model_fast
,
model_ori
):
copy_qkv_linear
(
model_fast
.
to_qkv
,
model_ori
.
linear_q
,
model_ori
.
linear_k
,
model_ori
.
linear_v
)
copy_linear
(
model_fast
.
gating_linear
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
o_linear
,
model_ori
.
linear_o
)
try
:
model_fast
.
gating_bias
.
copy_
(
model_ori
.
linear_g
.
bias
)
except
:
print
(
"no gating_bias need copy"
)
def
copy_left_right
(
model_fast
,
ori_left
,
ori_right
):
model_fast
.
weight
.
copy_
(
torch
.
cat
((
ori_left
.
weight
,
ori_right
.
weight
),
dim
=
0
))
model_fast
.
bias
.
copy_
(
torch
.
cat
((
ori_left
.
bias
,
ori_right
.
bias
),
dim
=
0
))
def
copy_transition
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
norm
,
model_ori
.
layer_norm
)
copy_linear
(
model_fast
.
linear1
,
model_ori
.
linear_1
)
copy_linear
(
model_fast
.
linear2
,
model_ori
.
linear_2
)
def
copy_triangle
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm_in
)
copy_layernorm
(
model_fast
.
layernorm2
,
model_ori
.
layer_norm_out
)
copy_linear
(
model_fast
.
output_gate
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
output_projection
,
model_ori
.
linear_z
)
model_fast
.
output_bias
.
copy_
(
model_ori
.
linear_z
.
bias
)
copy_linear
(
model_fast
.
left_projection
,
model_ori
.
linear_a_p
)
copy_linear
(
model_fast
.
right_projection
,
model_ori
.
linear_b_p
)
copy_linear
(
model_fast
.
left_gate
,
model_ori
.
linear_a_g
)
copy_linear
(
model_fast
.
right_gate
,
model_ori
.
linear_b_g
)
def
copy_triangle_att
(
model_fast
,
model_ori
):
copy_layernorm
(
model_fast
.
layernorm1
,
model_ori
.
layer_norm
)
model_fast
.
linear_b_weights
=
model_ori
.
linear
.
weight
copy_attention
(
model_fast
.
attention
,
model_ori
.
mha
)
model_fast
.
out_bias
.
copy_
(
model_ori
.
mha
.
linear_o
.
bias
)
def
copy_native_att
(
model_fast
,
model_ori
):
copy_native_linear
(
model_fast
.
linear_q
,
model_ori
.
linear_q
)
copy_native_linear
(
model_fast
.
linear_k
,
model_ori
.
linear_k
)
copy_native_linear
(
model_fast
.
linear_v
,
model_ori
.
linear_v
)
copy_native_linear
(
model_fast
.
linear_o
,
model_ori
.
linear_o
)
if
model_ori
.
gating
:
copy_native_linear
(
model_fast
.
linear_g
,
model_ori
.
linear_g
)
def
copy_evoformer_para
(
block_fast
,
block_ori
):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm
(
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
layernormM
,
block_ori
.
msa_att_row
.
layer_norm_m
)
copy_layernorm
(
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
layernormZ
,
block_ori
.
msa_att_row
.
layer_norm_z
)
copy_attention
(
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
attention
,
block_ori
.
msa_att_row
.
mha
)
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
linear_b_weights
.
copy_
(
block_ori
.
msa_att_row
.
linear_z
.
weight
)
block_fast
.
msa
.
MSARowAttentionWithPairBias
.
out_bias
.
copy_
(
block_ori
.
msa_att_row
.
mha
.
linear_o
.
bias
)
# MSAColumnAttention
copy_layernorm
(
block_fast
.
msa
.
MSAColumnAttention
.
layernormM
,
block_ori
.
msa_att_col
.
_msa_att
.
layer_norm_m
)
copy_attention
(
block_fast
.
msa
.
MSAColumnAttention
.
attention
,
block_ori
.
msa_att_col
.
_msa_att
.
mha
)
# MSATransition
copy_transition
(
block_fast
.
msa
.
MSATransition
,
block_ori
.
core
.
msa_transition
)
# communication
copy_layernorm
(
block_fast
.
communication
.
layernormM
,
block_ori
.
core
.
outer_product_mean
.
layer_norm
)
copy_linear
(
block_fast
.
communication
.
linear_a
,
block_ori
.
core
.
outer_product_mean
.
linear_1
)
copy_linear
(
block_fast
.
communication
.
linear_b
,
block_ori
.
core
.
outer_product_mean
.
linear_2
)
copy_linear
(
block_fast
.
communication
.
o_linear
,
block_ori
.
core
.
outer_product_mean
.
linear_out
)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
pair
.
TriangleMultiplicationOutgoing
,
block_ori
.
core
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
pair
.
TriangleMultiplicationIncoming
,
block_ori
.
core
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
pair
.
TriangleAttentionStartingNode
,
block_ori
.
core
.
tri_att_start
)
copy_triangle_att
(
block_fast
.
pair
.
TriangleAttentionEndingNode
,
block_ori
.
core
.
tri_att_end
)
copy_transition
(
block_fast
.
pair
.
PairTransition
,
block_ori
.
core
.
pair_transition
)
def
copy_global_attention
(
model_fast
,
model_ori
):
copy_linear
(
model_fast
.
to_q
,
model_ori
.
linear_q
)
copy_kv_linear
(
model_fast
.
to_kv
,
model_ori
.
linear_k
,
model_ori
.
linear_v
)
copy_linear
(
model_fast
.
gating_linear
,
model_ori
.
linear_g
)
copy_linear
(
model_fast
.
o_linear
,
model_ori
.
linear_o
)
try
:
model_fast
.
gating_bias
.
copy_
(
model_ori
.
linear_g
.
bias
)
except
:
print
(
"no gating_bias need copy"
)
def
copy_extra_msa_para
(
block_fast
,
block_ori
):
# msa_stack
# MSARowAttentionWithPairBias
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormM
,
block_ori
.
msa_att_row
.
layer_norm_m
,
)
copy_layernorm
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
layernormZ
,
block_ori
.
msa_att_row
.
layer_norm_z
,
)
copy_attention
(
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
attention
,
block_ori
.
msa_att_row
.
mha
,
)
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
linear_b_weights
.
copy_
(
block_ori
.
msa_att_row
.
linear_z
.
weight
)
block_fast
.
msa_stack
.
MSARowAttentionWithPairBias
.
out_bias
.
copy_
(
block_ori
.
msa_att_row
.
mha
.
linear_o
.
bias
)
# MSAColumnAttention
copy_layernorm
(
block_fast
.
msa_stack
.
MSAColumnAttention
.
layernormM
,
block_ori
.
msa_att_col
.
layer_norm_m
,
)
copy_global_attention
(
block_fast
.
msa_stack
.
MSAColumnAttention
.
global_attention
,
block_ori
.
msa_att_col
.
global_attention
,
)
# MSATransition
copy_transition
(
block_fast
.
msa_stack
.
MSATransition
,
block_ori
.
core
.
msa_transition
)
# communication
comm_model
=
(
block_ori
.
core
.
outer_product_mean
# if not block_ori.is_multimer else block_ori.outer_product_mean
)
copy_layernorm
(
block_fast
.
communication
.
layernormM
,
comm_model
.
layer_norm
)
copy_linear
(
block_fast
.
communication
.
linear_a
,
comm_model
.
linear_1
)
copy_linear
(
block_fast
.
communication
.
linear_b
,
comm_model
.
linear_2
)
copy_linear
(
block_fast
.
communication
.
o_linear
,
comm_model
.
linear_out
)
# pair_stack
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
pair_stack
.
TriangleMultiplicationOutgoing
,
block_ori
.
core
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
pair_stack
.
TriangleMultiplicationIncoming
,
block_ori
.
core
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
pair_stack
.
TriangleAttentionStartingNode
,
block_ori
.
core
.
tri_att_start
,
)
copy_triangle_att
(
block_fast
.
pair_stack
.
TriangleAttentionEndingNode
,
block_ori
.
core
.
tri_att_end
)
copy_transition
(
block_fast
.
pair_stack
.
PairTransition
,
block_ori
.
core
.
pair_transition
)
def
copy_template_pair_stack_para
(
block_fast
,
block_ori
):
# TriangleMultiplicationOutgoing
copy_triangle
(
block_fast
.
TriangleMultiplicationOutgoing
,
block_ori
.
tri_mul_out
)
# TriangleMultiplicationIncoming
copy_triangle
(
block_fast
.
TriangleMultiplicationIncoming
,
block_ori
.
tri_mul_in
)
# TriangleAttentionStartingNode
copy_triangle_att
(
block_fast
.
TriangleAttentionStartingNode
,
block_ori
.
tri_att_start
)
copy_triangle_att
(
block_fast
.
TriangleAttentionEndingNode
,
block_ori
.
tri_att_end
)
copy_transition
(
block_fast
.
PairTransition
,
block_ori
.
pair_transition
)
def
copy_template_pair_block_para
(
fast_module
,
target_module
):
with
torch
.
no_grad
():
for
ori_block
,
fast_block
in
zip
(
target_module
.
blocks
,
fast_module
.
blocks
):
copy_template_pair_stack_para
(
fast_block
,
ori_block
)
if
ori_block
.
training
==
False
:
fast_block
.
eval
()
def
copy_template_para
(
block_fast
,
block_ori
):
# TemplateAngleEmbedder
copy_linear
(
block_fast
.
template_angle_embedder
.
linear_1
,
block_ori
.
template_angle_embedder
.
linear_1
)
copy_linear
(
block_fast
.
template_angle_embedder
.
linear_2
,
block_ori
.
template_angle_embedder
.
linear_2
)
# TemplatePairEmbedder
copy_linear
(
block_fast
.
template_pair_embedder
.
linear
,
block_ori
.
template_pair_embedder
.
linear
)
# TemplatePairStack
copy_template_pair_block_para
(
block_fast
.
template_pair_stack
,
block_ori
.
template_pair_stack
)
copy_layernorm
(
block_fast
.
template_pair_stack
.
layer_norm
,
block_ori
.
template_pair_stack
.
layer_norm
)
# TemplatePointwiseAttention
copy_native_att
(
block_fast
.
template_pointwise_att
.
mha
,
block_ori
.
template_pointwise_att
.
mha
)
def
copy_template_multimer_para
(
block_fast
,
block_ori
):
# TemplatePairEmbedderMultimer
copy_linear
(
block_fast
.
template_pair_embedder
.
dgram_linear
,
block_ori
.
template_pair_embedder
.
dgram_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
aatype_linear_1
,
block_ori
.
template_pair_embedder
.
aatype_linear_1
)
copy_linear
(
block_fast
.
template_pair_embedder
.
aatype_linear_2
,
block_ori
.
template_pair_embedder
.
aatype_linear_2
)
copy_layernorm
(
block_fast
.
template_pair_embedder
.
query_embedding_layer_norm
,
block_ori
.
template_pair_embedder
.
query_embedding_layer_norm
)
copy_linear
(
block_fast
.
template_pair_embedder
.
query_embedding_linear
,
block_ori
.
template_pair_embedder
.
query_embedding_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
pseudo_beta_mask_linear
,
block_ori
.
template_pair_embedder
.
pseudo_beta_mask_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
x_linear
,
block_ori
.
template_pair_embedder
.
x_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
y_linear
,
block_ori
.
template_pair_embedder
.
y_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
z_linear
,
block_ori
.
template_pair_embedder
.
z_linear
)
copy_linear
(
block_fast
.
template_pair_embedder
.
backbone_mask_linear
,
block_ori
.
template_pair_embedder
.
backbone_mask_linear
)
# TemplateSingleEmbedderMultimer
copy_linear
(
block_fast
.
template_single_embedder
.
template_single_embedder
,
block_ori
.
template_single_embedder
.
template_single_embedder
)
copy_linear
(
block_fast
.
template_single_embedder
.
template_projector
,
block_ori
.
template_single_embedder
.
template_projector
)
# TemplatePairStack
copy_template_pair_block_para
(
block_fast
.
template_pair_stack
,
block_ori
.
template_pair_stack
)
copy_layernorm
(
block_fast
.
template_pair_stack
.
layer_norm
,
block_ori
.
template_pair_stack
.
layer_norm
)
# linear_t
copy_linear
(
block_fast
.
linear_t
,
block_ori
.
linear_t
)
def
inject_evoformer
(
model
):
with
torch
.
no_grad
():
target_module
=
model
.
evoformer
fast_module
=
EvoformerStack
(
c_m
=
target_module
.
blocks
[
0
].
msa_att_row
.
c_in
,
c_z
=
target_module
.
blocks
[
0
].
msa_att_row
.
c_z
,
c_s
=
target_module
.
linear
.
out_features
,
no_blocks
=
len
(
target_module
.
blocks
),
blocks_per_ckpt
=
target_module
.
blocks_per_ckpt
,
clear_cache_between_blocks
=
target_module
.
clear_cache_between_blocks
,
is_multimer
=
target_module
.
blocks
[
0
].
is_multimer
,
)
for
target_block
,
fast_block
in
zip
(
target_module
.
blocks
,
fast_module
.
blocks
):
copy_evoformer_para
(
fast_block
,
target_block
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
copy_linear
(
fast_module
.
linear
,
target_module
.
linear
)
model
.
evoformer
=
fast_module
def
inject_extramsa
(
model
):
with
torch
.
no_grad
():
target_module
=
model
.
extra_msa_stack
fast_module
=
ExtraMSAStack
(
c_m
=
target_module
.
blocks
[
0
].
msa_att_row
.
c_in
,
c_z
=
target_module
.
blocks
[
0
].
msa_att_row
.
c_z
,
no_blocks
=
len
(
target_module
.
blocks
),
blocks_per_ckpt
=
1
,
clear_cache_between_blocks
=
target_module
.
clear_cache_between_blocks
,
is_multimer
=
target_module
.
blocks
[
0
].
is_multimer
,
)
for
target_block
,
fast_block
in
zip
(
target_module
.
blocks
,
fast_module
.
blocks
):
copy_extra_msa_para
(
fast_block
,
target_block
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
extra_msa_stack
=
fast_module
def
inject_template
(
model
):
with
torch
.
no_grad
():
if
model
.
evoformer
.
blocks
[
0
].
is_multimer
:
target_module
=
model
.
template_embedder
fast_module
=
TemplateEmbedderMultimer
(
config
=
model
.
template_embedder
.
config
)
copy_template_multimer_para
(
fast_module
,
target_module
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
template_embedder
=
fast_module
else
:
target_module
=
model
.
template_embedder
fast_module
=
TemplateEmbedder
(
config
=
model
.
template_embedder
.
config
)
copy_template_para
(
fast_module
,
target_module
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
template_embedder
=
fast_module
def
inject_embedder
(
model
):
if
model
.
evoformer
.
blocks
[
0
].
is_multimer
:
return
# recycle embedder
with
torch
.
no_grad
():
target_module
=
model
.
recycling_embedder
fast_module
=
RecyclingEmbedder
(
c_m
=
target_module
.
c_m
,
c_z
=
target_module
.
c_z
,
min_bin
=
target_module
.
min_bin
,
max_bin
=
target_module
.
max_bin
,
no_bins
=
target_module
.
no_bins
,
inf
=
target_module
.
inf
)
copy_native_linear
(
fast_module
.
linear
,
target_module
.
linear
)
copy_layernorm
(
fast_module
.
layer_norm_m
,
target_module
.
layer_norm_m
)
copy_layernorm
(
fast_module
.
layer_norm_z
,
target_module
.
layer_norm_z
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
recycling_embedder
=
fast_module
# input embedder
with
torch
.
no_grad
():
target_module
=
model
.
input_embedder
fast_module
=
InputEmbedder
(
tf_dim
=
target_module
.
tf_dim
,
msa_dim
=
target_module
.
msa_dim
,
c_z
=
target_module
.
c_z
,
c_m
=
target_module
.
c_m
,
relpos_k
=
target_module
.
relpos_k
,
)
copy_linear
(
fast_module
.
linear_tf_z_i
,
target_module
.
linear_tf_z_i
)
copy_linear
(
fast_module
.
linear_tf_z_j
,
target_module
.
linear_tf_z_j
)
copy_linear
(
fast_module
.
linear_tf_m
,
target_module
.
linear_tf_m
)
copy_linear
(
fast_module
.
linear_msa_m
,
target_module
.
linear_msa_m
)
copy_linear
(
fast_module
.
linear_relpos
,
target_module
.
linear_relpos
)
if
target_module
.
training
==
False
:
fast_module
.
eval
()
model
.
input_embedder
=
fast_module
def
inject_habana
(
model
):
inject_evoformer
(
model
)
inject_extramsa
(
model
)
#inject_template(model)
#inject_embedder(model)
return
model
fastfold/model/__init__.py
0 → 100644
View file @
b14e47f4
fastfold/model/fastnn/__init__.py
0 → 100644
View file @
b14e47f4
from
.msa
import
MSACore
,
ExtraMSACore
,
ExtraMSABlock
,
ExtraMSAStack
from
.ops
import
OutProductMean
,
set_chunk_size
from
.triangle
import
PairCore
from
.evoformer
import
Evoformer
,
EvoformerStack
from
.template
import
TemplatePairBlock
,
TemplatePairStack
__all__
=
[
'MSACore'
,
'OutProductMean'
,
'PairCore'
,
'set_chunk_size'
,
'TemplatePairBlock'
,
'TemplatePairStack'
,
'ExtraMSACore'
,
'ExtraMSABlock'
,
'ExtraMSAStack'
,
'Evoformer'
,
'EvoformerStack'
,
]
fastfold/model/fastnn/embedders.py
0 → 100644
View file @
b14e47f4
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
from
functools
import
partial
from
fastfold.utils.feats
import
(
build_template_angle_feat
,
build_template_pair_feat
,
)
from
fastfold.model.fastnn.ops
import
Linear
from
fastfold.utils.tensor_utils
import
one_hot
from
fastfold.model.fastnn.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
fastfold.utils.tensor_utils
import
one_hot
,
tensor_tree_map
,
dict_multimap
class
InputEmbedder
(
nn
.
Module
):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def
__init__
(
self
,
tf_dim
:
int
,
msa_dim
:
int
,
c_z
:
int
,
c_m
:
int
,
relpos_k
:
int
,
**
kwargs
,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super
(
InputEmbedder
,
self
).
__init__
()
self
.
tf_dim
=
tf_dim
self
.
msa_dim
=
msa_dim
self
.
c_z
=
c_z
self
.
c_m
=
c_m
self
.
linear_tf_z_i
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_z_j
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_m
=
Linear
(
tf_dim
,
c_m
)
self
.
linear_msa_m
=
Linear
(
msa_dim
,
c_m
)
# RPE stuff
self
.
relpos_k
=
relpos_k
self
.
no_bins
=
2
*
relpos_k
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
ri
:
torch
.
Tensor
):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d
=
ri
[...,
None
]
-
ri
[...,
None
,
:]
boundaries
=
torch
.
arange
(
start
=-
self
.
relpos_k
,
end
=
self
.
relpos_k
+
1
,
device
=
d
.
device
)
oh
=
one_hot
(
d
,
boundaries
).
type
(
ri
.
dtype
)
return
self
.
linear_relpos
(
oh
)
def
forward
(
self
,
tf
:
torch
.
Tensor
,
ri
:
torch
.
Tensor
,
msa
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
# [*, N_res, N_res, c_z]
pair_emb
=
self
.
relpos
(
ri
.
type
(
tf_emb_i
.
dtype
))
pair_emb
+=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
class
TemplateEmbedder
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedder
,
self
).
__init__
()
self
.
config
=
config
self
.
template_angle_embedder
=
TemplateAngleEmbedder
(
**
config
[
"template_angle_embedder"
],
)
self
.
template_pair_embedder
=
TemplatePairEmbedder
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
template_pointwise_att
=
TemplatePointwiseAttention
(
**
config
[
"template_pointwise_attention"
],
)
def
forward
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
,
chunk_size
,
_mask_trans
=
True
,
inplace
=
False
):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
if
isinstance
(
chunk_size
,
int
)
and
1
<=
chunk_size
<=
4
:
t
=
torch
.
empty
((
n_templ
,
z
.
shape
[
0
],
z
.
shape
[
1
],
64
),
dtype
=
z
.
dtype
,
device
=
'cpu'
)
else
:
t
=
torch
.
empty
((
n_templ
,
z
.
shape
[
0
],
z
.
shape
[
1
],
64
),
dtype
=
z
.
dtype
,
device
=
z
.
device
)
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
if
self
.
config
.
embed_angles
:
template_angle_feat
=
build_template_angle_feat
(
single_template_feats
,
)
# [*, S_t, N, C_m]
a
=
self
.
template_angle_embedder
(
template_angle_feat
)
single_template_embeds
[
"angle"
]
=
a
# [*, S_t, N, N, C_t]
tt
=
build_template_pair_feat
(
single_template_feats
,
use_unit_vector
=
self
.
config
.
use_unit_vector
,
inf
=
self
.
config
.
inf
,
chunk
=
chunk_size
,
eps
=
self
.
config
.
eps
,
**
self
.
config
.
distogram
,
).
to
(
z
.
dtype
).
to
(
z
.
device
)
tt
=
self
.
template_pair_embedder
(
tt
)
# single_template_embeds.update({"pair": t})
template_embeds
.
append
(
single_template_embeds
)
# [*, S_t, N, N, C_z]
if
inplace
:
tt
=
[
tt
]
t
[
i
]
=
self
.
template_pair_stack
.
inplace
(
tt
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)[
0
].
to
(
t
.
device
)
else
:
t
[
i
]
=
self
.
template_pair_stack
(
tt
,
pair_mask
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
).
to
(
t
.
device
)
del
tt
,
single_template_feats
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, N, N, C_z]
if
inplace
:
z
=
self
.
template_pointwise_att
.
inplace
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
*
256
if
chunk_size
is
not
None
else
chunk_size
,
)
else
:
z
=
self
.
template_pointwise_att
(
t
,
z
,
template_mask
=
batch
[
"template_mask"
].
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
*
256
if
chunk_size
is
not
None
else
chunk_size
,
)
ret
=
{}
ret
[
"template_pair_embedding"
]
=
z
if
self
.
config
.
embed_angles
:
ret
[
"template_single_embedding"
]
=
template_embeds
[
"angle"
]
return
ret
class
TemplateAngleEmbedder
(
nn
.
Module
):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def
__init__
(
self
,
c_in
:
int
,
c_out
:
int
,
**
kwargs
,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super
(
TemplateAngleEmbedder
,
self
).
__init__
()
self
.
c_out
=
c_out
self
.
c_in
=
c_in
self
.
linear_1
=
Linear
(
self
.
c_in
,
self
.
c_out
,
initializer
=
"relu"
)
self
.
relu
=
nn
.
ReLU
()
self
.
linear_2
=
Linear
(
self
.
c_out
,
self
.
c_out
,
initializer
=
"relu"
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x
=
self
.
linear_1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
linear_2
(
x
)
return
x
class
TemplatePairEmbedder
(
nn
.
Module
):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def
__init__
(
self
,
c_in
:
int
,
c_out
:
int
,
**
kwargs
,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super
(
TemplatePairEmbedder
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_out
=
c_out
# Despite there being no relu nearby, the source uses that initializer
self
.
linear
=
Linear
(
self
.
c_in
,
self
.
c_out
,
initializer
=
"relu"
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x
=
self
.
linear
(
x
)
return
x
class
ExtraMSAEmbedder
(
nn
.
Module
):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def
__init__
(
self
,
c_in
:
int
,
c_out
:
int
,
**
kwargs
,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super
(
ExtraMSAEmbedder
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_out
=
c_out
self
.
linear
=
Linear
(
self
.
c_in
,
self
.
c_out
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x
=
self
.
linear
(
x
)
return
x
fastfold/model/fastnn/embedders_multimer.py
0 → 100644
View file @
b14e47f4
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
Dict
from
fastfold.utils
import
all_atom_multimer
from
fastfold.utils.feats
import
dgram_from_positions
from
fastfold.model.fastnn.ops
import
Linear
,
LayerNorm
from
fastfold.model.fastnn.template
import
(
TemplatePairStack
,
TemplatePointwiseAttention
,
)
from
fastfold.utils
import
geometry
from
fastfold.utils.tensor_utils
import
one_hot
,
tensor_tree_map
,
dict_multimap
class
InputEmbedderMultimer
(
nn
.
Module
):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def
__init__
(
self
,
tf_dim
:
int
,
msa_dim
:
int
,
c_z
:
int
,
c_m
:
int
,
max_relative_idx
:
int
,
use_chain_relative
:
bool
,
max_relative_chain
:
int
,
**
kwargs
,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super
(
InputEmbedderMultimer
,
self
).
__init__
()
self
.
tf_dim
=
tf_dim
self
.
msa_dim
=
msa_dim
self
.
c_z
=
c_z
self
.
c_m
=
c_m
self
.
linear_tf_z_i
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_z_j
=
Linear
(
tf_dim
,
c_z
)
self
.
linear_tf_m
=
Linear
(
tf_dim
,
c_m
)
self
.
linear_msa_m
=
Linear
(
msa_dim
,
c_m
)
# RPE stuff
self
.
max_relative_idx
=
max_relative_idx
self
.
use_chain_relative
=
use_chain_relative
self
.
max_relative_chain
=
max_relative_chain
if
self
.
use_chain_relative
:
self
.
no_bins
=
2
*
max_relative_idx
+
2
+
1
+
2
*
max_relative_chain
+
2
else
:
self
.
no_bins
=
2
*
max_relative_idx
+
1
self
.
linear_relpos
=
Linear
(
self
.
no_bins
,
c_z
)
def
relpos
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
]):
pos
=
batch
[
"residue_index"
]
asym_id
=
batch
[
"asym_id"
]
asym_id_same
=
asym_id
[...,
None
]
==
asym_id
[...,
None
,
:]
offset
=
pos
[...,
None
]
-
pos
[...,
None
,
:]
clipped_offset
=
torch
.
clamp
(
offset
+
self
.
max_relative_idx
,
0
,
2
*
self
.
max_relative_idx
)
rel_feats
=
[]
if
self
.
use_chain_relative
:
final_offset
=
torch
.
where
(
asym_id_same
,
clipped_offset
,
(
2
*
self
.
max_relative_idx
+
1
)
*
torch
.
ones_like
(
clipped_offset
),
)
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
final_offset
,
2
*
self
.
max_relative_idx
+
2
,
)
rel_feats
.
append
(
rel_pos
)
entity_id
=
batch
[
"entity_id"
]
entity_id_same
=
entity_id
[...,
None
]
==
entity_id
[...,
None
,
:]
rel_feats
.
append
(
entity_id_same
[...,
None
])
sym_id
=
batch
[
"sym_id"
]
rel_sym_id
=
sym_id
[...,
None
]
-
sym_id
[...,
None
,
:]
max_rel_chain
=
self
.
max_relative_chain
clipped_rel_chain
=
torch
.
clamp
(
rel_sym_id
+
max_rel_chain
,
0
,
2
*
max_rel_chain
,
)
final_rel_chain
=
torch
.
where
(
entity_id_same
,
clipped_rel_chain
,
(
2
*
max_rel_chain
+
1
)
*
torch
.
ones_like
(
clipped_rel_chain
),
)
rel_chain
=
torch
.
nn
.
functional
.
one_hot
(
final_rel_chain
.
long
(),
2
*
max_rel_chain
+
2
,
)
rel_feats
.
append
(
rel_chain
)
else
:
rel_pos
=
torch
.
nn
.
functional
.
one_hot
(
clipped_offset
,
2
*
self
.
max_relative_idx
+
1
,
)
rel_feats
.
append
(
rel_pos
)
rel_feat
=
torch
.
cat
(
rel_feats
,
dim
=-
1
).
to
(
self
.
linear_relpos
.
weight
.
dtype
)
return
self
.
linear_relpos
(
rel_feat
)
def
forward
(
self
,
batch
:
Dict
[
str
,
torch
.
Tensor
]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
tf
=
batch
[
"target_feat"
]
msa
=
batch
[
"msa_feat"
]
# [*, N_res, c_z]
tf_emb_i
=
self
.
linear_tf_z_i
(
tf
)
tf_emb_j
=
self
.
linear_tf_z_j
(
tf
)
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
pair_emb
+
self
.
relpos
(
batch
)
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
class
TemplatePairEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_z
:
int
,
c_out
:
int
,
c_dgram
:
int
,
c_aatype
:
int
,
):
super
().
__init__
()
self
.
dgram_linear
=
Linear
(
c_dgram
,
c_out
)
self
.
aatype_linear_1
=
Linear
(
c_aatype
,
c_out
)
self
.
aatype_linear_2
=
Linear
(
c_aatype
,
c_out
)
self
.
query_embedding_layer_norm
=
LayerNorm
(
c_z
)
self
.
query_embedding_linear
=
Linear
(
c_z
,
c_out
)
self
.
pseudo_beta_mask_linear
=
Linear
(
1
,
c_out
)
self
.
x_linear
=
Linear
(
1
,
c_out
)
self
.
y_linear
=
Linear
(
1
,
c_out
)
self
.
z_linear
=
Linear
(
1
,
c_out
)
self
.
backbone_mask_linear
=
Linear
(
1
,
c_out
)
def
forward
(
self
,
template_dgram
:
torch
.
Tensor
,
aatype_one_hot
:
torch
.
Tensor
,
query_embedding
:
torch
.
Tensor
,
pseudo_beta_mask
:
torch
.
Tensor
,
backbone_mask
:
torch
.
Tensor
,
multichain_mask_2d
:
torch
.
Tensor
,
unit_vector
:
geometry
.
Vec3Array
,
)
->
torch
.
Tensor
:
act
=
0.
pseudo_beta_mask_2d
=
(
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
)
pseudo_beta_mask_2d
*=
multichain_mask_2d
template_dgram
*=
pseudo_beta_mask_2d
[...,
None
]
act
+=
self
.
dgram_linear
(
template_dgram
)
act
+=
self
.
pseudo_beta_mask_linear
(
pseudo_beta_mask_2d
[...,
None
])
aatype_one_hot
=
aatype_one_hot
.
to
(
template_dgram
.
dtype
)
act
+=
self
.
aatype_linear_1
(
aatype_one_hot
[...,
None
,
:,
:])
act
+=
self
.
aatype_linear_2
(
aatype_one_hot
[...,
None
,
:])
backbone_mask_2d
=
(
backbone_mask
[...,
None
]
*
backbone_mask
[...,
None
,
:]
)
backbone_mask_2d
*=
multichain_mask_2d
x
,
y
,
z
=
[
coord
*
backbone_mask_2d
for
coord
in
unit_vector
]
act
+=
self
.
x_linear
(
x
[...,
None
])
act
+=
self
.
y_linear
(
y
[...,
None
])
act
+=
self
.
z_linear
(
z
[...,
None
])
act
+=
self
.
backbone_mask_linear
(
backbone_mask_2d
[...,
None
])
query_embedding
=
self
.
query_embedding_layer_norm
(
query_embedding
)
act
+=
self
.
query_embedding_linear
(
query_embedding
)
return
act
class
TemplateSingleEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
c_in
:
int
,
c_m
:
int
,
):
super
().
__init__
()
self
.
template_single_embedder
=
Linear
(
c_in
,
c_m
)
self
.
template_projector
=
Linear
(
c_m
,
c_m
)
def
forward
(
self
,
batch
,
atom_pos
,
aatype_one_hot
,
):
out
=
{}
template_chi_angles
,
template_chi_mask
=
(
all_atom_multimer
.
compute_chi_angles
(
atom_pos
,
batch
[
"template_all_atom_mask"
],
batch
[
"template_aatype"
],
)
)
template_features
=
torch
.
cat
(
[
aatype_one_hot
,
torch
.
sin
(
template_chi_angles
)
*
template_chi_mask
,
torch
.
cos
(
template_chi_angles
)
*
template_chi_mask
,
template_chi_mask
,
],
dim
=-
1
,
)
template_mask
=
template_chi_mask
[...,
0
]
template_features
=
self
.
template_single_embedder
(
template_features
)
template_features
=
torch
.
nn
.
functional
.
relu
(
template_features
)
template_features
=
self
.
template_projector
(
template_features
,
)
out
[
"template_single_embedding"
]
=
(
template_features
)
out
[
"template_mask"
]
=
template_mask
return
out
class
TemplateEmbedderMultimer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
TemplateEmbedderMultimer
,
self
).
__init__
()
self
.
config
=
config
self
.
template_pair_embedder
=
TemplatePairEmbedderMultimer
(
**
config
[
"template_pair_embedder"
],
)
self
.
template_single_embedder
=
TemplateSingleEmbedderMultimer
(
**
config
[
"template_single_embedder"
],
)
self
.
template_pair_stack
=
TemplatePairStack
(
**
config
[
"template_pair_stack"
],
)
self
.
linear_t
=
Linear
(
config
.
c_t
,
config
.
c_z
)
def
forward
(
self
,
batch
,
z
,
padding_mask_2d
,
templ_dim
,
chunk_size
,
multichain_mask_2d
,
inplace
):
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
template_pair_embeddings
=
torch
.
zeros
((
z
.
shape
[
0
],
z
.
shape
[
1
],
64
),
dtype
=
z
.
dtype
,
device
=
z
.
device
)
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
lambda
t
:
torch
.
index_select
(
t
,
templ_dim
,
idx
),
batch
,
)
single_template_embeds
=
{}
template_positions
,
pseudo_beta_mask
=
(
single_template_feats
[
"template_pseudo_beta"
],
single_template_feats
[
"template_pseudo_beta_mask"
],
)
template_dgram
=
dgram_from_positions
(
template_positions
,
inf
=
self
.
config
.
inf
,
**
self
.
config
.
distogram
,
)
aatype_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
single_template_feats
[
"template_aatype"
],
22
,
)
raw_atom_pos
=
single_template_feats
[
"template_all_atom_positions"
]
atom_pos
=
geometry
.
Vec3Array
.
from_array
(
raw_atom_pos
)
rigid
,
backbone_mask
=
all_atom_multimer
.
make_backbone_affine
(
atom_pos
,
single_template_feats
[
"template_all_atom_mask"
],
single_template_feats
[
"template_aatype"
],
)
points
=
rigid
.
translation
rigid_vec
=
rigid
[...,
None
].
inverse
().
apply_to_point
(
points
)
unit_vector
=
rigid_vec
.
normalized
()
pair_embedding
=
self
.
template_pair_embedder
(
template_dgram
,
aatype_one_hot
,
z
,
pseudo_beta_mask
,
backbone_mask
,
multichain_mask_2d
,
unit_vector
,
)
if
not
inplace
:
# [*, S_t, N, N, C_z]
template_pair_embeddings
=
template_pair_embeddings
+
self
.
template_pair_stack
(
pair_embedding
,
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
).
squeeze
(
0
)
else
:
# [*, S_t, N, N, C_z]
template_pair_embeddings
+=
self
.
template_pair_stack
.
inplace
(
[
pair_embedding
],
padding_mask_2d
.
unsqueeze
(
-
3
).
to
(
dtype
=
z
.
dtype
),
chunk_size
=
chunk_size
,
_mask_trans
=
False
,
)[
0
].
squeeze
(
0
)
single_template_embeds
.
update
(
self
.
template_single_embedder
(
single_template_feats
,
atom_pos
,
aatype_one_hot
,
)
)
template_embeds
.
append
(
single_template_embeds
)
template_embeds
=
dict_multimap
(
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
)
# [*, N, N, C_z]
template_pair_embeddings
=
template_pair_embeddings
/
n_templ
template_pair_embeddings
=
torch
.
nn
.
functional
.
relu
(
template_pair_embeddings
)
template_pair_embeddings
=
self
.
linear_t
(
template_pair_embeddings
)
template_embeds
[
"template_pair_embedding"
]
=
template_pair_embeddings
return
template_embeds
fastfold/model/fastnn/evoformer.py
0 → 100644
View file @
b14e47f4
from
typing
import
Optional
,
Tuple
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
colossalai.context.parallel_mode
import
ParallelMode
from
colossalai.core
import
global_context
as
gpc
from
fastfold.model.fastnn
import
MSACore
,
OutProductMean
,
PairCore
from
fastfold.model.fastnn.ops
import
Linear
from
fastfold.distributed.comm
import
gather
,
scatter
,
col_to_row
from
fastfold.distributed.comm_async
import
All_to_All_Async
,
All_to_All_Async_Opp
from
fastfold.utils.checkpointing
import
checkpoint_blocks
class
Evoformer
(
nn
.
Module
):
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
first_block
:
bool
,
last_block
:
bool
,
is_multimer
:
bool
=
False
):
super
(
Evoformer
,
self
).
__init__
()
self
.
first_block
=
first_block
self
.
last_block
=
last_block
self
.
msa
=
MSACore
(
c_m
,
c_z
,
p_drop
=
0.15
)
self
.
communication
=
OutProductMean
(
n_feat
=
c_m
,
n_feat_out
=
c_z
,
n_feat_proj
=
32
)
self
.
pair
=
PairCore
(
d_pair
=
c_z
)
self
.
is_multimer
=
is_multimer
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
pair_mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
if
self
.
first_block
:
m
=
m
.
unsqueeze
(
0
)
z
=
z
.
unsqueeze
(
0
)
m
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
0
,
0
,
padding_size
))
z
=
torch
.
nn
.
functional
.
pad
(
z
,
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
if
self
.
is_multimer
:
m
=
scatter
(
m
,
dim
=
2
)
else
:
m
=
scatter
(
m
,
dim
=
1
)
z
=
scatter
(
z
,
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
padding_size
))
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
padding_size
,
0
,
padding_size
))
if
not
self
.
is_multimer
:
m
=
self
.
msa
(
m
,
z
,
msa_mask
)
z
=
self
.
communication
(
m
,
msa_mask
,
z
)
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
else
:
z
=
self
.
communication
(
m
,
msa_mask
,
z
)
z_ori
=
z
m
,
work
=
All_to_All_Async
.
apply
(
m
,
1
,
2
)
z
=
self
.
pair
(
z
,
pair_mask
)
m
=
All_to_All_Async_Opp
.
apply
(
m
,
work
,
1
,
2
)
m
=
self
.
msa
(
m
,
z_ori
,
msa_mask
)
if
self
.
last_block
:
m
=
m
.
squeeze
(
0
)
z
=
z
.
squeeze
(
0
)
if
self
.
is_multimer
:
m
=
gather
(
m
,
dim
=
1
)
else
:
m
=
gather
(
m
,
dim
=
0
)
z
=
gather
(
z
,
dim
=
0
)
m
=
m
[:,
:
-
padding_size
,
:]
z
=
z
[:
-
padding_size
,
:
-
padding_size
,
:]
return
m
,
z
def
inplace
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
Optional
[
int
]
=
None
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
dap_size
=
gpc
.
get_world_size
(
ParallelMode
.
TENSOR
)
seq_length
=
pair_mask
.
size
(
-
1
)
padding_size
=
(
int
(
seq_length
/
dap_size
)
+
1
)
*
dap_size
-
seq_length
if
self
.
first_block
:
m
[
0
]
=
m
[
0
].
unsqueeze
(
0
)
z
[
0
]
=
z
[
0
].
unsqueeze
(
0
)
m
[
0
]
=
torch
.
nn
.
functional
.
pad
(
m
[
0
],
(
0
,
0
,
0
,
padding_size
))
z
[
0
]
=
torch
.
nn
.
functional
.
pad
(
z
[
0
],
(
0
,
0
,
0
,
padding_size
,
0
,
padding_size
))
if
self
.
is_multimer
:
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
2
)
else
:
m
[
0
]
=
scatter
(
m
[
0
],
dim
=
1
)
z
[
0
]
=
scatter
(
z
[
0
],
dim
=
1
)
msa_mask
=
msa_mask
.
unsqueeze
(
0
)
pair_mask
=
pair_mask
.
unsqueeze
(
0
)
msa_mask
=
torch
.
nn
.
functional
.
pad
(
msa_mask
,
(
0
,
padding_size
))
pair_mask
=
torch
.
nn
.
functional
.
pad
(
pair_mask
,
(
0
,
padding_size
,
0
,
padding_size
))
if
not
self
.
is_multimer
:
m
[
0
]
=
self
.
msa
(
m
[
0
],
z
[
0
],
msa_mask
)
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
],
work
=
All_to_All_Async
.
apply
(
m
[
0
],
1
,
2
)
z
=
self
.
pair
.
inplace
(
z
,
pair_mask
)
m
[
0
]
=
All_to_All_Async_Opp
.
apply
(
m
[
0
],
work
,
1
,
2
)
else
:
z
=
self
.
communication
.
inplace
(
m
[
0
],
msa_mask
,
z
)
m
[
0
]
=
col_to_row
(
m
[
0
])
m
[
0
]
=
self
.
msa
(
m
[
0
],
z
[
0
],
msa_mask
)
z
=
self
.
pair
.
inplace
(
z
,
pair_mask
)
if
self
.
last_block
:
m
[
0
]
=
m
[
0
].
squeeze
(
0
)
z
[
0
]
=
z
[
0
].
squeeze
(
0
)
if
self
.
is_multimer
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
1
)
else
:
m
[
0
]
=
gather
(
m
[
0
],
dim
=
0
)
z
[
0
]
=
gather
(
z
[
0
],
dim
=
0
)
m
[
0
]
=
m
[
0
][:,
:
-
padding_size
,
:]
z
[
0
]
=
z
[
0
][:
-
padding_size
,
:
-
padding_size
,
:]
return
m
,
z
class
EvoformerStack
(
nn
.
Module
):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def
__init__
(
self
,
c_m
:
int
,
c_z
:
int
,
c_s
:
int
,
no_blocks
:
int
,
blocks_per_ckpt
:
int
,
clear_cache_between_blocks
:
bool
=
False
,
is_multimer
:
bool
=
False
,
**
kwargs
,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super
(
EvoformerStack
,
self
).
__init__
()
self
.
blocks_per_ckpt
=
blocks_per_ckpt
self
.
clear_cache_between_blocks
=
clear_cache_between_blocks
self
.
blocks
=
nn
.
ModuleList
()
for
block_id
in
range
(
no_blocks
):
block
=
Evoformer
(
c_m
=
c_m
,
c_z
=
c_z
,
first_block
=
(
block_id
==
0
),
last_block
=
(
block_id
==
no_blocks
-
1
),
is_multimer
=
is_multimer
,
)
self
.
blocks
.
append
(
block
)
self
.
linear
=
Linear
(
c_m
,
c_s
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks
=
[
partial
(
b
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
s
=
self
.
linear
(
m
[...,
0
,
:,
:])
return
m
,
z
,
s
def
inplace
(
self
,
m
:
torch
.
Tensor
,
z
:
torch
.
Tensor
,
msa_mask
:
torch
.
Tensor
,
pair_mask
:
torch
.
Tensor
,
chunk_size
:
int
,
_mask_trans
:
bool
=
True
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks
=
[
partial
(
b
.
inplace
,
msa_mask
=
msa_mask
,
pair_mask
=
pair_mask
,
chunk_size
=
chunk_size
,
_mask_trans
=
_mask_trans
,
)
for
b
in
self
.
blocks
]
if
(
self
.
clear_cache_between_blocks
):
def
block_with_cache_clear
(
block
,
*
args
):
torch
.
cuda
.
empty_cache
()
return
block
(
*
args
)
blocks
=
[
partial
(
block_with_cache_clear
,
b
)
for
b
in
blocks
]
m
,
z
=
checkpoint_blocks
(
blocks
,
args
=
(
m
,
z
),
blocks_per_ckpt
=
self
.
blocks_per_ckpt
if
self
.
training
else
None
,
)
s
=
self
.
linear
(
m
[
0
][...,
0
,
:,
:])
return
m
,
z
,
s
fastfold/model/fastnn/initializer.py
0 → 100644
View file @
b14e47f4
import
math
import
numpy
as
np
import
torch.nn
as
nn
def
glorot_uniform_af
(
x
,
gain
=
1.0
):
"""
initialize tensors the same as xavier_initializer in PyTorch, but the dimensions are different:
In PyTorch:
[feature_out, feature_in, n_head ...]
In Jax:
[... n_head, feature_in, feature_out]
However, there is a feature in original Alphafold2 code that they use the Jax version initializer to initialize tensors like:
[feature_in, n_head, feature_out]
In this function, we keep this feature to initialize [feature_in, n_head, ..., feature_out] tensors
"""
fan_in
,
fan_out
=
x
.
shape
[
-
2
:]
if
len
(
x
.
shape
)
>
2
:
receptive_field_size
=
np
.
prod
(
x
.
shape
[:
-
2
])
fan_in
*=
receptive_field_size
fan_out
*=
receptive_field_size
std
=
gain
*
math
.
sqrt
(
2.0
/
float
(
fan_in
+
fan_out
))
dev
=
math
.
sqrt
(
3.0
)
*
std
# Calculate uniform bounds from standard deviation
nn
.
init
.
uniform_
(
x
,
-
dev
,
dev
)
return
x
fastfold/model/fastnn/kernel/__init__.py
0 → 100644
View file @
b14e47f4
from
.jit.fused_ops
import
bias_dropout_add
,
bias_sigmod_ele
,
bias_ele_dropout_residual
from
.layer_norm
import
FusedLayerNorm
as
LayerNorm
from
.softmax
import
fused_softmax
from
.attention_core
import
fused_attention_core
__all__
=
[
"bias_dropout_add"
,
"bias_sigmod_ele"
,
"bias_ele_dropout_residual"
,
"LayerNorm"
,
"fused_softmax"
,
"fused_attention_core"
,
]
\ No newline at end of file
fastfold/model/fastnn/kernel/attention_core.py
0 → 100644
View file @
b14e47f4
import
math
import
logging
import
torch
from
einops
import
rearrange
_triton_available
=
True
if
_triton_available
:
try
:
from
.triton.attention_core
import
attention_core_triton_kernel_wrapper
except
ImportError
:
logging
.
warning
(
"Triton is not available, fallback to old kernel."
)
_triton_available
=
False
def
_torch_attention_core
(
q
,
k
,
v
,
mask
,
bias
):
scaling
=
1.
/
math
.
sqrt
(
q
.
size
(
-
1
))
q
=
q
*
scaling
logits
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
))
logits
+=
bias
logits
+=
(
1e20
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
weights
=
torch
.
nn
.
functional
.
softmax
(
logits
.
float
(),
-
1
).
to
(
dtype
=
q
.
dtype
)
weighted_avg
=
torch
.
matmul
(
weights
,
v
)
weighted_avg
=
rearrange
(
weighted_avg
,
'b1 b2 h n d -> b1 b2 n (h d)'
)
return
weighted_avg
class
FusedAttenionCoreFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
mask
=
None
,
bias
=
None
):
if
_triton_available
:
o
=
attention_core_triton_kernel_wrapper
(
q
,
k
,
v
,
mask
,
bias
)
else
:
o
=
_torch_attention_core
(
q
,
k
,
v
,
mask
,
bias
)
# ctx.save_for_backward(q, k, v, o, L, m, mask, bias)
# ctx.BLOCK = BLOCK
# ctx.grid = grid
# ctx.sm_scale = sm_scale
# ctx.BLOCK_DMODEL = Lk
return
o
fused_attention_core
=
FusedAttenionCoreFunc
.
apply
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/__init__.py
0 → 100644
View file @
b14e47f4
fastfold/model/fastnn/kernel/cuda_native/csrc/compat.h
0 → 100644
View file @
b14e47f4
// modified from https://github.com/NVIDIA/apex/blob/master/csrc/compat.h
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/layer_norm_cuda.cpp
0 → 100644
View file @
b14e47f4
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <cassert>
#include <vector>
#include "compat.h"
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n2
)
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
assert
(
input
.
sizes
()[
i
+
idiff
]
==
normalized_shape
[
i
]);
n2
*=
normalized_shape
[
i
];
}
n1
=
1
;
for
(
int
i
=
0
;
i
<
idiff
;
++
i
)
{
n1
*=
input
.
sizes
()[
i
];
}
}
void
check_args
(
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
)
{
TORCH_CHECK
(
!
gamma
.
defined
()
||
gamma
.
sizes
().
equals
(
normalized_shape
));
TORCH_CHECK
(
!
beta
.
defined
()
||
beta
.
sizes
().
equals
(
normalized_shape
));
}
void
check_args
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n2
)
{
int64_t
normalized_ndim
=
normalized_shape
.
size
();
if
(
normalized_ndim
<
1
)
{
std
::
stringstream
ss
;
ss
<<
"Expected normalized_shape to be at least 1-dimensional, i.e., "
<<
"containing at least one element, but got normalized_shape="
<<
normalized_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
auto
input_shape
=
input
.
sizes
();
auto
input_ndim
=
input
.
dim
();
if
(
input_ndim
<
normalized_ndim
||
!
input_shape
.
slice
(
input_ndim
-
normalized_ndim
).
equals
(
normalized_shape
))
{
std
::
stringstream
ss
;
ss
<<
"Given normalized_shape="
<<
normalized_shape
<<
", expected input with shape [*"
;
for
(
auto
size
:
normalized_shape
)
{
ss
<<
", "
<<
size
;
}
ss
<<
"], but got input of size"
<<
input_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
compute_n1_n2
(
input
,
normalized_shape
,
n1
,
n2
);
}
void
check_args
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
int
&
n1
,
int
&
n2
)
{
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
normalized_shape
,
gamma
,
beta
);
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
}
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/layer_norm_cuda_kernel.cu
0 → 100644
View file @
b14e47f4
// part of code modified from https://github.com/NVIDIA/apex
#include <cooperative_groups.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
#include <THC/THCDeviceUtils.cuh>
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#include "type_shim.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
inline
__device__
void
WelfordOnline
(
float
val
,
float
*
mean
,
float
*
m2
,
float
*
count
)
{
*
count
+=
1
;
float
delta1
=
val
-
*
mean
;
*
mean
+=
delta1
/
(
*
count
);
float
delta2
=
val
-
*
mean
;
*
m2
+=
delta1
*
delta2
;
}
inline
__device__
void
WelfordOnline
(
float
b_mean
,
float
b_m2
,
float
b_count
,
float
*
mean
,
float
*
m2
,
float
*
count
)
{
if
(
b_count
==
0
)
{
return
;
}
float
new_count
=
*
count
+
b_count
;
float
nb_n
=
b_count
/
new_count
;
float
delta
=
b_mean
-
*
mean
;
*
mean
+=
delta
*
nb_n
;
*
m2
+=
b_m2
+
delta
*
delta
*
(
*
count
)
*
nb_n
;
*
count
=
new_count
;
}
__inline__
__device__
void
WelfordWarpAllReduce
(
float
thread_mean
,
float
thread_m2
,
float
thread_count
,
float
*
mean
,
float
*
m2
,
float
*
count
)
{
*
mean
=
thread_mean
;
*
m2
=
thread_m2
;
*
count
=
thread_count
;
for
(
int
mask
=
1
;
mask
<
32
;
mask
*=
2
)
{
float
b_mean
=
__shfl_down_sync
(
0xffffffff
,
*
mean
,
mask
);
float
b_m2
=
__shfl_down_sync
(
0xffffffff
,
*
m2
,
mask
);
float
b_count
=
__shfl_down_sync
(
0xffffffff
,
*
count
,
mask
);
WelfordOnline
(
b_mean
,
b_m2
,
b_count
,
mean
,
m2
,
count
);
}
*
mean
=
__shfl_sync
(
0xffffffff
,
*
mean
,
0
,
32
);
*
m2
=
__shfl_sync
(
0xffffffff
,
*
m2
,
0
,
32
);
*
count
=
__shfl_sync
(
0xffffffff
,
*
count
,
0
,
32
);
}
template
<
typename
T
>
__global__
void
fastfold_layernorm
(
T
*
input
,
T
*
output
,
T
*
gamma
,
T
*
beta
,
float
*
mean
,
float
*
invvar
,
int
rows
,
int
cols
,
double
epsilon
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
int
row_offset
=
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
float
buf
[
32
];
float
thread_mean
=
0.
f
;
float
thread_m2
=
0.
f
;
float
thread_count
=
0.
f
;
float
warp_mean
;
float
warp_m2
;
float
warp_count
;
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
buf
[
i
]
=
static_cast
<
float
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
WelfordOnline
(
buf
[
i
],
&
thread_mean
,
&
thread_m2
,
&
thread_count
);
}
WelfordWarpAllReduce
(
thread_mean
,
thread_m2
,
thread_count
,
&
warp_mean
,
&
warp_m2
,
&
warp_count
);
float
row_mean
=
warp_mean
;
float
row_variance
=
max
(
warp_m2
/
warp_count
,
0.
f
);
float
row_inv_var
=
rsqrt
(
row_variance
+
epsilon
);
if
(
lane_id
==
0
)
{
mean
[
row_offset
]
=
row_mean
;
invvar
[
row_offset
]
=
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
buf
[
i
]
=
(
buf
[
i
]
-
row_mean
)
*
row_inv_var
;
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
buf
[
i
])
*
gamma
[
lane_id
*
cols_per_thread
+
i
]
+
beta
[
lane_id
*
cols_per_thread
+
i
];
}
}
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
rows
,
int
cols
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
)
{
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
->
dtype
()
==
torch
::
kFloat32
)
{
fastfold_layernorm
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
input
->
data_ptr
(),
(
float
*
)
output
->
data_ptr
(),
(
float
*
)
gamma
->
data_ptr
(),
(
float
*
)
beta
->
data_ptr
(),
(
float
*
)
mean
->
data_ptr
(),
(
float
*
)
invvar
->
data_ptr
(),
rows
,
cols
,
epsilon
);
}
else
if
(
output
->
dtype
()
==
torch
::
kFloat16
)
{
fastfold_layernorm
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
->
data_ptr
(),
(
at
::
Half
*
)
output
->
data_ptr
(),
(
at
::
Half
*
)
gamma
->
data_ptr
(),
(
at
::
Half
*
)
beta
->
data_ptr
(),
(
float
*
)
mean
->
data_ptr
(),
(
float
*
)
invvar
->
data_ptr
(),
rows
,
cols
,
epsilon
);
}
else
if
(
output
->
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_layernorm
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
->
data_ptr
(),
(
at
::
BFloat16
*
)
output
->
data_ptr
(),
(
at
::
BFloat16
*
)
gamma
->
data_ptr
(),
(
at
::
BFloat16
*
)
beta
->
data_ptr
(),
(
float
*
)
mean
->
data_ptr
(),
(
float
*
)
invvar
->
data_ptr
(),
rows
,
cols
,
epsilon
);
}
}
template
<
typename
T
>
struct
SharedMemory
;
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputePartGradGammaBeta
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y +
// (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
__syncthreads
();
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputeGradInput
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
V
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
gamma
[
l
+
k
];
sum_loss2
+=
c_loss
*
gamma
[
l
+
k
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
gamma
[
l
];
sum_loss2
+=
c_loss
*
gamma
[
l
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
else
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
gamma
[
l
];
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
const
V
*
dout
,
const
U
*
mean
,
const
U
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
const
V
*
gamma
,
const
V
*
beta
,
double
epsilon
,
T
*
grad_input
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
());
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
(),
part_size
,
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
gamma
->
scalar_type
(),
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
dout
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);)
}
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda.cpp
0 → 100644
View file @
b14e47f4
#include <torch/extension.h>
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
);
at
::
Tensor
fused_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
softmax
,
"Softmax forward (CUDA)"
);
m
.
def
(
"backward"
,
&
softmax_gradient
,
"Softmax backward (CUDA)"
);
m
.
def
(
"fused_mask_softmax_forward"
,
&
fused_mask_softmax_forward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_mask_softmax_backward"
,
&
fused_mask_softmax_backward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_mask_bias_softmax_forward"
,
&
fused_mask_bias_softmax_forward
,
"Softmax forward (CUDA)"
);
m
.
def
(
"fused_mask_bias_softmax_backward"
,
&
fused_mask_bias_softmax_backward
,
"Softmax forward (CUDA)"
);
}
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/csrc/softmax_cuda_kernel.cu
0 → 100644
View file @
b14e47f4
#include <c10/cuda/CUDAGuard.h>
#include <math_constants.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include <iostream>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "compat.h"
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
__inline__
__device__
float
WarpAllReduceMax
(
float
val
)
{
for
(
int
mask
=
1
;
mask
<
32
;
mask
*=
2
)
{
val
=
max
(
val
,
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
));
}
return
val
;
}
__inline__
__device__
float
WarpAllReduceSum
(
float
val
)
{
for
(
int
mask
=
1
;
mask
<
32
;
mask
*=
2
)
{
val
+=
__shfl_xor_sync
(
0xffffffff
,
val
,
mask
);
}
return
val
;
}
inline
cudaError_t
GetNumBlocks
(
int64_t
block_size
,
int64_t
max_blocks
,
int64_t
waves
,
int
*
num_blocks
)
{
int
dev
;
{
cudaError_t
err
=
cudaGetDevice
(
&
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
int
sm_count
;
{
cudaError_t
err
=
cudaDeviceGetAttribute
(
&
sm_count
,
cudaDevAttrMultiProcessorCount
,
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
int
tpm
;
{
cudaError_t
err
=
cudaDeviceGetAttribute
(
&
tpm
,
cudaDevAttrMaxThreadsPerMultiProcessor
,
dev
);
if
(
err
!=
cudaSuccess
)
{
return
err
;
}
}
*
num_blocks
=
std
::
max
<
int
>
(
1
,
std
::
min
<
int64_t
>
(
max_blocks
,
sm_count
*
tpm
/
block_size
*
waves
));
return
cudaSuccess
;
}
template
<
typename
T
>
struct
SumOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
MaxOp
{
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
max
(
a
,
b
);
}
};
template
<
template
<
typename
>
class
ReductionOp
,
typename
T
,
int
block_size
>
__inline__
__device__
T
BlockAllReduce
(
T
val
)
{
typedef
cub
::
BlockReduce
<
T
,
block_size
>
BlockReduce
;
__shared__
typename
BlockReduce
::
TempStorage
temp_storage
;
__shared__
T
result_broadcast
;
T
result
=
BlockReduce
(
temp_storage
).
Reduce
(
val
,
ReductionOp
<
T
>
());
if
(
threadIdx
.
x
==
0
)
{
result_broadcast
=
result
;
}
__syncthreads
();
return
result_broadcast
;
}
////////////////
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax
(
T
*
input
,
T
*
output
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
float
buf
[
cols_per_thread
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
else
{
buf
[
i
]
=
-
1
*
CUDART_INF_F
;
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
}
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_sm
(
T
*
input
,
T
*
output
,
long
long
rows
,
long
long
cols
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
for
(
int64_t
row
=
blockIdx
.
x
;
row
<
rows
;
row
+=
gridDim
.
x
)
{
float
thread_max
=
-
1
*
CUDART_INF_F
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
buf
[
id
]
=
static_cast
<
T
>
(
input
[
row
*
cols
+
id
]);
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
const
float
row_max
=
BlockAllReduce
<
MaxOp
,
float
,
block_size
>
(
thread_max
);
float
thread_sum
=
0
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
buf
[
id
]
=
__expf
(
buf
[
id
]
-
row_max
);
thread_sum
+=
buf
[
id
];
}
const
float
row_sum
=
BlockAllReduce
<
SumOp
,
float
,
block_size
>
(
thread_sum
);
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
output
[
row
*
cols
+
id
]
=
static_cast
<
T
>
(
buf
[
id
]
/
row_sum
);
}
}
}
at
::
Tensor
softmax
(
at
::
Tensor
input
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
at
::
Tensor
output
=
at
::
empty_like
(
input
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax
<
float
,
1
><<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax
<
at
::
Half
,
1
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax
<
at
::
BFloat16
,
1
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)output.data_ptr(), rows, cols); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax<at::Half, col_per_thread><<<grid, block>>>( \
(at::Half *)input.data_ptr(), (at::Half *)output.data_ptr(), rows, cols); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)output.data_ptr(), rows, cols); \
} \
}
COLS_CASE
(
2
)
COLS_CASE
(
3
)
COLS_CASE
(
4
)
COLS_CASE
(
5
)
COLS_CASE
(
6
)
COLS_CASE
(
7
)
COLS_CASE
(
8
)
COLS_CASE
(
9
)
COLS_CASE
(
10
)
COLS_CASE
(
11
)
COLS_CASE
(
12
)
COLS_CASE
(
13
)
COLS_CASE
(
14
)
COLS_CASE
(
15
)
COLS_CASE
(
16
)
COLS_CASE
(
17
)
COLS_CASE
(
18
)
COLS_CASE
(
19
)
COLS_CASE
(
20
)
COLS_CASE
(
21
)
COLS_CASE
(
22
)
COLS_CASE
(
23
)
COLS_CASE
(
24
)
COLS_CASE
(
25
)
COLS_CASE
(
26
)
COLS_CASE
(
27
)
COLS_CASE
(
28
)
COLS_CASE
(
29
)
COLS_CASE
(
30
)
COLS_CASE
(
31
)
COLS_CASE
(
32
)
#undef COLS_CASE
else
{
int
grid_dim
;
constexpr
int
waves
=
32
;
GetNumBlocks
(
128
,
rows
,
waves
,
&
grid_dim
);
dim3
block
(
128
);
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_sm
<
float
,
128
><<<
grid_dim
,
block
,
smem
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_sm
<
at
::
Half
,
128
><<<
grid_dim
,
block
,
smem
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_sm
<
at
::
BFloat16
,
128
><<<
grid_dim
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
rows
,
cols
);
}
}
return
output
;
}
template
<
typename
T
>
__global__
void
fastfold_softmax_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
long
long
rows
,
long
long
cols
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
T
*
row_d_output
=
d_output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_d_input
=
d_input
+
row_offset
*
cols
;
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
y_buf
[
i
]
=
static_cast
<
T
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
T
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
}
}
}
at
::
Tensor
softmax_gradient
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
output
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
output
));
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_grad
<
float
><<<
grid
,
block
>>>
((
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_grad
<
at
::
Half
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
rows
,
cols
);
}
return
grad_input
;
}
////////////////
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax_mask
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
float
buf
[
cols_per_thread
];
int
lane_id
=
threadidx_y
;
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
1e9
;
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
]);
}
}
else
{
buf
[
i
]
=
-
1
*
CUDART_INF_F
;
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_mask_sm
(
T
*
input
,
T
*
mask
,
T
*
output
,
long
long
rows
,
long
long
cols
,
int
head
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
for
(
int64_t
row
=
blockIdx
.
x
;
row
<
rows
;
row
+=
gridDim
.
x
)
{
T
*
mask_ptr
=
mask
+
((
row
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
if
(
mask_ptr
[
id
]
==
0
)
{
buf
[
id
]
=
-
1
*
1e9
;
}
else
{
buf
[
id
]
=
input
[
row
*
cols
+
id
];
}
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
const
float
row_max
=
BlockAllReduce
<
MaxOp
,
float
,
block_size
>
(
thread_max
);
float
thread_sum
=
0
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
buf
[
id
]
=
__expf
(
buf
[
id
]
-
row_max
);
thread_sum
+=
buf
[
id
];
}
const
float
row_sum
=
BlockAllReduce
<
SumOp
,
float
,
block_size
>
(
thread_sum
);
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
output
[
row
*
cols
+
id
]
=
buf
[
id
]
/
row_sum
;
}
}
}
at
::
Tensor
fused_mask_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
head
=
input
.
sizes
()[
2
];
// at::Tensor output = at::empty_like(input);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_mask
<
float
,
1
>
<<<
grid
,
block
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_mask
<
at
::
Half
,
1
>
<<<
grid
,
block
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_mask
<
at
::
BFloat16
,
1
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_mask<float, col_per_thread> \
<<<grid, block>>>((float *)input.data_ptr(), (float *)mask.data_ptr(), \
(float *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_mask<at::Half, col_per_thread> \
<<<grid, block>>>((at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_mask<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)input.data_ptr(), rows, cols, head); \
} \
}
COLS_CASE
(
2
)
COLS_CASE
(
3
)
COLS_CASE
(
4
)
COLS_CASE
(
5
)
COLS_CASE
(
6
)
COLS_CASE
(
7
)
COLS_CASE
(
8
)
COLS_CASE
(
9
)
COLS_CASE
(
10
)
COLS_CASE
(
11
)
COLS_CASE
(
12
)
COLS_CASE
(
13
)
COLS_CASE
(
14
)
COLS_CASE
(
15
)
COLS_CASE
(
16
)
COLS_CASE
(
17
)
COLS_CASE
(
18
)
COLS_CASE
(
19
)
COLS_CASE
(
20
)
COLS_CASE
(
21
)
COLS_CASE
(
22
)
COLS_CASE
(
23
)
COLS_CASE
(
24
)
COLS_CASE
(
25
)
COLS_CASE
(
26
)
COLS_CASE
(
27
)
COLS_CASE
(
28
)
COLS_CASE
(
29
)
COLS_CASE
(
30
)
COLS_CASE
(
31
)
COLS_CASE
(
32
)
#undef COLS_CASE
else
{
int
grid_dim
;
constexpr
int
waves
=
32
;
GetNumBlocks
(
128
,
rows
,
waves
,
&
grid_dim
);
dim3
block
(
128
);
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_mask_sm
<
float
,
128
>
<<<
grid
,
block
,
smem
>>>
((
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_mask_sm
<
at
::
Half
,
128
>
<<<
grid
,
block
,
smem
>>>
((
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_mask_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
}
return
input
;
}
template
<
typename
T
>
__global__
void
fastfold_softmax_mask_grad
(
T
*
d_output
,
T
*
output
,
T
*
d_input
,
T
*
mask
,
long
long
rows
,
long
long
cols
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
int
cols_per_thread
=
(
cols
+
31
)
/
32
;
int
cols_this_thread
=
cols_per_thread
;
int
last_y
=
(
cols
/
cols_per_thread
);
if
(
threadidx_y
==
last_y
)
{
cols_this_thread
=
cols
-
cols_per_thread
*
last_y
;
}
else
if
(
threadidx_y
>
last_y
)
{
cols_this_thread
=
0
;
}
float
y_buf
[
32
];
float
dy_buf
[
32
];
int
lane_id
=
threadidx_y
;
if
(
row_offset
<
rows
)
{
T
*
row_d_output
=
d_output
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
row_d_input
=
d_input
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
y_buf
[
i
]
=
static_cast
<
T
>
(
row_output
[
lane_id
*
cols_per_thread
+
i
]);
dy_buf
[
i
]
=
static_cast
<
T
>
(
row_d_output
[
lane_id
*
cols_per_thread
+
i
]);
}
}
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
i
++
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
thread_sum
+=
y_buf
[
i
]
*
dy_buf
[
i
];
}
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_this_thread
;
++
i
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
!=
0
)
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
((
dy_buf
[
i
]
-
warp_sum
)
*
y_buf
[
i
]);
}
else
{
row_d_input
[
lane_id
*
cols_per_thread
+
i
]
=
0
;
}
}
}
}
}
at
::
Tensor
fused_mask_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
int
head
=
output
.
sizes
()[
2
];
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_mask_grad
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_mask_grad
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
return
grad_input
;
}
////////////////
template
<
typename
T
,
int
cols_per_thread
>
__global__
void
fastfold_softmax_mask_bias
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
int
head
)
{
int
threadidx_x
=
threadIdx
.
x
/
32
;
int
threadidx_y
=
threadIdx
.
x
%
32
;
long
long
row_offset
=
(
long
long
)
blockIdx
.
x
*
4
+
threadidx_x
;
float
buf
[
cols_per_thread
];
int
lane_id
=
threadidx_y
;
T
*
row_input
=
input
+
row_offset
*
cols
;
T
*
row_output
=
output
+
row_offset
*
cols
;
T
*
mask_ptr
=
mask
+
((
row_offset
/
(
head
*
cols
))
*
cols
);
T
*
bias_ptr
=
bias
+
((
row_offset
%
(
head
*
cols
))
*
cols
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
if
(
mask_ptr
[
lane_id
*
cols_per_thread
+
i
]
==
0
)
{
buf
[
i
]
=
-
1
*
10e9
;
}
else
{
buf
[
i
]
=
static_cast
<
T
>
(
row_input
[
lane_id
*
cols_per_thread
+
i
])
+
static_cast
<
T
>
(
bias_ptr
[
lane_id
*
cols_per_thread
+
i
]);
}
}
else
{
buf
[
i
]
=
-
1
*
CUDART_INF_F
;
}
}
float
thread_max
=
-
1
*
CUDART_INF_F
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
i
++
)
{
thread_max
=
max
(
thread_max
,
buf
[
i
]);
}
float
warp_max
=
WarpAllReduceMax
(
thread_max
);
float
thread_sum
=
0.
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
buf
[
i
]
=
__expf
(
buf
[
i
]
-
warp_max
);
thread_sum
+=
buf
[
i
];
}
float
warp_sum
=
WarpAllReduceSum
(
thread_sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
cols_per_thread
;
++
i
)
{
if
(
lane_id
*
cols_per_thread
+
i
<
cols
)
{
row_output
[
lane_id
*
cols_per_thread
+
i
]
=
static_cast
<
T
>
(
__fdividef
(
buf
[
i
],
warp_sum
));
}
}
}
template
<
typename
T
,
int
block_size
>
__global__
void
fastfold_softmax_mask_bias_sm
(
T
*
input
,
T
*
mask
,
T
*
bias
,
T
*
output
,
long
long
rows
,
long
long
cols
,
int
head
)
{
extern
__shared__
__align__
(
sizeof
(
double
))
unsigned
char
shared_buf
[];
auto
*
buf
=
reinterpret_cast
<
float
*>
(
shared_buf
);
const
int
tid
=
threadIdx
.
x
;
for
(
int64_t
row
=
blockIdx
.
x
;
row
<
rows
;
row
+=
gridDim
.
x
)
{
T
*
mask_ptr
=
mask
+
((
row
/
(
head
*
cols
))
*
cols
);
T
*
bias_ptr
=
bias
+
((
row
%
(
head
*
cols
))
*
cols
);
float
thread_max
=
-
1
*
CUDART_INF_F
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
if
(
mask_ptr
[
id
]
==
0
)
{
buf
[
id
]
=
-
1
*
1e9
;
}
else
{
buf
[
id
]
=
input
[
row
*
cols
+
id
]
+
bias_ptr
[
id
];
}
thread_max
=
max
(
thread_max
,
buf
[
id
]);
}
const
float
row_max
=
BlockAllReduce
<
MaxOp
,
float
,
block_size
>
(
thread_max
);
float
thread_sum
=
0
;
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
buf
[
id
]
=
__expf
(
buf
[
id
]
-
row_max
);
thread_sum
+=
buf
[
id
];
}
const
float
row_sum
=
BlockAllReduce
<
SumOp
,
float
,
block_size
>
(
thread_sum
);
for
(
int
id
=
tid
;
id
<
cols
;
id
+=
block_size
)
{
output
[
row
*
cols
+
id
]
=
buf
[
id
]
/
row_sum
;
}
}
}
at
::
Tensor
fused_mask_bias_softmax_forward
(
at
::
Tensor
input
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
mask
);
CHECK_INPUT
(
bias
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
int
head
=
input
.
sizes
()[
2
];
// at::Tensor output = at::empty_like(input);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
cols
<=
32
)
{
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_mask_bias
<
float
,
1
><<<
grid
,
block
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_mask_bias
<
at
::
Half
,
1
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_mask_bias
<
at
::
BFloat16
,
1
>
<<<
grid
,
block
>>>
((
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
}
#define COLS_CASE(col_per_thread) \
else if (cols <= col_per_thread * 32) { \
if (input.dtype() == torch::kFloat32) { \
fastfold_softmax_mask_bias<float, col_per_thread><<<grid, block>>>( \
(float *)input.data_ptr(), (float *)mask.data_ptr(), (float *)bias.data_ptr(), \
(float *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kFloat16) { \
fastfold_softmax_mask_bias<at::Half, col_per_thread><<<grid, block>>>( \
(at::Half *)input.data_ptr(), (at::Half *)mask.data_ptr(), \
(at::Half *)bias.data_ptr(), (at::Half *)input.data_ptr(), rows, cols, head); \
} else if (input.dtype() == torch::kBFloat16) { \
fastfold_softmax_mask_bias<at::BFloat16, col_per_thread><<<grid, block>>>( \
(at::BFloat16 *)input.data_ptr(), (at::BFloat16 *)mask.data_ptr(), \
(at::BFloat16 *)bias.data_ptr(), (at::BFloat16 *)input.data_ptr(), rows, cols, \
head); \
} \
}
COLS_CASE
(
2
)
COLS_CASE
(
3
)
COLS_CASE
(
4
)
COLS_CASE
(
5
)
COLS_CASE
(
6
)
COLS_CASE
(
7
)
COLS_CASE
(
8
)
COLS_CASE
(
9
)
COLS_CASE
(
10
)
COLS_CASE
(
11
)
COLS_CASE
(
12
)
COLS_CASE
(
13
)
COLS_CASE
(
14
)
COLS_CASE
(
15
)
COLS_CASE
(
16
)
COLS_CASE
(
17
)
COLS_CASE
(
18
)
COLS_CASE
(
19
)
COLS_CASE
(
20
)
COLS_CASE
(
21
)
COLS_CASE
(
22
)
COLS_CASE
(
23
)
COLS_CASE
(
24
)
COLS_CASE
(
25
)
COLS_CASE
(
26
)
COLS_CASE
(
27
)
COLS_CASE
(
28
)
COLS_CASE
(
29
)
COLS_CASE
(
30
)
COLS_CASE
(
31
)
COLS_CASE
(
32
)
#undef COLS_CASE
else
{
int
grid_dim
;
constexpr
int
waves
=
32
;
GetNumBlocks
(
128
,
rows
,
waves
,
&
grid_dim
);
dim3
block
(
128
);
const
size_t
smem
=
cols
*
sizeof
(
float
);
if
(
input
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_mask_bias_sm
<
float
,
128
><<<
grid
,
block
,
smem
>>>
(
(
float
*
)
input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
(
float
*
)
bias
.
data_ptr
(),
(
float
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_mask_bias_sm
<
at
::
Half
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
Half
*
)
input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
(
at
::
Half
*
)
bias
.
data_ptr
(),
(
at
::
Half
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
input
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_mask_bias_sm
<
at
::
BFloat16
,
128
><<<
grid
,
block
,
smem
>>>
(
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
(
at
::
BFloat16
*
)
bias
.
data_ptr
(),
(
at
::
BFloat16
*
)
input
.
data_ptr
(),
rows
,
cols
,
head
);
}
}
return
input
;
}
at
::
Tensor
fused_mask_bias_softmax_backward
(
at
::
Tensor
d_output
,
at
::
Tensor
output
,
at
::
Tensor
mask
,
at
::
Tensor
bias
,
long
long
rows
,
long
long
cols
)
{
CHECK_INPUT
(
output
);
CHECK_INPUT
(
mask
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
mask
));
int
head
=
output
.
sizes
()[
2
];
at
::
Tensor
grad_input
=
at
::
empty_like
(
output
);
int
grid
=
(
rows
+
3
)
/
4
;
dim3
block
(
128
);
if
(
output
.
dtype
()
==
torch
::
kFloat32
)
{
fastfold_softmax_mask_grad
<
float
><<<
grid
,
block
>>>
(
(
float
*
)
d_output
.
data_ptr
(),
(
float
*
)
output
.
data_ptr
(),
(
float
*
)
grad_input
.
data_ptr
(),
(
float
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kFloat16
)
{
fastfold_softmax_mask_grad
<
at
::
Half
><<<
grid
,
block
>>>
(
(
at
::
Half
*
)
d_output
.
data_ptr
(),
(
at
::
Half
*
)
output
.
data_ptr
(),
(
at
::
Half
*
)
grad_input
.
data_ptr
(),
(
at
::
Half
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
else
if
(
output
.
dtype
()
==
torch
::
kBFloat16
)
{
fastfold_softmax_mask_grad
<
at
::
BFloat16
><<<
grid
,
block
>>>
(
(
at
::
BFloat16
*
)
d_output
.
data_ptr
(),
(
at
::
BFloat16
*
)
output
.
data_ptr
(),
(
at
::
BFloat16
*
)
grad_input
.
data_ptr
(),
(
at
::
BFloat16
*
)
mask
.
data_ptr
(),
rows
,
cols
,
head
);
}
return
grad_input
;
}
fastfold/model/fastnn/kernel/cuda_native/csrc/type_shim.h
0 → 100644
View file @
b14e47f4
// modified from https://github.com/NVIDIA/apex
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch (TYPEIN) { \
case at::ScalarType::Float: { \
using scalar_t_in = float; \
switch (TYPEOUT) { \
case at::ScalarType::Float: { \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: { \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: { \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Double: { \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: { \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
\ No newline at end of file
fastfold/model/fastnn/kernel/cuda_native/layer_norm.py
0 → 100644
View file @
b14e47f4
import
importlib
import
torch
fastfold_layer_norm_cuda
=
importlib
.
import_module
(
"fastfold_layer_norm_cuda"
)
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
weight_
=
weight
.
contiguous
()
bias_
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
fastfold_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
\
=
fastfold_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
fastfold/model/fastnn/kernel/cuda_native/softmax.py
0 → 100644
View file @
b14e47f4
import
importlib
fastfold_softmax_cuda
=
importlib
.
import_module
(
"fastfold_softmax_cuda"
)
def
softmax_cuda_kernel_wrapper
(
input_
,
mask_
,
bias_
,
rows
,
cols
):
if
bias_
is
not
None
:
output
=
fastfold_softmax_cuda
.
fused_mask_bias_softmax_forward
(
input_
,
mask_
,
bias_
,
rows
,
cols
)
elif
mask_
is
not
None
:
output
=
fastfold_softmax_cuda
.
fused_mask_softmax_forward
(
input_
,
mask_
,
rows
,
cols
)
else
:
output
=
fastfold_softmax_cuda
.
forward
(
input_
,
rows
,
cols
)
return
output
def
softmax_grad_cuda_kernel_wrapper
(
grad_output
,
output
,
mask_
,
rows
,
cols
):
if
mask_
is
not
None
:
grad_input
=
fastfold_softmax_cuda
.
fused_mask_softmax_backward
(
grad_output
,
output
,
mask_
,
rows
,
cols
)
else
:
grad_input
=
fastfold_softmax_cuda
.
backward
(
grad_output
,
output
,
rows
,
cols
)
return
grad_input
Prev
1
2
3
4
5
6
7
8
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment