"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "d346782c7c4bcbbb5ad0fdf0deeace42506ce887"
Commit eb608db9 authored by Christina Floristean's avatar Christina Floristean
Browse files

Minor refactoring of ds kernel integration

parent 0a6230a3
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import math import math
from typing import Optional, Callable, List, Tuple, Sequence from typing import Optional, Callable, List, Tuple
import numpy as np import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
if deepspeed_is_installed: if deepspeed_is_installed:
import deepspeed import deepspeed
if importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None: if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed: if fa_is_installed:
...@@ -504,39 +505,7 @@ class Attention(nn.Module): ...@@ -504,39 +505,7 @@ class Attention(nn.Module):
"If use_deepspeed_evo_attention is True, you may only " "If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms" "provide up to two bias terms"
) )
o = _deepspeed_evo_attn(q, k, v, biases)
orig_shape = q.shape
no_batch_dims = len(orig_shape[:-3])
if no_batch_dims > 2:
raise ValueError(
f"Q is of shape {list(orig_shape)} but must be "
"of shape [B, N, Q/K, H, C_hidden] if "
"use_deepspeed_evo_attention is True."
)
# Bypass asserts for bias shapes in DS4Sci_EvoformerAttention()
# by adding batch and N_seq dims if needed.
if no_batch_dims < 2:
addl_dims = (1,) * (2 - no_batch_dims)
q = q.view(*(addl_dims + q.shape))
k = k.view(*(addl_dims + k.shape))
v = v.view(*(addl_dims + v.shape))
biases = [b.view(*(addl_dims + b.shape)) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.view(orig_shape)
elif use_lma: elif use_lma:
biases = [ biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
...@@ -644,6 +613,67 @@ class GlobalAttention(nn.Module): ...@@ -644,6 +613,67 @@ class GlobalAttention(nn.Module):
return m return m
@torch.jit.ignore
def _deepspeed_evo_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
):
"""""
Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
Args:
q:
[*, Q, H, C_hidden] query data
k:
[*, K, H, C_hidden] key data
v:
[*, V, H, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""
if not ds4s_is_installed:
raise ValueError(
"_deepspeed_evo_attn requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
def reshape_dims(x):
no_batch_dims = len(x.shape[:-3])
if no_batch_dims < 2:
return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape))
if no_batch_dims > 2:
return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape
if len(orig_shape[:-3]) != 2:
q = reshape_dims(q)
k = reshape_dims(k)
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.reshape(orig_shape)
return o
def _lma( def _lma(
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment