Unverified Commit 2a86df2b authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Canonicalize the dtype for the better user experience (#480)



canonicalize the dtype for the better user experience
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 1afb6256
...@@ -65,6 +65,7 @@ def jax_dtype_to_te_dtype(jax_dtype): ...@@ -65,6 +65,7 @@ def jax_dtype_to_te_dtype(jax_dtype):
""" """
convert jax dtype to TE dtype convert jax dtype to TE dtype
""" """
jax_dtype = dtypes.canonicalize_dtype(jax_dtype)
if jax_dtype == jnp.float32: if jax_dtype == jnp.float32:
return TEDType.kFloat32 return TEDType.kFloat32
if jax_dtype == jnp.float16: if jax_dtype == jnp.float16:
...@@ -1626,6 +1627,7 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1626,6 +1627,7 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""Check Softmax kernel availability based on size""" """Check Softmax kernel availability based on size"""
attn_batches = batch * heads attn_batches = batch * heads
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16] if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096 # k_seqlen must be 16 ~ 4096
...@@ -1757,6 +1759,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1757,6 +1759,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""Check Softmax kernel availability based on size""" """Check Softmax kernel availability based on size"""
attn_batches = batch * heads attn_batches = batch * heads
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16] if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096 # k_seqlen must be 16 ~ 4096
...@@ -1908,6 +1911,7 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -1908,6 +1911,7 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""Check Softmax kernel availability based on size""" """Check Softmax kernel availability based on size"""
attn_batches = batch * heads attn_batches = batch * heads
dtype = dtypes.canonicalize_dtype(dtype)
if (dtype in [jnp.float16, jnp.bfloat16] if (dtype in [jnp.float16, jnp.bfloat16]
and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported and 16 < k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
# k_seqlen must be 16 ~ 4096 # k_seqlen must be 16 ~ 4096
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment