"...models/git@developer.sourcefind.cn:OpenDAS/openpcdet.git" did not exist on "69326cef7a4c780524c090ff1e6739ec906ea79d"
Commit eb608db9 authored by Christina Floristean's avatar Christina Floristean
Browse files

Minor refactoring of ds kernel integration

parent 0a6230a3
......@@ -14,15 +14,16 @@
# limitations under the License.
import importlib
import math
from typing import Optional, Callable, List, Tuple, Sequence
from typing import Optional, Callable, List, Tuple
import numpy as np
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
if deepspeed_is_installed:
import deepspeed
if importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
if ds4s_is_installed:
from deepspeed.ops.deepspeed4science import DS4Sci_EvoformerAttention
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
if fa_is_installed:
......@@ -504,39 +505,7 @@ class Attention(nn.Module):
"If use_deepspeed_evo_attention is True, you may only "
"provide up to two bias terms"
)
orig_shape = q.shape
no_batch_dims = len(orig_shape[:-3])
if no_batch_dims > 2:
raise ValueError(
f"Q is of shape {list(orig_shape)} but must be "
"of shape [B, N, Q/K, H, C_hidden] if "
"use_deepspeed_evo_attention is True."
)
# Bypass asserts for bias shapes in DS4Sci_EvoformerAttention()
# by adding batch and N_seq dims if needed.
if no_batch_dims < 2:
addl_dims = (1,) * (2 - no_batch_dims)
q = q.view(*(addl_dims + q.shape))
k = k.view(*(addl_dims + k.shape))
v = v.view(*(addl_dims + v.shape))
biases = [b.view(*(addl_dims + b.shape)) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.view(orig_shape)
o = _deepspeed_evo_attn(q, k, v, biases)
elif use_lma:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
......@@ -644,6 +613,67 @@ class GlobalAttention(nn.Module):
return m
@torch.jit.ignore
def _deepspeed_evo_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
):
"""""
Compute attention using the DeepSpeed DS4Sci_EvoformerAttention kernel.
Args:
q:
[*, Q, H, C_hidden] query data
k:
[*, K, H, C_hidden] key data
v:
[*, V, H, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""
if not ds4s_is_installed:
raise ValueError(
"_deepspeed_evo_attn requires that DeepSpeed be installed "
"and that the deepspeed.ops.deepspeed4science package exists"
)
def reshape_dims(x):
no_batch_dims = len(x.shape[:-3])
if no_batch_dims < 2:
return x.reshape(*((1,) * (2 - no_batch_dims) + x.shape))
if no_batch_dims > 2:
return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape
if len(orig_shape[:-3]) != 2:
q = reshape_dims(q)
k = reshape_dims(k)
v = reshape_dims(v)
biases = [reshape_dims(b) for b in biases]
# DeepSpeed attn. kernel requires inputs to be type bf16 or fp16
# Cast to bf16 so kernel can be used during inference
orig_dtype = q.dtype
if orig_dtype not in [torch.bfloat16, torch.float16]:
o = DS4Sci_EvoformerAttention(q.to(dtype=torch.bfloat16),
k.to(dtype=torch.bfloat16),
v.to(dtype=torch.bfloat16),
[b.to(dtype=torch.bfloat16) for b in biases])
o = o.to(dtype=orig_dtype)
else:
o = DS4Sci_EvoformerAttention(q, k, v, biases)
o = o.reshape(orig_shape)
return o
def _lma(
q: torch.Tensor,
k: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment