"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "06947e87b5511f8ad69ccd00286de9227f0fad24"
Unverified Commit 3ff0b8d4 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Change Flax MHA to DPA to remove the duplicated QKV projection step (#2429)


Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent df39a7c2
...@@ -53,7 +53,7 @@ ...@@ -53,7 +53,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 26,
"id": "d5284a38", "id": "d5284a38",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -67,7 +67,7 @@ ...@@ -67,7 +67,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 27,
"id": "a4d1cfdc", "id": "a4d1cfdc",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -116,19 +116,33 @@ ...@@ -116,19 +116,33 @@
" qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
" q, k, v = jnp.split(qkv, 3, axis=3)\n", " q, k, v = jnp.split(qkv, 3, axis=3)\n",
" \n", " \n",
" # Reshape to [batch, seq_len, num_heads * head_dim] for Flax MultiHeadDotProductAttention\n", " # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
" q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n", " # which is the correct format for dot_product_attention\n",
" k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
" v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
" \n", " \n",
" # Attention using Flax's MultiHeadDotProductAttention\n", " # Apply dot product attention\n",
" attention = nn.MultiHeadDotProductAttention(\n", " # Note: dot_product_attention expects mask to be broadcastable to \n",
" num_heads=self.num_attention_heads,\n", " # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
" qkv_features=self.kv_channels,\n", " # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
" \n",
" # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
" dropout_rng = None\n",
" if not deterministic and self.attention_dropout > 0:\n",
" dropout_rng = self.make_rng('dropout')\n",
" \n",
" x = nn.dot_product_attention(\n",
" query=q,\n",
" key=k,\n",
" value=v,\n",
" mask=attention_mask,\n",
" dropout_rng=dropout_rng,\n",
" dropout_rate=self.attention_dropout,\n", " dropout_rate=self.attention_dropout,\n",
" deterministic=deterministic,\n",
" broadcast_dropout=True,\n",
" )\n", " )\n",
" x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n", " \n",
"\n", " # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
" x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
" \n",
" x = res + x\n", " x = res + x\n",
" \n", " \n",
" # Second residual connection\n", " # Second residual connection\n",
...@@ -157,7 +171,7 @@ ...@@ -157,7 +171,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 28,
"id": "8b44649d", "id": "8b44649d",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -178,7 +192,7 @@ ...@@ -178,7 +192,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 29,
"id": "e44ed26d", "id": "e44ed26d",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -187,7 +201,7 @@ ...@@ -187,7 +201,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Pure Flax FlaxTransformerLayer initialized successfully!\n", "Pure Flax FlaxTransformerLayer initialized successfully!\n",
"Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}, 'MultiHeadDotProductAttention_0': {'key': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'out': {'bias': (4096,), 'kernel': (32, 4, 4096)}, 'query': {'bias': (32, 4), 'kernel': (4096, 32, 4)}, 'value': {'bias': (32, 4), 'kernel': (4096, 32, 4)}}}}\n" "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
] ]
} }
], ],
...@@ -208,7 +222,7 @@ ...@@ -208,7 +222,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 30,
"id": "de91af7a", "id": "de91af7a",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -234,7 +248,7 @@ ...@@ -234,7 +248,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 31,
"id": "037bc8d9", "id": "037bc8d9",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -242,7 +256,7 @@ ...@@ -242,7 +256,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 17.708301544189453 ms\n" "Mean time: 18.546080589294434 ms\n"
] ]
} }
], ],
...@@ -290,7 +304,7 @@ ...@@ -290,7 +304,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 32,
"id": "bed20d6b", "id": "bed20d6b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -309,7 +323,7 @@ ...@@ -309,7 +323,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 33,
"id": "56105579", "id": "56105579",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -414,7 +428,7 @@ ...@@ -414,7 +428,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 34,
"id": "4b67511f", "id": "4b67511f",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -422,7 +436,7 @@ ...@@ -422,7 +436,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 16.505107879638672 ms\n" "Mean time: 16.375374794006348 ms\n"
] ]
} }
], ],
...@@ -456,15 +470,39 @@ ...@@ -456,15 +470,39 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 35,
"id": "5146cd99", "id": "5146cd99",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:634: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
" warnings.warn(\n",
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BSHD_BSHD_BSHD: <NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: 9>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 12.80329704284668 ms\n" "Mean time: 12.403340339660645 ms\n"
] ]
} }
], ],
...@@ -515,7 +553,7 @@ ...@@ -515,7 +553,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 11, "execution_count": 36,
"id": "c2eee376", "id": "c2eee376",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -527,7 +565,7 @@ ...@@ -527,7 +565,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 37,
"id": "de96827c", "id": "de96827c",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -535,7 +573,7 @@ ...@@ -535,7 +573,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.615030288696289 ms\n" "Mean time: 9.396424293518066 ms\n"
] ]
} }
], ],
...@@ -588,7 +626,7 @@ ...@@ -588,7 +626,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 13, "execution_count": 38,
"id": "11203785", "id": "11203785",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
...@@ -659,7 +697,7 @@ ...@@ -659,7 +697,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 39,
"id": "6b0c705e", "id": "6b0c705e",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -667,7 +705,7 @@ ...@@ -667,7 +705,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.331779479980469 ms\n" "Mean time: 9.145426750183105 ms\n"
] ]
} }
], ],
...@@ -704,10 +742,33 @@ ...@@ -704,10 +742,33 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 40,
"id": "b2aaa8ef", "id": "b2aaa8ef",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
"Fall back to the unfused attention.\n",
"Please try to update the cuDNN and TE to the latest version.\n",
"self.dtype=<class 'jax.numpy.float32'>\n",
"qkv_layout=<QKVLayout.BS3HD: <NVTE_QKV_Layout.NVTE_BS3HD: 5>>\n",
"attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
"attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
"self.attention_dropout=0.1\n",
"self.num_attention_heads=32\n",
"self.num_gqa_groups=32\n",
"seqlen_q=2048\n",
"seqlen_kv=2048\n",
"head_dim_qk=128\n",
"head_dim_v=128\n",
"\n",
" warnings.warn(\n"
]
}
],
"source": [ "source": [
"\n", "\n",
"te_transformer = te_flax.TransformerLayer(\n", "te_transformer = te_flax.TransformerLayer(\n",
...@@ -731,7 +792,7 @@ ...@@ -731,7 +792,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 41,
"id": "b9cdbf22", "id": "b9cdbf22",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
...@@ -739,7 +800,7 @@ ...@@ -739,7 +800,7 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Mean time: 9.23741340637207 ms\n" "Mean time: 9.020795822143555 ms\n"
] ]
} }
], ],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment