Unverified Commit 5986342a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Splitting cpp_extensions.py (#899)



* Splitted cpp_extensions.py, renamed mlp.py and fused_attn.py
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixed import in tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent b5a7c9f9
......@@ -9,9 +9,7 @@ from typing import List, Tuple
import jax
import jax.numpy as jnp
from .cpp_extensions import cast_fp8, cast_transpose, transpose
from .cpp_extensions import rmsnorm_fwd, rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd, layernorm_fwd_fp8, layernorm_bwd
from . import cpp_extensions as tex
from .dot import fp8_dot_impl, get_precision_of_fp8_dot
from .fp8 import FP8Helper, FP8MetaPackage
from .sharding import with_sharding_constraint_by_logical_axes
......@@ -64,11 +62,11 @@ def _layernorm_fwd_rule(x,
epsilon: float = 1e-6):
layernorm_type = canonicalize_layernorm_type(layernorm_type)
if layernorm_type == 'layernorm':
output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
output, mu, rsigma = tex.layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon)
elif layernorm_type == 'rmsnorm':
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
output, rsigma = rmsnorm_fwd(x, gamma, epsilon)
output, rsigma = tex.rmsnorm_fwd(x, gamma, epsilon)
mu = None
else:
raise ValueError(f"{layernorm_type=} is not supported.")
......@@ -78,7 +76,7 @@ def _layernorm_fwd_rule(x,
def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
x, mu, rsigma, gamma = ctx
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dz,
dx, dgamma, dbeta = tex.layernorm_bwd(dz,
x,
mu,
rsigma,
......@@ -88,7 +86,7 @@ def _layernorm_bwd_rule(layernorm_type, zero_centered_gamma, epsilon, ctx, dz):
elif layernorm_type == 'rmsnorm':
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
dx, dgamma = tex.rmsnorm_bwd(dz, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
else:
raise ValueError(f"{layernorm_type=} is not supported.")
......@@ -175,7 +173,7 @@ def _layernorm_fp8_dot_fwd_rule(
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
x,
gamma,
beta,
......@@ -188,7 +186,7 @@ def _layernorm_fp8_dot_fwd_rule(
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(x,
gamma,
x_amax,
x_scale,
......@@ -207,7 +205,7 @@ def _layernorm_fp8_dot_fwd_rule(
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel, updated_kernel_amax = \
cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
tex.cast_fp8(kernel, kernel_amax, kernel_scale, kernel_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_input_axes)
......@@ -238,14 +236,14 @@ def _layernorm_fp8_dot_bwd_rule(
x_shape, kernel_shape, mu, rsigma, x, gamma, \
x_contracting_dims, k_contracting_dims, maybe_fp32_to_fm32 = ctx
ln_out_t = transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
ln_out_t = tex.transpose(ln_out_, static_axis_boundary=-1, transpose_axis_boundary=-1)
grad_amax = amax_list[FP8MetaPackage.GRAD_IDX][0:1]
grad_scale = scale_list[FP8MetaPackage.GRAD_IDX]
grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
tex.cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype,
static_axis_boundary=-1, transpose_axis_boundary=min(x_contracting_dims))
xt_constracting_dim = tuple(range(len(x_contracting_dims), len(x_shape)))
......@@ -265,7 +263,7 @@ def _layernorm_fp8_dot_bwd_rule(
dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad,
dx, dgamma, dbeta = tex.layernorm_bwd(dgrad,
x,
mu,
rsigma,
......@@ -275,7 +273,7 @@ def _layernorm_fp8_dot_bwd_rule(
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
dx, dgamma = tex.rmsnorm_bwd(dgrad, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax_list[FP8MetaPackage.INPUT_IDX] = \
......
......@@ -10,11 +10,7 @@ import jax
import jax.numpy as jnp
from jax.ad_checkpoint import checkpoint_name
from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose
from .cpp_extensions import act_lu, act_lu_fp8, dact_lu
from .cpp_extensions import dact_lu_dbias_cast_transpose, dgated_act_lu_cast_transpose
from .cpp_extensions import rmsnorm_fwd_fp8, rmsnorm_bwd
from .cpp_extensions import layernorm_fwd_fp8, layernorm_bwd
from . import cpp_extensions as tex
from .dot import fp8_dot_impl, get_precision_of_fp8_dot, quantize
from .layernorm import canonicalize_layernorm_type
from .fp8 import FP8Helper, FP8MetaPackage
......@@ -40,7 +36,7 @@ def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable
def _activation_lu_fwd_rule(x, activation_type):
fwd_output = act_lu(x, activation_type)
fwd_output = tex.act_lu(x, activation_type)
return fwd_output, (x,)
......@@ -48,7 +44,7 @@ def _activation_lu_bwd_rule(activation_type, ctx, g):
x, = ctx
assert x.dtype == g.dtype
dx = dact_lu(g, x, activation_type)
dx = tex.dact_lu(g, x, activation_type)
dx = jnp.reshape(dx, x.shape)
return (dx,)
......@@ -186,7 +182,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes)
if layernorm_type == 'layernorm':
ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8(
ln_out, mu, rsigma, updated_x_amax = tex.layernorm_fwd_fp8(
x,
gamma,
beta,
......@@ -199,7 +195,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x,
ln_out, rsigma, updated_x_amax = tex.rmsnorm_fwd_fp8(x,
gamma,
x_amax,
x_scale,
......@@ -217,7 +213,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1, updated_kernel_1_amax = \
cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
tex.cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype)
ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes)
......@@ -239,7 +235,7 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
# (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out, updated_activation_lu_amax = \
act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
tex.act_lu_fp8(dot_1_output, activation_lu_out_amax, activation_lu_out_scale,
activation_lu_out_scale_inv, fwd_dtype, activation_type)
casted_activation_lu_out = with_sharding_constraint_by_logical_axes(
......@@ -304,20 +300,20 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes)
if use_bias:
casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \
dbias_cast_transpose(grad, grad_amax, grad_scale,
tex.dbias_cast_transpose(grad, grad_amax, grad_scale,
grad_scale_inv, bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_2 = jnp.reshape(dbias_2, bias_2_shape)
else:
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale,
tex.cast_transpose(grad, grad_amax, grad_scale,
grad_scale_inv, bwd_dtype,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
dbias_2 = None
casted_activation_lu_out_t = transpose(casted_activation_lu_out,
casted_activation_lu_out_t = tex.transpose(casted_activation_lu_out,
static_axis_boundary=-1,
transpose_axis_boundary=-1)
......@@ -341,9 +337,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
if len(activation_type) > 1: # if gated
if use_bias:
dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \
dbias_cast_transpose(
tex.dbias_cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
......@@ -354,7 +350,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
dgated_act_lu_cast_transpose(
tex.dgated_act_lu_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
......@@ -367,7 +363,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
else:
if use_bias:
casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax=\
dact_lu_dbias_cast_transpose(
tex.dact_lu_dbias_cast_transpose(
dgrad_2,
dot_1_output,
dactivation_lu_amax,
......@@ -379,9 +375,9 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
activation_type=activation_type)
dbias_1 = jnp.reshape(dbias_1, bias_1_shape)
else:
dactivation_lu = dact_lu(dgrad_2, dot_1_output, activation_type)
dactivation_lu = tex.dact_lu(dgrad_2, dot_1_output, activation_type)
casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \
cast_transpose(
tex.cast_transpose(
dactivation_lu,
dactivation_lu_amax,
dactivation_lu_scale,
......@@ -391,7 +387,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
transpose_axis_boundary=-2)
dbias_1 = None
ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
ln_out_t = tex.transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv = scale_inv_list_1[FP8MetaPackage.INPUT_IDX]
......@@ -410,7 +406,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes)
if layernorm_type == 'layernorm':
dx, dgamma, dbeta = layernorm_bwd(dgrad_1,
dx, dgamma, dbeta = tex.layernorm_bwd(dgrad_1,
x,
mu,
rsigma,
......@@ -420,7 +416,7 @@ def _fused_layernorm_fp8_mlp_bwd_rule(
else:
assert not zero_centered_gamma, "zero_centered_gamma is not supported " \
"if layernorm_type is 'rmsnorm'"
dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dx, dgamma = tex.rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon)
dbeta = None
amax_list_1[FP8MetaPackage.INPUT_IDX] = \
......
......@@ -18,7 +18,7 @@ from ..flax.transformer import DotProductAttention as flax_DotProductAttention
from ..flax.transformer import MultiHeadAttention as flax_MultiHeadAttention
from ..flax.transformer import RelativePositionBiases as flax_RelativePositionBiases
from ..flax.transformer import TransformerLayer as flax_TransformerLayer
from ..fused_attn import AttnBiasType, AttnMaskType
from ..attention import AttnBiasType, AttnMaskType
class RelativePositionBiases(TransformerEngineBaseLayer):
......
......@@ -9,15 +9,7 @@ from typing import Optional
import jax
import jax.numpy as jnp
from .cpp_extensions import scaled_softmax_fwd
from .cpp_extensions import scaled_softmax_bwd
from .cpp_extensions import scaled_masked_softmax_fwd
from .cpp_extensions import scaled_masked_softmax_bwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_fwd
from .cpp_extensions import scaled_upper_triang_masked_softmax_bwd
from .cpp_extensions import ScaledSoftmaxFwdPrimitive
from .cpp_extensions import ScaledMaskedSoftmaxFwdPrimitive
from .cpp_extensions import ScaledUpperTriangMaskedSoftmaxFwdPrimitive
from . import cpp_extensions as tex
class SoftmaxType(Enum):
......@@ -27,22 +19,6 @@ class SoftmaxType(Enum):
SCALED_UPPER_TRIANG_MASKED = "scaled_upper_triang_masked"
def is_softmax_kernel_available(softmax_type: SoftmaxType, batch: int, heads: int, q_seqlen: int,
k_seqlen: int, dtype: jnp.dtype):
"""check softmax available"""
if softmax_type is SoftmaxType.SCALED:
return ScaledSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
if softmax_type is SoftmaxType.SCALED_MASKED:
return ScaledMaskedSoftmaxFwdPrimitive.is_kernel_available(batch, heads, q_seqlen, k_seqlen,
dtype)
if softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.is_kernel_available(
batch, heads, q_seqlen, k_seqlen, dtype)
raise NotImplementedError
def softmax(logits: jnp.ndarray,
mask: Optional[jnp.ndarray] = None,
scale_factor: Optional[float] = 1.0,
......@@ -64,11 +40,11 @@ def _softmax(logits, mask, scale_factor, softmax_type):
def _softmax_fwd_rule(logits, mask, scale_factor, softmax_type):
if softmax_type is SoftmaxType.SCALED_MASKED:
assert mask is not None
output = scaled_masked_softmax_fwd(logits, mask, scale_factor)
output = tex.scaled_masked_softmax_fwd(logits, mask, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
output = scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
output = tex.scaled_upper_triang_masked_softmax_fwd(logits, scale_factor)
else:
output = scaled_softmax_fwd(logits, scale_factor)
output = tex.scaled_softmax_fwd(logits, scale_factor)
return output, (output,)
......@@ -77,11 +53,11 @@ def _softmax_bwd_rule(scale_factor, softmax_type, ctx, dz):
softmax_output, = ctx
if softmax_type is SoftmaxType.SCALED_MASKED:
dgrad = scaled_masked_softmax_bwd(dz, softmax_output, scale_factor)
dgrad = tex.scaled_masked_softmax_bwd(dz, softmax_output, scale_factor)
elif softmax_type is SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
dgrad = scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor)
dgrad = tex.scaled_upper_triang_masked_softmax_bwd(dz, softmax_output, scale_factor)
else:
dgrad = scaled_softmax_bwd(dz, softmax_output, scale_factor)
dgrad = tex.scaled_softmax_bwd(dz, softmax_output, scale_factor)
return (dgrad, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment