Commit 7fb12cf5 authored by Christina Floristean's avatar Christina Floristean
Browse files

Update setup script and refactor qkv prep

parent 54d414e4
......@@ -379,7 +379,6 @@ class Attention(nn.Module):
def _prep_qkv(self,
q_x: torch.Tensor,
kv_x: torch.Tensor,
transpose_qkv_dims: bool = True,
apply_scale: bool = True
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor
......@@ -394,11 +393,10 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
if transpose_qkv_dims:
# [*, H, Q/K, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# [*, H, Q/K, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
if apply_scale:
q /= math.sqrt(self.c_hidden)
......@@ -486,10 +484,8 @@ class Attention(nn.Module):
if biases is None:
biases = []
# DeepSpeed attention kernel expects Q/K/V of shape [*, Q/K, H, C_hidden]
# All other attention modules expect Q/K/V of shape [*, H, Q/K, C_hidden]
# DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x,
transpose_qkv_dims=not use_deepspeed_evo_attention,
apply_scale=not use_deepspeed_evo_attention)
if is_fp16_enabled():
......@@ -629,11 +625,11 @@ def _deepspeed_evo_attn(
Args:
q:
[*, Q, H, C_hidden] query data
[*, H, Q, C_hidden] query data
k:
[*, K, H, C_hidden] key data
[*, H, K, C_hidden] key data
v:
[*, V, H, C_hidden] value data
[*, H, V, C_hidden] value data
biases:
List of biases that broadcast to [*, H, Q, K]
"""
......@@ -652,6 +648,11 @@ def _deepspeed_evo_attn(
return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x
# [*, Q/K, H, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape
......
......@@ -13,9 +13,7 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.
python setup.py install
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
echo "Attempting to download CUTLASS, required for Deepspeed Evoformer attention kernel"
echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass --depth 1
conda env config vars set CUTLASS_PATH=$PWD/cutlass
......@@ -24,3 +22,6 @@ conda env config vars set KMP_AFFINITY=none
# Reactivate env so that the above environment variables take effect
conda activate $CONDA_PREFIX
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
......@@ -155,7 +155,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
n_seq = 18
c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,)
eps = 2e-2
eps = 5e-2
activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment