Unverified Commit 3ff0b8d4 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Change Flax MHA to DPA to remove the duplicated QKV projection step (#2429)


Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent df39a7c2
......@@ -53,7 +53,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 26,
"id": "d5284a38",
"metadata": {},
"outputs": [],
......@@ -67,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 27,
"id": "a4d1cfdc",
"metadata": {},
"outputs": [],
......@@ -116,19 +116,33 @@
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n",
" \n",
" # Reshape to [batch, seq_len, num_heads * head_dim] for Flax MultiHeadDotProductAttention\n",
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
" # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
" # which is the correct format for dot_product_attention\n",
" \n",
" # Attention using Flax's MultiHeadDotProductAttention\n",
" attention = nn.MultiHeadDotProductAttention(\n",
" num_heads=self.num_attention_heads,\n",
" qkv_features=self.kv_channels,\n",
" # Apply dot product attention\n",
" # Note: dot_product_attention expects mask to be broadcastable to \n",
" # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
" # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
" \n",
" # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
" dropout_rng = None\n",
" if not deterministic and self.attention_dropout > 0:\n",
" dropout_rng = self.make_rng('dropout')\n",
" \n",
" x = nn.dot_product_attention(\n",
" query=q,\n",
" key=k,\n",
" value=v,\n",
" mask=attention_mask,\n",
" dropout_rng=dropout_rng,\n",
" dropout_rate=self.attention_dropout,\n",
" deterministic=deterministic,\n",
" broadcast_dropout=True,\n",
" )\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
"\n",
" \n",
" # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
" x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
" \n",
" x = res + x\n",
" \n",
" # Second residual connection\n",
......@@ -157,7 +171,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 28,
"id": "8b44649d",
"metadata": {},
"outputs": [],
......@@ -178,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 29,
"id": "e44ed26d",
"metadata": {},
"outputs": [
......@@ -187,7 +201,7 @@
"output_type": "stream",
"text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}, 'MultiHeadDotProductAttention_0': {'key': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'out': {'bias': (4096,), 'kernel': (32, 4, 4096)}, 'query': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'value': {'bias': (32, 4), 'kernel': (4096, 32, 4)}}}}\n"
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
]
}
],
......@@ -208,7 +222,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 30,
"id": "de91af7a",
"metadata": {},
"outputs": [
......@@ -234,7 +248,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 31,
"id": "037bc8d9",
"metadata": {},
"outputs": [
......@@ -242,7 +256,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 17.708301544189453 ms\n"
"Mean time: 18.546080589294434 ms\n"
]
}
],
......@@ -290,7 +304,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 32,
"id": "bed20d6b",
"metadata": {},
"outputs": [],
......@@ -309,7 +323,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 33,
"id": "56105579",
"metadata": {},
"outputs": [],
......@@ -414,7 +428,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 34,
"id": "4b67511f",
"metadata": {},
"outputs": [
......@@ -422,7 +436,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 16.505107879638672 ms\n"
"Mean time: 16.375374794006348 ms\n"
]
}
],
......@@ -456,15 +470,39 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 35,
"id": "5146cd99",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:634: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BSHD_BSHD_BSHD: <NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: 9>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 12.80329704284668 ms\n"
"Mean time: 12.403340339660645 ms\n"
]
}
],
......@@ -515,7 +553,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 36,
"id": "c2eee376",
"metadata": {},
"outputs": [],
......@@ -527,7 +565,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 37,
"id": "de96827c",
"metadata": {},
"outputs": [
......@@ -535,7 +573,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.615030288696289 ms\n"
"Mean time: 9.396424293518066 ms\n"
]
}
],
......@@ -588,7 +626,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 38,
"id": "11203785",
"metadata": {},
"outputs": [],
......@@ -659,7 +697,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 39,
"id": "6b0c705e",
"metadata": {},
"outputs": [
......@@ -667,7 +705,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.331779479980469 ms\n"
"Mean time: 9.145426750183105 ms\n"
]
}
],
......@@ -704,10 +742,33 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 40,
"id": "b2aaa8ef",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BS3HD: <NVTE_QKV_Layout.NVTE_BS3HD: 5>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
}
],
"source": [
"\n",
"te_transformer = te_flax.TransformerLayer(\n",
......@@ -731,7 +792,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 41,
"id": "b9cdbf22",
"metadata": {},
"outputs": [
......@@ -739,7 +800,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Mean time: 9.23741340637207 ms\n"
"Mean time: 9.020795822143555 ms\n"
]
}
],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment