"CPPLINT.cfg" did not exist on "64a8dc900840e89ffd17e1536b377e3c32f26d93"
Unverified Commit 5986342a authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Splitting cpp_extensions.py (#899)



* Splitted cpp_extensions.py, renamed mlp.py and fused_attn.py
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fixed import in tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent b5a7c9f9
......@@ -15,13 +15,25 @@ from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose
from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
from transformer_engine.jax.mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax.cpp_extensions import act_lu_fp8, dact_lu_dbias_cast_transpose
from transformer_engine.jax.cpp_extensions import dgated_act_lu_cast_transpose
from transformer_engine.jax.dot import (
type_safe_dot_general,
dequantize,
quantize
)
from transformer_engine.jax.fp8 import (
FP8MetaPackage,
FP8Helper,
is_fp8_available
)
from transformer_engine.jax.layernorm import (
layernorm,
layernorm_fp8_dot
)
from transformer_engine.jax.layernorm_mlp import (
activation_lu,
fused_layernorm_fp8_mlp
)
from transformer_engine.jax import cpp_extensions as tex
GEMM_CASES = [
(256, 256, 512),
......@@ -429,7 +441,7 @@ class TestActivationLuFP8(TestActivationLu):
return output
def _prim_func_fwd(x, _x_t, _dbias, _amax):
activation_lu_out, _ = act_lu_fp8(x, amax, scale, scale_inv,
activation_lu_out, _ = tex.act_lu_fp8(x, amax, scale, scale_inv,
FP8Helper.FWD_DTYPE, activation_type)
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
ctx = (x)
......@@ -439,12 +451,12 @@ class TestActivationLuFP8(TestActivationLu):
x = ctx
if len(self.activation_type) > 1: #gated, no bias
dactivation_lu, dactivation_lu_trans, amax_out = \
dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
tex.dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
FP8Helper.BWD_DTYPE, -1, activation_type)
dbias = jnp.empty(x.shape[-1], x.dtype)
else: #not gated, with bias
dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
tex.dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
-1, -2, self.activation_type)
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
......
......@@ -9,15 +9,27 @@ import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from distributed_test_base import generate_configs, generate_collectives_count
from distributed_test_base import compare_ops
from jax.sharding import (
Mesh,
NamedSharding,
PartitionSpec
)
from distributed_test_base import (
generate_configs,
generate_collectives_count,
compare_ops
)
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.attention import (
is_fused_attn_kernel_available,
fused_attn_qkvpacked,
fused_attn_kvpacked,
AttnBiasType,
AttnMaskType,
QKVLayout
)
DTYPES = [jnp.float16, jnp.bfloat16]
......
......@@ -13,12 +13,20 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper
from transformer_engine.jax.fp8 import is_fp8_available
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import HIDDEN_AXES, HIDDEN_TP_AXES, \
BATCH_AXES, SEQLEN_TP_AXES, SEQLEN_AXES, \
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import (
HIDDEN_AXES, HIDDEN_TP_AXES,
BATCH_AXES,
SEQLEN_TP_AXES, SEQLEN_AXES,
W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
)
from transformer_engine.jax.sharding import MeshResource
from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
from utils import (
assert_allclose,
assert_tree_like_allclose,
is_devices_enough
)
is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16]
......
......@@ -18,8 +18,14 @@ from jax import Array
from jax import value_and_grad, jit
from jax.typing import ArrayLike, DTypeLike
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from transformer_engine.jax.attention import (
AttnBiasType,
AttnMaskType,
QKVLayout,
fused_attn_qkvpacked,
fused_attn_kvpacked,
fused_attn
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
......
......@@ -16,7 +16,7 @@ from jax.typing import DTypeLike
from utils import assert_allclose
from transformer_engine.jax.softmax import is_softmax_kernel_available
from transformer_engine.jax.cpp_extensions import is_softmax_kernel_available
from transformer_engine.jax.softmax import SoftmaxType, softmax
......
......@@ -13,10 +13,7 @@ from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked
from .cpp_extensions import fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked
from .cpp_extensions import fused_attn_fwd, fused_attn_bwd
from . import cpp_extensions as tex
class AttnBiasType(Enum):
......@@ -75,7 +72,7 @@ def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type
"""
To check whether the fused attention kernel is supported
"""
return FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value,
return tex.FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value,
attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads,
q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available()
......@@ -123,7 +120,7 @@ def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, m
assert mask is not None
mask = jnp.logical_not(mask)
actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,)
output, softmax_aux, rng_state = fused_attn_fwd_qkvpacked(
output, softmax_aux, rng_state = tex.fused_attn_fwd_qkvpacked(
qkv,
bias,
actual_seqlen,
......@@ -143,7 +140,7 @@ def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_facto
dropout_probability, is_training, ctx, dz):
qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx
grad_qkv, grad_bias = fused_attn_bwd_qkvpacked(qkv,
grad_qkv, grad_bias = tex.fused_attn_bwd_qkvpacked(qkv,
bias,
softmax_aux,
rng_state,
......@@ -216,7 +213,7 @@ def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = fused_attn_fwd_kvpacked(
output, softmax_aux, rng_state = tex.fused_attn_fwd_kvpacked(
q,
kv,
bias,
......@@ -238,7 +235,7 @@ def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor
dropout_probability, is_training, ctx, dz):
q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_kv, grad_bias = fused_attn_bwd_kvpacked(q,
grad_q, grad_kv, grad_bias = tex.fused_attn_bwd_kvpacked(q,
kv,
bias,
softmax_aux,
......@@ -312,7 +309,7 @@ def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_ty
# When mask is causal, the actual seqlen is not the last row, use max to find it
kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2))
output, softmax_aux, rng_state = fused_attn_fwd(q,
output, softmax_aux, rng_state = tex.fused_attn_fwd(q,
k,
v,
bias,
......@@ -335,7 +332,7 @@ def _fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout
is_training, ctx, dz):
q, k, v, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx
grad_q, grad_k, grad_v, grad_bias = fused_attn_bwd(q,
grad_q, grad_k, grad_v, grad_bias = tex.fused_attn_bwd(q,
k,
v,
bias,
......
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Python interface for c++ extensions"""
from .activation import *
from .attention import *
from .normalization import *
from .quantization import *
from .softmax import *
from .transpose import *
This diff is collapsed.
This diff is collapsed.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE base custom ops"""
from abc import ABCMeta, abstractmethod
from functools import partial
from jax import core
from jax.interpreters import xla, mlir
from jax.experimental.custom_partitioning import custom_partitioning
from jax._src.interpreters import batching
from jax._src import dispatch
class BasePrimitive(metaclass=ABCMeta):
"""
jax primitive
"""
@staticmethod
@abstractmethod
def abstract():
"""
to describe computing graph
"""
return NotImplemented
@classmethod
def outer_abstract(cls, *args, **kwargs):
"""
optional abstract wrapper to eliminate workspace tensors
"""
return cls.abstract(*args, **kwargs)
@staticmethod
@abstractmethod
def lowering():
"""
to describe MLIR
"""
return NotImplemented
@staticmethod
@abstractmethod
def impl():
"""
to describe implementation
"""
return NotImplemented
@staticmethod
@abstractmethod
def batcher():
"""
to describe batch rules for vmap
"""
return NotImplemented
@staticmethod
@abstractmethod
def infer_sharding_from_operands():
"""
to describe infer_sharding_from_operands for custom_partitioning
"""
return NotImplemented
@staticmethod
@abstractmethod
def partition():
"""
to describe partition for custom_partitioning
"""
return NotImplemented
def register_primitive(cls):
"""
register jax primitive
"""
def name_of_wrapper_p():
return cls.name + "_wrapper"
inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results
inner_p.def_impl(partial(xla.apply_primitive, inner_p))
inner_p.def_abstract_eval(cls.abstract)
mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
cls.inner_primitive = inner_p
outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl)
outer_p.def_abstract_eval(cls.outer_abstract)
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition)
mlir.register_lowering(outer_p,
mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
cls.outer_primitive = outer_p
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom call"""
from dataclasses import dataclass
from jax.lib import xla_client
from jax.interpreters import mlir
from transformer_engine import transformer_engine_jax
try:
from jaxlib.hlo_helpers import custom_call
except ImportError:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
for _name, _value in transformer_engine_jax.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
@dataclass
class CustomCallArgsWrapper:
"""
wrapper of XLA custom call args
"""
def __init__(self,
output_types,
operands,
operand_shapes,
operand_specific_layouts=None,
output_specific_layouts=None):
self.output_types = output_types
self.operands = operands
self.operand_layouts = CustomCallArgsWrapper.generate_layouts(operand_shapes,
operand_specific_layouts)
output_shapes = [x.shape for x in output_types]
self.output_layouts = CustomCallArgsWrapper.generate_layouts(output_shapes,
output_specific_layouts)
@staticmethod
def generate_layouts(shapes, specific_layouts):
"""
setup layouts for XLA custom call
"""
def default_layout(shape):
return range(len(shape) - 1, -1, -1)
if specific_layouts is None:
specific_layouts = {}
layouts = []
for idx, shape in enumerate(shapes):
if idx in specific_layouts:
layouts.append(specific_layouts[idx])
else:
layouts.append(default_layout(shape))
return layouts
def custom_caller(name, args, opaque, has_side_effect, **kwargs):
"""
XLA custom call warpper
"""
if hasattr(mlir, "custom_call"):
out = mlir.custom_call(name,
result_types=args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs).results
else:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out = custom_call( # pylint: disable=too-many-function-args
name,
args.output_types,
operands=args.operands,
operand_layouts=args.operand_layouts,
result_layouts=args.output_layouts,
backend_config=opaque,
has_side_effect=has_side_effect,
**kwargs)
return out
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -8,7 +8,7 @@ from functools import partial
import jax
import jax.numpy as jnp
from .cpp_extensions import cast_transpose
from . import cpp_extensions as tex
from .fp8 import FP8Helper, FP8MetaPackage
Precision = jax.lax.Precision
......@@ -148,7 +148,7 @@ def _fp8_dot_bwd_rule(fwd_dtype, bwd_dtype, contracting_dims, ctx, grad): # p
grad_scale_inv = scale_inv_list[FP8MetaPackage.GRAD_IDX]
casted_grad, casted_grad_t, updated_grad_amax = \
cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
tex.cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv,
bwd_dtype, static_axis_boundary=-1,
transpose_axis_boundary=min(lhs_contracting_dims))
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment