Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
eb608db9
"...git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "c23c42084594e27147620bb0ef124fe24ca36061"
Commit
eb608db9
authored
Sep 13, 2023
by
Christina Floristean
Browse files
Minor refactoring of ds kernel integration
parent
0a6230a3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
36 deletions
+66
-36
openfold/model/primitives.py
openfold/model/primitives.py
+66
-36
No files found.
openfold/model/primitives.py
View file @
eb608db9
...
@@ -14,14 +14,15 @@
...
@@ -14,14 +14,15 @@
# limitations under the License.
# limitations under the License.
import
importlib
import
importlib
import
math
import
math
from
typing
import
Optional
,
Callable
,
List
,
Tuple
,
Sequence
from
typing
import
Optional
,
Callable
,
List
,
Tuple
import
numpy
as
np
import
numpy
as
np
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
deepspeed_is_installed
=
importlib
.
util
.
find_spec
(
"deepspeed"
)
is
not
None
ds4s_is_installed
=
deepspeed_is_installed
and
importlib
.
util
.
find_spec
(
"deepspeed.ops.deepspeed4science"
)
is
not
None
if
deepspeed_is_installed
:
if
deepspeed_is_installed
:
import
deepspeed
import
deepspeed
if
importlib
.
util
.
find_spec
(
"deepspeed.ops.deepspeed4science"
)
is
not
None
:
if
ds4s_is_installed
:
from
deepspeed.ops.deepspeed4science
import
DS4Sci_EvoformerAttention
from
deepspeed.ops.deepspeed4science
import
DS4Sci_EvoformerAttention
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
fa_is_installed
=
importlib
.
util
.
find_spec
(
"flash_attn"
)
is
not
None
...
@@ -504,39 +505,7 @@ class Attention(nn.Module):
...
@@ -504,39 +505,7 @@ class Attention(nn.Module):
"If use_deepspeed_evo_attention is True, you may only "
"If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
"provide up to two bias terms"
)
)
o
=
_deepspeed_evo_attn
(
q
,
k
,
v
,
biases
)
orig_shape
=
q
.
shape
no_batch_dims
=
len
(
orig_shape
[:
-
3
])
if
no_batch_dims
>
2
:
raise
ValueError
(
f
"Q is of shape
{
list
(
orig_shape
)
}
but must be "
"of shape [B, N, Q/K, H, C_hidden] if "
"use_deepspeed_evo_attention is True."
)
# Bypass asserts for bias shapes in DS4Sci_EvoformerAttention()
# by adding batch and N_seq dims if needed.
if
no_batch_dims
<
2
:
addl_dims
=
(
1
,)
*
(
2
-
no_batch_dims
)
q
=
q
.
view
(
*
(
addl_dims
+
q
.
shape
))
k
=
k
.
view
(
*
(
addl_dims
+
k
.
shape
))
v
=
v
.
view
(
*
(
addl_dims
+
v
.
shape
))
biases
=
[
b
.
view
(
*
(
addl_dims
+
b
.
shape
))
for
b
in
biases
]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype
=
q
.
dtype
if
orig_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]:
o
=
DS4Sci_EvoformerAttention
(
q
.
to
(
dtype
=
torch
.
bfloat16
),
k
.
to
(
dtype
=
torch
.
bfloat16
),
v
.
to
(
dtype
=
torch
.
bfloat16
),
[
b
.
to
(
dtype
=
torch
.
bfloat16
)
for
b
in
biases
])
o
=
o
.
to
(
dtype
=
orig_dtype
)
else
:
o
=
DS4Sci_EvoformerAttention
(
q
,
k
,
v
,
biases
)
o
=
o
.
view
(
orig_shape
)
elif
use_lma
:
elif
use_lma
:
biases
=
[
biases
=
[
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
b
.
expand
(
b
.
shape
[:
-
2
]
+
(
q_x
.
shape
[
-
2
],)
+
(
kv_x
.
shape
[
-
2
],))
...
@@ -644,6 +613,67 @@ class GlobalAttention(nn.Module):
...
@@ -644,6 +613,67 @@ class GlobalAttention(nn.Module):
return
m
return
m
@
torch
.
jit
.
ignore
def
_deepspeed_evo_attn
(
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
v
:
torch
.
Tensor
,
biases
:
List
[
torch
.
Tensor
],
):
"""""
Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
Args:
q:
[*, Q, H, C_hidden] query data
k:
[*, K, H, C_hidden] key data
v:
[*, V, H, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""
if
not
ds4s_is_installed
:
raise
ValueError
(
"_deepspeed_evo_attn requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
def
reshape_dims
(
x
):
no_batch_dims
=
len
(
x
.
shape
[:
-
3
])
if
no_batch_dims
<
2
:
return
x
.
reshape
(
*
((
1
,)
*
(
2
-
no_batch_dims
)
+
x
.
shape
))
if
no_batch_dims
>
2
:
return
x
.
reshape
(
*
((
x
.
shape
[
0
],
-
1
)
+
x
.
shape
[
-
3
:]))
return
x
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape
=
q
.
shape
if
len
(
orig_shape
[:
-
3
])
!=
2
:
q
=
reshape_dims
(
q
)
k
=
reshape_dims
(
k
)
v
=
reshape_dims
(
v
)
biases
=
[
reshape_dims
(
b
)
for
b
in
biases
]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype
=
q
.
dtype
if
orig_dtype
not
in
[
torch
.
bfloat16
,
torch
.
float16
]:
o
=
DS4Sci_EvoformerAttention
(
q
.
to
(
dtype
=
torch
.
bfloat16
),
k
.
to
(
dtype
=
torch
.
bfloat16
),
v
.
to
(
dtype
=
torch
.
bfloat16
),
[
b
.
to
(
dtype
=
torch
.
bfloat16
)
for
b
in
biases
])
o
=
o
.
to
(
dtype
=
orig_dtype
)
else
:
o
=
DS4Sci_EvoformerAttention
(
q
,
k
,
v
,
biases
)
o
=
o
.
reshape
(
orig_shape
)
return
o
def
_lma
(
def
_lma
(
q
:
torch
.
Tensor
,
q
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment