Unverified Commit 29b0c9ca authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Fix unfused GQA performance (#643)



* Fix unfused GQA perf
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Remove WAR for Check failed: reduction_kind.has_value()
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent e2803b16
...@@ -4,9 +4,6 @@ ...@@ -4,9 +4,6 @@
set -xe set -xe
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_* pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
...@@ -14,7 +14,5 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist ...@@ -14,7 +14,5 @@ pytest -Wignore -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# WAR(rewang) for the "Check failed: reduction_kind.has_value()"
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_xla_runtime_executable=true"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
...@@ -207,17 +207,29 @@ def core_attention(query: Array, ...@@ -207,17 +207,29 @@ def core_attention(query: Array,
key = key.astype(jnp.float32) key = key.astype(jnp.float32)
h_q, h_kv = query.shape[-2], key.shape[-2] h_q, h_kv = query.shape[-2], key.shape[-2]
# The generated GQA kernels are slower than normal MHA kernels even when h_q == h_kv.
# Therefore, we have to maintain two code paths.
is_gqa = (h_q != h_kv)
if is_gqa:
assert (h_q % h_kv == 0) and (h_q >= h_kv) assert (h_q % h_kv == 0) and (h_q >= h_kv)
group_size = h_q // h_kv group_size = h_q // h_kv
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1])) grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if transpose_batch_sequence: if transpose_batch_sequence:
if is_gqa:
attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key) attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
else: else:
attn_weights = jnp.einsum('qbhd,kbhd->bhqk', query, key)
else:
if is_gqa:
attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key) attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
else:
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', query, key)
attn_weights = checkpoint_name(attn_weights, 'logits') attn_weights = checkpoint_name(attn_weights, 'logits')
if is_gqa:
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
attn_weights_without_groups_shape = (b, h * g, q, k) attn_weights_without_groups_shape = (b, h * g, q, k)
attn_weights = attn_weights.reshape(attn_weights_without_groups_shape) attn_weights = attn_weights.reshape(attn_weights_without_groups_shape)
...@@ -237,6 +249,7 @@ def core_attention(query: Array, ...@@ -237,6 +249,7 @@ def core_attention(query: Array,
attn_weights = Softmax(softmax_type=softmax_type, attn_weights = Softmax(softmax_type=softmax_type,
scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype) scale_factor=fused_scale_factor)(attn_weights, mask, bias).astype(dtype)
if is_gqa:
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape) attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
if not deterministic and dropout_rate > 0.: if not deterministic and dropout_rate > 0.:
...@@ -248,9 +261,13 @@ def core_attention(query: Array, ...@@ -248,9 +261,13 @@ def core_attention(query: Array,
attn_weights = attn_weights * multiplier attn_weights = attn_weights * multiplier
if transpose_batch_sequence: if transpose_batch_sequence:
if is_gqa:
return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape) return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,kbhd->qbhd', attn_weights, value)
if is_gqa:
return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape) return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
return jnp.einsum('bhqk,bkhd->bqhd', attn_weights, value)
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment