"git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "d57847b2a95ad4f92e816178feaed72dc06d1059"
Commit 7fb12cf5 authored by Christina Floristean's avatar Christina Floristean
Browse files

Update setup script and refactor qkv prep

parent 54d414e4
...@@ -379,7 +379,6 @@ class Attention(nn.Module): ...@@ -379,7 +379,6 @@ class Attention(nn.Module):
def _prep_qkv(self, def _prep_qkv(self,
q_x: torch.Tensor, q_x: torch.Tensor,
kv_x: torch.Tensor, kv_x: torch.Tensor,
transpose_qkv_dims: bool = True,
apply_scale: bool = True apply_scale: bool = True
) -> Tuple[ ) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor torch.Tensor, torch.Tensor, torch.Tensor
...@@ -394,11 +393,10 @@ class Attention(nn.Module): ...@@ -394,11 +393,10 @@ class Attention(nn.Module):
k = k.view(k.shape[:-1] + (self.no_heads, -1)) k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1)) v = v.view(v.shape[:-1] + (self.no_heads, -1))
if transpose_qkv_dims: # [*, H, Q/K, C_hidden]
# [*, H, Q/K, C_hidden] q = q.transpose(-2, -3)
q = q.transpose(-2, -3) k = k.transpose(-2, -3)
k = k.transpose(-2, -3) v = v.transpose(-2, -3)
v = v.transpose(-2, -3)
if apply_scale: if apply_scale:
q /= math.sqrt(self.c_hidden) q /= math.sqrt(self.c_hidden)
...@@ -486,10 +484,8 @@ class Attention(nn.Module): ...@@ -486,10 +484,8 @@ class Attention(nn.Module):
if biases is None: if biases is None:
biases = [] biases = []
# DeepSpeed attention kernel expects Q/K/V of shape [*, Q/K, H, C_hidden] # DeepSpeed attention kernel applies scaling internally
# All other attention modules expect Q/K/V of shape [*, H, Q/K, C_hidden]
q, k, v = self._prep_qkv(q_x, kv_x, q, k, v = self._prep_qkv(q_x, kv_x,
transpose_qkv_dims=not use_deepspeed_evo_attention,
apply_scale=not use_deepspeed_evo_attention) apply_scale=not use_deepspeed_evo_attention)
if is_fp16_enabled(): if is_fp16_enabled():
...@@ -629,11 +625,11 @@ def _deepspeed_evo_attn( ...@@ -629,11 +625,11 @@ def _deepspeed_evo_attn(
Args: Args:
q: q:
[*, Q, H, C_hidden] query data [*, H, Q, C_hidden] query data
k: k:
[*, K, H, C_hidden] key data [*, H, K, C_hidden] key data
v: v:
[*, V, H, C_hidden] value data [*, H, V, C_hidden] value data
biases: biases:
List of biases that broadcast to [*, H, Q, K] List of biases that broadcast to [*, H, Q, K]
""" """
...@@ -652,6 +648,11 @@ def _deepspeed_evo_attn( ...@@ -652,6 +648,11 @@ def _deepspeed_evo_attn(
return x.reshape(*((x.shape[0], -1) + x.shape[-3:])) return x.reshape(*((x.shape[0], -1) + x.shape[-3:]))
return x return x
# [*, Q/K, H, C_hidden]
q = q.transpose(-2, -3)
k = k.transpose(-2, -3)
v = v.transpose(-2, -3)
# Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden] # Reshape tensors to match expected input shape [B, N, Q/K, H, C_hidden]
# for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed. # for DS4Sci_EvoformerAttention() by adding or flattening batch dims as needed.
orig_shape = q.shape orig_shape = q.shape
......
...@@ -13,9 +13,7 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats. ...@@ -13,9 +13,7 @@ gunzip -c tests/test_data/sample_feats.pickle.gz > tests/test_data/sample_feats.
python setup.py install python setup.py install
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH echo "Download CUTLASS, required for Deepspeed Evoformer attention kernel"
echo "Attempting to download CUTLASS, required for Deepspeed Evoformer attention kernel"
git clone https://github.com/NVIDIA/cutlass --depth 1 git clone https://github.com/NVIDIA/cutlass --depth 1
conda env config vars set CUTLASS_PATH=$PWD/cutlass conda env config vars set CUTLASS_PATH=$PWD/cutlass
...@@ -24,3 +22,6 @@ conda env config vars set KMP_AFFINITY=none ...@@ -24,3 +22,6 @@ conda env config vars set KMP_AFFINITY=none
# Reactivate env so that the above environment variables take effect # Reactivate env so that the above environment variables take effect
conda activate $CONDA_PREFIX conda activate $CONDA_PREFIX
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
...@@ -155,7 +155,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -155,7 +155,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
n_seq = 18 n_seq = 18
c_m_shape = (consts.c_m,) c_m_shape = (consts.c_m,)
c_z_shape = (consts.c_z,) c_z_shape = (consts.c_z,)
eps = 2e-2 eps = 5e-2
activations = { activations = {
"msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype), "msa": torch.rand(n_seq, n_res, consts.c_m, device='cuda', dtype=dtype),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment