Unverified Commit 1975ace4 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Bug Fix: Softmax FFIs with correct Encapsulates (#1375)



* softmax custom calls with correct encapsulates

* rm jax deprecated features

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 1ae81903
...@@ -8,7 +8,7 @@ from functools import reduce, partial ...@@ -8,7 +8,7 @@ from functools import reduce, partial
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import core, dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi from jax.extend import ffi
...@@ -98,7 +98,7 @@ class ActLuPrimitive(BasePrimitive): ...@@ -98,7 +98,7 @@ class ActLuPrimitive(BasePrimitive):
assert x_shape[-2] == 2 or x_shape[-2] == 1 assert x_shape[-2] == 2 or x_shape[-2] == 1
hidden_size = x_shape[-1] hidden_size = x_shape[-1]
batch_shapes = x_shape[:-2] batch_shapes = x_shape[:-2]
out_aval = core.raise_to_shaped(x_aval) out_aval = x_aval
out_shape = (batch_shapes) + (hidden_size,) out_shape = (batch_shapes) + (hidden_size,)
out_aval = out_aval.update(shape=out_shape, dtype=dtype) out_aval = out_aval.update(shape=out_shape, dtype=dtype)
...@@ -225,7 +225,7 @@ class DActLuPrimitive(BasePrimitive): ...@@ -225,7 +225,7 @@ class DActLuPrimitive(BasePrimitive):
i_hidden_size = dz_aval.shape[-1] i_hidden_size = dz_aval.shape[-1]
g_hidden_size = x_aval.shape[-1] g_hidden_size = x_aval.shape[-1]
assert i_hidden_size == g_hidden_size assert i_hidden_size == g_hidden_size
out_aval = core.raise_to_shaped(x_aval) out_aval = x_aval
return out_aval return out_aval
......
...@@ -7,7 +7,7 @@ import re ...@@ -7,7 +7,7 @@ import re
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from functools import partial from functools import partial
from jax import core from jax.extend import core
from jax.interpreters import xla, mlir from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching from jax._src.interpreters import batching
......
...@@ -9,7 +9,7 @@ import warnings ...@@ -9,7 +9,7 @@ import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import core, dtypes from jax import dtypes
from jax.interpreters import mlir from jax.interpreters import mlir
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
...@@ -74,7 +74,7 @@ class LayerNormFwdPrimitive(BasePrimitive): ...@@ -74,7 +74,7 @@ class LayerNormFwdPrimitive(BasePrimitive):
mu_rsigama_dtype = jnp.float32 mu_rsigama_dtype = jnp.float32
out_aval = core.raise_to_shaped(x_aval) out_aval = x_aval
mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype) mu_aval = rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=mu_rsigama_dtype)
assert gamma_aval.size == beta_aval.size assert gamma_aval.size == beta_aval.size
...@@ -361,8 +361,8 @@ class LayerNormBwdPrimitive(BasePrimitive): ...@@ -361,8 +361,8 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1] assert mu_aval.shape == rsigma_aval.shape == x_aval.shape[:-1]
assert mu_dtype == rsigma_dtype == jnp.float32 assert mu_dtype == rsigma_dtype == jnp.float32
dx_aval = core.raise_to_shaped(dz_aval) dx_aval = dz_aval
dgamma_aval = dbeta_aval = core.raise_to_shaped(gamma_aval) dgamma_aval = dbeta_aval = gamma_aval
(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size x_aval.size // gamma_aval.size, # batch size
...@@ -589,7 +589,7 @@ class RmsNormFwdPrimitive(BasePrimitive): ...@@ -589,7 +589,7 @@ class RmsNormFwdPrimitive(BasePrimitive):
rsigama_dtype = jnp.float32 rsigama_dtype = jnp.float32
out_aval = core.raise_to_shaped(x_aval) out_aval = x_aval
rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype) rsigma_aval = out_aval.update(shape=out_aval.shape[:-1], dtype=rsigama_dtype)
hidden_size = gamma_aval.size hidden_size = gamma_aval.size
...@@ -783,8 +783,8 @@ class RmsNormBwdPrimitive(BasePrimitive): ...@@ -783,8 +783,8 @@ class RmsNormBwdPrimitive(BasePrimitive):
assert rsigma_aval.shape == x_aval.shape[:-1] assert rsigma_aval.shape == x_aval.shape[:-1]
assert rsigma_dtype == jnp.float32 assert rsigma_dtype == jnp.float32
dx_aval = core.raise_to_shaped(dz_aval) dx_aval = dz_aval
dgamma_aval = core.raise_to_shaped(gamma_aval) dgamma_aval = gamma_aval
(wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes( (wkspace_info,) = transformer_engine_jax.get_layernorm_bwd_workspace_sizes(
x_aval.size // gamma_aval.size, # batch size x_aval.size // gamma_aval.size, # batch size
......
...@@ -9,7 +9,7 @@ import warnings ...@@ -9,7 +9,7 @@ import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax import core, dtypes from jax import dtypes
from jax.interpreters.mlir import ir from jax.interpreters.mlir import ir
from jax.sharding import PartitionSpec, NamedSharding from jax.sharding import PartitionSpec, NamedSharding
from jax.extend import ffi from jax.extend import ffi
...@@ -126,7 +126,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -126,7 +126,7 @@ class SoftmaxPrimitive(BasePrimitive):
assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported assert k_seqlen <= SoftmaxPrimitive.max_k_seqlen_supported
assert q_seqlen > 1 assert q_seqlen > 1
out_aval = core.raise_to_shaped(logits_aval) out_aval = logits_aval
return out_aval return out_aval
@staticmethod @staticmethod
...@@ -237,7 +237,7 @@ class SoftmaxPrimitive(BasePrimitive): ...@@ -237,7 +237,7 @@ class SoftmaxPrimitive(BasePrimitive):
assert dz_aval.shape == softmax_out_aval.shape assert dz_aval.shape == softmax_out_aval.shape
dx_aval = core.raise_to_shaped(dz_aval) dx_aval = dz_aval
return dx_aval return dx_aval
@staticmethod @staticmethod
...@@ -578,7 +578,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive): ...@@ -578,7 +578,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
assert mask_shape[-2] == q_seqlen assert mask_shape[-2] == q_seqlen
assert mask_shape[-1] == k_seqlen assert mask_shape[-1] == k_seqlen
out_aval = core.raise_to_shaped(logits_aval) out_aval = logits_aval
return out_aval return out_aval
@staticmethod @staticmethod
......
...@@ -61,26 +61,23 @@ pybind11::dict Registrations() { ...@@ -61,26 +61,23 @@ pybind11::dict Registrations() {
dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler);
dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler); dict["te_act_lu_fp8_ffi"] = EncapsulateFFI(ActLuFP8Handler);
dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler); dict["te_dact_lu_ffi"] = EncapsulateFFI(DActLuHandler);
dict["te_dact_lu_dbias_cast_transpose_ffi"] = dict["te_dact_lu_dbias_cast_transpose_ffi"] = EncapsulateFFI(DActLuDBiasCastTransposeHandler);
EncapsulateFunction(DActLuDBiasCastTransposeHandler); dict["te_dgated_act_lu_cast_transpose_ffi"] = EncapsulateFFI(DGatedActLuCastTransposeHandler);
dict["te_dgated_act_lu_cast_transpose_ffi"] =
EncapsulateFunction(DGatedActLuCastTransposeHandler);
// Quantization // Quantization
dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler); dict["te_quantize_ffi"] = EncapsulateFFI(QuantizeHandler);
dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler);
// Softmax // Softmax
dict["te_scaled_softmax_forward_ffi"] = EncapsulateFunction(ScaledSoftmaxForwardHandler); dict["te_scaled_softmax_forward_ffi"] = EncapsulateFFI(ScaledSoftmaxForwardHandler);
dict["te_scaled_softmax_backward_ffi"] = EncapsulateFunction(ScaledSoftmaxBackwardHandler); dict["te_scaled_softmax_backward_ffi"] = EncapsulateFFI(ScaledSoftmaxBackwardHandler);
dict["te_scaled_masked_softmax_forward_ffi"] = dict["te_scaled_masked_softmax_forward_ffi"] = EncapsulateFFI(ScaledMaskedSoftmaxForwardHandler);
EncapsulateFunction(ScaledMaskedSoftmaxForwardHandler);
dict["te_scaled_masked_softmax_backward_ffi"] = dict["te_scaled_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledMaskedSoftmaxBackwardHandler); EncapsulateFFI(ScaledMaskedSoftmaxBackwardHandler);
dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] = dict["te_scaled_upper_triang_masked_softmax_forward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForwardHandler); EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxForwardHandler);
dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] = dict["te_scaled_upper_triang_masked_softmax_backward_ffi"] =
EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackwardHandler); EncapsulateFFI(ScaledUpperTriangMaskedSoftmaxBackwardHandler);
// Normalization // Normalization
dict["te_layernorm_forward_ffi"] = dict["te_layernorm_forward_ffi"] =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment