Unverified Commit 3ff0b8d4 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Change Flax MHA to DPA to remove the duplicated QKV projection step (#2429)


Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent df39a7c2
...@@ -53,7 +53,7 @@ ...@@ -53,7 +53,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 26,
"id": "d5284a38", "id": "d5284a38",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -67,7 +67,7 @@ ...@@ -67,7 +67,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 27,
"id": "a4d1cfdc", "id": "a4d1cfdc",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -116,19 +116,33 @@ ...@@ -116,19 +116,33 @@
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n", " q, k, v = jnp.split(qkv, 3, axis=3)\n",
" \n", " \n",
" # Reshape to [batch, seq_len, num_heads * head_dim] for Flax MultiHeadDotProductAttention\n", " # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n", " # which is the correct format for dot_product_attention\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
" \n", " \n",
" # Attention using Flax's MultiHeadDotProductAttention\n", " # Apply dot product attention\n",
" attention = nn.MultiHeadDotProductAttention(\n", " # Note: dot_product_attention expects mask to be broadcastable to \n",
" num_heads=self.num_attention_heads,\n", " # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
" qkv_features=self.kv_channels,\n", " # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
" \n",
" # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
" dropout_rng = None\n",
" if not deterministic and self.attention_dropout > 0:\n",
" dropout_rng = self.make_rng('dropout')\n",
" \n",
" x = nn.dot_product_attention(\n",
" query=q,\n",
" key=k,\n",
" value=v,\n",
" mask=attention_mask,\n",
" dropout_rng=dropout_rng,\n",
" dropout_rate=self.attention_dropout,\n", " dropout_rate=self.attention_dropout,\n",
" deterministic=deterministic,\n",
" broadcast_dropout=True,\n",
" )\n", " )\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n", " \n",
"\n", " # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
" x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
" \n",
" x = res + x\n", " x = res + x\n",
" \n", " \n",
" # Second residual connection\n", " # Second residual connection\n",
...@@ -157,7 +171,7 @@ ...@@ -157,7 +171,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 28,
"id": "8b44649d", "id": "8b44649d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -178,7 +192,7 @@ ...@@ -178,7 +192,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 29,
"id": "e44ed26d", "id": "e44ed26d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -187,7 +201,7 @@ ...@@ -187,7 +201,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n", "Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}, 'MultiHeadDotProductAttention_0': {'key': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'out': {'bias': (4096,), 'kernel': (32, 4, 4096)}, 'query': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'value': {'bias': (32, 4), 'kernel': (4096, 32, 4)}}}}\n" "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
] ]
} }
], ],
...@@ -208,7 +222,7 @@ ...@@ -208,7 +222,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 30,
"id": "de91af7a", "id": "de91af7a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -234,7 +248,7 @@ ...@@ -234,7 +248,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 31,
"id": "037bc8d9", "id": "037bc8d9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -242,7 +256,7 @@ ...@@ -242,7 +256,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 17.708301544189453 ms\n" "Mean time: 18.546080589294434 ms\n"
] ]
} }
], ],
...@@ -290,7 +304,7 @@ ...@@ -290,7 +304,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 32,
"id": "bed20d6b", "id": "bed20d6b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -309,7 +323,7 @@ ...@@ -309,7 +323,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 33,
"id": "56105579", "id": "56105579",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -414,7 +428,7 @@ ...@@ -414,7 +428,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 34,
"id": "4b67511f", "id": "4b67511f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -422,7 +436,7 @@ ...@@ -422,7 +436,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 16.505107879638672 ms\n" "Mean time: 16.375374794006348 ms\n"
] ]
} }
], ],
...@@ -456,15 +470,39 @@ ...@@ -456,15 +470,39 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 35,
"id": "5146cd99", "id": "5146cd99",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:634: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BSHD_BSHD_BSHD: <NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: 9>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 12.80329704284668 ms\n" "Mean time: 12.403340339660645 ms\n"
] ]
} }
], ],
...@@ -515,7 +553,7 @@ ...@@ -515,7 +553,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 36,
"id": "c2eee376", "id": "c2eee376",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -527,7 +565,7 @@ ...@@ -527,7 +565,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 37,
"id": "de96827c", "id": "de96827c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -535,7 +573,7 @@ ...@@ -535,7 +573,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.615030288696289 ms\n" "Mean time: 9.396424293518066 ms\n"
] ]
} }
], ],
...@@ -588,7 +626,7 @@ ...@@ -588,7 +626,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 38,
"id": "11203785", "id": "11203785",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -659,7 +697,7 @@ ...@@ -659,7 +697,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 39,
"id": "6b0c705e", "id": "6b0c705e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -667,7 +705,7 @@ ...@@ -667,7 +705,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.331779479980469 ms\n" "Mean time: 9.145426750183105 ms\n"
] ]
} }
], ],
...@@ -704,10 +742,33 @@ ...@@ -704,10 +742,33 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 40,
"id": "b2aaa8ef", "id": "b2aaa8ef",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BS3HD: <NVTE_QKV_Layout.NVTE_BS3HD: 5>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
}
],
"source": [ "source": [
"\n", "\n",
"te_transformer = te_flax.TransformerLayer(\n", "te_transformer = te_flax.TransformerLayer(\n",
...@@ -731,7 +792,7 @@ ...@@ -731,7 +792,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 41,
"id": "b9cdbf22", "id": "b9cdbf22",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -739,7 +800,7 @@ ...@@ -739,7 +800,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.23741340637207 ms\n" "Mean time: 9.020795822143555 ms\n"
] ]
} }
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment