Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a207db1d
Commit
a207db1d
authored
Apr 01, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
fbee8990
69365f88
Changes
101
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3693 additions
and
4298 deletions
+3693
-4298
transformer_engine/common/recipe/current_scaling.cu
transformer_engine/common/recipe/current_scaling.cu
+4
-3
transformer_engine/jax/__init__.py
transformer_engine/jax/__init__.py
+24
-11
transformer_engine/jax/activation.py
transformer_engine/jax/activation.py
+98
-0
transformer_engine/jax/attention.py
transformer_engine/jax/attention.py
+57
-0
transformer_engine/jax/cpp_extensions/__init__.py
transformer_engine/jax/cpp_extensions/__init__.py
+1
-1
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+952
-340
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+73
-180
transformer_engine/jax/cpp_extensions/base.py
transformer_engine/jax/cpp_extensions/base.py
+13
-0
transformer_engine/jax/cpp_extensions/custom_call.py
transformer_engine/jax/cpp_extensions/custom_call.py
+0
-121
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+516
-0
transformer_engine/jax/cpp_extensions/misc.py
transformer_engine/jax/cpp_extensions/misc.py
+89
-14
transformer_engine/jax/cpp_extensions/normalization.py
transformer_engine/jax/cpp_extensions/normalization.py
+937
-1220
transformer_engine/jax/cpp_extensions/quantization.py
transformer_engine/jax/cpp_extensions/quantization.py
+542
-116
transformer_engine/jax/cpp_extensions/softmax.py
transformer_engine/jax/cpp_extensions/softmax.py
+95
-202
transformer_engine/jax/cpp_extensions/transpose.py
transformer_engine/jax/cpp_extensions/transpose.py
+0
-1270
transformer_engine/jax/csrc/extensions.h
transformer_engine/jax/csrc/extensions.h
+28
-210
transformer_engine/jax/csrc/extensions/activation.cpp
transformer_engine/jax/csrc/extensions/activation.cpp
+235
-537
transformer_engine/jax/csrc/extensions/attention.cpp
transformer_engine/jax/csrc/extensions/attention.cpp
+0
-72
transformer_engine/jax/csrc/extensions/cublas.cpp
transformer_engine/jax/csrc/extensions/cublas.cpp
+23
-0
transformer_engine/jax/csrc/extensions/ffi.cpp
transformer_engine/jax/csrc/extensions/ffi.cpp
+6
-1
No files found.
transformer_engine/common/recipe/current_scaling.cu
View file @
a207db1d
...
...
@@ -201,8 +201,9 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
max_fp8
=
Quantized_Limits
<
DType
>::
max_norm
;);
// Update scale
compute_scale_from_amax_kernel
<<<
1
,
1
>>>
(
reinterpret_cast
<
const
float
*>
(
output
.
amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
scale
.
dptr
),
max_fp8
,
config
.
force_pow_2_scales
,
config
.
amax_epsilon
);
compute_scale_from_amax_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
reinterpret_cast
<
const
float
*>
(
output
.
amax
.
dptr
),
reinterpret_cast
<
float
*>
(
output
.
scale
.
dptr
),
max_fp8
,
config
.
force_pow_2_scales
,
config
.
amax_epsilon
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
transformer_engine/jax/__init__.py
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
"""Transformer Engine bindings for JAX.
This module provides JAX bindings for NVIDIA's Transformer Engine, enabling
high-performance transformer operations with mixed precision and quantization
support. It includes implementations of key transformer components like attention,
linear layers, and layer normalization, optimized for NVIDIA GPUs.
The module exports various transformer operations and utilities:
- Attention mechanisms (self-attention, cross-attention)
- Linear transformations with optional quantization
- Layer normalization operations
- Activation functions
- Softmax operations
- Sharding utilities for distributed training
All operations are designed to work seamlessly with JAX's functional programming
model and support automatic differentiation.
"""
# pylint: disable=wrong-import-position,wrong-import-order
import
sys
import
logging
import
importlib
import
importlib.util
import
ctypes
from
importlib.metadata
import
version
import
sys
from
transformer_engine.common
import
get_te_path
,
is_package_installed
from
transformer_engine.common
import
_get_sys_extension
_logger
=
logging
.
getLogger
(
__name__
)
def
_load_library
():
"""Load shared library with Transformer Engine C extensions"""
...
...
@@ -41,7 +55,7 @@ def _load_library():
if
is_package_installed
(
"transformer-engine-cu12"
):
if
not
is_package_installed
(
module_name
):
_
logg
er
.
info
(
logg
ing
.
info
(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[jax]==VERSION'"
,
module_name
,
...
...
@@ -67,8 +81,10 @@ def _load_library():
_load_library
()
from
.
import
flax
from
.fp8
import
fp8_autocast
,
update_collections
,
get_delayed_scaling
from
.fp8
import
NVTE_FP8_COLLECTION_NAME
from
.
import
quantize
from
.quantize
import
fp8_autocast
from
.sharding
import
MeshResource
from
.sharding
import
MajorShardingType
,
ShardingResource
,
ShardingType
...
...
@@ -85,10 +101,7 @@ ShardingResource = deprecate_wrapper(
)
__all__
=
[
"NVTE_FP8_COLLECTION_NAME"
,
"fp8_autocast"
,
"update_collections"
,
"get_delayed_scaling"
,
"MeshResource"
,
"MajorShardingType"
,
"ShardingResource"
,
...
...
transformer_engine/jax/activation.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Activation functions for Transformer Engine in JAX.
This module provides optimized activation functions with quantization support.
"""
from
typing
import
Sequence
,
Union
,
Callable
,
Optional
from
functools
import
partial
import
jax
import
jax.numpy
as
jnp
from
.
import
cpp_extensions
as
tex
from
.quantize.tensor
import
ScaledTensor
from
.quantize.quantizer
import
Quantizer
def
activation
(
x
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
quantizer
:
Optional
[
Quantizer
]
=
None
,
)
->
Union
[
jnp
.
ndarray
,
ScaledTensor
]:
"""Apply activation functions to input tensor with optional quantization.
This function applies a sequence of activation functions to the input tensor.
It supports string-based activation types (e.g., 'relu', 'gelu', ('gelu', 'linear')).
Args:
x: Input tensor to apply activations to
activation_type: Sequence of activation functions
quantizer: Optional quantizer for quantizing the output
Returns:
Activated output tensor
"""
assert
x
.
shape
[
-
1
]
%
len
(
activation_type
)
==
0
output
=
_activation
(
x
,
activation_type
,
quantizer
)
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
1
,))
def
_activation
(
x
,
activation_type
,
quantizer
):
"""Internal implementation of activation with custom VJP.
This function implements the core activation logic with support for
custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Activated tensor
"""
_output
,
_
=
_activation_fwd_rule
(
x
,
activation_type
,
quantizer
)
return
_output
def
_activation_fwd_rule
(
x
,
activation_type
,
quantizer
):
"""Forward pass rule for activation function.
Args:
x: Input tensor
activation_type: Sequence of activation functions
quantizer: Optional quantizer
Returns:
Tuple of (output, context) for backward pass
"""
fwd_output
=
tex
.
act_lu
(
x
,
activation_type
,
quantizer
)
if
isinstance
(
fwd_output
,
ScaledTensor
):
fwd_output
=
fwd_output
.
dequantize
()
return
fwd_output
,
(
x
,
quantizer
)
def
_activation_bwd_rule
(
activation_type
,
ctx
,
g
):
"""Backward pass rule for activation function.
Args:
activation_type: Sequence of activation functions
ctx: Context from forward pass
g: Gradient from upstream
Returns:
Gradient with respect to input
"""
(
x
,
_
)
=
ctx
assert
x
.
dtype
==
g
.
dtype
dx
=
tex
.
dact_lu
(
g
,
x
,
activation_type
)
dx
=
jnp
.
reshape
(
dx
,
x
.
shape
)
return
(
dx
,
None
)
_activation
.
defvjp
(
_activation_fwd_rule
,
_activation_bwd_rule
)
transformer_engine/jax/attention.py
View file @
a207db1d
...
...
@@ -378,6 +378,44 @@ def _mask_to_seqlens_offset(mask, max_segments_per_seq):
return
q_seqlen
,
q_offset
,
kv_seqlen
,
kv_offset
def
_fast_causal_adjust_seqlen_and_offsets
(
segment_pos_q
,
q_len
,
q_offset
,
segment_pos_kv
,
kv_len
,
kv_offset
):
# The assumption is that for any segment tokens respect causal ordering except at the ends
# of the segment. This allows us to tweak the length and offset by only looking at the start
# and end tokens between segments.
is_active_segment
=
jnp
.
logical_and
(
q_len
>
0
,
kv_len
>
0
)
q_seq_id_start
=
jnp
.
take
(
segment_pos_q
,
q_offset
[...,
:
-
1
],
fill_value
=-
1
)
kv_seq_id_start
=
jnp
.
take
(
segment_pos_kv
,
kv_offset
[...,
:
-
1
],
fill_value
=-
1
)
skip_start_token
=
jnp
.
logical_and
(
kv_seq_id_start
>
q_seq_id_start
,
is_active_segment
).
astype
(
jnp
.
int32
)
q_len
-=
skip_start_token
q_offset
+=
jnp
.
insert
(
skip_start_token
,
skip_start_token
.
shape
[
-
1
],
0
,
axis
=-
1
)
q_seq_id_end
=
jnp
.
take
(
segment_pos_q
,
q_offset
[...,
1
:]
-
1
,
fill_value
=-
1
)
kv_seq_id_end
=
jnp
.
take
(
segment_pos_kv
,
kv_offset
[...,
1
:]
-
1
,
fill_value
=-
1
)
skip_end_token
=
jnp
.
logical_and
(
kv_seq_id_end
>
q_seq_id_end
,
is_active_segment
).
astype
(
jnp
.
int32
)
kv_len
-=
skip_end_token
return
q_len
,
kv_len
,
q_offset
,
kv_offset
def
_segment_ids_pos_to_seqlens_offsets_fast_causal_path
(
segment_ids_q
,
segment_ids_kv
,
segment_pos_q
,
segment_pos_kv
,
max_segments_per_seq
):
q_len
,
q_offset
=
_get_seqlens_and_offsets
(
segment_ids_q
,
max_segments_per_seq
)
kv_len
,
kv_offset
=
_get_seqlens_and_offsets
(
segment_ids_kv
,
max_segments_per_seq
)
return
_fast_causal_adjust_seqlen_and_offsets
(
segment_pos_q
,
q_len
,
q_offset
,
segment_pos_kv
,
kv_len
,
kv_offset
)
def
_segment_ids_pos_to_seqlens_offsets
(
segment_ids_q
,
segment_ids_kv
,
...
...
@@ -387,6 +425,25 @@ def _segment_ids_pos_to_seqlens_offsets(
window_size
,
max_segments_per_seq
,
):
# TODO(mgoldfarb-nvidia): Consider an opt-in for arbitrary masking if needed here.
# Computing the full mask is expensive due to quadratic expansion of Q * KV masking.
# Assumptions for cudnn causal mask correctness.
# 1. Segments are monotonic [4 4 4 0 0 5 5 5 6 6 0 0]
# 2. No intra-segment padding, only inter-segment paddding allowed
# 3. Only start or end token within a segment may violate the causal order relationship
# 1 5 9 0 4 8 10 0 4 8
# 0 x x
# 4 x x x x x
# 8 x x x x x x x x
#
# This fast path avoids expanding the mask to Q * KV matrix and instead allows us to
# examine only O(Q+KV) elements.
if
attn_mask_type
.
is_causal
()
and
window_size
is
None
or
window_size
==
(
-
1
,
-
1
):
return
_segment_ids_pos_to_seqlens_offsets_fast_causal_path
(
segment_ids_q
,
segment_ids_kv
,
segment_pos_q
,
segment_pos_kv
,
max_segments_per_seq
)
# (1 = attend, 0 = masked)
segment_mask
=
make_attention_mask
(
segment_ids_q
,
...
...
transformer_engine/jax/cpp_extensions/__init__.py
View file @
a207db1d
...
...
@@ -7,4 +7,4 @@ from .attention import *
from
.normalization
import
*
from
.quantization
import
*
from
.softmax
import
*
from
.
transpose
import
*
from
.
gemm
import
*
transformer_engine/jax/cpp_extensions/activation.py
View file @
a207db1d
...
...
@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for activation"""
from
typing
import
Tuple
,
Sequence
,
Union
,
Callable
from
typing
import
Sequence
,
Union
,
Callable
,
Optional
,
Tuple
import
operator
from
functools
import
reduce
,
partial
from
packaging
import
version
...
...
@@ -10,31 +10,38 @@ from packaging import version
import
jax
import
jax.numpy
as
jnp
from
jax
import
dtypes
from
jax.interpreters.mlir
import
ir
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
import
transformer_engine_jax
from
transformer_engine_jax
import
NVTE_Activation_Type
from
.base
import
BasePrimitive
,
register_primitive
from
.custom_call
import
custom_caller
,
CustomCallArgsWrapper
from
.misc
import
(
check_valid_batch_dims
,
jax_dtype_to_te_dtype
,
jax
_dtype_to_
ir
_dtype
,
te
_dtype_to_
jax
_dtype
,
get_padded_spec
,
is_ffi_enabled
,
check_valid_batch_dims
,
multidim_transpose
,
try_apply_delayed_scaling_2x_war
,
should_apply_1x_fused_dbias_war_for_arch_l_100
,
NamedSharding
,
)
from
.quantization
import
_jax_quantize_dbias
,
_jax_dbias
,
quantize_dbias
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
,
all_reduce_sum_along_dp_fsdp
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
(
Quantizer
,
QuantizeAxis
,
DelayedScaleQuantizer
,
ScalingMode
,
)
from
.quantization
import
_jax_cast_fp8
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
from
jax
import
ffi
# pylint: disable=ungrouped-imports
else
:
from
jax.extend
import
ffi
# pylint: disable=ungrouped-imports
__all__
=
[
"act_lu"
,
"dact_lu"
,
"act_lu_fp8"
]
__all__
=
[
"act_lu"
,
"dact_lu"
,
"quantize_dact_dbias"
]
ActivationEnum
=
{
...
...
@@ -66,448 +73,1053 @@ def _convert_to_activation_function(fn_or_string):
raise
ValueError
(
f
"Unsupported
{
fn_or_string
}
to an activation function"
)
def
_jax_act_lu
(
inputs
,
activation_type
):
"""
JAX native activation implementation
"""
x
=
jnp
.
split
(
inputs
,
len
(
activation_type
),
axis
=-
2
)
acts
=
[]
for
idx
,
act_fn
in
enumerate
(
activation_type
):
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
acts
.
append
(
x_i
)
x
=
reduce
(
operator
.
mul
,
acts
)
x
=
jnp
.
squeeze
(
x
,
axis
=-
2
)
return
x
class
ActLuPrimitive
(
BasePrimitive
):
"""
Act
ivation Forward
Primitive
Act
Lu
Primitive
"""
name
=
"te_act_lu"
multiple_results
=
False
name
=
"te_act_lu_ffi"
multiple_results
=
True
impl_static_args
=
(
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
)
# out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, scale_shapes, is_outer
inner_primitive
=
None
outer_primitive
=
None
impl_static_args
=
(
1
,)
@
staticmethod
def
abstract
(
x_aval
,
*
,
act_enum
):
# pylint: disable=unused-argument
def
abstract
(
x_aval
,
scale_aval
,
*
,
out_dtype
,
act_enum
,
act_len
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
):
"""
act_lu abstract
te_
act_lu
_p
abstract
"""
del
act_enum
,
act_len
,
scale_shapes
dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
out_shape
=
(
*
x_aval
.
shape
[:
-
2
],
1
,
x_aval
.
shape
[
-
1
],
)
out_aval
=
x_aval
.
update
(
shape
=
out_shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
).
get_scale_shape_2x
(
out_shape
[:
-
2
]
+
(
out_shape
[
-
1
],),
is_padded
=
not
is_outer
)
x_shape
=
x_aval
.
shape
assert
x_shape
[
-
2
]
==
2
or
x_shape
[
-
2
]
==
1
hidden_size
=
x_shape
[
-
1
]
batch_shapes
=
x_shape
[:
-
2
]
out_aval
=
x_aval
out_shape
=
(
batch_shapes
)
+
(
hidden_size
,)
out_aval
=
out_aval
.
update
(
shape
=
out_shape
,
dtype
=
dtype
)
if
len
(
rowwise_scale_inv_shape
)
>
1
:
rowwise_scale_inv_shape
=
(
rowwise_scale_inv_shape
[:
-
1
]
+
(
1
,)
+
rowwise_scale_inv_shape
[
-
1
:]
)
if
len
(
colwise_scale_inv_shape
)
>
1
:
colwise_scale_inv_shape
=
(
colwise_scale_inv_shape
[:
-
1
]
+
(
1
,)
+
colwise_scale_inv_shape
[
-
1
:]
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
out_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
scale_dtype
)
if
is_2x
:
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
out_shape
,
dtype
=
out_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
return
out_aval
return
out_aval
,
colwise_out_aval
,
scale_inv_aval
,
colwise_scale_inv_aval
,
updated_amax_aval
@
staticmethod
def
lowering
(
ctx
,
x
,
*
,
act_enum
):
def
lowering
(
ctx
,
x
,
scale
,
*
,
out_dtype
,
act_enum
,
act_len
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
):
"""
act_lu lowering rules
te_gated_
act_lu
_p
lowering rules
"""
(
x_aval
,)
=
ctx
.
avals_in
del
out_dtype
,
scale_dtype
,
scale_shapes
,
act_len
,
is_outer
x_aval
,
scale_aval
=
ctx
.
avals_in
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
if
is_ffi_enabled
():
name
=
"te_act_lu_ffi"
out
=
ffi
.
ffi_lowering
(
name
)(
ctx
,
x
,
act_enum
=
act_enum
)
else
:
ir_x_type
=
ir
.
RankedTensorType
(
x
.
type
)
ir_x_shape
=
ir_x_type
.
shape
out_shape
=
ir_x_shape
[:
-
2
]
+
[
ir_x_shape
[
-
1
]]
out_types
=
[
ir
.
RankedTensorType
.
get
(
out_shape
,
ir_x_type
.
element_type
),
]
operands
=
[
x
]
operand_shapes
=
[
ir_x_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
hidden_size
=
ir_x_shape
[
-
1
]
batch_size
=
reduce
(
operator
.
mul
,
ir_x_shape
[:
-
2
])
in_dtype
=
jax_dtype_to_te_dtype
(
x_aval
.
dtype
)
opaque
=
transformer_engine_jax
.
pack_common_descriptor
(
(
batch_size
,
hidden_size
),
in_dtype
,
in_dtype
,
act_enum
out
=
ffi
.
ffi_lowering
(
ActLuPrimitive
.
name
)(
ctx
,
x
,
scale
,
act_enum
=
act_enum
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
)
out
=
custom_caller
(
ActLuPrimitive
.
name
,
args
,
opaque
,
False
)
return
out
@
staticmethod
def
impl
(
x
,
act_enum
):
def
impl
(
x
,
scale
,
out_dtype
,
act_enum
,
act_len
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
):
"""
to describe implementation
"""
del
is_outer
assert
ActLuPrimitive
.
inner_primitive
is
not
None
out
=
ActLuPrimitive
.
inner_primitive
.
bind
(
x
,
act_enum
=
act_enum
)
return
out
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
=
(
ActLuPrimitive
.
inner_primitive
.
bind
(
x
,
scale
,
out_dtype
=
out_dtype
,
act_enum
=
act_enum
,
act_len
=
act_len
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_outer
=
False
,
)
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
).
get_scale_shape_2x
(
out
.
shape
[:
-
2
]
+
(
out
.
shape
[
-
1
],),
is_padded
=
False
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
rowwise_scale_inv_shape
=
(
rowwise_scale_inv_shape
[:
-
1
]
+
(
1
,)
+
rowwise_scale_inv_shape
[
-
1
:]
)
if
is_2x
:
colwise_scale_inv_shape
=
(
colwise_scale_inv_shape
[:
-
1
]
+
(
1
,)
+
colwise_scale_inv_shape
[
-
1
:]
)
scale_inv
=
jax
.
lax
.
slice
(
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
)
if
is_2x
:
colwise_scale_inv
=
jax
.
lax
.
slice
(
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
)
return
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
act_enum
):
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
act_enum
,
act_len
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
):
"""
act_lu batcher
to describe batch rules for vmap
"""
del
act_len
,
is_outer
check_valid_batch_dims
(
batch_dims
)
assert
ActLuPrimitive
.
outer_primitive
is
not
None
(
inputs
,)
=
batched_args
(
inputs_bdim
,)
=
batch_dims
x
,
scale
=
batched_args
x_bdim
,
scale_bdim
=
batch_dims
amax_bdim
=
scale_bdim
out_bdims
=
inputs_bdim
return
ActLuPrimitive
.
outer_primitive
.
bind
(
inputs
,
act_enum
=
act_enum
),
out_bdims
out_bdims
=
x_bdim
,
x_bdim
,
scale_bdim
,
scale_bdim
,
amax_bdim
return
(
ActLuPrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
out_dtype
=
out_dtype
,
act_enum
=
act_enum
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
act_enum
,
mesh
,
arg_infos
,
result_infos
):
"""
act_lu infer_sharding_from_operands
"""
del
result_infos
,
act_enum
# Unused.
def
infer_sharding_from_operands
(
out_dtype
,
act_enum
,
act_len
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
arg_infos
,
result_infos
,
):
del
(
out_dtype
,
result_infos
,
act_enum
,
scale_dtype
,
scale_shapes
,
act_len
,
is_outer
,
)
# Unused.
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
2
],
x_spec
[
-
1
]))
return
out_sharding
out_spec
=
(
*
x_spec
[:
-
2
],
None
,
x_spec
[
-
2
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.out"
)
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
colwise_out_spec
=
multidim_transpose
(
out_spec
)
else
:
colwise_out_spec
=
out_spec
else
:
colwise_out_spec
=
(
None
,)
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"ActLuPrimitive.colwise_out"
)
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])),
desc
=
"ActLuPrimitive.scale_inv"
)
amax_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"ActLuPrimitive.amax"
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"ActLuPrimitive.colwise_scale_inv"
)
return
(
out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
colwise_scale_inv_sharding
,
amax_sharding
,
)
@
staticmethod
def
partition
(
act_enum
,
mesh
,
arg_infos
,
result_infos
):
"""
act_lu partitioning
"""
del
result_infos
def
partition
(
out_dtype
,
act_enum
,
act_len
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
arg_infos
,
result_infos
,
):
del
result_infos
,
is_outer
# Unused.
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
2
],
x_spec
[
-
1
]))
out_spec
=
(
*
x_spec
[:
-
1
],
x_spec
[
-
1
])
if
act_len
==
2
and
x_spec
[
-
1
]
is
None
:
# Ensure last axis is partitioned and not the gating axis
out_spec
=
(
*
x_spec
[:
-
2
],
None
,
x_spec
[
-
2
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.out"
)
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
colwise_out_spec
=
multidim_transpose
(
out_spec
)
else
:
colwise_out_spec
=
out_spec
else
:
colwise_out_spec
=
(
None
,)
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"ActLuPrimitive.colwise_out"
)
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])),
desc
=
"ActLuPrimitive.scale_inv"
)
amax_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"ActLuPrimitive.amax"
)
def
sharded_impl
(
x
):
return
ActLuPrimitive
.
impl
(
x
,
act_enum
=
act_enum
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"ActLuPrimitive.colwise_scale_inv"
)
arg_shardings
=
list
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
[
0
]
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
))
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
colwise_scale_inv_sharding
,
amax_sharding
,
)
return
mesh
,
sharded_impl
,
out_sharding
,
arg_shardings
def
sharded_impl
(
x
,
scale
):
local_x
,
local_colwise_x
,
local_scale_inv
,
local_colwise_scale_inv
,
local_amax
=
(
ActLuPrimitive
.
impl
(
x
,
scale
,
out_dtype
=
out_dtype
,
act_enum
=
act_enum
,
act_len
=
act_len
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_outer
=
True
,
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
else
:
global_updated_amax
=
local_amax
register_primitive
(
ActLuPrimitive
)
return
(
local_x
,
local_colwise_x
,
local_scale_inv
,
local_colwise_scale_inv
,
global_updated_amax
,
)
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
def
act_lu
(
inputs
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]])
->
jnp
.
ndarray
:
"""
act_lu wrapper
Return act_lu(inputs)
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
"""
if
not
ActLuPrimitive
.
enabled
():
return
_jax_act_lu
(
inputs
,
activation_type
)
act_type_id
=
ActivationEnum
[
activation_type
].
value
return
ActLuPrimitive
.
outer_primitive
.
bind
(
inputs
,
act_enum
=
act_type_id
)
register_primitive
(
ActLuPrimitive
)
class
DActLuPrimitive
(
BasePrimitive
):
class
DActLu
DBiasQuantize
Primitive
(
BasePrimitive
):
"""
D
gated ActLu
Primitive
D
ActLu DBias Cast Transpose
Primitive
"""
name
=
"te_dact_lu"
multiple_results
=
False
name
=
"te_dact_dbias_quantize_ffi"
multiple_results
=
True
# out_dtype, scaling_mode, is_2x, scale_dtype, scale_shapes, is_dbias, act_enum, act_len, is_outer
impl_static_args
=
(
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
)
inner_primitive
=
None
outer_primitive
=
None
impl_static_args
=
(
2
,)
@
staticmethod
def
abstract
(
dz_aval
,
x_aval
,
*
,
act_enum
):
# pylint: disable=unused-argument
def
abstract
(
dz_aval
,
x_aval
,
scale_aval
,
*
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_dbias
,
act_enum
,
act_len
,
is_outer
,
):
"""
dact_
lu
abstract
te_
dact_
dbias_quantize_p
abstract
"""
del
act_enum
,
scale_shapes
dtype
=
dtypes
.
canonicalize_dtype
(
dz_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
x_aval
.
dtype
==
dtype
for
axis
in
range
(
len
(
dz_aval
.
shape
)
-
1
):
assert
dz_aval
.
shape
[
axis
]
==
x_aval
.
shape
[
axis
]
assert
x_aval
.
shape
[
-
2
]
==
2
or
x_aval
.
shape
[
-
2
]
==
1
assert
scale_aval
.
dtype
==
jnp
.
float32
ir_hidden_size
=
dz_aval
.
shape
[
-
1
]
gi_hidden_size
=
x_aval
.
shape
[
-
1
]
assert
act_len
*
ir_hidden_size
==
gi_hidden_size
out_shape
=
x_aval
.
shape
out_aval
=
x_aval
.
update
(
shape
=
out_shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
i_hidden_size
=
dz_aval
.
shape
[
-
1
]
g_hidden_size
=
x_aval
.
shape
[
-
1
]
assert
i_hidden_size
==
g_hidden_size
out_aval
=
x_aval
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
).
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
)
return
out_aval
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
@
staticmethod
def
lowering
(
ctx
,
dz
,
x
,
*
,
act_enum
):
"""
dact_lu lowering rules
"""
in_aval
,
gi_aval
=
ctx
.
avals_in
assert
in_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
gi_aval
.
dtype
==
in_aval
.
dtype
if
is_ffi_enabled
():
name
=
"te_dact_lu_ffi"
out
=
ffi
.
ffi_lowering
(
name
)(
ctx
,
dz
,
x
,
act_enum
=
act_enum
)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
scale_dtype
)
dbias_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
wkspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
if
is_2x
:
# Don't transpose output for MXFP8
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
t_shape
=
out_shape
else
:
ir_in_type
=
ir
.
RankedTensorType
(
dz
.
type
)
ir_in_shape
=
ir_in_type
.
shape
gi_type
=
ir
.
RankedTensorType
(
x
.
type
)
gi_shape
=
gi_type
.
shape
# assert ir_in_shape == gi_shape
for
axis
in
range
(
len
(
ir_in_shape
)
-
1
):
assert
ir_in_shape
[
axis
]
==
gi_shape
[
axis
]
ir_batch_size
=
reduce
(
operator
.
mul
,
ir_in_shape
[:
-
1
])
i_hidden_size
=
ir_in_shape
[
-
1
]
g_hidden_size
=
gi_shape
[
-
1
]
assert
i_hidden_size
==
g_hidden_size
out_dtype
=
ir_in_type
.
element_type
out_shape
=
gi_shape
out_types
=
[
ir
.
RankedTensorType
.
get
(
out_shape
,
out_dtype
),
]
operands
=
[
dz
,
x
]
operand_shapes
=
[
ir_in_shape
,
gi_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
in_dtype
=
jax_dtype_to_te_dtype
(
in_aval
.
dtype
)
opaque
=
transformer_engine_jax
.
pack_common_descriptor
(
(
ir_batch_size
,
i_hidden_size
),
in_dtype
,
in_dtype
,
act_enum
)
out
=
custom_caller
(
DActLuPrimitive
.
name
,
args
,
opaque
,
False
)
t_shape
=
multidim_transpose
(
out_shape
)
colwise_out_aval
=
x_aval
.
update
(
shape
=
t_shape
,
dtype
=
out_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
return
out
if
is_dbias
:
dbias_shape
=
gi_hidden_size
dbias_aval
=
x_aval
.
update
(
shape
=
dbias_shape
,
dtype
=
dtype
)
(
wkspace_info
,)
=
transformer_engine_jax
.
get_dact_dbias_quantize_workspace_sizes
(
x_aval
.
size
//
gi_hidden_size
,
gi_hidden_size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
scaling_mode
,
is_2x
,
)
wkspace_aval
=
x_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
return
(
out_aval
,
colwise_out_aval
,
scale_inv_aval
,
colwise_scale_inv_aval
,
updated_amax_aval
,
dbias_aval
,
wkspace_aval
,
)
@
staticmethod
def
impl
(
dz
,
x
,
act_enum
):
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
dact_
lu implementation
te_
dact_
dbias_quantize_p outer abstract
"""
assert
DActLuPrimitive
.
inner_primitive
is
not
None
dx
=
DActLuPrimitive
.
inner_primitive
.
bind
(
dz
,
x
,
act_enum
=
act_enum
)
return
dx
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
_
)
=
(
DActLuDBiasQuantizePrimitive
.
abstract
(
*
args
,
**
kwargs
)
)
return
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
act_enum
):
def
lowering
(
ctx
,
dz
,
x
,
scale
,
*
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_dbias
,
act_enum
,
act_len
,
is_outer
,
):
"""
dact_
lu batcher
te_
dact_
dbias_quantize_p lowering rules
"""
check_valid_batch_dims
(
batch_dims
)
assert
DActLuPrimitive
.
outer_primitive
is
not
None
dz
,
x
=
batched_args
_
,
x_bdim
=
batch_dims
out_bdims
=
x_bdim
return
DActLuPrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
act_enum
=
act_enum
),
out_bdims
del
out_dtype
,
scale_dtype
,
scale_shapes
,
act_len
,
is_outer
dz_aval
,
x_aval
,
scale_aval
=
ctx
.
avals_in
assert
dz_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
x_aval
.
dtype
==
dz_aval
.
dtype
assert
scale_aval
.
dtype
==
jnp
.
float32
return
ffi
.
ffi_lowering
(
DActLuDBiasQuantizePrimitive
.
name
)(
ctx
,
dz
,
x
,
scale
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
is_dbias
=
is_dbias
,
act_enum
=
int
(
act_enum
),
)
@
staticmethod
def
infer_sharding_from_operands
(
act_enum
,
mesh
,
arg_infos
,
result_infos
):
def
impl
(
dz
,
x
,
scale
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_dbias
,
act_enum
,
act_len
,
is_outer
,
):
"""
dact_
lu infer_sharding_from_operands
te_
dact_
dbias_quantize_p impl
"""
del
result_infos
,
act_enum
# Unused.
act_lu_out_spec
=
get_padded_spec
(
arg_infos
[
1
])
dx_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
act_lu_out_spec
))
return
dx_sharding
del
is_outer
assert
DActLuDBiasQuantizePrimitive
.
inner_primitive
is
not
None
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
_
)
=
(
DActLuDBiasQuantizePrimitive
.
inner_primitive
.
bind
(
dz
,
x
,
scale
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
act_enum
=
act_enum
,
act_len
=
act_len
,
is_outer
=
False
,
)
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
).
get_scale_shape_2x
(
x
.
shape
,
is_padded
=
False
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
scale_inv
=
jax
.
lax
.
slice
(
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
)
if
is_2x
:
colwise_scale_inv
=
jax
.
lax
.
slice
(
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
)
return
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
)
# Exclude wkspace
@
staticmethod
def
partition
(
act_enum
,
mesh
,
arg_infos
,
result_infos
):
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_dbias
,
act_enum
,
act_len
,
is_outer
,
):
"""
dact_lu partition
to describe batch rules for vmap
"""
del
result_infos
dx_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
dx_sharding
del
is_outer
check_valid_batch_dims
(
batch_dims
)
assert
DActLuDBiasQuantizePrimitive
.
outer_primitive
is
not
None
dz
,
x
,
scale
=
batched_args
_
,
x_bdim
,
scale_bdim
=
batch_dims
out_bdims
=
(
x_bdim
,
# rowwise output
scale_bdim
,
# rowwise scale_inv
x_bdim
,
# colwise output
scale_bdim
,
# colwise scale_inv
scale_bdim
,
# amax
x_bdim
,
# dbias
)
return
(
DActLuDBiasQuantizePrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
scale
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
act_enum
=
act_enum
,
act_len
=
act_len
,
),
out_bdims
,
)
def
sharded_impl
(
dz
,
x
):
return
DActLuPrimitive
.
impl
(
dz
,
x
,
act_enum
=
act_enum
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_dbias
,
act_enum
,
act_len
,
is_outer
,
mesh
,
arg_infos
,
result_infos
,
):
del
out_dtype
,
result_infos
,
act_enum
del
scale_dtype
,
scale_shapes
,
is_dbias
,
act_len
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.out"
)
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
colwise_x_spec
=
multidim_transpose
(
x_spec
)
else
:
colwise_x_spec
=
x_spec
else
:
colwise_x_spec
=
(
None
,)
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding
=
NamedSharding
(
mesh
,
PartitionSpec
(
x_spec
[
-
1
]),
desc
=
"DActLuDBiasQuantizePrimitive.dbias"
,
)
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"DActLuDBiasQuantizePrimitive.amax"
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
)
return
(
out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
colwise_scale_inv_sharding
,
amax_sharding
,
dbias_shaprding
,
)
register_primitive
(
DActLuPrimitive
)
@
staticmethod
def
partition
(
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_dbias
,
act_enum
,
act_len
,
is_outer
,
mesh
,
arg_infos
,
result_infos
,
):
del
result_infos
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"out"
)
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
colwise_x_spec
=
multidim_transpose
(
x_spec
)
else
:
colwise_x_spec
=
x_spec
else
:
colwise_x_spec
=
(
None
,)
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.colwise_out"
)
dbias_shaprding
=
NamedSharding
(
mesh
,
PartitionSpec
(
x_spec
[
-
1
]),
desc
=
"DActLuDBiasQuantizePrimitive.dbias"
,
)
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"DActLuDBiasQuantizePrimitive.amax"
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
)
def
dact_lu
(
inputs
:
jnp
.
ndarray
,
act_lu_inputs
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
)
->
jnp
.
ndarray
:
"""
dact_lu fusion wrapper
Return dgated_act_lu(inputs)
"""
if
not
DActLuPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_act_lu
,
activation_type
=
activation_type
),
act_lu_inputs
)
return
vjp_func
(
inputs
)[
0
]
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
(
arg_shardings
[
1
],
arg_shardings
[
1
],
*
arg_shardings
[
2
:],
)
# dz and x are the same
out_shardings
=
(
out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
colwise_scale_inv_sharding
,
amax_sharding
,
dbias_shaprding
,
)
act_type_id
=
ActivationEnum
[
activation_type
].
value
return
DActLuPrimitive
.
outer_primitive
.
bind
(
inputs
,
act_lu_inputs
,
act_enum
=
act_type_id
)
def
sharded_impl
(
dz
,
x
,
scale
):
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
local_amax
,
local_dbias
)
=
(
DActLuDBiasQuantizePrimitive
.
impl
(
dz
,
x
,
scale
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
act_enum
=
act_enum
,
act_len
=
act_len
,
is_outer
=
True
,
)
)
if
is_dbias
:
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
else
:
global_dbias
=
local_dbias
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
else
:
global_updated_amax
=
local_amax
class
ActLuFp8Primitive
(
BasePrimitive
):
"""
ActLu FP8 Primitive
"""
return
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
global_updated_amax
,
global_dbias
name
=
"te_act_lu_fp8"
multiple_results
=
True
impl_static_args
=
(
4
,
5
)
# out_dtype, act_enum
inner_primitive
=
None
outer_primitive
=
None
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
abstract
(
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
,
*
,
out_dtype
,
act_enum
):
# pylint: disable=unused-argument
register_primitive
(
DActLuDBiasQuantizePrimitive
)
def
_jax_act_lu
(
inputs
,
activation_type
,
quantizer
=
None
)
->
Union
[
jnp
.
ndarray
,
ScaledTensor
]:
"""
te_act_lu_p abstract
JAX native activation implementation
"""
dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
# Currently only support casting to E4M3 only in C side.
assert
out_dtype
==
jnp
.
float8_e4m3fn
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
assert
x_aval
.
shape
[
-
2
]
==
1
or
x_aval
.
shape
[
-
2
]
==
2
hidden_size
=
x_aval
.
shape
[
-
1
]
batch_shape
=
x_aval
.
shape
[:
-
2
]
out_shape
=
(
batch_shape
)
+
(
hidden_size
,)
out_aval
=
x_aval
.
update
(
shape
=
out_shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
amax_aval
.
update
(
shape
=
amax_aval
.
shape
,
dtype
=
amax_aval
.
dtype
)
x
=
jnp
.
split
(
inputs
,
len
(
activation_type
),
axis
=-
1
)
acts
=
[]
for
idx
,
act_fn
in
enumerate
(
activation_type
):
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
acts
.
append
(
x_i
)
x
=
reduce
(
operator
.
mul
,
acts
)
if
quantizer
:
return
quantizer
.
quantize
(
x
)
return
x
return
out_aval
,
updated_amax_aval
@
staticmethod
def
lowering
(
ctx
,
x
,
amax
,
scale
,
scale_inv
,
*
,
out_dtype
,
act_enum
):
def
_jax_quantize_dact_dbias
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
is_dbias
:
bool
=
True
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
):
"""
te_gated_act_lu_p lowering rules
JAX implementation of dact_lu and dbias with optional quantization
"""
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
=
ctx
.
avals_in
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
if
is_ffi_enabled
():
name
=
"te_act_lu_fp8_ffi"
out
=
ffi
.
ffi_lowering
(
name
,
operand_output_aliases
=
{
1
:
1
})(
ctx
,
x
,
amax
,
scale
,
scale_inv
,
act_enum
=
act_enum
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_act_lu
,
activation_type
=
activation_type
),
x
.
astype
(
jnp
.
float32
)
)
(
dx
,)
=
vjp_func
(
dz
.
astype
(
jnp
.
float32
))
dbias
=
None
if
is_dbias
:
dbias
=
_jax_dbias
(
dx
).
astype
(
x
.
dtype
)
if
quantizer
is
not
None
:
dx
=
quantizer
.
quantize
(
dx
,
dq_dtype
=
x
.
dtype
)
else
:
ir_x_type
=
ir
.
RankedTensorType
(
x
.
type
)
ir_x_shape
=
ir_x_type
.
shape
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
out_dtype
)
ir_amax_type
=
ir
.
RankedTensorType
(
amax
.
type
)
ir_amax_dtype
=
ir_amax_type
.
element_type
ir_amax_shape
=
ir_amax_type
.
shape
ir_scale_shape
=
ir_amax_shape
ir_scale_inv_shape
=
ir_amax_shape
hidden_size
=
ir_x_shape
[
-
1
]
batch_shape
=
ir_x_shape
[:
-
2
]
batch_size
=
reduce
(
operator
.
mul
,
batch_shape
)
out_shape
=
batch_shape
+
[
hidden_size
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
out_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
ir_amax_shape
,
ir_amax_dtype
),
]
operands
=
[
x
,
amax
,
scale
,
scale_inv
]
operand_shapes
=
[
ir_x_shape
,
ir_amax_shape
,
ir_scale_shape
,
ir_scale_inv_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
opaque
=
transformer_engine_jax
.
pack_common_descriptor
(
(
batch_size
,
hidden_size
),
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
act_enum
,
)
dx
=
dx
.
astype
(
x
.
dtype
)
out
=
custom_caller
(
ActLuFp8Primitive
.
name
,
args
,
opaque
,
False
,
operand_output_aliases
=
{
1
:
1
}
)
return
dx
,
dbias
return
out
@
staticmethod
def
impl
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
,
act_enum
):
"""
to describe implementation
def
act_lu
(
x
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
quantizer
:
Optional
[
Quantizer
]
=
None
,
)
->
Union
[
jnp
.
ndarray
,
ScaledTensor
]:
"""Activation with optional quantization.
Args:
x: Input tensor to be processed.
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
If quantizer is None:
The activated input tensor with the same dtype as input.
If quantizer is provided:
A ScaledTensor containing the quantized activated input.
"""
assert
ActLuFp8Primitive
.
inner_primitive
is
not
None
out
,
updated_amax
=
ActLuFp8Primitive
.
inner_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
act_enum
=
act_enum
act_type_id
=
ActivationEnum
[
activation_type
].
value
if
not
ActLuPrimitive
.
enabled
():
return
_jax_act_lu
(
x
,
activation_type
,
quantizer
)
# TE/common does not support colwise-only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_axis
==
QuantizeAxis
.
COLWISE
:
return
_jax_act_lu
(
x
,
activation_type
,
quantizer
)
# TE/common does not support 2x quantization for DelayedScaling yet
war_output
=
try_apply_delayed_scaling_2x_war
(
f
=
act_lu
,
x
=
x
,
activation_type
=
activation_type
,
quantizer
=
quantizer
)
if
war_output
is
not
None
:
return
war_output
scale
=
jnp
.
empty
((
1
,),
jnp
.
float32
)
output_shape
=
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
len
(
activation_type
))
if
quantizer
is
None
:
x
=
x
.
reshape
((
-
1
,
len
(
activation_type
),
x
.
shape
[
-
1
]
//
len
(
activation_type
)))
out
,
_
,
_
,
_
,
_
=
ActLuPrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
out_dtype
=
x
.
dtype
,
act_enum
=
act_type_id
,
act_len
=
len
(
activation_type
),
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
,
is_2x
=
False
,
scale_dtype
=
jnp
.
float32
,
scale_shapes
=
((),
()),
is_outer
=
True
,
)
return
out
,
updated_amax
out
=
out
.
reshape
(
output_shape
)
return
out
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
act_enum
):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims
(
batch_dims
)
assert
ActLuFp8Primitive
.
outer_primitive
is
not
None
x
,
amax
,
scale
,
scale_inv
=
batched_args
x_bdim
,
amax_bdim
,
_
,
_
=
batch_dims
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
scale
=
quantizer
.
scale
x
=
x
.
reshape
((
*
x
.
shape
[:
-
1
],
len
(
activation_type
),
x
.
shape
[
-
1
]
//
len
(
activation_type
)))
(
rowwise_casted_output
,
colwise_casted_output
,
rowwise_scale_inv
,
colwise_scale_inv
,
updated_amax
,
)
=
ActLuPrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
out_dtype
=
quantizer
.
q_dtype
,
act_enum
=
act_type_id
,
act_len
=
len
(
activation_type
),
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
is_2x
=
quantizer
.
is_2x2x
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
output_shape
),
is_outer
=
True
,
)
out_bdims
=
x_bdim
,
amax_bdim
return
(
ActLuFp8Primitive
.
outer_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
act_enum
=
act_enum
),
out_bdims
,
rowwise_casted_output
=
rowwise_casted_output
.
reshape
(
output_shape
)
if
len
(
rowwise_scale_inv
.
shape
)
>
1
:
rowwise_scale_inv
=
jnp
.
squeeze
(
rowwise_scale_inv
,
axis
=-
2
)
# Remove act axis
if
quantizer
.
q_axis
in
(
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
):
colwise_output_shape
=
output_shape
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
colwise_output_shape
=
multidim_transpose
(
output_shape
)
colwise_casted_output
=
colwise_casted_output
.
reshape
(
colwise_output_shape
)
if
len
(
colwise_scale_inv
.
shape
)
>
1
:
colwise_scale_inv
=
jnp
.
squeeze
(
colwise_scale_inv
,
axis
=-
2
)
# Remove act axis
quantizer
.
update
(
updated_amax
)
return
ScaledTensorFactory
.
create
(
data
=
rowwise_casted_output
,
scale_inv
=
rowwise_scale_inv
,
colwise_data
=
colwise_casted_output
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
x
.
dtype
,
q_axis
=
quantizer
.
q_axis
,
layout
=
quantizer
.
get_layout
(),
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
act_enum
,
mesh
,
arg_infos
,
result_infos
):
del
out_dtype
,
result_infos
,
act_enum
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
2
],
x_spec
[
-
1
]))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
return
(
out_sharding
,
amax_sharding
)
@
staticmethod
def
partition
(
out_dtype
,
act_enum
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
2
],
x_spec
[
-
1
]))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
out_sharding
,
amax_sharding
)
def
quantize_dact_dbias
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
is_dbias
:
bool
=
True
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
)
->
Tuple
[
ScaledTensor
,
jnp
.
ndarray
]:
"""Compute gradients of activation and bias with optional quantization.
Args:
dz: Gradient of the output with respect to the activation output.
x: Input tensor that was processed by the forward pass.
Shape: (..., ACT_DIM * K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
Tuple[ScaledTensor, jnp.ndarray]: A tuple containing:
- The gradient of the activation with respect to the input.
- The gradient of the activation with respect to the bias.
"""
def
sharded_impl
(
x
,
amax
,
scale
,
scale_inv
):
local_x
,
local_amax
=
ActLuFp8Primitive
.
impl
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
act_enum
=
act_enum
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
if
not
DActLuDBiasQuantizePrimitive
.
enabled
():
return
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
,
is_dbias
,
quantizer
)
return
local_x
,
global_updated_amax
# TE/common does not support colwise-only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_axis
==
QuantizeAxis
.
COLWISE
:
return
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
,
is_dbias
,
quantizer
)
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
=
is_dbias
,
quantizer
=
quantizer
):
out
,
_
=
quantize_dact_dbias
(
dz
=
dz
,
x
=
x
,
activation_type
=
activation_type
,
is_dbias
=
False
,
quantizer
=
None
)
return
quantize_dbias
(
out
,
is_dbias
=
True
,
quantizer
=
quantizer
)
is_gated
=
len
(
activation_type
)
==
2
# TE/common does not support DelayedScaling2x for gated-act yet
if
is_gated
:
war_output
=
try_apply_delayed_scaling_2x_war
(
f
=
quantize_dact_dbias
,
dz
=
dz
,
x
=
x
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
quantizer
=
quantizer
,
)
if
war_output
is
not
None
:
return
war_output
scale
=
jnp
.
empty
((),
jnp
.
float32
)
act_type_id
=
ActivationEnum
[
activation_type
]
if
quantizer
is
None
:
output
,
_
,
_
,
_
,
_
,
_
=
DActLuDBiasQuantizePrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
scale
,
# outputs float32 for dbias accumulation
out_dtype
=
(
jnp
.
float32
if
is_dbias
else
x
.
dtype
),
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
,
is_2x
=
False
,
# unused
scale_dtype
=
jnp
.
float32
,
# unused
scale_shapes
=
((),
()),
# unused
is_dbias
=
False
,
act_enum
=
act_type_id
,
act_len
=
len
(
activation_type
),
is_outer
=
True
,
)
dbias
=
None
if
is_dbias
:
dbias
=
_jax_dbias
(
output
).
astype
(
x
.
dtype
)
return
output
.
astype
(
x
.
dtype
),
dbias
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
scale
=
quantizer
.
scale
# TE/common dact_dbias_quantize does not support gated act yet
if
is_dbias
and
is_gated
:
dgated
=
dact_lu
(
dz
.
astype
(
jnp
.
float32
),
x
.
astype
(
jnp
.
float32
),
activation_type
=
activation_type
)
# TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
out
,
dbias
=
_jax_quantize_dbias
(
dgated
,
quantizer
=
quantizer
,
dq_dtype
=
x
.
dtype
)
else
:
out
,
dbias
=
quantize_dbias
(
dgated
,
quantizer
=
quantizer
,
is_dbias
=
True
,
dq_dtype
=
x
.
dtype
,
)
return
out
,
dbias
out_shape
=
x
.
shape
(
rowwise_casted_output
,
colwise_casted_output
,
rowwise_scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
)
=
DActLuDBiasQuantizePrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
scale
,
out_dtype
=
quantizer
.
q_dtype
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
is_2x
=
quantizer
.
is_2x2x
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
out_shape
),
is_dbias
=
is_dbias
,
act_enum
=
act_type_id
,
act_len
=
len
(
activation_type
),
is_outer
=
True
,
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
and
quantizer
.
is_2x2x
():
colwise_scale_inv
=
rowwise_scale_inv
quantizer
.
update
(
updated_amax
)
out
=
ScaledTensorFactory
.
create
(
data
=
rowwise_casted_output
,
scale_inv
=
rowwise_scale_inv
,
colwise_data
=
colwise_casted_output
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
x
.
dtype
,
q_axis
=
quantizer
.
q_axis
,
layout
=
quantizer
.
get_layout
(),
)
register_primitive
(
ActLuFp8Primitive
)
return
out
,
dbias
def
act_lu_fp8
(
def
dact_lu
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
out_dtype
:
jnp
.
dtype
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""
act wrapper
Return FP8(act_lu(x))
Input shape: (N, 1, H) for non-gated activations
(N, 2, H) for gated activations
quantizer
:
Optional
[
Quantizer
]
=
None
,
)
->
Union
[
jnp
.
ndarray
,
ScaledTensor
]:
"""
if
not
ActLuFp8Primitive
.
enabled
():
act_lu_output
=
_jax_act_lu
(
x
,
activation_type
)
casted_output
,
updated_amax
=
_jax_cast_fp8
(
act_lu_output
,
scale
,
amax
,
out_dtype
)
return
casted_output
,
updated_amax
Backward pass for activation with optional quantization.
act_type_id
=
ActivationEnum
[
activation_type
].
value
return
ActLuFp8Primitive
.
outer_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
act_enum
=
act_type_id
Args:
dz: Gradient tensor from upstream.
x: Input tensor that was used in forward pass.
activation_type: Type of activation function that was applied.
quantizer: Optional quantizer for FP8 quantization of the output gradient.
Returns:
The gradient of the activation with respect to the input.
"""
output
,
_
=
quantize_dact_dbias
(
dz
=
dz
,
x
=
x
,
activation_type
=
activation_type
,
is_dbias
=
False
,
quantizer
=
quantizer
,
)
return
output
transformer_engine/jax/cpp_extensions/attention.py
View file @
a207db1d
...
...
@@ -13,8 +13,6 @@ from packaging import version
import
jax
import
jax.numpy
as
jnp
from
jax
import
dtypes
,
lax
from
jax.interpreters
import
mlir
from
jax.interpreters.mlir
import
ir
from
jax.sharding
import
PartitionSpec
,
NamedSharding
import
transformer_engine_jax
...
...
@@ -29,14 +27,12 @@ from transformer_engine.jax.attention import (
)
from
.base
import
BasePrimitive
,
register_primitive
from
.custom_call
import
custom_caller
,
CustomCallArgsWrapper
from
.misc
import
(
check_valid_batch_dims
,
jax_dtype_to_te_dtype
,
te_dtype_to_jax_dtype
,
get_padded_spec
,
get_cudnn_version
,
is_ffi_enabled
,
)
from
..sharding
import
(
global_mesh_resource
,
...
...
@@ -227,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
Fused Attention Forward Primitive
"""
name
=
"te_fused_attn_forward"
name
=
"te_fused_attn_forward
_ffi
"
multiple_results
=
True
impl_static_args
=
(
13
,)
inner_primitive
=
None
...
...
@@ -400,9 +396,7 @@ class FusedAttnFwdPrimitive(BasePrimitive):
*
bias_batch_shape
,
bias_heads
,
_
,
_
=
bias_aval
.
shape
bias_batch
=
reduce
(
operator
.
mul
,
bias_batch_shape
)
if
is_ffi_enabled
():
name
=
"te_fused_attn_forward_ffi"
out
=
ffi
.
ffi_lowering
(
name
)(
return
ffi
.
ffi_lowering
(
FusedAttnFwdPrimitive
.
name
)(
ctx
,
q
,
k
,
...
...
@@ -436,54 +430,6 @@ class FusedAttnFwdPrimitive(BasePrimitive):
window_size_left
=
config
.
window_size
[
0
],
window_size_right
=
config
.
window_size
[
1
],
)
else
:
operands
=
[
q
,
k
,
v
,
bias
,
seed
,
q_cu_seqlen
,
kv_cu_seqlen
,
q_seq_offsets
,
k_seq_offsets
,
]
operand_shapes
=
map
(
lambda
x
:
x
.
type
.
shape
,
operands
)
out_types
=
[
ir
.
RankedTensorType
.
get
(
output
.
shape
,
mlir
.
dtype_to_ir_type
(
output
.
dtype
))
for
output
in
ctx
.
avals_out
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
opaque
=
transformer_engine_jax
.
pack_fused_attn_descriptor
(
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
head_dim
,
config
.
max_segments_per_seq
,
wkspace_aval
.
size
,
config
.
scaling_factor
,
config
.
dropout_probability
,
config
.
attn_bias_type
,
config
.
attn_mask_type
,
config
.
qkv_layout
,
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
config
.
is_training
,
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
config
.
window_size
[
0
],
config
.
window_size
[
1
],
)
out
=
custom_caller
(
FusedAttnFwdPrimitive
.
name
,
args
,
opaque
,
has_side_effect
=
False
)
return
out
@
staticmethod
def
impl
(
...
...
@@ -681,7 +627,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
Fused Attention Backward Primitive
"""
name
=
"te_fused_attn_backward"
name
=
"te_fused_attn_backward
_ffi
"
multiple_results
=
True
impl_static_args
=
(
16
,)
inner_primitive
=
None
...
...
@@ -813,9 +759,7 @@ class FusedAttnBwdPrimitive(BasePrimitive):
*
bias_batch_shape
,
bias_heads
,
_
,
_
=
bias_aval
.
shape
bias_batch
=
reduce
(
operator
.
mul
,
bias_batch_shape
)
if
is_ffi_enabled
():
name
=
"te_fused_attn_backward_ffi"
out
=
ffi
.
ffi_lowering
(
name
)(
return
ffi
.
ffi_lowering
(
FusedAttnBwdPrimitive
.
name
)(
ctx
,
q
,
k
,
...
...
@@ -852,57 +796,6 @@ class FusedAttnBwdPrimitive(BasePrimitive):
window_size_left
=
config
.
window_size
[
0
],
window_size_right
=
config
.
window_size
[
1
],
)
else
:
operands
=
[
q
,
k
,
v
,
bias
,
softmax_aux
,
rng_state
,
output
,
doutput
,
q_cu_seqlen
,
kv_cu_seqlen
,
q_seq_offsets
,
k_seq_offsets
,
]
operand_shapes
=
map
(
lambda
x
:
x
.
type
.
shape
,
operands
)
out_types
=
[
ir
.
RankedTensorType
.
get
(
output
.
shape
,
mlir
.
dtype_to_ir_type
(
output
.
dtype
))
for
output
in
ctx
.
avals_out
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
opaque
=
transformer_engine_jax
.
pack_fused_attn_descriptor
(
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
head_dim
,
config
.
max_segments_per_seq
,
wkspace_aval
.
size
,
config
.
scaling_factor
,
config
.
dropout_probability
,
config
.
attn_bias_type
,
config
.
attn_mask_type
,
config
.
qkv_layout
,
jax_dtype_to_te_dtype
(
q_aval
.
dtype
),
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
config
.
is_training
,
not
FusedAttnHelper
.
is_non_deterministic_allowed
(),
config
.
window_size
[
0
],
config
.
window_size
[
1
],
)
out
=
custom_caller
(
FusedAttnBwdPrimitive
.
name
,
args
,
opaque
,
has_side_effect
=
False
)
return
out
@
staticmethod
def
impl
(
...
...
transformer_engine/jax/cpp_extensions/base.py
View file @
a207db1d
...
...
@@ -6,6 +6,7 @@ import os
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
functools
import
partial
from
packaging
import
version
from
jax.extend
import
core
from
jax.interpreters
import
xla
,
mlir
...
...
@@ -13,6 +14,14 @@ from jax.experimental.custom_partitioning import custom_partitioning
from
jax._src.interpreters
import
batching
from
jax._src
import
dispatch
import
jax
import
transformer_engine_jax
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
from
jax
import
ffi
# pylint: disable=ungrouped-imports
else
:
from
jax.extend
import
ffi
# pylint: disable=ungrouped-imports
class
BasePrimitive
(
metaclass
=
ABCMeta
):
"""
...
...
@@ -120,3 +129,7 @@ def register_primitive(cls):
outer_p
,
mlir
.
lower_fun
(
outer_p_lower
,
multiple_results
=
cls
.
multiple_results
)
)
cls
.
outer_primitive
=
outer_p
for
_name
,
_value
in
transformer_engine_jax
.
registrations
().
items
():
ffi
.
register_ffi_target
(
_name
,
_value
,
platform
=
"CUDA"
)
transformer_engine/jax/cpp_extensions/custom_call.py
deleted
100644 → 0
View file @
fbee8990
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom call"""
from
dataclasses
import
dataclass
from
enum
import
IntEnum
from
packaging
import
version
import
jax
from
jax.interpreters
import
mlir
import
transformer_engine_jax
from
.misc
import
is_ffi_enabled
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
from
jax
import
ffi
# pylint: disable=ungrouped-imports
else
:
from
jax.extend
import
ffi
# pylint: disable=ungrouped-imports
try
:
from
jaxlib.hlo_helpers
import
custom_call
except
ImportError
:
# Newer JAX changed its API. But we want to support a few JAX
# version, so we still need this import.
pass
class
CustomCallAPIVersion
(
IntEnum
):
"""Enum for selecting between old and new custom call registration API"""
OPAQUE
=
0
FFI
=
1
for
_name
,
_value
in
transformer_engine_jax
.
registrations
().
items
():
if
_name
.
endswith
(
"_ffi"
):
if
is_ffi_enabled
():
ffi
.
register_ffi_target
(
_name
,
_value
,
platform
=
"CUDA"
,
api_version
=
CustomCallAPIVersion
.
FFI
.
value
)
else
:
ffi
.
register_ffi_target
(
_name
,
_value
,
platform
=
"CUDA"
,
api_version
=
CustomCallAPIVersion
.
OPAQUE
.
value
)
@
dataclass
class
CustomCallArgsWrapper
:
"""
wrapper of XLA custom call args
"""
def
__init__
(
self
,
output_types
,
operands
,
operand_shapes
,
operand_specific_layouts
=
None
,
output_specific_layouts
=
None
,
):
self
.
output_types
=
output_types
self
.
operands
=
operands
self
.
operand_layouts
=
CustomCallArgsWrapper
.
generate_layouts
(
operand_shapes
,
operand_specific_layouts
)
output_shapes
=
[
x
.
shape
for
x
in
output_types
]
self
.
output_layouts
=
CustomCallArgsWrapper
.
generate_layouts
(
output_shapes
,
output_specific_layouts
)
@
staticmethod
def
generate_layouts
(
shapes
,
specific_layouts
):
"""
setup layouts for XLA custom call
"""
def
default_layout
(
shape
):
return
range
(
len
(
shape
)
-
1
,
-
1
,
-
1
)
if
specific_layouts
is
None
:
specific_layouts
=
{}
layouts
=
[]
for
idx
,
shape
in
enumerate
(
shapes
):
if
idx
in
specific_layouts
:
layouts
.
append
(
specific_layouts
[
idx
])
else
:
layouts
.
append
(
default_layout
(
shape
))
return
layouts
def
custom_caller
(
name
,
args
,
opaque
,
has_side_effect
,
**
kwargs
):
"""
XLA custom call warpper
"""
if
hasattr
(
mlir
,
"custom_call"
):
out
=
mlir
.
custom_call
(
name
,
result_types
=
args
.
output_types
,
operands
=
args
.
operands
,
operand_layouts
=
args
.
operand_layouts
,
result_layouts
=
args
.
output_layouts
,
backend_config
=
opaque
,
has_side_effect
=
has_side_effect
,
**
kwargs
,
).
results
else
:
# Need to disable one pylint error as the second function
# parameter name recenctly in JAX. Otherwise we won't be
# compatible with multiple JAX version.
out
=
custom_call
(
# pylint: disable=too-many-function-args
name
,
args
.
output_types
,
operands
=
args
.
operands
,
operand_layouts
=
args
.
operand_layouts
,
result_layouts
=
args
.
output_layouts
,
backend_config
=
opaque
,
has_side_effect
=
has_side_effect
,
**
kwargs
,
)
return
out
transformer_engine/jax/cpp_extensions/gemm.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
from
typing
import
Tuple
,
Sequence
,
Union
,
Dict
,
List
from
functools
import
partial
,
reduce
import
operator
from
transformer_engine_jax
import
get_device_compute_capability
import
jax
import
jax.numpy
as
jnp
from
.base
import
BasePrimitive
,
register_primitive
from
..quantize
import
(
ScaledTensor
,
ScalingMode
,
Quantizer
,
QuantizeConfig
,
noop_quantizer_set
,
)
__all__
=
[
"gemm"
,
"grouped_gemm"
]
num_cublas_streams
=
4
def
get_cublas_workspace_size_bytes
()
->
None
:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
if
get_device_compute_capability
(
0
)
>=
90
:
return
33_554_432
return
4_194_304
class
GroupedGemmPrimitive
(
BasePrimitive
):
"""
Primitive for grouped GEMM
"""
name
=
"te_grouped_gemm_ffi"
multiple_results
=
True
impl_static_args
=
(
6
,
7
,
8
,
9
)
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
lhs_contig_aval
,
lhs_scale_contig_aval
,
rhs_contig_aval
,
rhs_scale_contig_aval
,
bias_contig_aval
,
dim_list_aval
,
*
,
num_gemms
,
scaling_mode
,
out_dtype
,
out_flat_size
,
):
del
lhs_contig_aval
,
lhs_scale_contig_aval
del
rhs_contig_aval
,
rhs_scale_contig_aval
del
bias_contig_aval
,
dim_list_aval
del
num_gemms
,
scaling_mode
out_flat_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
out_flat_size
,),
dtype
=
out_dtype
)
wkspace_size
=
get_cublas_workspace_size_bytes
()
*
num_cublas_streams
wkspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
wkspace_size
,),
dtype
=
jnp
.
uint8
)
return
(
out_flat_aval
,
wkspace_aval
)
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
(
out_aval
,
_
)
=
GroupedGemmPrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
out_aval
@
staticmethod
def
lowering
(
ctx
,
lhs_contig
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
*
,
num_gemms
,
scaling_mode
,
out_dtype
,
out_flat_size
,
)
->
jnp
.
ndarray
:
del
out_dtype
,
out_flat_size
return
jax
.
ffi
.
ffi_lowering
(
GroupedGemmPrimitive
.
name
)(
ctx
,
lhs_contig
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
num_gemms
=
num_gemms
,
scaling_mode
=
int
(
scaling_mode
),
)
@
staticmethod
def
impl
(
lhs_contig
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
num_gemms
,
scaling_mode
,
out_dtype
,
out_flat_size
,
)
->
jnp
.
ndarray
:
assert
GroupedGemmPrimitive
.
inner_primitive
is
not
None
out
=
GroupedGemmPrimitive
.
inner_primitive
.
bind
(
lhs_contig
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
num_gemms
=
num_gemms
,
scaling_mode
=
scaling_mode
.
value
,
out_dtype
=
out_dtype
,
out_flat_size
=
out_flat_size
,
)
return
out
[
0
]
# out is [out_flat, wkspace], only return out_flat
register_primitive
(
GroupedGemmPrimitive
)
def
_shape_normalization
(
x
,
dimension_numbers
,
already_transposed
:
bool
=
False
):
orig_order
=
list
(
range
(
x
.
ndim
))
contracting_dims
,
batch_dims
=
dimension_numbers
contracting_order
=
[
d
for
d
in
orig_order
if
d
in
contracting_dims
]
batch_order
=
[
d
for
d
in
orig_order
if
d
in
batch_dims
]
non_contracting_order
=
[
d
for
d
in
orig_order
if
d
not
in
contracting_dims
and
d
not
in
batch_dims
]
batch_shape
=
[
x
.
shape
[
d
]
for
d
in
batch_order
]
rows_shape
=
[
x
.
shape
[
d
]
for
d
in
non_contracting_order
]
cols_shape
=
[
x
.
shape
[
d
]
for
d
in
contracting_order
]
new_order
=
batch_order
+
non_contracting_order
+
contracting_order
rows
,
cols
,
batches
=
(
reduce
(
operator
.
mul
,
rows_shape
,
1
),
reduce
(
operator
.
mul
,
cols_shape
,
1
),
reduce
(
operator
.
mul
,
batch_shape
,
1
),
)
# Remove this transpose when non-TN dot is supported
if
not
already_transposed
:
t
=
jnp
.
transpose
(
x
,
new_order
)
else
:
t
=
x
return
jnp
.
reshape
(
t
,
(
batches
,
rows
,
cols
))
def
_calculate_remaining_shape
(
shape
,
contracting_dims
):
return
tuple
(
shape
[
dim
]
for
dim
in
range
(
len
(
shape
))
if
dim
not
in
contracting_dims
)
def
_dequantize
(
x
,
scale_inv
,
dq_dtype
):
return
x
.
astype
(
dq_dtype
)
*
scale_inv
.
astype
(
dq_dtype
)
# Apply jit to guarantee correctness of FP8 GEMM.
@
partial
(
jax
.
jit
,
static_argnums
=
(
2
,
3
,
4
,
),
)
def
__jitted_jax_gemm_delayed_scaling_fp8
(
lhs
,
rhs
,
lhs_dn
,
rhs_dn
,
precision
):
# Need to hard-code the dequantize here instead of calling lhs.dequantize() for pattern matching
lhs_dq
=
_dequantize
(
lhs
.
data
,
lhs
.
scale_inv
,
lhs
.
dq_dtype
)
rhs_dq
=
_dequantize
(
rhs
.
data
,
rhs
.
scale_inv
,
rhs
.
dq_dtype
)
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d
=
_shape_normalization
(
lhs_dq
,
lhs_dn
,
lhs
.
layout
==
"N"
)
rhs_3d
=
_shape_normalization
(
rhs_dq
,
rhs_dn
,
rhs
.
layout
==
"T"
)
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums
=
(((
2
,),
(
2
,)),
((
0
,),
(
0
,)))
out_3d
=
jax
.
lax
.
dot_general
(
lhs_3d
,
rhs_3d
,
dim_nums
,
precision
=
precision
,
preferred_element_type
=
lhs
.
dq_dtype
)
return
out_3d
def
_jax_gemm_delayed_scaling_fp8
(
lhs
:
ScaledTensor
,
rhs
:
ScaledTensor
,
dim_nums
:
Tuple
[
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]]
):
"""FP8 GEMM for XLA pattern match"""
assert
(
rhs
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
),
"rhs does not have delayed tensor scaling mode"
(
lhs_contract
,
rhs_contract
),
(
lhs_batch
,
rhs_batch
)
=
dim_nums
if
lhs
.
layout
==
"T"
:
lhs_contract
=
tuple
((
lhs
.
data
.
ndim
-
1
-
i
)
%
lhs
.
data
.
ndim
for
i
in
lhs_contract
)
if
rhs
.
layout
==
"T"
:
rhs_contract
=
tuple
((
rhs
.
data
.
ndim
-
1
-
i
)
%
rhs
.
data
.
ndim
for
i
in
rhs_contract
)
lhs_dn
=
(
lhs_contract
,
lhs_batch
)
rhs_dn
=
(
rhs_contract
,
rhs_batch
)
lhs_remain_shape
=
_calculate_remaining_shape
(
lhs
.
data
.
shape
,
lhs_contract
)
rhs_remain_shape
=
_calculate_remaining_shape
(
rhs
.
data
.
shape
,
rhs_contract
)
precision
=
(
jax
.
lax
.
Precision
.
HIGHEST
if
QuantizeConfig
.
FP8_2X_ACC_FPROP
else
jax
.
lax
.
Precision
.
DEFAULT
)
out_3d
=
__jitted_jax_gemm_delayed_scaling_fp8
(
lhs
,
rhs
,
lhs_dn
,
rhs_dn
,
precision
)
# Reshape [B, M, N] -> [..., M, N]
out
=
out_3d
.
reshape
(
*
lhs_remain_shape
,
*
rhs_remain_shape
)
return
out
def
_jax_gemm_mxfp8_1d
(
lhs
:
ScaledTensor
,
rhs
:
ScaledTensor
,
dim_nums
:
Tuple
[
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]]
):
"""
JAX GEMM for MXFP8 via scaled_matmul
"""
assert
(
rhs
.
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
),
"rhs does not have MXFP8 1D scaling mode"
from
jax._src.cudnn.scaled_matmul_stablehlo
import
scaled_matmul_wrapper
(
lhs_contract
,
rhs_contract
),
(
lhs_batch
,
rhs_batch
)
=
dim_nums
expected_lhs_is_colwise
=
lhs_contract
[
-
1
]
!=
lhs
.
data
.
ndim
-
1
expected_rhs_is_colwise
=
rhs_contract
[
-
1
]
!=
rhs
.
data
.
ndim
-
1
assert
lhs
.
is_colwise
is
expected_lhs_is_colwise
,
(
f
"LHS with unexpected quantize dimension.
\n
Expect is_colwise=
{
expected_lhs_is_colwise
}
, got"
f
"
{
lhs
.
is_colwise
}
"
)
assert
rhs
.
is_colwise
is
expected_rhs_is_colwise
,
(
f
"RHS with unexpected quantize dimension.
\n
Expect is_colwise=
{
expected_rhs_is_colwise
}
, got"
f
"
{
rhs
.
is_colwise
}
"
)
# Reshape + Transpose (if needed)
# [..., M, K] -> [1, reduce(..., M), K]
# [..., K, M] -> [1, reduce(..., M), K]
lhs_3d
=
_shape_normalization
(
lhs
.
data
,
(
lhs_contract
,
lhs_batch
))
rhs_3d
=
_shape_normalization
(
rhs
.
data
,
(
rhs_contract
,
rhs_batch
))
lhs_scale_3d
=
_shape_normalization
(
lhs
.
scale_inv
,
(
lhs_contract
,
lhs_batch
))
rhs_scale_3d
=
_shape_normalization
(
rhs
.
scale_inv
,
(
rhs_contract
,
rhs_batch
))
# Slice out the padding as scaled_matmul does not support padded scales yet
lhs_scale_3d
=
jnp
.
asarray
(
lhs_scale_3d
[:,
:
lhs_3d
.
shape
[
1
],
:
int
(
lhs_3d
.
shape
[
2
]
/
32
)])
rhs_scale_3d
=
jnp
.
asarray
(
rhs_scale_3d
[:,
:
rhs_3d
.
shape
[
1
],
:
int
(
rhs_3d
.
shape
[
2
]
/
32
)])
# JAX scaled_matmul only supports NT now (TN-gemm)
# * Expected shape:
# * lhs_data (B, M, K) * rhs_data (B, N, K)
# * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block)
out_3d
=
scaled_matmul_wrapper
(
lhs_3d
,
rhs_3d
,
lhs_scale_3d
,
rhs_scale_3d
,
preferred_element_type
=
lhs
.
dq_dtype
)
# Reshape [1, reduce(..., M), N] -> [..., M, N]
lhs_remain_shape
=
tuple
(
lhs
.
data
.
shape
[
dim
]
for
dim
in
range
(
len
(
lhs
.
data
.
shape
))
if
dim
not
in
lhs_contract
)
rhs_remain_shape
=
tuple
(
rhs
.
data
.
shape
[
dim
]
for
dim
in
range
(
len
(
rhs
.
data
.
shape
))
if
dim
not
in
rhs_contract
)
out
=
out_3d
.
reshape
(
*
lhs_remain_shape
,
*
rhs_remain_shape
)
return
out
def
_jax_gemm
(
lhs
:
Union
[
jnp
.
ndarray
,
ScaledTensor
],
rhs
:
Union
[
jnp
.
ndarray
,
ScaledTensor
],
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
quantizer_set
:
Dict
[
"str"
,
Quantizer
]
=
noop_quantizer_set
,
)
->
jnp
.
ndarray
:
"""
FP8 GEMM via JAX
"""
dim_nums
=
(
contracting_dims
,
((),
()))
def
_jax_gemm_fp8_impl
(
lhs
,
rhs
):
if
lhs
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
return
_jax_gemm_delayed_scaling_fp8
(
lhs
,
rhs
,
dim_nums
)
if
lhs
.
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
return
_jax_gemm_mxfp8_1d
(
lhs
,
rhs
,
dim_nums
)
raise
NotImplementedError
(
"Unsupported ScalingMode: {lhs.scaling_mode}"
)
if
isinstance
(
lhs
,
ScaledTensor
)
and
isinstance
(
rhs
,
ScaledTensor
):
return
_jax_gemm_fp8_impl
(
lhs
,
rhs
)
if
not
isinstance
(
lhs
,
ScaledTensor
)
and
not
isinstance
(
rhs
,
ScaledTensor
):
if
quantizer_set
!=
noop_quantizer_set
:
assert
type
(
quantizer_set
.
x
)
is
type
(
quantizer_set
.
kernel
)
(((
lhs_contract_dim
,),
(
rhs_contract_dim
,)),
_
)
=
dim_nums
lhs_is_rowwise
=
lhs_contract_dim
==
lhs
.
ndim
-
1
rhs_is_rowwise
=
rhs_contract_dim
==
rhs
.
ndim
-
1
# Call JAX quantization so that XLA can do pattern matching (QDQ --> FP8 gemm)
lhs_q
=
quantizer_set
.
x
.
quantize
(
lhs
,
is_rowwise
=
lhs_is_rowwise
,
is_colwise
=
not
lhs_is_rowwise
,
)
rhs_q
=
quantizer_set
.
kernel
.
quantize
(
rhs
,
is_rowwise
=
rhs_is_rowwise
,
is_colwise
=
not
rhs_is_rowwise
,
)
return
_jax_gemm_fp8_impl
(
lhs_q
,
rhs_q
)
if
(
isinstance
(
lhs
,
jnp
.
ndarray
)
and
isinstance
(
rhs
,
jnp
.
ndarray
)
and
quantizer_set
==
noop_quantizer_set
):
return
jax
.
lax
.
dot_general
(
lhs
,
rhs
,
dim_nums
,
preferred_element_type
=
lhs
.
dtype
)
raise
NotImplementedError
(
"Not supporting multiplication of ScaledTensor and jnp.array"
)
def
gemm
(
lhs
:
Union
[
jnp
.
ndarray
,
ScaledTensor
],
rhs
:
Union
[
jnp
.
ndarray
,
ScaledTensor
],
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
quantizer_set
:
Dict
[
"str"
,
Quantizer
]
=
noop_quantizer_set
,
)
->
jnp
.
ndarray
:
"""General matrix multiplication with optional quantization.
Args:
lhs: First input matrix.
rhs: Second input matrix.
contracting_dims: Tuple of two sequences representing the contracting dimensions.
The first sequence represents the contracting dimensions of the first matrix,
and the second sequence represents the contracting dimensions of the second matrix.
quantizer_set: Set of quantizers for FP8 quantization of the output.
If None, no quantization is applied and the output has the same dtype as the inputs.
Returns:
If quantizer_set is None:
The matrix multiplication result.
Shape: (M, N)
Dtype: Same as input dtype
If quantizer_set is provided:
A ScaledTensor containing the quantized matrix multiplication result.
"""
return
_jax_gemm
(
lhs
,
rhs
,
contracting_dims
,
quantizer_set
)
def
swizzled_scale
(
scales
):
"""Swizzle the scale tensor for FP8 GEMM"""
assert
scales
.
ndim
==
2
rows
,
cols
=
scales
.
shape
scales
=
scales
.
reshape
(
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
scales
=
jnp
.
transpose
(
scales
,
(
0
,
3
,
2
,
1
,
4
))
return
scales
def
grouped_gemm
(
lhs_list
:
List
[
Union
[
jnp
.
ndarray
,
ScaledTensor
]],
rhs_list
:
List
[
Union
[
jnp
.
ndarray
,
ScaledTensor
]],
contracting_dims_list
:
List
[
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]],
bias_list
:
List
[
jnp
.
ndarray
]
=
None
,
)
->
List
[
jnp
.
ndarray
]:
"""Grouped GEMM for multiple pairs of tensors."""
assert
(
len
(
lhs_list
)
==
len
(
rhs_list
)
==
len
(
contracting_dims_list
)
),
"lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms
=
len
(
lhs_list
)
out_flat_size
=
0
dims
=
[]
lhs_contig_
=
[]
rhs_contig_
=
[]
lhs_scale_inv_contig_
=
[]
rhs_scale_inv_contig_
=
[]
bias_contig_
=
[]
out_offsets
=
[]
remain_shape_list
=
[]
num_gemms
=
len
(
lhs_list
)
for
i
in
range
(
num_gemms
):
lhs
=
lhs_list
[
i
]
rhs
=
rhs_list
[
i
]
contracting_dims
=
contracting_dims_list
[
i
]
dim_nums
=
(
contracting_dims
,
((),
()))
if
isinstance
(
lhs
,
ScaledTensor
)
and
isinstance
(
rhs
,
ScaledTensor
):
scaling_mode
=
lhs
.
scaling_mode
lhs_shape
=
lhs
.
data
.
shape
rhs_shape
=
rhs
.
data
.
shape
out_dtype
=
lhs
.
dq_dtype
# For ScaledTensors and NVTE_DELAYED_TENSOR_SCALING, need to handle internal layout
if
lhs
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
assert
not
(
lhs
.
data
.
dtype
==
jnp
.
float8_e5m2
and
rhs
.
data
.
dtype
==
jnp
.
float8_e5m2
),
"FP8 GEMM does not support E5M2 * E5M2"
((
lhs_contract_dim
,),
(
rhs_contract_dim
,))
=
contracting_dims
if
lhs
.
layout
==
"T"
:
lhs_contract_dim
=
(
lhs_contract_dim
-
1
)
%
lhs
.
data
.
ndim
if
rhs
.
layout
==
"T"
:
rhs_contract_dim
=
(
rhs_contract_dim
-
1
)
%
rhs
.
data
.
ndim
dim_nums
=
((
lhs_contract_dim
,),
(
rhs_contract_dim
,)),
((),
())
else
:
# For jnp.ndarray, only consider contracting_dims, layout is always NN
scaling_mode
=
ScalingMode
.
NVTE_NO_SCALING
lhs_shape
=
lhs
.
shape
rhs_shape
=
rhs
.
shape
out_dtype
=
lhs
.
dtype
(
lhs_contract
,
rhs_contract
),
(
lhs_batch
,
rhs_batch
)
=
dim_nums
lhs_dn
=
(
lhs_contract
,
lhs_batch
)
rhs_dn
=
(
rhs_contract
,
rhs_batch
)
lhs_remain_shape
=
_calculate_remaining_shape
(
lhs_shape
,
lhs_contract
)
rhs_remain_shape
=
_calculate_remaining_shape
(
rhs_shape
,
rhs_contract
)
if
scaling_mode
==
ScalingMode
.
NVTE_NO_SCALING
:
lhs_3d
=
_shape_normalization
(
lhs
,
lhs_dn
)
rhs_3d
=
_shape_normalization
(
rhs
,
rhs_dn
)
elif
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
lhs_3d
=
_shape_normalization
(
lhs
.
data
,
lhs_dn
,
lhs
.
layout
==
"N"
)
rhs_3d
=
_shape_normalization
(
rhs
.
data
,
rhs_dn
,
rhs
.
layout
==
"T"
)
elif
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
lhs_3d
=
_shape_normalization
(
lhs
.
data
,
lhs_dn
)
rhs_3d
=
_shape_normalization
(
rhs
.
data
,
rhs_dn
)
lhs_scale_inv
=
_shape_normalization
(
lhs
.
scale_inv
,
lhs_dn
)
rhs_scale_inv
=
_shape_normalization
(
rhs
.
scale_inv
,
rhs_dn
)
lhs_scale_inv
=
swizzled_scale
(
lhs_scale_inv
.
squeeze
())
rhs_scale_inv
=
swizzled_scale
(
rhs_scale_inv
.
squeeze
())
else
:
raise
NotImplementedError
(
"Unsupported ScalingMode: {scaling_mode}"
)
# Note: if _shape_normalization() is updated to support non-TN, need to update here
# already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
# x.shape = [D1, D2]
# contracting_dims = (1, ) --> output.shape = [1, D1, D2]
# contracting_dims = (0, ) --> output.shape = [1, D2, D1]
bm
=
lhs_remain_shape
[
0
]
bn
=
rhs_remain_shape
[
0
]
kl
=
lhs_3d
.
shape
[
-
1
]
kr
=
rhs_3d
.
shape
[
-
1
]
remain_shape_list
.
append
(((
bm
,),
(
bn
,)))
assert
kl
==
kr
,
f
"lhs_3d.shape[-1] (
{
kl
}
) != rhs_3d.shape[-1] (
{
kr
}
)"
k
=
kl
if
(
bm
%
16
!=
0
)
or
(
bn
%
16
!=
0
)
or
(
k
%
16
!=
0
):
print
(
f
"grouped_gemm input pair
{
i
}
has invalid problem shape for lowering: "
)
print
(
f
"m =
{
bm
}
, n =
{
bn
}
, k =
{
k
}
; cuBLAS requires the problem shapes being multiples"
" of 16"
)
assert
bm
%
16
==
0
and
bn
%
16
==
0
and
k
%
16
==
0
dims
.
append
((
bm
,
bn
,
k
))
lhs_contig_
.
append
(
lhs_3d
.
reshape
(
-
1
))
rhs_contig_
.
append
(
rhs_3d
.
reshape
(
-
1
))
if
scaling_mode
==
ScalingMode
.
NVTE_NO_SCALING
:
lhs_scale_inv_contig_
.
append
(
jnp
.
ones
(
1
,
dtype
=
jnp
.
float32
))
rhs_scale_inv_contig_
.
append
(
jnp
.
ones
(
1
,
dtype
=
jnp
.
float32
))
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
lhs_scale_inv_contig_
.
append
(
lhs
.
scale_inv
.
reshape
(
-
1
))
rhs_scale_inv_contig_
.
append
(
rhs
.
scale_inv
.
reshape
(
-
1
))
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
lhs_scale_inv_contig_
.
append
(
lhs_scale_inv
.
reshape
(
-
1
))
rhs_scale_inv_contig_
.
append
(
rhs_scale_inv
.
reshape
(
-
1
))
if
bias_list
is
not
None
:
bias_contig_
.
append
(
bias_list
[
i
].
reshape
(
-
1
))
out_flat_size
+=
bm
*
bn
out_offsets
.
append
(
out_flat_size
)
lhs_contig
=
jnp
.
concatenate
(
lhs_contig_
)
rhs_contig
=
jnp
.
concatenate
(
rhs_contig_
)
lhs_scale_inv_contig
=
jnp
.
concatenate
(
lhs_scale_inv_contig_
)
rhs_scale_inv_contig
=
jnp
.
concatenate
(
rhs_scale_inv_contig_
)
bias_contig
=
jnp
.
empty
(
0
)
if
bias_list
is
None
else
jnp
.
concatenate
(
bias_contig_
)
dim_list
=
jnp
.
array
(
dims
,
dtype
=
jnp
.
int32
)
# Perform batched GEMM on flattened inputs
out_contig
=
GroupedGemmPrimitive
.
outer_primitive
.
bind
(
lhs_contig
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
num_gemms
=
num_gemms
,
scaling_mode
=
scaling_mode
,
out_dtype
=
out_dtype
,
out_flat_size
=
out_flat_size
,
)
# Split the output back into tensors
out_offsets
=
jnp
.
array
(
out_offsets
)
out_flat_list
=
jnp
.
split
(
out_contig
,
out_offsets
[:
-
1
])
out_tensors
=
[]
for
out_flat
,
(
lhs_remain_shape
,
rhs_remain_shape
)
in
zip
(
out_flat_list
,
remain_shape_list
):
out_tensors
.
append
(
out_flat
.
reshape
(
*
lhs_remain_shape
,
*
rhs_remain_shape
))
return
out_tensors
transformer_engine/jax/cpp_extensions/misc.py
View file @
a207db1d
...
...
@@ -11,14 +11,17 @@ from packaging.version import Version as PkgVersion
import
numpy
as
np
import
jax
.numpy
as
jnp
import
jax
from
jax
import
dtypes
import
jax.numpy
as
jnp
from
jax.interpreters.mlir
import
dtype_to_ir_type
from
transformer_engine_jax
import
DType
as
TEDType
import
transformer_engine_jax
from
..sharding
import
get_padded_spec
as
te_get_padded_spec
from
..quantize
import
ScalingMode
,
ScaledTensorFactory
,
QuantizeAxis
TEDType
=
transformer_engine_jax
.
DType
def
te_dtype_to_jax_dtype
(
te_dtype
):
...
...
@@ -104,7 +107,7 @@ def normalize_axis_boundary(axis, ndim):
return
axis
if
axis
>=
0
else
ndim
+
axis
def
multidim_transpose
(
shape
,
static_axis_boundary
,
transpose_axis_boundary
):
def
multidim_transpose
(
shape
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=-
1
):
"""
te_cast_transpose_p multi-dims transpose
...
...
@@ -158,17 +161,6 @@ def jax_version_meet_requirement(version: str):
return
jax_version
>=
jax_version_required
def
is_ffi_enabled
():
"""
Helper function checking if XLA Custom Call with FFI is enabled
"""
is_supported
=
jax_version_meet_requirement
(
"0.4.35"
)
# New APIs with FFI are enabled by default
is_enabled
=
int
(
os
.
getenv
(
"NVTE_JAX_WITH_FFI"
,
"1"
))
assert
is_enabled
in
(
0
,
1
),
"Invalid NVTE_JAX_WITH_FFI value"
return
is_supported
and
is_enabled
def
get_xla_flag
(
flag
:
str
,
default
=
None
,
cast
=
str
):
"""
Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value.
...
...
@@ -189,3 +181,86 @@ def get_xla_flag(flag: str, default=None, cast=str):
if
name
==
flag
:
return
True
return
default
def
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
:
bool
=
False
,
quantizer
=
None
):
"""
Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to
calculate dbias separately. This function checks if the workaround should be applied.
"""
arch_l_100
=
False
for
local_gpu_id
in
range
(
len
(
jax
.
local_devices
())):
if
transformer_engine_jax
.
get_device_compute_capability
(
local_gpu_id
)
<
100
:
arch_l_100
=
True
break
return
(
quantizer
is
not
None
and
quantizer
.
q_axis
==
QuantizeAxis
.
ROWWISE
and
arch_l_100
and
is_dbias
)
def
try_apply_delayed_scaling_2x_war
(
f
,
*
args
,
quantizer
=
None
,
**
kwargs
):
"""
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
If 'f' returns a tuple, the first output must be the only ScaledTensor output.
@param f: function to call
@param args: positional arguments to pass to 'f'
@param quantizer: quantizer to use
@param kwargs: keyword arguments to pass to 'f'
@return: the output of 'f' with the colwise output calculated
"""
should_apply_war
=
(
quantizer
is
not
None
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
and
quantizer
.
is_2x2x
()
)
if
not
should_apply_war
:
return
None
# 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX
quantizer
.
q_axis
=
QuantizeAxis
.
ROWWISE
rowwise
=
f
(
*
args
,
**
kwargs
,
quantizer
=
quantizer
)
other_outputs
=
None
if
isinstance
(
rowwise
,
tuple
):
other_outputs
=
rowwise
[
1
:]
rowwise
=
rowwise
[
0
]
quantizer
.
q_axis
=
QuantizeAxis
.
ROWWISE_COLWISE
colwise_data
=
jnp
.
transpose
(
rowwise
.
data
,
(
-
1
,
*
range
(
rowwise
.
data
.
ndim
-
1
)))
output_2x
=
ScaledTensorFactory
.
create
(
data
=
rowwise
.
data
,
scale_inv
=
rowwise
.
scale_inv
,
colwise_data
=
colwise_data
,
colwise_scale_inv
=
rowwise
.
scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
rowwise
.
dq_dtype
,
q_axis
=
QuantizeAxis
.
ROWWISE_COLWISE
,
layout
=
quantizer
.
get_layout
(),
)
if
other_outputs
is
not
None
:
return
(
output_2x
,)
+
other_outputs
return
output_2x
class
NamedSharding
(
jax
.
sharding
.
NamedSharding
):
"""
Wrapper around jax.sharding.NamedSharding that adds a string description field as metadata for easier debugging.
"""
def
__init__
(
self
,
*
args
,
desc
:
str
=
None
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
desc
=
desc
def
__repr__
(
self
):
return
f
"NamedSharding(
{
self
.
mesh
}
,
{
self
.
spec
}
, desc=
{
self
.
desc
}
)"
def
duplicate_with_new_description
(
self
,
desc
:
str
):
"""
Create a new NamedSharding with the same mesh and spec but with a new description.
"""
return
NamedSharding
(
self
.
mesh
,
self
.
spec
,
desc
=
desc
)
transformer_engine/jax/cpp_extensions/normalization.py
View file @
a207db1d
...
...
@@ -2,33 +2,38 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for normalization"""
import
operator
import
os
import
warnings
from
functools
import
partial
,
reduce
,
cache
import
operator
from
functools
import
partial
,
cache
,
reduce
from
typing
import
Optional
,
Union
from
packaging
import
version
import
jax
import
jax.numpy
as
jnp
from
jax
import
dtypes
from
jax.interpreters
import
mlir
from
jax.interpreters.mlir
import
ir
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
import
transformer_engine_jax
from
transformer_engine_jax
import
NVTE_Norm_Type
from
.base
import
BasePrimitive
,
register_primitive
from
.custom_call
import
custom_caller
,
CustomCallArgsWrapper
from
.misc
import
(
get_padded_spec
,
check_valid_batch_dims
,
jax_dtype_to_te_dtype
,
jax_dtype_to_ir_dtype
,
te_dtype_to_jax_dtype
,
is_ffi_enabled
,
NamedSharding
,
)
from
.quantization
import
_jax_cast_fp8
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
,
all_reduce_sum_along_dp_fsdp
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
(
Quantizer
,
QuantizeAxis
,
DelayedScaleQuantizer
,
ScalingMode
,
)
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
from
jax
import
ffi
# pylint: disable=ungrouped-imports
...
...
@@ -41,8 +46,8 @@ __all__ = [
"layernorm_bwd"
,
"rmsnorm_fwd"
,
"rmsnorm_bwd"
,
"
layernorm_fwd_fp8
"
,
"
rms
norm
_fwd_fp8
"
,
"
normalization_fwd
"
,
"norm
alization_bwd
"
,
]
...
...
@@ -58,325 +63,520 @@ def get_backward_sm_margin():
return
int
(
os
.
getenv
(
"NVTE_BWD_LAYERNORM_SM_MARGIN"
,
"0"
))
class
Layer
NormFwdPrimitive
(
BasePrimitive
):
class
NormFwdPrimitive
(
BasePrimitive
):
"""
Layer Normalization Forward Primitive
Layer Normalization Forward
FP8
Primitive
"""
name
=
"te_
layer
norm_forward"
name
=
"te_norm_forward
_ffi
"
multiple_results
=
True
impl_static_args
=
(
3
,
4
)
# zero_centered_gamma, epsilon
impl_static_args
=
(
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
)
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
x_aval
,
gamma_aval
,
beta_aval
,
**
kwargs
):
def
abstract
(
x_aval
,
scale_aval
,
gamma_aval
,
beta_aval
,
*
,
norm_type
,
zero_centered_gamma
,
epsilon
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
):
"""
LayerNorm fwd inner primitive abstract
"""
del
scale_shapes
x_dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
x_dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
mu_rsigama_dtype
=
jnp
.
float32
out_aval
=
x_aval
mu_aval
=
rsigma_aval
=
out_aval
.
update
(
shape
=
out_aval
.
shape
[:
-
1
],
dtype
=
mu_rsigama_dtype
)
if
norm_type
==
NVTE_Norm_Type
.
LayerNorm
:
assert
gamma_aval
.
size
==
beta_aval
.
size
hidden_size
=
gamma_aval
.
size
assert
x_aval
.
size
%
hidden_size
==
0
(
wkspace_info
,)
=
transformer_engine_jax
.
get_layernorm_fwd_workspace_sizes
(
x_aval
.
size
//
hidden_size
,
# batch size
hidden_size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# in te_dtype
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# weight te_dtype
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# out te_dtype (same as input for Fp16/Bf16)
True
,
kwargs
[
"zero_centered_gamma"
],
kwargs
[
"epsilon"
],
(
wkspace_info
,)
=
transformer_engine_jax
.
get_norm_fwd_workspace_sizes
(
x_aval
.
size
//
gamma_aval
.
size
,
# batch size
gamma_aval
.
size
,
# hidden size
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# itype
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# wtype
jax_dtype_to_te_dtype
(
out_dtype
),
norm_type
,
scaling_mode
.
value
,
zero_centered_gamma
,
epsilon
,
get_forward_sm_margin
(),
is_2x
,
)
out_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
mu_aval
=
rsigma_aval
=
out_aval
.
update
(
shape
=
out_aval
.
shape
[:
-
1
],
dtype
=
mu_rsigama_dtype
)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
:
mu_aval
=
mu_aval
.
update
(
shape
=
(
1
,))
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
scaling_mode
.
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
)
wkspace_aval
=
out_aval
.
update
(
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
x_aval
.
shape
if
is_2x
else
(
1
,),
dtype
=
out_dtype
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
wkspace_aval
=
x_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
return
out_aval
,
mu_aval
,
rsigma_aval
,
wkspace_aval
return
(
out_aval
,
colwise_out_aval
,
scale_inv_aval
,
colwise_scale_inv_aval
,
updated_amax_aval
,
mu_aval
,
rsigma_aval
,
wkspace_aval
,
)
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
LayerNorm fwd outer primitive abstract
"""
out_aval
,
mu_aval
,
rsigma_aval
,
_
=
LayerNormFwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
out_aval
,
mu_aval
,
rsigma_aval
(
out_aval
,
colwise_out_aval
,
scale_inv_aval
,
colwise_scale_inv_aval
,
updated_amax_aval
,
mu_aval
,
rsigma_aval
,
_
,
)
=
NormFwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
(
out_aval
,
colwise_out_aval
,
scale_inv_aval
,
colwise_scale_inv_aval
,
updated_amax_aval
,
mu_aval
,
rsigma_aval
,
)
@
staticmethod
def
lowering
(
ctx
,
x
,
gamma
,
beta
,
*
,
zero_centered_gamma
,
epsilon
):
def
lowering
(
ctx
,
x
,
scale
,
gamma
,
beta
,
*
,
norm_type
,
zero_centered_gamma
,
epsilon
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
):
"""
LayerNorm fwd lowering rules
"""
x_aval
,
gamma_aval
,
beta_aval
=
ctx
.
avals_in
assert
gamma_aval
.
dtype
==
beta_aval
.
dtype
x_type
=
ir
.
RankedTensorType
(
x
.
type
)
x_shape
=
x_type
.
shape
del
out_dtype
,
scale_dtype
,
scale_shapes
,
is_outer
x_aval
,
scale_aval
,
gamma_aval
,
beta_aval
=
ctx
.
avals_in
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
g_type
=
ir
.
RankedTensorType
(
gamma
.
type
)
g_shape
=
g_type
.
shape
if
norm_type
==
NVTE_Norm_Type
.
LayerNorm
:
assert
gamma_aval
.
dtype
==
beta_aval
.
dtype
b_type
=
ir
.
RankedTensorType
(
beta
.
type
)
b_shape
=
b_type
.
shape
assert
g_type
==
b_type
assert
g_shape
==
b_shape
if
is_ffi_enabled
():
name
=
"te_layernorm_forward_ffi"
sm_margin
=
get_forward_sm_margin
()
out
=
ffi
.
ffi_lowering
(
name
)(
return
ffi
.
ffi_lowering
(
NormFwdPrimitive
.
name
)(
ctx
,
x
,
scale
,
gamma
,
beta
,
norm_type
=
norm_type
.
value
,
zero_centered_gamma
=
zero_centered_gamma
,
eps
=
epsilon
,
epsilon
=
epsilon
,
sm_margin
=
sm_margin
,
scaling_mode
=
scaling_mode
.
value
,
is_2x
=
is_2x
,
)
else
:
# Output shape is same as the input shape, but the output type is same as the weight type.
# See ln_api.cpp
output_type
=
g_type
.
element_type
ir_mu_dtype
=
ir
.
F32Type
.
get
()
ir_rsigma_dtype
=
ir
.
F32Type
.
get
()
out_shape
=
x_shape
hidden_size
=
reduce
(
operator
.
mul
,
g_shape
)
batch_shape
=
out_shape
[:
-
1
]
batch_size
=
reduce
(
operator
.
mul
,
x_shape
)
//
hidden_size
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
out_shape
,
output_type
),
ir
.
RankedTensorType
.
get
(
batch_shape
,
ir_mu_dtype
),
ir
.
RankedTensorType
.
get
(
batch_shape
,
ir_rsigma_dtype
),
ir
.
RankedTensorType
.
get
(
wkspace_aval
.
shape
,
jax_dtype_to_ir_dtype
(
wkspace_aval
.
dtype
)
),
]
operands
=
[
x
,
gamma
,
beta
]
operand_shapes
=
[
x_shape
,
g_shape
,
b_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
sm_margin
=
get_forward_sm_margin
()
opaque
=
transformer_engine_jax
.
pack_norm_descriptor
(
batch_size
,
hidden_size
,
wkspace_aval
.
siz
e
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
)
,
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
)
,
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
)
,
@
staticmethod
def
impl
(
x
,
scal
e
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
epsilon
,
sm_margin
,
)
out
=
custom_caller
(
LayerNormFwdPrimitive
.
name
,
args
,
opaque
,
False
)
return
out
@
staticmethod
def
impl
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
):
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
):
"""
to describe implementation
"""
assert
LayerNormFwdPrimitive
.
inner_primitive
is
not
None
out
,
mu
,
rsigma
,
_
=
LayerNormFwdPrimitive
.
inner_primitive
.
bind
(
x
,
gamma
,
beta
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
)
return
out
,
mu
,
rsigma
del
is_outer
assert
NormFwdPrimitive
.
inner_primitive
is
not
None
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
mu
,
rsigma
,
_
,
)
=
NormFwdPrimitive
.
inner_primitive
.
bind
(
x
,
scale
,
gamma
,
beta
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_outer
=
False
,
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
scaling_mode
.
get_scale_shape_2x
(
x
.
shape
,
is_padded
=
False
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
scale_inv
=
scale_inv
.
flatten
()[
:
reduce
(
operator
.
mul
,
rowwise_scale_inv_shape
)
].
reshape
(
rowwise_scale_inv_shape
)
if
is_2x
:
colwise_scale_inv
=
colwise_scale_inv
.
flatten
()[
:
reduce
(
operator
.
mul
,
colwise_scale_inv_shape
)
].
reshape
(
colwise_scale_inv_shape
)
return
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
mu
,
rsigma
,
)
# Exclude wkspace
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
zero_centered_gamma
,
epsilon
):
def
batcher
(
batched_args
,
batch_dims
,
*
,
norm_type
,
zero_centered_gamma
,
epsilon
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
):
"""
to describe batch rules for vmap
"""
del
is_outer
check_valid_batch_dims
(
batch_dims
)
assert
LayerNormFwdPrimitive
.
outer_primitive
is
not
None
x
,
gamma
,
beta
=
batched_args
x_bdim
,
_
,
_
=
batch_dims
out_bdims
=
x_bdim
,
x_bdim
,
x_bdim
assert
NormFwdPrimitive
.
outer_primitive
is
not
None
x
,
scale
,
gamma
,
beta
=
batched_args
x_bdim
,
scale_bdim
,
_
,
_
=
batch_dims
out_bdims
=
(
x_bdim
,
# rowwise output
scale_bdim
,
# rowwise scale_inv
x_bdim
,
# colwise output
scale_bdim
,
# colwise scale_inv
scale_bdim
,
# amax
x_bdim
,
# mu
x_bdim
,
# rsigma
)
return
(
LayerNormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
gamma
,
beta
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
NormFwdPrimitive
.
outer_primitive
.
bind
(
scale
,
x
,
gamma
,
beta
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
zero_centered_gamma
,
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
zero_centered_gamma
,
epsilon
,
result_infos
def
infer_sharding_from_operands
(
norm_type
,
zero_centered_gamma
,
epsilon
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
arg_infos
,
result_infos
,
):
del
zero_centered_gamma
,
epsilon
,
out_dtype
,
result_infos
del
scale_dtype
,
scale_shapes
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
Layer
NormFwdPrimitive
.
name
}
! "
f
"Does not support to shard hidden dim in
{
NormFwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
mu_sharding
=
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]))
return
(
out_sharding
,
mu_sharding
,
rsigma_sharding
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
),
desc
=
"NormFwdPrimitive.out"
)
if
is_2x
:
colwise_out_sharding
=
out_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.colwise_out"
)
else
:
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.colwise_out"
)
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]),
desc
=
"NormFwdPrimitive.rsigma"
)
mu_sharding
=
rsigma_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.mu"
)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
:
mu_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.mu"
)
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])),
desc
=
"NormFwdPrimitive.scale_inv"
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"NormFwdPrimitive.scale_inv"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.amax"
)
output
=
(
out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
# rowwise
scale_inv_sharding
,
# colwise
amax_sharding
,
mu_sharding
,
rsigma_sharding
,
)
return
output
@
staticmethod
def
partition
(
zero_centered_gamma
,
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
,
g_spec
,
b_spec
=
map
(
get_padded_spec
,
arg_infos
)
def
partition
(
norm_type
,
zero_centered_gamma
,
epsilon
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
arg_infos
,
result_infos
,
):
del
result_infos
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
g_spec
=
get_padded_spec
(
arg_infos
[
2
])
b_spec
=
get_padded_spec
(
arg_infos
[
3
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
Layer
NormFwdPrimitive
.
name
}
! "
f
"Does not support to shard hidden dim in
{
NormFwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if
g_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
Layer
NormFwdPrimitive
.
name
}
does not support sharding of parameter gamma "
f
"
{
NormFwdPrimitive
.
name
}
does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
if
b_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
Layer
NormFwdPrimitive
.
name
}
does not support sharding of parameter beta "
f
"
{
NormFwdPrimitive
.
name
}
does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
g_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
b_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
out_sharding
=
x_sharding
mu_sharding
=
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]))
arg_shardings
=
(
x_sharding
,
g_sharding
,
b_sharding
)
out_shardings
=
(
out_sharding
,
mu_sharding
,
rsigma_sharding
)
impl
=
partial
(
LayerNormFwdPrimitive
.
impl
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
),
desc
=
"NormFwdPrimitive.x"
)
return
mesh
,
impl
,
out_shardings
,
arg_shardings
register_primitive
(
LayerNormFwdPrimitive
)
def
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
eps
):
"""
JAX native layernorm implementation
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
mean
=
jnp
.
mean
(
x_
,
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
normed_input
=
(
x_
-
mean
)
*
jax
.
lax
.
rsqrt
(
var
+
eps
)
if
zero_centered_gamma
:
gamma
+=
1.0
return
jnp
.
asarray
(
normed_input
*
gamma
+
beta
).
astype
(
x
.
dtype
)
def
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
eps
):
"""
JAX native rmsnorm implementation
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
),
axis
=-
1
,
keepdims
=
True
)
normed_input
=
x_
*
jax
.
lax
.
rsqrt
(
var
+
eps
)
if
zero_centered_gamma
:
gamma
+=
1.0
return
jnp
.
asarray
(
normed_input
*
gamma
).
astype
(
x
.
dtype
)
def
_jax_layernorm_fp8
(
x
,
gamma
,
beta
,
scale
,
amax
,
out_dtype
,
zero_centered_gamma
,
eps
):
"""
JAX native layernorm fp8 implementation
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
mean
=
jnp
.
mean
(
x_
,
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
rsigma
=
jax
.
lax
.
rsqrt
(
var
+
eps
)
normed_input
=
(
x_
-
mean
)
*
rsigma
if
zero_centered_gamma
:
gamma
+=
1.0
output
=
normed_input
*
gamma
+
beta
casted_output
,
updated_amax
=
_jax_cast_fp8
(
output
,
scale
,
amax
,
out_dtype
=
out_dtype
)
return
casted_output
,
jnp
.
squeeze
(
mean
,
axis
=-
1
),
jnp
.
squeeze
(
rsigma
,
axis
=-
1
),
updated_amax
g_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.gamma"
)
b_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.beta"
)
out_sharding
=
x_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.out"
)
if
is_2x
:
colwise_out_sharding
=
out_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.colwise_out"
)
else
:
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.colwise_out"
)
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
0
])[:
-
1
]),
desc
=
"NormFwdPrimitive.rsigma"
,
)
mu_sharding
=
rsigma_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.mu"
)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
:
mu_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.mu"
)
scale_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])),
desc
=
"NormFwdPrimitive.scale"
)
scale_inv_sharding
=
scale_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.scale_inv"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.amax"
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"NormFwdPrimitive.scale_inv"
)
arg_shardings
=
(
x_sharding
,
scale_sharding
,
g_sharding
,
b_sharding
)
out_shardings
=
(
out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
# rowwise
scale_inv_sharding
,
# colwise
amax_sharding
,
mu_sharding
,
rsigma_sharding
,
)
def
sharded_impl
(
x
,
scale
,
gamma
,
beta
):
# expect tp and dp giving same shape, or tp being same shape as global
(
local_x
,
local_colwise_x
,
local_scale_inv
,
local_colwise_scale_inv
,
local_amax
,
local_mu
,
local_rsigma
,
)
=
NormFwdPrimitive
.
impl
(
x
,
scale
,
gamma
,
beta
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_outer
=
True
,
)
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
else
:
global_updated_amax
=
local_amax
return
(
local_x
,
local_colwise_x
,
local_scale_inv
,
local_colwise_scale_inv
,
global_updated_amax
,
local_mu
,
local_rsigma
,
)
def
_jax_rmsnorm_fp8
(
x
,
gamma
,
scale
,
amax
,
out_dtype
,
zero_centered_gamma
,
eps
):
"""
JAX native rmsnorm fp8 implementation
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
),
axis
=-
1
,
keepdims
=
True
)
rsigma
=
jax
.
lax
.
rsqrt
(
var
+
eps
)
normed_input
=
x_
*
rsigma
if
zero_centered_gamma
:
gamma
+=
1.0
output
=
normed_input
*
gamma
casted_output
,
updated_amax
=
_jax_cast_fp8
(
output
,
scale
,
amax
,
out_dtype
=
out_dtype
)
return
casted_output
,
jnp
.
squeeze
(
rsigma
,
axis
=-
1
),
updated_amax
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
def
layernorm_fwd
(
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
zero_centered_gamma
:
bool
,
epsilon
:
float
):
"""
Wrapper for TE layernorm fwd
"""
if
not
LayerNormFwdPrimitive
.
enabled
():
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
mu
=
jnp
.
mean
(
x_
,
axis
=-
1
,
keepdims
=
True
)
rsigma
=
jax
.
lax
.
rsqrt
(
jnp
.
mean
(
jnp
.
square
(
x_
-
mu
),
axis
=-
1
,
keepdims
=
True
)
+
epsilon
)
return
(
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
),
jnp
.
squeeze
(
mu
,
axis
=-
1
),
jnp
.
squeeze
(
rsigma
,
axis
=-
1
),
)
return
LayerNormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
gamma
,
beta
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
)
register_primitive
(
NormFwdPrimitive
)
class
Layer
NormBwdPrimitive
(
BasePrimitive
):
class
NormBwdPrimitive
(
BasePrimitive
):
"""
Layer Normalization Backward Primitive
"""
name
=
"te_
layer
norm_backward"
name
=
"te_norm_backward
_ffi
"
multiple_results
=
True
impl_static_args
=
(
5
,
6
)
# zero_centered_gamma
, epsilon
impl_static_args
=
(
5
,
6
)
#
norm_type,
zero_centered_gamma
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
dz_aval
,
x_aval
,
mu_aval
,
rsigma_aval
,
gamma_aval
,
**
kwargs
):
def
abstract
(
dz_aval
,
x_aval
,
mu_aval
,
rsigma_aval
,
gamma_aval
,
norm_type
,
zero_centered_gamma
):
"""
Layernorm
bwd inner primitive abstract
bwd inner primitive abstract
"""
w_dtype
=
dtypes
.
canonicalize_dtype
(
gamma_aval
.
dtype
)
mu_dtype
=
dtypes
.
canonicalize_dtype
(
mu_aval
.
dtype
)
rsigma_dtype
=
dtypes
.
canonicalize_dtype
(
rsigma_aval
.
dtype
)
assert
dtypes
.
canonicalize_dtype
(
dz_aval
.
dtype
)
==
w_dtype
assert
dz_aval
.
shape
==
x_aval
.
shape
if
norm_type
==
NVTE_Norm_Type
.
LayerNorm
:
mu_dtype
=
dtypes
.
canonicalize_dtype
(
mu_aval
.
dtype
)
assert
mu_aval
.
shape
==
rsigma_aval
.
shape
==
x_aval
.
shape
[:
-
1
]
assert
mu_dtype
==
rsigma_dtype
==
jnp
.
float32
dx_aval
=
dz_aval
dgamma_aval
=
dbeta_aval
=
gamma_aval
if
norm_type
!=
NVTE_Norm_Type
.
LayerNorm
:
dbeta_aval
=
dbeta_aval
.
update
(
shape
=
(
1
,))
(
wkspace_info
,)
=
transformer_engine_jax
.
get_
layer
norm_bwd_workspace_sizes
(
(
wkspace_info
,)
=
transformer_engine_jax
.
get_norm_bwd_workspace_sizes
(
x_aval
.
size
//
gamma_aval
.
size
,
# batch size
gamma_aval
.
size
,
# hidden size
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# input te_dtype
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# weight te_dtype
True
,
kwargs
[
"zero_centered_gamma"
],
kwargs
[
"epsilon"
],
norm_type
,
zero_centered_gamma
,
get_backward_sm_margin
(),
)
wkspace_aval
=
dx_aval
.
update
(
...
...
@@ -395,17 +595,14 @@ class LayerNormBwdPrimitive(BasePrimitive):
"""
LayerNorm bwd outer primitive abstract
"""
dx_aval
,
dgamma_aval
,
dbeta_aval
,
_
=
Layer
NormBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
dx_aval
,
dgamma_aval
,
dbeta_aval
,
_
=
NormBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
dx_aval
,
dgamma_aval
,
dbeta_aval
@
staticmethod
def
lowering
(
ctx
,
dz
,
x
,
mu
,
rsigma
,
gamma
,
*
,
zero_centered_gamma
,
epsilon
):
def
lowering
(
ctx
,
dz
,
x
,
mu
,
rsigma
,
gamma
,
*
,
norm_type
,
zero_centered_gamma
):
"""
Layernorm
bwd lowering rules
bwd lowering rules
"""
_
,
x_aval
,
_
,
_
,
gamma_aval
=
ctx
.
avals_in
x_type
=
ir
.
RankedTensorType
(
x
.
type
)
x_shape
=
x_type
.
shape
g_type
=
ir
.
RankedTensorType
(
gamma
.
type
)
g_shape
=
g_type
.
shape
b_type
=
ir
.
RankedTensorType
(
gamma
.
type
)
...
...
@@ -413,1124 +610,644 @@ class LayerNormBwdPrimitive(BasePrimitive):
assert
g_type
==
b_type
assert
g_shape
==
b_shape
if
is_ffi_enabled
():
name
=
"te_layernorm_backward_ffi"
sm_margin
=
get_backward_sm_margin
()
out
=
ffi
.
ffi_lowering
(
name
)(
return
ffi
.
ffi_lowering
(
NormBwdPrimitive
.
name
)(
ctx
,
dz
,
x
,
mu
,
rsigma
,
gamma
,
norm_type
=
norm_type
.
value
,
zero_centered_gamma
=
zero_centered_gamma
,
eps
=
epsilon
,
sm_margin
=
sm_margin
,
)
else
:
dz_shape
=
ir
.
RankedTensorType
(
dz
.
type
).
shape
mu_shape
=
ir
.
RankedTensorType
(
mu
.
type
).
shape
rsigma_shape
=
ir
.
RankedTensorType
(
rsigma
.
type
).
shape
hidden_size
=
reduce
(
operator
.
mul
,
g_shape
)
batch_size
=
reduce
(
operator
.
mul
,
x_shape
)
//
hidden_size
out_types
=
[
ir
.
RankedTensorType
.
get
(
output
.
shape
,
mlir
.
dtype_to_ir_type
(
output
.
dtype
))
for
output
in
ctx
.
avals_out
]
operands
=
[
dz
,
mu
,
rsigma
,
x
,
gamma
]
operand_shapes
=
[
dz_shape
,
mu_shape
,
rsigma_shape
,
x_shape
,
g_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
sm_margin
=
get_backward_sm_margin
()
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
opaque
=
transformer_engine_jax
.
pack_norm_descriptor
(
batch_size
,
hidden_size
,
wkspace_aval
.
size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
zero_centered_gamma
,
epsilon
,
sm_margin
,
)
out
=
custom_caller
(
LayerNormBwdPrimitive
.
name
,
args
,
opaque
,
False
)
return
out
@
staticmethod
def
impl
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
zero_centered_gamma
,
epsilon
):
assert
Layer
NormBwdPrimitive
.
inner_primitive
is
not
None
dx
,
dgamma
,
dbeta
,
_
=
Layer
NormBwdPrimitive
.
inner_primitive
.
bind
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
def
impl
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
norm_type
,
zero_centered_gamma
):
assert
NormBwdPrimitive
.
inner_primitive
is
not
None
dx
,
dgamma
,
dbeta
,
_
=
NormBwdPrimitive
.
inner_primitive
.
bind
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
)
return
dx
,
dgamma
,
dbeta
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
zero_centered_gamma
,
epsilon
):
def
batcher
(
batched_args
,
batch_dims
,
*
,
norm_type
,
zero_centered_gamma
):
check_valid_batch_dims
(
batch_dims
)
assert
Layer
NormBwdPrimitive
.
outer_primitive
is
not
None
assert
NormBwdPrimitive
.
outer_primitive
is
not
None
dz
,
x
,
mu
,
rsigma
,
gamma
=
batched_args
_
,
x_bdim
,
_
,
_
,
gamma_bdim
=
batch_dims
out_bdims
=
x_bdim
,
gamma_bdim
,
gamma_bdim
return
(
LayerNormBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
NormBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
zero_centered_gamma
,
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
zero_centered_gamma
,
epsilon
,
result_infos
def
infer_sharding_from_operands
(
norm_type
,
zero_centered_gamma
,
mesh
,
arg_infos
,
result_infos
):
del
norm_type
,
zero_centered_gamma
,
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
Layer
NormBwdPrimitive
.
name
}
! "
f
"Does not support to shard hidden dim in
{
NormBwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_b_spec
=
get_padded_spec
(
arg_infos
[
4
])
if
g_b_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
Layer
NormBwdPrimitive
.
name
}
does not support sharding of gradients "
"of gamma and beta of
Layernorm
"
f
"
{
NormBwdPrimitive
.
name
}
does not support sharding of gradients "
"of gamma and beta of "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
dgamma_sharding
=
dbeta_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
dx_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
),
desc
=
"NormBwdPrimitive.dx"
)
dgamma_sharding
=
dbeta_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormBwdPrimitive.dgamma"
)
return
dx_sharding
,
dgamma_sharding
,
dbeta_sharding
@
staticmethod
def
partition
(
zero_centered_gamma
,
epsilon
,
mesh
,
arg_infos
,
result_infos
):
def
partition
(
norm_type
,
zero_centered_gamma
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
Layer
NormBwdPrimitive
.
name
}
! "
f
"Does not support to shard hidden dim in
{
NormBwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_b_spec
=
get_padded_spec
(
arg_infos
[
4
])
if
g_b_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
Layer
NormBwdPrimitive
.
name
}
does not support sharding of gradients "
"of gamma and beta of
Layernorm
"
f
"
{
NormBwdPrimitive
.
name
}
does not support sharding of gradients "
"of gamma and beta of "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
dgamma_sharding
=
dbeta_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
dx_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
),
desc
=
"NormBwdPrimitive.dx"
)
dgamma_sharding
=
dbeta_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormBwdPrimitive.dgamma"
)
out_shardings
=
dx_sharding
,
dgamma_sharding
,
dbeta_sharding
x_shardings
=
(
dx_sharding
,)
*
2
# dz and x should have the same sharding.
mu_shardings
=
(
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
])),)
*
2
arg_shardings
=
(
*
x_shardings
,
*
mu_shardings
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
)))
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]),
desc
=
"NormBwdPrimitive.rsigma"
)
mu_sharding
=
rsigma_sharding
.
duplicate_with_new_description
(
"NormBwdPrimitive.mu"
)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
:
mu_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormBwdPrimitive.mu"
)
arg_shardings
=
(
*
x_shardings
,
mu_sharding
,
rsigma_sharding
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormBwdPrimitive.gamma"
),
)
def
sharded_impl
(
dz
,
x
,
mu
,
rsigma
,
gamma
):
local_dx
,
local_dgamma
,
local_dbeta
=
LayerNormBwdPrimitive
.
impl
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
local_dx
,
local_dgamma
,
local_dbeta
=
NormBwdPrimitive
.
impl
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
norm_type
=
norm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
)
global_dgamma
=
all_reduce_sum_along_dp_fsdp
(
local_dgamma
,
mesh
)
if
norm_type
==
NVTE_Norm_Type
.
LayerNorm
:
global_dbeta
=
all_reduce_sum_along_dp_fsdp
(
local_dbeta
,
mesh
)
else
:
global_dbeta
=
local_dbeta
return
local_dx
,
global_dgamma
,
global_dbeta
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
register_primitive
(
Layer
NormBwdPrimitive
)
register_primitive
(
NormBwdPrimitive
)
def
layernorm_bwd
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
mu
:
jnp
.
ndarray
,
rsigma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
):
def
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
=
None
):
"""
Wrapper for TE layernorm bwd
JAX native layernorm implementation
"""
if
not
LayerNormBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_layernorm
,
zero_centered_gamma
=
zero_centered_gamma
,
eps
=
epsilon
),
x
,
gamma
,
beta
,
)
return
vjp_func
(
dz
)
return
LayerNormBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
)
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
mean
=
jnp
.
mean
(
x_
,
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
rsigma
=
jax
.
lax
.
rsqrt
(
var
+
epsilon
)
normed_input
=
(
x_
-
mean
)
*
rsigma
if
zero_centered_gamma
:
gamma
+=
1.0
output
=
normed_input
*
gamma
+
beta
if
quantizer
:
ln_out
=
quantizer
.
quantize
(
output
,
dq_dtype
=
x
.
dtype
)
else
:
ln_out
=
jnp
.
asarray
(
output
).
astype
(
x
.
dtype
)
class
RmsNormFwdPrimitive
(
BasePrimitive
):
"""
RMS Normalization Forward Primitive
"""
name
=
"te_rmsnorm_forward"
multiple_results
=
True
impl_static_args
=
(
2
,)
# epsilon
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
x_aval
,
gamma_aval
,
**
kwargs
):
"""
RMSNorm fwd inner primitive abstract
"""
x_dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
x_dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
rsigama_dtype
=
jnp
.
float32
out_aval
=
x_aval
rsigma_aval
=
out_aval
.
update
(
shape
=
out_aval
.
shape
[:
-
1
],
dtype
=
rsigama_dtype
)
hidden_size
=
gamma_aval
.
size
assert
x_aval
.
size
%
hidden_size
==
0
(
wkspace_info
,)
=
transformer_engine_jax
.
get_layernorm_fwd_workspace_sizes
(
x_aval
.
size
//
hidden_size
,
# batch size
hidden_size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# in te_dtype
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# weight te_dtype
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# out te_dtype (same as input for Fp16/Bf16)
False
,
False
,
kwargs
[
"epsilon"
],
get_forward_sm_margin
(),
)
wkspace_aval
=
out_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
return
out_aval
,
rsigma_aval
,
wkspace_aval
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
RMSNorm fwd outer primitive abstract
"""
out_aval
,
rsigma_aval
,
_
=
RmsNormFwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
out_aval
,
rsigma_aval
@
staticmethod
def
lowering
(
ctx
,
x
,
gamma
,
*
,
epsilon
):
"""
RMSNorm fwd lowering rules
"""
if
is_ffi_enabled
():
name
=
"te_rmsnorm_forward_ffi"
sm_margin
=
get_forward_sm_margin
()
zero_centered_gamma
=
False
# RMSNorm doesn't support zero_centered_gamma
out
=
ffi
.
ffi_lowering
(
name
)(
ctx
,
x
,
gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
eps
=
epsilon
,
sm_margin
=
sm_margin
,
)
else
:
x_aval
,
gamma_aval
=
ctx
.
avals_in
x_type
=
ir
.
RankedTensorType
(
x
.
type
)
x_shape
=
x_type
.
shape
g_type
=
ir
.
RankedTensorType
(
gamma
.
type
)
g_shape
=
g_type
.
shape
rsigma_element_type
=
ir
.
F32Type
.
get
()
out_shape
=
x_shape
hidden_size
=
reduce
(
operator
.
mul
,
g_shape
)
batch_shape
=
out_shape
[:
-
1
]
batch_size
=
reduce
(
operator
.
mul
,
x_shape
)
//
hidden_size
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
out_shape
,
x_type
.
element_type
),
ir
.
RankedTensorType
.
get
(
batch_shape
,
rsigma_element_type
),
ir
.
RankedTensorType
.
get
(
wkspace_aval
.
shape
,
jax_dtype_to_ir_dtype
(
wkspace_aval
.
dtype
)
),
]
operands
=
[
x
,
gamma
]
operand_shapes
=
[
x_shape
,
g_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
sm_margin
=
get_forward_sm_margin
()
opaque
=
transformer_engine_jax
.
pack_norm_descriptor
(
batch_size
,
hidden_size
,
wkspace_aval
.
size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
False
,
# RMSNorm doesn't support zero_centered_gamma
epsilon
,
sm_margin
,
)
out
=
custom_caller
(
RmsNormFwdPrimitive
.
name
,
args
,
opaque
,
False
)
return
out
@
staticmethod
def
impl
(
x
,
gamma
,
epsilon
):
"""
to describe implementation
"""
assert
RmsNormFwdPrimitive
.
inner_primitive
is
not
None
out
,
rsigma
,
_
=
RmsNormFwdPrimitive
.
inner_primitive
.
bind
(
x
,
gamma
,
epsilon
=
epsilon
)
return
out
,
rsigma
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
epsilon
):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims
(
batch_dims
)
assert
RmsNormFwdPrimitive
.
outer_primitive
is
not
None
x
,
gamma
=
batched_args
x_bdim
,
_
=
batch_dims
out_bdims
=
x_bdim
,
x_bdim
return
RmsNormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
gamma
,
epsilon
=
epsilon
),
out_bdims
@
staticmethod
def
infer_sharding_from_operands
(
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
epsilon
,
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
RmsNormFwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]))
return
(
out_sharding
,
rsigma_sharding
)
@
staticmethod
def
partition
(
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
,
g_spec
=
map
(
get_padded_spec
,
arg_infos
)
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
RmsNormFwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if
g_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
RmsNormFwdPrimitive
.
name
}
does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
g_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
out_sharding
=
x_sharding
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]))
arg_shardings
=
(
x_sharding
,
g_sharding
)
out_shardings
=
(
out_sharding
,
rsigma_sharding
)
impl
=
partial
(
RmsNormFwdPrimitive
.
impl
,
epsilon
=
epsilon
)
return
mesh
,
impl
,
out_shardings
,
arg_shardings
register_primitive
(
RmsNormFwdPrimitive
)
def
rmsnorm_fwd
(
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
epsilon
:
float
):
"""
Wrapper for TE rmsnorm fwd
"""
if
not
RmsNormFwdPrimitive
.
enabled
():
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
rsigma
=
jax
.
lax
.
rsqrt
(
jnp
.
mean
(
jnp
.
square
(
x_
),
axis
=-
1
,
keepdims
=
True
)
+
epsilon
)
return
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
=
False
,
eps
=
epsilon
),
jnp
.
squeeze
(
rsigma
,
axis
=-
1
)
return
RmsNormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
gamma
,
epsilon
=
epsilon
)
class
RmsNormBwdPrimitive
(
BasePrimitive
):
"""
RMS Normalization Backward Primitive
"""
name
=
"te_rmsnorm_backward"
multiple_results
=
True
impl_static_args
=
(
4
,)
# epsilon
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
dz_aval
,
x_aval
,
rsigma_aval
,
gamma_aval
,
**
kwargs
):
"""
RMSNorm bwd inner primitive abstract
"""
w_dtype
=
dtypes
.
canonicalize_dtype
(
gamma_aval
.
dtype
)
rsigma_dtype
=
dtypes
.
canonicalize_dtype
(
rsigma_aval
.
dtype
)
assert
dtypes
.
canonicalize_dtype
(
dz_aval
.
dtype
)
==
w_dtype
assert
dz_aval
.
shape
==
x_aval
.
shape
assert
rsigma_aval
.
shape
==
x_aval
.
shape
[:
-
1
]
assert
rsigma_dtype
==
jnp
.
float32
dx_aval
=
dz_aval
dgamma_aval
=
gamma_aval
(
wkspace_info
,)
=
transformer_engine_jax
.
get_layernorm_bwd_workspace_sizes
(
x_aval
.
size
//
gamma_aval
.
size
,
# batch size
gamma_aval
.
size
,
# hidden size
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# in te_dtype
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# weight te_dtype
False
,
False
,
kwargs
[
"epsilon"
],
get_backward_sm_margin
(),
)
wkspace_aval
=
dx_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
return
dx_aval
,
dgamma_aval
,
wkspace_aval
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
RMSNorm bwd outer primitive abstract
"""
dx_aval
,
dgamma_aval
,
_
=
RmsNormBwdPrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
dx_aval
,
dgamma_aval
@
staticmethod
def
lowering
(
ctx
,
dz
,
x
,
rsigma
,
gamma
,
*
,
epsilon
):
"""
RMSNorm bwd lowering rules
"""
if
is_ffi_enabled
():
name
=
"te_rmsnorm_backward_ffi"
sm_margin
=
get_backward_sm_margin
()
zero_centered_gamma
=
False
# RMSNorm doesn't support zero_centered_gamma
out
=
ffi
.
ffi_lowering
(
name
)(
ctx
,
dz
,
x
,
rsigma
,
gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
eps
=
epsilon
,
sm_margin
=
sm_margin
,
)
else
:
_
,
x_aval
,
_
,
gamma_aval
=
ctx
.
avals_in
x_type
=
ir
.
RankedTensorType
(
x
.
type
)
x_shape
=
x_type
.
shape
g_type
=
ir
.
RankedTensorType
(
gamma
.
type
)
g_shape
=
g_type
.
shape
dz_shape
=
ir
.
RankedTensorType
(
dz
.
type
).
shape
rsigma_shape
=
ir
.
RankedTensorType
(
rsigma
.
type
).
shape
hidden_size
=
reduce
(
operator
.
mul
,
g_shape
)
batch_size
=
reduce
(
operator
.
mul
,
x_shape
)
//
hidden_size
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
x_shape
,
x_type
.
element_type
),
ir
.
RankedTensorType
.
get
(
g_shape
,
g_type
.
element_type
),
ir
.
RankedTensorType
.
get
(
wkspace_aval
.
shape
,
jax_dtype_to_ir_dtype
(
wkspace_aval
.
dtype
)
),
]
operands
=
[
dz
,
rsigma
,
x
,
gamma
]
operand_shapes
=
[
dz_shape
,
rsigma_shape
,
x_shape
,
g_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
sm_margin
=
get_backward_sm_margin
()
opaque
=
transformer_engine_jax
.
pack_norm_descriptor
(
batch_size
,
hidden_size
,
wkspace_aval
.
size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
False
,
# RMSNorm doesn't support zero_centered_gamma
epsilon
,
sm_margin
,
)
out
=
custom_caller
(
RmsNormBwdPrimitive
.
name
,
args
,
opaque
,
False
)
return
out
@
staticmethod
def
impl
(
dz
,
x
,
rsigma
,
gamma
,
epsilon
):
assert
RmsNormBwdPrimitive
.
inner_primitive
is
not
None
dx
,
dgamma
,
_
=
RmsNormBwdPrimitive
.
inner_primitive
.
bind
(
dz
,
x
,
rsigma
,
gamma
,
epsilon
=
epsilon
)
return
dx
,
dgamma
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
epsilon
):
check_valid_batch_dims
(
batch_dims
)
assert
RmsNormBwdPrimitive
.
outer_primitive
is
not
None
dz
,
x
,
rsigma
,
gamma
=
batched_args
_
,
x_bdim
,
_
,
gamma_bdim
=
batch_dims
out_bdims
=
x_bdim
,
gamma_bdim
return
(
RmsNormBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
rsigma
,
gamma
,
epsilon
=
epsilon
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
epsilon
,
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
RmsNormBwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_spec
=
get_padded_spec
(
arg_infos
[
3
])
if
g_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
RmsNormBwdPrimitive
.
name
}
does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
dgamma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
return
dx_sharding
,
dgamma_sharding
@
staticmethod
def
partition
(
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
RmsNormBwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
g_spec
=
get_padded_spec
(
arg_infos
[
3
])
if
g_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
RmsNormBwdPrimitive
.
name
}
does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
dx_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
dgamma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
out_shardings
=
dx_sharding
,
dgamma_sharding
x_shardings
=
(
dx_sharding
,)
*
2
# dz and x should have the same sharding.
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]))
arg_shardings
=
(
*
x_shardings
,
rsigma_sharding
,
NamedSharding
(
mesh
,
PartitionSpec
(
None
)))
def
sharded_impl
(
dz
,
x
,
rsigma
,
gamma
):
local_dx
,
local_dgamma
=
RmsNormBwdPrimitive
.
impl
(
dz
,
x
,
rsigma
,
gamma
,
epsilon
=
epsilon
)
global_dgamma
=
all_reduce_sum_along_dp_fsdp
(
local_dgamma
,
mesh
)
return
local_dx
,
global_dgamma
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
register_primitive
(
RmsNormBwdPrimitive
)
def
rmsnorm_bwd
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
rsigma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
epsilon
:
float
):
"""
Wrapper for TE layernorm bwd
"""
if
not
RmsNormBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_rmsnorm
,
zero_centered_gamma
=
False
,
eps
=
epsilon
),
x
,
gamma
)
return
vjp_func
(
dz
)
return
RmsNormBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
rsigma
,
gamma
,
epsilon
=
epsilon
)
class
LayerNormFwdFp8Primitive
(
BasePrimitive
):
"""
Layer Normalization Forward FP8 Primitive
"""
name
=
"te_layernorm_forward_fp8"
multiple_results
=
True
impl_static_args
=
(
6
,
7
,
8
)
# out_type, zero_centered_gamma, epsilon
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
x_aval
,
gamma_aval
,
beta_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
,
*
,
out_dtype
,
zero_centered_gamma
,
epsilon
,
):
"""
LayerNorm fwd (fp8 out) inner primitive abstract
"""
x_dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
x_dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
mu_rsigama_dtype
=
jnp
.
float32
assert
gamma_aval
.
size
==
beta_aval
.
size
(
wkspace_info
,)
=
transformer_engine_jax
.
get_layernorm_fwd_workspace_sizes
(
x_aval
.
size
//
gamma_aval
.
size
,
# batch size
gamma_aval
.
size
,
# hidden size
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# in type
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# weight type
jax_dtype_to_te_dtype
(
out_dtype
),
True
,
zero_centered_gamma
,
epsilon
,
get_forward_sm_margin
(),
)
out_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
mu_aval
=
rsigma_aval
=
out_aval
.
update
(
shape
=
out_aval
.
shape
[:
-
1
],
dtype
=
mu_rsigama_dtype
)
updated_amax_aval
=
amax_aval
.
update
(
shape
=
amax_aval
.
shape
,
dtype
=
amax_aval
.
dtype
)
wkspace_aval
=
x_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
return
out_aval
,
mu_aval
,
rsigma_aval
,
updated_amax_aval
,
wkspace_aval
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
LayerNorm fwd (fp8 out) outer primitive abstract
"""
out_aval
,
mu_aval
,
rsigma_aval
,
updated_amax_aval
,
_
=
LayerNormFwdFp8Primitive
.
abstract
(
*
args
,
**
kwargs
)
return
out_aval
,
mu_aval
,
rsigma_aval
,
updated_amax_aval
@
staticmethod
def
lowering
(
ctx
,
x
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
,
*
,
out_dtype
,
zero_centered_gamma
,
epsilon
):
"""
LayerNorm fwd (fp8 out) lowering rules
"""
x_aval
,
gamma_aval
,
beta_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
=
ctx
.
avals_in
# Currently only support casting to E4M3 only in C side.
assert
out_dtype
==
jnp
.
float8_e4m3fn
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
gamma_aval
.
dtype
==
beta_aval
.
dtype
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
x_type
=
ir
.
RankedTensorType
(
x
.
type
)
x_shape
=
x_type
.
shape
g_type
=
ir
.
RankedTensorType
(
gamma
.
type
)
g_shape
=
g_type
.
shape
b_type
=
ir
.
RankedTensorType
(
beta
.
type
)
b_shape
=
b_type
.
shape
assert
g_type
==
b_type
assert
g_shape
==
b_shape
if
is_ffi_enabled
():
name
=
"te_layernorm_forward_fp8_ffi"
sm_margin
=
get_forward_sm_margin
()
out
=
ffi
.
ffi_lowering
(
name
,
operand_output_aliases
=
{
3
:
3
})(
ctx
,
x
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
,
zero_centered_gamma
=
zero_centered_gamma
,
eps
=
epsilon
,
sm_margin
=
sm_margin
,
)
else
:
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
out_dtype
)
ir_mu_dtype
=
ir
.
F32Type
.
get
()
ir_rsigma_dtype
=
ir
.
F32Type
.
get
()
ir_amax_type
=
ir
.
RankedTensorType
(
amax
.
type
)
ir_amax_dtype
=
ir_amax_type
.
element_type
ir_amax_shape
=
ir_amax_type
.
shape
ir_scale_shape
=
ir_amax_shape
ir_scale_inv_shape
=
ir_amax_shape
out_shape
=
x_shape
hidden_size
=
reduce
(
operator
.
mul
,
g_shape
)
batch_shape
=
out_shape
[:
-
1
]
batch_size
=
reduce
(
operator
.
mul
,
x_shape
)
//
hidden_size
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
out_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
batch_shape
,
ir_mu_dtype
),
ir
.
RankedTensorType
.
get
(
batch_shape
,
ir_rsigma_dtype
),
ir
.
RankedTensorType
.
get
(
ir_amax_shape
,
ir_amax_dtype
),
ir
.
RankedTensorType
.
get
(
wkspace_aval
.
shape
,
jax_dtype_to_ir_dtype
(
wkspace_aval
.
dtype
)
),
]
operands
=
[
x
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
]
operand_shapes
=
[
x_shape
,
g_shape
,
b_shape
,
ir_amax_shape
,
ir_scale_shape
,
ir_scale_inv_shape
,
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
return
ln_out
,
jnp
.
squeeze
(
mean
,
axis
=-
1
),
jnp
.
squeeze
(
rsigma
,
axis
=-
1
)
sm_margin
=
get_forward_sm_margin
()
opaque
=
transformer_engine_jax
.
pack_norm_descriptor
(
batch_size
,
hidden_size
,
wkspace_aval
.
size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
gamma_aval
.
dtyp
e
)
,
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
zero_centered_gam
ma
,
epsilon
,
sm_margin
,
)
def
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
=
None
):
"""
JAX native rmsnorm implementation
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
),
axis
=-
1
,
keepdims
=
Tru
e
)
rsigma
=
jax
.
lax
.
rsqrt
(
var
+
epsilon
)
normed_input
=
x_
*
rsig
ma
if
zero_centered_gamma
:
gamma
+=
1.0
output
=
normed_input
*
gamma
out
=
custom_caller
(
LayerNormFwdFp8Primitive
.
name
,
args
,
opaque
,
False
,
operand_output_aliases
=
{
3
:
3
}
)
if
quantizer
:
ln_out
=
quantizer
.
quantize
(
output
,
dq_dtype
=
x
.
dtype
)
else
:
ln_out
=
jnp
.
asarray
(
output
).
astype
(
x
.
dtype
)
return
out
return
ln_out
,
jnp
.
squeeze
(
rsigma
,
axis
=-
1
)
@
staticmethod
def
impl
(
x
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
,
out_dtype
,
zero_centered_gamma
,
epsilon
):
"""
to describe implementation
"""
assert
LayerNormFwdFp8Primitive
.
inner_primitive
is
not
None
out
,
mu
,
rsigma
,
updated_amax
,
_
=
LayerNormFwdFp8Primitive
.
inner_primitive
.
bind
(
def
layernorm_fwd
(
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
quantizer
:
Optional
[
Quantizer
],
)
->
tuple
[
Union
[
jnp
.
ndarray
,
ScaledTensor
],
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""Layer normalization forward pass with optional quantization.
Args:
x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor. Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Mean of the input tensor. Shape: (..., 1)
- Reciprocal of the standard deviation of the input tensor. Shape: (..., 1)
"""
if
not
NormFwdPrimitive
.
enabled
():
return
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
)
# TE/common does not support normalization with colwise only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_axis
==
QuantizeAxis
.
COLWISE
:
return
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
)
scale
=
(
quantizer
.
scale
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
)
else
jnp
.
ones
((
1
,),
dtype
=
jnp
.
float32
)
)
if
quantizer
is
None
:
output
,
_
,
_
,
_
,
_
,
mu
,
rsigma
=
NormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
norm_type
=
NVTE_Norm_Type
.
LayerNorm
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
)
return
out
,
mu
,
rsigma
,
updated_amax
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
zero_centered_gamma
,
epsilon
):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims
(
batch_dims
)
assert
LayerNormFwdFp8Primitive
.
outer_primitive
is
not
None
x
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
=
batched_args
x_bdim
,
_
,
_
,
amax_bdim
,
_
,
_
=
batch_dims
out_bdims
=
x_bdim
,
x_bdim
,
x_bdim
,
amax_bdim
return
(
LayerNormFwdFp8Primitive
.
outer_primitive
.
bind
(
out_dtype
=
x
.
dtype
,
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
is_2x
=
False
,
scale_dtype
=
jnp
.
float32
,
scale_shapes
=
((
1
,),
(
1
,)),
is_outer
=
True
,
)
return
output
,
mu
,
rsigma
is_2x2x
=
quantizer
.
is_2x2x
()
# TE/common normalization doesn't support 2x delayed scaling
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
is_2x2x
=
False
(
rowwise_casted_output
,
colwise_casted_output
,
rowwise_scale_inv
,
colwise_scale_inv
,
updated_amax
,
mu
,
rsigma
,
)
=
NormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
norm_type
=
NVTE_Norm_Type
.
LayerNorm
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
zero_centered_gamma
,
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
out_dtype
,
zero_centered_gamma
,
epsilon
,
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
LayerNormFwdPrimitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_dtype
=
quantizer
.
q_dtype
,
scaling_mode
=
quantizer
.
scaling_mode
,
is_2x
=
is_2x2x
,
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
),
is_outer
=
True
,
)
quantizer
.
update
(
updated_amax
)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
colwise_casted_output
=
jnp
.
transpose
(
rowwise_casted_output
,
(
-
1
,
*
range
(
rowwise_casted_output
.
ndim
-
1
))
)
colwise_scale_inv
=
rowwise_scale_inv
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
rowwise_unpadded_shape
,
colwise_unpadded_shape
=
quantizer
.
get_scale_shapes
(
x
.
shape
,
is_padded
=
False
)
rowwise_scale_inv
=
rowwise_scale_inv
.
flatten
()[
:
reduce
(
operator
.
mul
,
rowwise_unpadded_shape
)
].
reshape
(
rowwise_unpadded_shape
)
colwise_scale_inv
=
colwise_scale_inv
.
flatten
()[
:
reduce
(
operator
.
mul
,
colwise_unpadded_shape
)
].
reshape
(
colwise_unpadded_shape
)
scaled_tensor
=
ScaledTensorFactory
.
create
(
data
=
rowwise_casted_output
,
scale_inv
=
rowwise_scale_inv
,
colwise_data
=
colwise_casted_output
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
x
.
dtype
,
q_axis
=
quantizer
.
q_axis
,
layout
=
quantizer
.
get_layout
(),
)
return
scaled_tensor
,
mu
,
rsigma
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
mu_sharding
=
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
3
])))
return
(
out_sharding
,
mu_sharding
,
rsigma_sharding
,
amax_sharding
)
@
staticmethod
def
partition
(
out_dtype
,
zero_centered_gamma
,
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
g_spec
=
get_padded_spec
(
arg_infos
[
1
])
b_spec
=
get_padded_spec
(
arg_infos
[
2
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
LayerNormFwdFp8Primitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if
g_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
LayerNormFwdFp8Primitive
.
name
}
does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
if
b_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
LayerNormFwdFp8Primitive
.
name
}
does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
g_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
b_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
out_sharding
=
x_sharding
mu_sharding
=
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
0
])[:
-
1
])
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
3
])))
fp8_meta_sharding
=
amax_sharding
arg_shardings
=
(
x_sharding
,
g_sharding
,
b_sharding
)
+
(
fp8_meta_sharding
,)
*
3
out_shardings
=
(
out_sharding
,
mu_sharding
,
rsigma_sharding
,
amax_sharding
)
def
sharded_impl
(
x
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
):
local_x
,
local_mu
,
local_rsigma
,
local_amax
=
LayerNormFwdFp8Primitive
.
impl
(
def
layernorm_bwd
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
mu
:
jnp
.
ndarray
,
rsigma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
):
"""Layer normalization backward pass.
Args:
dz: Gradient of the output with respect to the normalized output.
Shape: (..., K) where K is the hidden size.
x: Input tensor that was normalized in the forward pass.
Shape: (..., K)
mu: Mean of the input tensor from the forward pass.
Shape: (..., 1)
rsigma: Reciprocal of the standard deviation from the forward pass.
Shape: (..., 1)
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
Returns:
A tuple containing:
- Gradient of the input tensor.
Shape: (..., K)
- Gradient of the scale parameter (gamma).
Shape: (K,)
- Gradient of the bias parameter (beta).
Shape: (K,)
"""
if
not
NormBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_layernorm
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
),
x
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
)
mu_empty
=
jnp
.
zeros
(
mu
.
shape
,
mu
.
dtype
)
rsigma_empty
=
jnp
.
zeros
(
rsigma
.
shape
,
rsigma
.
dtype
)
return
vjp_func
((
dz
,
mu_empty
,
rsigma_empty
))
return
NormBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
norm_type
=
NVTE_Norm_Type
.
LayerNorm
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
return
local_x
,
local_mu
,
local_rsigma
,
global_updated_amax
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
register_primitive
(
LayerNormFwdFp8Primitive
)
def
layernorm_fwd_fp8
(
def
rmsnorm_fwd
(
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
out_dtype
:
jnp
.
dtype
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
):
"""
Wrapper for TE layernorm fwd (fp8 out)
"""
if
not
LayerNormFwdFp8Primitive
.
enabled
():
return
_jax_layernorm_fp8
(
quantizer
:
Optional
[
Quantizer
],
)
->
tuple
[
Union
[
jnp
.
ndarray
,
ScaledTensor
],
jnp
.
ndarray
]:
"""Root mean square normalization forward pass with optional quantization.
Args:
x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor.
Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Reciprocal of the root mean square of the input tensor.
Shape: (..., 1)
"""
if
not
NormFwdPrimitive
.
enabled
():
return
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
)
# TE/common does not support normalization with colwise only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_axis
==
QuantizeAxis
.
COLWISE
:
return
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
)
scale
=
(
quantizer
.
scale
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
)
else
jnp
.
ones
((
1
,),
dtype
=
jnp
.
float32
)
)
beta
=
jnp
.
ones
((
1
,),
dtype
=
jnp
.
float32
)
if
quantizer
is
None
:
output
,
_
,
_
,
_
,
_
,
_
,
rsigma
=
NormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
gamma
,
beta
,
scale
,
amax
,
out_dtype
=
out_dtype
,
norm_type
=
NVTE_Norm_Type
.
RMSNorm
,
zero_centered_gamma
=
zero_centered_gamma
,
eps
=
epsilon
,
)
return
LayerNormFwdFp8Primitive
.
outer_primitive
.
bind
(
epsilon
=
epsilon
,
out_dtype
=
x
.
dtype
,
scaling_mode
=
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
,
is_2x
=
False
,
scale_dtype
=
jnp
.
float32
,
scale_shapes
=
((),
()),
is_outer
=
True
,
)
return
output
,
rsigma
is_2x2x
=
quantizer
.
is_2x2x
()
# TE/common normalization doesn't support 2x delayed scaling
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
is_2x2x
=
False
(
rowwise_casted_output
,
colwise_casted_output
,
rowwise_scale_inv
,
colwise_scale_inv
,
updated_amax
,
_
,
rsigma
,
)
=
NormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
gamma
,
beta
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
norm_type
=
NVTE_Norm_Type
.
RMSNorm
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
)
class
RmsNormFwdFp8Primitive
(
BasePrimitive
):
"""
RMS Normalization Forward FP8 Primitive
"""
out_dtype
=
quantizer
.
q_dtype
,
scaling_mode
=
quantizer
.
scaling_mode
,
is_2x
=
is_2x2x
,
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
),
is_outer
=
True
,
)
quantizer
.
update
(
updated_amax
)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
colwise_casted_output
=
jnp
.
transpose
(
rowwise_casted_output
,
(
-
1
,
*
range
(
rowwise_casted_output
.
ndim
-
1
))
)
colwise_scale_inv
=
rowwise_scale_inv
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
rowwise_unpadded_shape
,
colwise_unpadded_shape
=
quantizer
.
get_scale_shapes
(
x
.
shape
,
is_padded
=
False
)
rowwise_scale_inv
=
rowwise_scale_inv
.
flatten
()[
:
reduce
(
operator
.
mul
,
rowwise_unpadded_shape
)
].
reshape
(
rowwise_unpadded_shape
)
colwise_scale_inv
=
colwise_scale_inv
.
flatten
()[
:
reduce
(
operator
.
mul
,
colwise_unpadded_shape
)
].
reshape
(
colwise_unpadded_shape
)
scaled_tensor
=
ScaledTensorFactory
.
create
(
data
=
rowwise_casted_output
,
scale_inv
=
rowwise_scale_inv
,
colwise_data
=
colwise_casted_output
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
x
.
dtype
,
q_axis
=
quantizer
.
q_axis
,
layout
=
quantizer
.
get_layout
(),
)
return
scaled_tensor
,
rsigma
name
=
"te_rmsnorm_forward_fp8"
multiple_results
=
True
impl_static_args
=
(
5
,
6
)
# out_dtype, epsilon
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
x_aval
,
gamma_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
,
out_dtype
,
epsilon
):
"""
RMSNorm fwd (fp8 out) inner primitive abstract
"""
x_dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
x_dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
hidden_size
=
gamma_aval
.
size
assert
x_aval
.
size
%
hidden_size
==
0
rsigama_dtype
=
jnp
.
float32
(
wkspace_info
,)
=
transformer_engine_jax
.
get_layernorm_fwd_workspace_sizes
(
x_aval
.
size
//
hidden_size
,
# batch_size
hidden_size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
# in te_dtype
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# weight te_dtype
jax_dtype_to_te_dtype
(
out_dtype
),
# out te_dtype
False
,
False
,
epsilon
,
get_forward_sm_margin
(),
)
out_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
rsigma_aval
=
out_aval
.
update
(
shape
=
out_aval
.
shape
[:
-
1
],
dtype
=
rsigama_dtype
)
amax_aval
=
out_aval
.
update
(
shape
=
amax_aval
.
shape
,
dtype
=
amax_aval
.
dtype
)
wkspace_aval
=
x_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
def
rmsnorm_bwd
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
rsigma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
):
"""Root mean square normalization backward pass.
Args:
dz: Gradient of the output with respect to the normalized output.
Shape: (..., K) where K is the hidden size.
x: Input tensor that was normalized in the forward pass.
Shape: (..., K)
rsigma: Reciprocal of the root mean square from the forward pass.
Shape: (..., 1)
gamma: Scale parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
Returns:
A tuple containing:
- Gradient of the input tensor.
Shape: (..., K)
- Gradient of the scale parameter (gamma).
Shape: (K,)
"""
if
not
NormBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_rmsnorm
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
),
x
,
gamma
,
)
return
out_aval
,
rsigma_aval
,
amax_aval
,
wkspace_aval
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
RMSNorm fwd (fp8 out) outer primitive abstract
"""
out_aval
,
rsigma_aval
,
amax_aval
,
_
=
RmsNormFwdFp8Primitive
.
abstract
(
*
args
,
**
kwargs
)
return
out_aval
,
rsigma_aval
,
amax_aval
@
staticmethod
def
lowering
(
ctx
,
x
,
gamma
,
amax
,
scale
,
scale_inv
,
*
,
out_dtype
,
epsilon
):
"""
RMSNorm fwd (fp8 out) lowering rules
"""
# Currently only support casting to E4M3 only in C side.
assert
out_dtype
==
jnp
.
float8_e4m3fn
if
is_ffi_enabled
():
name
=
"te_rmsnorm_forward_fp8_ffi"
sm_margin
=
get_forward_sm_margin
()
zero_centered_gamma
=
False
# RMSNorm doesn't support zero_centered_gamma
out
=
ffi
.
ffi_lowering
(
name
,
operand_output_aliases
=
{
2
:
2
})(
ctx
,
rsigma_empty
=
jnp
.
zeros
(
rsigma
.
shape
,
rsigma
.
dtype
)
return
vjp_func
((
dz
,
rsigma_empty
))
mu
=
jnp
.
empty
(())
dx
,
dgamma
,
_
=
NormBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
amax
,
scale
,
scale_inv
,
norm_type
=
NVTE_Norm_Type
.
RMSNorm
,
zero_centered_gamma
=
zero_centered_gamma
,
eps
=
epsilon
,
sm_margin
=
sm_margin
,
)
else
:
x_aval
,
gamma_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
=
ctx
.
avals_in
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
x_type
=
ir
.
RankedTensorType
(
x
.
type
)
x_shape
=
x_type
.
shape
g_type
=
ir
.
RankedTensorType
(
gamma
.
type
)
g_shape
=
g_type
.
shape
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
out_dtype
)
ir_rsigma_dtype
=
ir
.
F32Type
.
get
()
ir_amax_type
=
ir
.
RankedTensorType
(
amax
.
type
)
ir_amax_dtype
=
ir_amax_type
.
element_type
ir_amax_shape
=
ir_amax_type
.
shape
ir_scale_shape
=
ir_amax_shape
ir_scale_inv_shape
=
ir_amax_shape
out_shape
=
x_shape
hidden_size
=
reduce
(
operator
.
mul
,
g_shape
)
batch_shape
=
out_shape
[:
-
1
]
batch_size
=
reduce
(
operator
.
mul
,
x_shape
)
//
hidden_size
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
out_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
batch_shape
,
ir_rsigma_dtype
),
ir
.
RankedTensorType
.
get
(
ir_amax_shape
,
ir_amax_dtype
),
ir
.
RankedTensorType
.
get
(
wkspace_aval
.
shape
,
jax_dtype_to_ir_dtype
(
wkspace_aval
.
dtype
)
),
]
operands
=
[
x
,
gamma
,
amax
,
scale
,
scale_inv
]
operand_shapes
=
[
x_shape
,
g_shape
,
ir_amax_shape
,
ir_scale_shape
,
ir_scale_inv_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
sm_margin
=
get_forward_sm_margin
()
opaque
=
transformer_engine_jax
.
pack_norm_descriptor
(
batch_size
,
hidden_size
,
wkspace_aval
.
size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
False
,
# RMSNorm doesn't support zero_centered_gamma
epsilon
,
sm_margin
,
)
return
(
dx
,
dgamma
)
out
=
custom_caller
(
RmsNormFwdFp8Primitive
.
name
,
args
,
opaque
,
False
,
operand_output_aliases
=
{
2
:
2
}
)
return
out
@
staticmethod
def
impl
(
x
,
gamma
,
amax
,
scale
,
scale_inv
,
out_dtype
,
epsilon
):
"""
to describe implementation
"""
assert
RmsNormFwdFp8Primitive
.
inner_primitive
is
not
None
out
,
rsigma
,
amax
,
_
=
RmsNormFwdFp8Primitive
.
inner_primitive
.
bind
(
x
,
gamma
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
epsilon
=
epsilon
)
return
out
,
rsigma
,
amax
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
epsilon
):
"""
to describe batch rules for vmap
"""
check_valid_batch_dims
(
batch_dims
)
assert
RmsNormFwdFp8Primitive
.
outer_primitive
is
not
None
x
,
gamma
,
amax
,
scale
,
scale_inv
=
batched_args
x_bdim
,
_
,
amax_bdim
,
_
,
_
=
batch_dims
out_bdims
=
x_bdim
,
x_bdim
,
amax_bdim
return
(
RmsNormFwdFp8Primitive
.
outer_primitive
.
bind
(
x
,
gamma
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
epsilon
=
epsilon
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
out_dtype
,
epsilon
,
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
RmsNormFwdFp8Primitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
2
])))
return
(
out_sharding
,
rsigma_sharding
,
amax_sharding
)
@
staticmethod
def
partition
(
out_dtype
,
epsilon
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
g_spec
=
get_padded_spec
(
arg_infos
[
1
])
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
RmsNormFwdFp8Primitive
.
name
}
! "
"Force to not shard the hidden dim, which might introduce extra collective ops, "
"and hurt performance."
)
if
g_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
f
"
{
RmsNormFwdFp8Primitive
.
name
}
does not support sharding of parameter gamma "
"Enforcing no sharding of parameters hidden dim! "
)
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
))
g_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
))
out_sharding
=
x_sharding
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
0
])[:
-
1
]))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
2
])))
fp8_meta_sharding
=
amax_sharding
arg_shardings
=
(
x_sharding
,
g_sharding
)
+
(
fp8_meta_sharding
,)
*
3
out_shardings
=
(
out_sharding
,
rsigma_sharding
,
amax_sharding
)
def
sharded_impl
(
x
,
gamma
,
amax
,
scale
,
scale_inv
):
local_x
,
local_rsigma
,
local_amax
=
RmsNormFwdFp8Primitive
.
impl
(
x
,
gamma
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
epsilon
=
epsilon
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
return
local_x
,
local_rsigma
,
global_updated_amax
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
def
normalization_fwd
(
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
norm_type
:
str
,
quantizer
:
Optional
[
Quantizer
],
):
"""Common wrapper for normalization forward pass.
Args:
x: Input tensor to be normalized.
Shape: (..., K) where K is the hidden size.
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
norm_type: Type of normalization to apply. Must be one of:
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
quantizer: Optional quantizer for FP8 quantization of the output.
Returns:
A tuple containing:
- If quantizer is None:
The normalized input tensor.
Shape: (..., K)
If quantizer is provided:
A ScaledTensor containing the quantized normalized input.
- Mean of the input tensor (None for RMSNorm).
Shape: (..., 1)
- Reciprocal of the standard deviation (or root mean square for RMSNorm).
Shape: (..., 1)
Note:
zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
"""
if
norm_type
==
"layernorm"
:
output
,
mu
,
rsigma
=
layernorm_fwd
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
)
elif
norm_type
==
"rmsnorm"
:
assert
(
not
zero_centered_gamma
),
"zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output
,
rsigma
=
rmsnorm_fwd
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
)
mu
=
None
else
:
raise
ValueError
(
f
"
{
norm_type
=
}
is not supported."
)
register_primitive
(
RmsNormFwdFp8Primitive
)
return
output
,
mu
,
rsigma
def
rmsnorm_fwd_fp8
(
def
normalization_bwd
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
mu
:
jnp
.
ndarray
,
rsigma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
out_dtype
:
jnp
.
dtype
,
beta
:
jnp
.
ndarray
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
norm_type
:
str
,
):
"""
Wrapper for TE rmsnorm fwd (fp8 out)
"""
if
not
RmsNormFwdFp8Primitive
.
enabled
():
return
_jax_rmsnorm_fp8
(
x
,
gamma
,
scale
,
amax
,
out_dtype
=
out_dtype
,
zero_centered_gamma
=
False
,
eps
=
epsilon
)
return
RmsNormFwdFp8Primitive
.
outer_primitive
.
bind
(
x
,
gamma
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
epsilon
=
epsilon
)
"""Common wrapper for normalization backward pass.
Args:
dz: Gradient of the output with respect to the normalized output.
Shape: (..., K) where K is the hidden size.
x: Input tensor that was normalized in the forward pass.
Shape: (..., K)
mu: Mean of the input tensor from the forward pass (None for RMSNorm).
Shape: (..., 1)
rsigma: Reciprocal of the standard deviation (or root mean square) from the forward pass.
Shape: (..., 1)
gamma: Scale parameter for normalization.
Shape: (K,)
beta: Bias parameter for normalization.
Shape: (K,)
zero_centered_gamma: If True, gamma is zero-centered.
epsilon: Small constant for numerical stability.
norm_type: Type of normalization used in the forward pass. Must be one of:
- 'layernorm': Layer normalization
- 'rmsnorm': Root mean square normalization
Returns:
A tuple containing:
- Gradient of the input tensor.
Shape: (..., K)
- Gradient of the scale parameter (gamma).
Shape: (K,)
- Gradient of the bias parameter (beta) (None for RMSNorm).
Shape: (K,)
Note:
zero_centered_gamma is not supported if norm_type is 'rmsnorm'.
"""
if
norm_type
==
"layernorm"
:
dx
,
dgamma
,
dbeta
=
layernorm_bwd
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
)
elif
norm_type
==
"rmsnorm"
:
assert
(
not
zero_centered_gamma
),
"zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
dx
,
dgamma
=
rmsnorm_bwd
(
dz
,
x
,
rsigma
,
gamma
,
zero_centered_gamma
,
epsilon
)
dbeta
=
None
else
:
raise
ValueError
(
f
"
{
norm_type
=
}
is not supported."
)
return
dx
,
dgamma
,
dbeta
transformer_engine/jax/cpp_extensions/quantization.py
View file @
a207db1d
...
...
@@ -2,28 +2,29 @@
#
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
from
typing
import
Tuple
from
typing
import
Tuple
,
Optional
from
packaging
import
version
import
jax
import
jax.numpy
as
jnp
from
jax
import
dtypes
from
jax.interpreters.mlir
import
ir
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
import
transformer_engine_jax
from
transformer_engine_jax
import
DType
as
TEDType
from
.base
import
BasePrimitive
,
register_primitive
from
.custom_call
import
custom_caller
,
CustomCallArgsWrapper
from
.misc
import
(
get_padded_spec
,
check_valid_batch_dims
,
te_dtype_to_jax_dtype
,
jax_dtype_to_te_dtype
,
jax_dtype_to_ir_dtype
,
is_ffi_enabled
,
multidim_transpose
,
should_apply_1x_fused_dbias_war_for_arch_l_100
,
NamedSharding
,
)
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
,
all_reduce_sum_along_dp_fsdp
from
..quantize
import
ScaledTensor2x
,
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
Quantizer
,
QuantizeAxis
,
DelayedScaleQuantizer
,
ScalingMode
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
from
jax
import
ffi
# pylint: disable=ungrouped-imports
...
...
@@ -31,166 +32,591 @@ else:
from
jax.extend
import
ffi
# pylint: disable=ungrouped-imports
__all__
=
[
"
cast_fp8
"
]
__all__
=
[
"
quantize"
,
"quantize_dbias
"
]
def
_jax_quantize
(
x
,
scale
,
q_dtyp
e
):
class
DBiasQuantizePrimitive
(
BasePrimitiv
e
):
"""
Quan
ti
z
e w
ith scale
Cast Primi
ti
v
e w
rapping nvte_quantize and nvte_quantize_dbias
"""
compute_dtype
=
scale
.
dtype
dtype_max
=
(
jnp
.
finfo
(
q_dtype
).
max
).
astype
(
compute_dtype
)
scaled_x
=
x
.
astype
(
compute_dtype
)
*
scale
clipped_scaled_x
=
jnp
.
clip
(
scaled_x
,
-
dtype_max
,
dtype_max
)
return
clipped_scaled_x
.
astype
(
q_dtype
)
def
_jax_cast_fp8
(
inputs
,
scale
,
amax
,
out_dtype
):
"""
JAX native fp8 casting implementation
"""
casted_output
=
_jax_quantize
(
inputs
,
scale
,
q_dtype
=
out_dtype
)
updated_amax
=
jax
.
lax
.
max
(
amax
,
jnp
.
max
(
jnp
.
abs
(
inputs
)).
astype
(
amax
.
dtype
))
return
casted_output
,
updated_amax
class
CastFP8Primitive
(
BasePrimitive
):
"""
Cast Primitive
"""
name
=
"te_quantize"
name
=
"te_dbias_quantize_ffi"
multiple_results
=
True
impl_static_args
=
(
4
,)
impl_static_args
=
(
2
,
3
,
4
,
5
,
6
,
7
,
8
,
)
# out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
,
*
,
out_dtype
):
def
abstract
(
x_aval
,
scale_aval
,
*
,
out_dtype
,
scaling_mode
,
q_axis
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
,
):
"""
te_
cast
abstract
te_
dbias_quantize_p
abstract
"""
del
scale_shapes
dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
casted_x_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
amax_aval
.
update
(
shape
=
amax_aval
.
shape
,
dtype
=
amax_aval
.
dtype
)
rowwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
out_dtype
)
return
casted_x_aval
,
updated_amax_aval
if
q_axis
in
(
QuantizeAxis
.
ROWWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
rowwise_out_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
@
staticmethod
def
lowering
(
ctx
,
x
,
amax
,
scale
,
scale_inv
,
*
,
out_dtype
):
"""
te_cast lowering rules
"""
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
=
ctx
.
avals_in
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
if
is_ffi_enabled
():
name
=
"te_quantize_ffi"
out
=
ffi
.
ffi_lowering
(
name
,
operand_output_aliases
=
{
1
:
1
})(
ctx
,
x
,
amax
,
scale
,
scale_inv
)
else
:
ir_x_type
=
ir
.
RankedTensorType
(
x
.
type
)
ir_x_shape
=
ir_x_type
.
shape
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
out_dtype
)
ir_amax_type
=
ir
.
RankedTensorType
(
amax
.
type
)
ir_amax_dtype
=
ir_amax_type
.
element_type
ir_amax_shape
=
ir_amax_type
.
shape
ir_scale_shape
=
ir_amax_shape
ir_scale_inv_shape
=
ir_amax_shape
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
).
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
)
out_types
=
[
ir
.
RankedTensorType
.
get
(
ir_x_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
ir_amax_shape
,
ir_amax_dtype
),
]
operands
=
[
x
,
amax
,
scale
,
scale_inv
]
operand_shapes
=
[
ir_x_shape
,
ir_amax_shape
,
ir_scale_shape
,
ir_scale_inv_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
opaque
=
transformer_engine_jax
.
pack_common_descriptor
(
ir_x_shape
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
out_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
scale_dtype
)
dbias_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
wkspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
if
q_axis
in
(
QuantizeAxis
.
COLWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
t_shape
=
multidim_transpose
(
x_aval
.
shape
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
# Don't transpose output for MXFP8
t_shape
=
x_aval
.
shape
colwise_out_aval
=
x_aval
.
update
(
shape
=
t_shape
,
dtype
=
out_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
out
=
custom_caller
(
CastFP8Primitive
.
name
,
args
,
opaque
,
False
,
operand_output_aliases
=
{
1
:
1
}
if
is_dbias
:
gi_hidden_size
=
x_aval
.
shape
[
-
1
]
dbias_shape
=
(
gi_hidden_size
,)
dbias_aval
=
x_aval
.
update
(
shape
=
dbias_shape
,
dtype
=
dtype
)
(
wkspace_info
,)
=
transformer_engine_jax
.
get_dbias_quantize_workspace_sizes
(
x_aval
.
size
//
gi_hidden_size
,
gi_hidden_size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
)
wkspace_aval
=
x_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
return
out
return
(
rowwise_out_aval
,
colwise_out_aval
,
scale_inv_aval
,
colwise_scale_inv_aval
,
updated_amax_aval
,
dbias_aval
,
wkspace_aval
,
)
@
staticmethod
def
impl
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
):
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
te_
cast implementation
te_
dbias_quantize_p outer primitive abstract
"""
assert
CastFP8Primitive
.
inner_primitive
is
not
None
casted_x
,
updated_amax
=
CastFP8Primitive
.
inner_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
_
,
)
=
DBiasQuantizePrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
@
staticmethod
def
lowering
(
ctx
,
x
,
scale
,
*
,
out_dtype
,
scaling_mode
,
q_axis
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
,
):
"""
te_dbias_quantize_p lowering rules
"""
del
out_dtype
,
scale_dtype
,
scale_shapes
,
is_outer
x_aval
,
scale_aval
=
ctx
.
avals_in
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
scale_aval
.
dtype
==
jnp
.
float32
return
ffi
.
ffi_lowering
(
DBiasQuantizePrimitive
.
name
)(
ctx
,
x
,
scale
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
,
is_dbias
=
is_dbias
,
)
return
casted_x
,
updated_amax
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
):
check_valid_batch_dims
(
batch_dims
)
assert
CastFP8Primitive
.
outer_primitive
is
not
None
def
impl
(
x
,
scale
,
out_dtype
,
scaling_mode
,
q_axis
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
,
):
"""
te_dbias_quantize_p implementation
"""
del
is_outer
assert
DBiasQuantizePrimitive
.
inner_primitive
is
not
None
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
_
,
)
=
DBiasQuantizePrimitive
.
inner_primitive
.
bind
(
x
,
scale
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
is_outer
=
False
,
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
).
get_scale_shape_2x
(
x
.
shape
,
is_padded
=
False
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
if
q_axis
in
(
QuantizeAxis
.
ROWWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
scale_inv
=
jax
.
lax
.
slice
(
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
)
if
q_axis
in
(
QuantizeAxis
.
COLWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
colwise_scale_inv
=
jax
.
lax
.
slice
(
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
)
return
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
)
# Exclude wkspace
x
,
amax
,
scale
,
scale_inv
=
batched_args
x_bdim
,
amax_bdim
,
*
_
=
batch_dims
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
scaling_mode
,
q_axis
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
,
):
"""
to describe batch rules for vmap
"""
del
is_outer
check_valid_batch_dims
(
batch_dims
)
assert
DBiasQuantizePrimitive
.
outer_primitive
is
not
None
x
,
scale
=
batched_args
x_bdim
,
scale_bdim
=
batch_dims
amax_bdim
=
scale_bdim
out_bdims
=
x_bdim
,
amax_bdim
out_bdims
=
x_bdim
,
x_bdim
,
scale_bdim
,
scale_bdim
,
amax_bdim
,
x_bdim
return
(
CastFP8Primitive
.
outer_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
),
DBiasQuantizePrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
mesh
,
arg_infos
,
result_infos
):
del
out_dtype
,
result_infos
def
infer_sharding_from_operands
(
out_dtype
,
scaling_mode
,
q_axis
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
,
mesh
,
arg_infos
,
result_infos
,
):
del
(
out_dtype
,
result_infos
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
)
# Unused.
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
casted_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
return
(
casted_x_sharding
,
amax_sharding
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
x_spec
[
-
1
]),
desc
=
"DBiasQuantizePrimitive.out_sharding"
,
)
if
q_axis
in
(
QuantizeAxis
.
COLWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
colwise_out_spec
=
multidim_transpose
(
x_spec
)
else
:
colwise_out_spec
=
x_spec
else
:
colwise_out_spec
=
(
None
,)
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"DBiasQuantizePrimitive.colwise_out_sharding"
,
)
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])),
desc
=
"DBiasQuantizePrimitive.scale_inv"
,
)
amax_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
desc
=
"DBiasQuantizePrimitive.amax_sharding"
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"DBiasQuantizePrimitive.colwise_scale_inv"
)
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
x_spec
[
-
1
]),
desc
=
"DBiasQuantizePrimitive.dbias_sharding"
,
)
return
(
out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
colwise_scale_inv_sharding
,
amax_sharding
,
dbias_sharding
,
)
@
staticmethod
def
partition
(
out_dtype
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
def
partition
(
out_dtype
,
scaling_mode
,
q_axis
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
,
mesh
,
arg_infos
,
result_infos
,
):
del
result_infos
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
casted_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
x_spec
[
-
1
]),
desc
=
"DBiasQuantizePrimitive.out_sharding"
,
)
if
q_axis
in
(
QuantizeAxis
.
COLWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
colwise_out_spec
=
multidim_transpose
(
x_spec
)
else
:
colwise_out_spec
=
x_spec
else
:
colwise_out_spec
=
(
None
,)
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"DBiasQuantizePrimitive.colwise_out_sharding"
,
)
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])),
desc
=
"DBiasQuantizePrimitive.scale_inv"
,
)
amax_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
desc
=
"DBiasQuantizePrimitive.amax_sharding"
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"DBiasQuantizePrimitive.colwise_scale_inv"
)
dbias_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
x_spec
[
-
1
]),
desc
=
"DBiasQuantizePrimitive.dbias_sharding"
,
)
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
casted_x_sharding
,
amax_sharding
)
out_shardings
=
(
out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
colwise_scale_inv_sharding
,
amax_sharding
,
dbias_sharding
,
)
def
sharded_impl
(
x
,
amax
,
scale
,
scale_inv
):
local_cx
,
local_updated_amax
=
CastFP8Primitive
.
impl
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
def
sharded_impl
(
x
,
scale
):
(
local_x
,
local_colwise_x
,
local_scale_inv
,
local_colwise_scale_inv
,
local_amax
,
local_dbias
,
)
=
DBiasQuantizePrimitive
.
impl
(
x
,
scale
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
is_outer
=
True
,
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_updated_amax
,
mesh
)
return
local_cx
,
global_updated_amax
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
.
value
:
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
else
:
global_updated_amax
=
local_amax
if
is_dbias
:
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
else
:
global_dbias
=
local_dbias
return
(
local_x
,
local_colwise_x
,
local_scale_inv
,
local_colwise_scale_inv
,
global_updated_amax
,
global_dbias
,
)
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
register_primitive
(
CastFP8
Primitive
)
register_primitive
(
DBiasQuantize
Primitive
)
def
cast_fp8
(
def
_jax_quantize
(
x
,
quantizer
:
Quantizer
=
None
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
):
if
quantizer
is
None
:
return
x
return
quantizer
.
quantize
(
x
,
dq_dtype
=
dq_dtype
)
def
_jax_dbias
(
dx
:
jnp
.
ndarray
):
dbias
=
jnp
.
sum
(
dx
,
axis
=
tuple
(
range
(
dx
.
ndim
-
1
)),
keepdims
=
False
,
)
dbias
=
dbias
.
ravel
()
# C++ function returns an 1D array for dbias
return
dbias
def
_jax_quantize_dbias
(
x
,
quantizer
:
Quantizer
=
None
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
):
if
quantizer
is
None
:
return
x
,
None
return
quantizer
.
quantize
(
x
,
dq_dtype
=
dq_dtype
),
_jax_dbias
(
x
)
def
_jax_dbias
(
dx
:
jnp
.
ndarray
,
):
dbias
=
jnp
.
sum
(
dx
.
astype
(
jnp
.
float32
),
axis
=
tuple
(
range
(
dx
.
ndim
-
1
)),
keepdims
=
False
,
)
dbias
=
dbias
.
ravel
()
# C++ function returns an 1D array for dbias
return
dbias
.
astype
(
dx
.
dtype
)
def
_quantize_impl
(
x
:
jnp
.
ndarray
,
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
out_dtype
:
TEDType
,
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
]:
quantizer
:
Quantizer
,
is_dbias
:
bool
=
False
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
)
->
Tuple
[
ScaledTensor2x
,
jnp
.
ndarray
]:
"""
Cast wrapper
Return FP8 tensor
"""
if
not
CastFP8Primitive
.
enabled
():
return
_jax_cast_fp8
(
x
,
scale
,
amax
,
out_dtype
=
out_dtype
)
return
CastFP8Primitive
.
outer_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
)
assert
(
dq_dtype
is
None
)
or
(
quantizer
is
not
None
),
"quantizer must be provided if dq_dtype is provided"
if
not
DBiasQuantizePrimitive
.
enabled
():
if
is_dbias
:
return
_jax_quantize_dbias
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
)
return
_jax_quantize
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
),
None
# TE/common doesn't support colwise only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_axis
==
QuantizeAxis
.
COLWISE
:
if
is_dbias
:
return
_jax_quantize_dbias
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
)
return
_jax_quantize
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
),
None
scale
=
jnp
.
empty
((),
jnp
.
float32
)
# TE/common dbias_quantize does not support 1x on arch < 100
if
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
=
is_dbias
,
quantizer
=
quantizer
):
out
,
_
=
_quantize_impl
(
x
=
x
,
is_dbias
=
False
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
)
dbias
=
_jax_dbias
(
x
)
return
out
,
dbias
if
quantizer
is
None
:
if
is_dbias
:
return
x
,
_jax_dbias
(
x
)
return
x
,
None
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
scale
=
quantizer
.
scale
(
rowwise_casted_output
,
colwise_casted_output
,
rowwise_scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
)
=
DBiasQuantizePrimitive
.
outer_primitive
.
bind
(
x
,
scale
,
out_dtype
=
quantizer
.
q_dtype
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
q_axis
=
quantizer
.
q_axis
.
value
,
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
),
is_dbias
=
is_dbias
,
is_outer
=
True
,
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
and
quantizer
.
is_2x2x
():
colwise_scale_inv
=
rowwise_scale_inv
quantizer
.
update
(
updated_amax
)
out
=
ScaledTensorFactory
.
create
(
data
=
rowwise_casted_output
,
scale_inv
=
rowwise_scale_inv
,
colwise_data
=
colwise_casted_output
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
,
q_axis
=
quantizer
.
q_axis
,
layout
=
quantizer
.
get_layout
(),
)
return
out
,
dbias
# TODO(Phuong): do not expose dq_dtype to users
def
quantize
(
x
:
jnp
.
ndarray
,
quantizer
:
Quantizer
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
)
->
Tuple
[
ScaledTensor
]:
"""Quantize input tensor according to the quantizer.
Args:
x: Input tensor to be quantized.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
Returns:
A ScaledTensor containing the quantized input tensor.
"""
out
,
_
=
_quantize_impl
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
)
return
out
# TODO(Phuong): do not expose dq_dtype to users
def
quantize_dbias
(
dz
:
jnp
.
ndarray
,
quantizer
:
Quantizer
,
is_dbias
:
bool
=
True
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
)
->
Tuple
[
ScaledTensor2x
,
jnp
.
ndarray
]:
"""Quantize input tensor and compute bias gradient.
Args:
dz: Input tensor to be quantized and used for bias gradient computation.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
is_dbias: If True, compute bias gradient. Defaults to True.
dq_dtype: Optional dtype for dequantization.
If None, uses the same dtype as the input tensor.
Returns:
A tuple containing:
- A ScaledTensor containing the quantized input tensor.
The ScaledTensor includes both the quantized data and scaling factors.
- The bias gradient tensor.
Shape: (K,) or empty if is_dbias is False.
"""
return
_quantize_impl
(
dz
,
quantizer
=
quantizer
,
is_dbias
=
is_dbias
,
dq_dtype
=
dq_dtype
,
)
transformer_engine/jax/cpp_extensions/softmax.py
View file @
a207db1d
...
...
@@ -11,14 +11,10 @@ from packaging import version
import
jax
import
jax.numpy
as
jnp
from
jax
import
dtypes
from
jax.interpreters.mlir
import
ir
from
jax.sharding
import
PartitionSpec
,
NamedSharding
import
transformer_engine_jax
from
.base
import
BasePrimitive
,
register_primitive
from
.custom_call
import
custom_caller
,
CustomCallArgsWrapper
from
.misc
import
get_padded_spec
,
check_valid_batch_dims
,
jax_dtype_to_te_dtype
,
is_ffi_enabled
from
.misc
import
get_padded_spec
,
check_valid_batch_dims
from
..softmax
import
SoftmaxType
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
...
...
@@ -38,30 +34,6 @@ __all__ = [
]
def
_jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
def
_jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
):
if
mask
is
not
None
:
logits
+=
jax
.
lax
.
select
(
mask
>
0
,
jnp
.
full
(
mask
.
shape
,
-
1e10
).
astype
(
logits
.
dtype
),
jnp
.
full
(
mask
.
shape
,
0.0
).
astype
(
logits
.
dtype
),
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
)
def
_jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
logits
+=
jax
.
lax
.
select
(
mask
>
0
,
jnp
.
full
(
mask
.
shape
,
-
1e10
).
astype
(
logits
.
dtype
),
jnp
.
full
(
mask
.
shape
,
0.0
).
astype
(
logits
.
dtype
),
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
)
def
is_softmax_kernel_available
(
softmax_type
:
SoftmaxType
,
batch
:
int
,
...
...
@@ -139,38 +111,7 @@ class SoftmaxPrimitive(BasePrimitive):
"""
softmax_forward lowering rules
"""
if
is_ffi_enabled
():
ffi_name
=
name
+
"_ffi"
out
=
ffi
.
ffi_lowering
(
ffi_name
)(
ctx
,
logits
,
scale_factor
=
scale_factor
)
else
:
(
i_aval
,)
=
ctx
.
avals_in
i_type
=
ir
.
RankedTensorType
(
logits
.
type
)
i_shape
=
i_type
.
shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch
=
reduce
(
operator
.
mul
,
i_shape
[:
-
3
])
pad_batch
=
batch
heads
=
i_shape
[
-
3
]
q_seqlen
=
i_shape
[
-
2
]
k_seqlen
=
i_shape
[
-
1
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
i_shape
,
i_type
.
element_type
)]
operands
=
[
logits
]
operand_shapes
=
[
i_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
opaque
=
transformer_engine_jax
.
pack_softmax_descriptor
(
batch
,
pad_batch
,
heads
,
q_seqlen
,
k_seqlen
,
jax_dtype_to_te_dtype
(
i_aval
.
dtype
),
scale_factor
,
)
out
=
custom_caller
(
name
,
args
,
opaque
,
False
)
return
out
return
ffi
.
ffi_lowering
(
name
)(
ctx
,
logits
,
scale_factor
=
scale_factor
)
@
staticmethod
def
forward_impl
(
primitive
,
logits
,
scale_factor
):
...
...
@@ -250,43 +191,7 @@ class SoftmaxPrimitive(BasePrimitive):
"""
softmax_backward lowering rules
"""
if
is_ffi_enabled
():
ffi_name
=
name
+
"_ffi"
out
=
ffi
.
ffi_lowering
(
ffi_name
)(
ctx
,
dz
,
softmax_out
,
scale_factor
=
scale_factor
)
else
:
dz_aval
,
_
=
ctx
.
avals_in
dz_type
=
ir
.
RankedTensorType
(
dz
.
type
)
dz_shape
=
dz_type
.
shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch
=
reduce
(
operator
.
mul
,
dz_shape
[:
-
3
])
pad_batch
=
batch
# unused
heads
=
dz_shape
[
-
3
]
q_seqlen
=
dz_shape
[
-
2
]
k_seqlen
=
dz_shape
[
-
1
]
softmax_out_type
=
ir
.
RankedTensorType
(
softmax_out
.
type
)
softmax_out_shape
=
softmax_out_type
.
shape
out_types
=
[
ir
.
RankedTensorType
.
get
(
dz_shape
,
dz_type
.
element_type
)]
operands
=
[
dz
,
softmax_out
]
operand_shapes
=
[
dz_shape
,
softmax_out_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
opaque
=
transformer_engine_jax
.
pack_softmax_descriptor
(
batch
,
pad_batch
,
heads
,
q_seqlen
,
k_seqlen
,
jax_dtype_to_te_dtype
(
dz_aval
.
dtype
),
scale_factor
,
)
out
=
custom_caller
(
name
,
args
,
opaque
,
False
)
return
out
return
ffi
.
ffi_lowering
(
name
)(
ctx
,
dz
,
softmax_out
,
scale_factor
=
scale_factor
)
@
staticmethod
def
backward_impl
(
primitive
,
dz
,
softmax_out
,
scale_factor
):
...
...
@@ -356,7 +261,7 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
Scaled Softmax Fwd Primitive
"""
name
=
"te_scaled_softmax_forward"
name
=
"te_scaled_softmax_forward
_ffi
"
multiple_results
=
False
impl_static_args
=
(
1
,)
# scale_factor
inner_primitive
=
None
...
...
@@ -429,22 +334,12 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledSoftmaxFwdPrimitive
)
def
scaled_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
"""
scaled_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if
not
ScaledSoftmaxFwdPrimitive
.
enabled
():
return
_jax_scaled_softmax
(
logits
,
scale_factor
)
return
ScaledSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
scale_factor
=
scale_factor
)
class
ScaledSoftmaxBwdPrimitive
(
SoftmaxPrimitive
):
"""
Scaled Softmax Bwd Primitive
"""
name
=
"te_scaled_softmax_backward"
name
=
"te_scaled_softmax_backward
_ffi
"
multiple_results
=
False
impl_static_args
=
(
2
,)
# scale_factor
inner_primitive
=
None
...
...
@@ -530,7 +425,7 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
Scaled Masked Softmax Fwd Primitive
"""
name
=
"te_scaled_masked_softmax_forward"
name
=
"te_scaled_masked_softmax_forward
_ffi
"
multiple_results
=
False
impl_static_args
=
(
2
,)
# scale_factor
inner_primitive
=
None
...
...
@@ -591,42 +486,10 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
"""
te_scaled_masked_softmax_forward lowering rules
"""
if
is_ffi_enabled
():
ffi_name
=
"te_scaled_masked_softmax_forward_ffi"
out
=
ffi
.
ffi_lowering
(
ffi_name
)(
ctx
,
logits
,
mask
,
scale_factor
=
scale_factor
)
else
:
logits_aval
,
_
=
ctx
.
avals_in
i_type
=
ir
.
RankedTensorType
(
logits
.
type
)
i_shape
=
i_type
.
shape
# Assume [...Batch, Head, Q_Seqlen, K_Seqlen]
batch
=
reduce
(
operator
.
mul
,
i_shape
[:
-
3
])
heads
=
i_shape
[
-
3
]
q_seqlen
=
i_shape
[
-
2
]
k_seqlen
=
i_shape
[
-
1
]
mask_type
=
ir
.
RankedTensorType
(
mask
.
type
)
mask_shape
=
mask_type
.
shape
pad_batch
=
reduce
(
operator
.
mul
,
mask_shape
[:
-
3
])
out_types
=
[
ir
.
RankedTensorType
.
get
(
i_shape
,
i_type
.
element_type
)]
operands
=
[
logits
,
mask
]
operand_shapes
=
[
i_shape
,
mask_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
opaque
=
transformer_engine_jax
.
pack_softmax_descriptor
(
batch
,
pad_batch
,
heads
,
q_seqlen
,
k_seqlen
,
jax_dtype_to_te_dtype
(
logits_aval
.
dtype
),
scale_factor
,
return
ffi
.
ffi_lowering
(
ScaledMaskedSoftmaxFwdPrimitive
.
name
)(
ctx
,
logits
,
mask
,
scale_factor
=
scale_factor
)
out
=
custom_caller
(
ScaledMaskedSoftmaxFwdPrimitive
.
name
,
args
,
opaque
,
False
)
return
out
@
staticmethod
def
impl
(
logits
,
mask
,
scale_factor
):
assert
ScaledMaskedSoftmaxFwdPrimitive
.
inner_primitive
is
not
None
...
...
@@ -666,26 +529,12 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledMaskedSoftmaxFwdPrimitive
)
def
scaled_masked_softmax_fwd
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
"""
scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if
not
ScaledMaskedSoftmaxFwdPrimitive
.
enabled
():
return
_jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
)
return
ScaledMaskedSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
mask
,
scale_factor
=
scale_factor
)
class
ScaledMaskedSoftmaxBwdPrimitive
(
SoftmaxPrimitive
):
"""
Scaled Masked Softmax Bwd Primitive
"""
name
=
"te_scaled_masked_softmax_backward"
name
=
"te_scaled_masked_softmax_backward
_ffi
"
multiple_results
=
False
impl_static_args
=
(
2
,)
# scale_factor
inner_primitive
=
None
...
...
@@ -712,12 +561,10 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
te_scaled_upper_triang_masked_backward lowering rules
"""
out
=
SoftmaxPrimitive
.
backward_lowering
(
return
SoftmaxPrimitive
.
backward_lowering
(
ScaledMaskedSoftmaxBwdPrimitive
.
name
,
ctx
,
dz
,
softmax_out
,
scale_factor
=
scale_factor
)
return
out
@
staticmethod
def
impl
(
dz
,
softmax_out
,
scale_factor
):
return
SoftmaxPrimitive
.
backward_impl
(
...
...
@@ -753,33 +600,12 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledMaskedSoftmaxBwdPrimitive
)
def
scaled_masked_softmax_bwd
(
dz
:
jnp
.
ndarray
,
softmax_out
:
jnp
.
ndarray
,
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
,
)
->
jnp
.
ndarray
:
"""
scaled_masked_backward wrapper
Return FP16/BF16 tensor
"""
if
not
ScaledMaskedSoftmaxBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_scaled_masked_softmax
,
scale_factor
=
scale_factor
),
logits
,
mask
)
return
vjp_func
(
dz
)[
0
]
return
ScaledMaskedSoftmaxBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
softmax_out
,
scale_factor
=
scale_factor
)
class
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
(
SoftmaxPrimitive
):
"""
Scaled Upper Triang Masked Softmax Fwd Primitive
"""
name
=
"te_scaled_upper_triang_masked_softmax_forward"
name
=
"te_scaled_upper_triang_masked_softmax_forward
_ffi
"
multiple_results
=
False
impl_static_args
=
(
1
,)
# scale_factor
inner_primitive
=
None
...
...
@@ -860,24 +686,12 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
)
def
scaled_upper_triang_masked_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
"""
scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if
not
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
enabled
():
return
_jax_scaled_upper_triang_masked_softmax
(
logits
,
scale_factor
)
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
scale_factor
=
scale_factor
)
class
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
(
SoftmaxPrimitive
):
"""
Scaled Upper Triang Masked Softmax Bwd Primitive
"""
name
=
"te_scaled_upper_triang_masked_softmax_backward"
name
=
"te_scaled_upper_triang_masked_softmax_backward
_ffi
"
multiple_results
=
False
impl_static_args
=
(
2
,)
# scale_factor
inner_primitive
=
None
...
...
@@ -904,7 +718,7 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
"""
te_scaled_upper_triang_masked_backward lowering rules
"""
out
=
SoftmaxPrimitive
.
backward_lowering
(
return
SoftmaxPrimitive
.
backward_lowering
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
.
name
,
ctx
,
dz
,
...
...
@@ -912,8 +726,6 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
scale_factor
=
scale_factor
,
)
return
out
@
staticmethod
def
impl
(
dz
,
softmax_out
,
scale_factor
):
return
SoftmaxPrimitive
.
backward_impl
(
...
...
@@ -953,6 +765,87 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
def
_jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
def
_jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
):
if
mask
is
not
None
:
logits
+=
jax
.
lax
.
select
(
mask
>
0
,
jnp
.
full
(
mask
.
shape
,
-
1e10
).
astype
(
logits
.
dtype
),
jnp
.
full
(
mask
.
shape
,
0.0
).
astype
(
logits
.
dtype
),
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
)
def
_jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
logits
+=
jax
.
lax
.
select
(
mask
>
0
,
jnp
.
full
(
mask
.
shape
,
-
1e10
).
astype
(
logits
.
dtype
),
jnp
.
full
(
mask
.
shape
,
0.0
).
astype
(
logits
.
dtype
),
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
)
def
scaled_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
"""
scaled_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if
not
ScaledSoftmaxFwdPrimitive
.
enabled
():
return
_jax_scaled_softmax
(
logits
,
scale_factor
)
return
ScaledSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
scale_factor
=
scale_factor
)
def
scaled_masked_softmax_fwd
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
"""
scaled_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if
not
ScaledMaskedSoftmaxFwdPrimitive
.
enabled
():
return
_jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
)
return
ScaledMaskedSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
mask
,
scale_factor
=
scale_factor
)
def
scaled_masked_softmax_bwd
(
dz
:
jnp
.
ndarray
,
softmax_out
:
jnp
.
ndarray
,
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
,
)
->
jnp
.
ndarray
:
"""
scaled_masked_backward wrapper
Return FP16/BF16 tensor
"""
if
not
ScaledMaskedSoftmaxBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_scaled_masked_softmax
,
scale_factor
=
scale_factor
),
logits
,
mask
)
return
vjp_func
(
dz
)[
0
]
return
ScaledMaskedSoftmaxBwdPrimitive
.
outer_primitive
.
bind
(
dz
,
softmax_out
,
scale_factor
=
scale_factor
)
def
scaled_upper_triang_masked_softmax_fwd
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
"""
scaled_upper_triang_masked_softmax_forward wrapper
Return FP16/BF16 tensor
"""
if
not
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
enabled
():
return
_jax_scaled_upper_triang_masked_softmax
(
logits
,
scale_factor
)
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
scale_factor
=
scale_factor
)
def
scaled_upper_triang_masked_softmax_bwd
(
dz
:
jnp
.
ndarray
,
softmax_out
:
jnp
.
ndarray
,
logits
:
jnp
.
ndarray
,
scale_factor
:
float
)
->
jnp
.
ndarray
:
...
...
transformer_engine/jax/cpp_extensions/transpose.py
deleted
100644 → 0
View file @
fbee8990
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX/TE custom ops for transpose"""
import
operator
from
functools
import
partial
,
reduce
from
typing
import
Tuple
,
Sequence
,
Union
,
Callable
from
packaging
import
version
import
jax
import
jax.numpy
as
jnp
from
jax
import
dtypes
from
jax.interpreters.mlir
import
ir
from
jax.sharding
import
PartitionSpec
,
NamedSharding
import
transformer_engine_jax
from
transformer_engine_jax
import
DType
as
TEDType
from
.base
import
BasePrimitive
,
register_primitive
from
.custom_call
import
custom_caller
,
CustomCallArgsWrapper
from
.misc
import
(
check_valid_batch_dims
,
jax_dtype_to_te_dtype
,
jax_dtype_to_ir_dtype
,
te_dtype_to_jax_dtype
,
get_padded_spec
,
multidim_transpose
,
normalize_axis_boundary
,
is_ffi_enabled
,
)
from
.activation
import
ActivationEnum
from
.activation
import
_jax_act_lu
from
.quantization
import
_jax_cast_fp8
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
,
all_reduce_sum_along_dp_fsdp
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
from
jax
import
ffi
# pylint: disable=ungrouped-imports
else
:
from
jax.extend
import
ffi
# pylint: disable=ungrouped-imports
__all__
=
[
"transpose"
,
"cast_transpose"
,
"dbias_cast_transpose"
,
"dact_lu_dbias_cast_transpose"
,
"dgated_act_lu_cast_transpose"
,
]
def
_jax_transpose
(
inputs
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
JAX native transpose implementation
"""
axes
=
multidim_transpose
(
range
(
inputs
.
ndim
),
static_axis_boundary
,
transpose_axis_boundary
)
return
jnp
.
transpose
(
inputs
,
axes
=
axes
)
def
_jax_cast_transpose
(
inputs
,
scale
,
amax
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
JAX native cast_transpose implementation
"""
casted_output
,
updated_amax
=
_jax_cast_fp8
(
inputs
,
scale
,
amax
,
out_dtype
=
out_dtype
)
casted_transposed_output
=
_jax_transpose
(
casted_output
,
static_axis_boundary
,
transpose_axis_boundary
)
return
casted_output
,
casted_transposed_output
,
updated_amax
def
_jax_dbias_cast_transpose
(
dz
,
amax
,
scale
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
JAX native dbias_cast_transpose implementation
"""
casted_dz
,
cast_transposed_dz
,
updated_amax
=
_jax_cast_transpose
(
dz
,
scale
,
amax
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
dbias
=
jnp
.
sum
(
dz
,
axis
=
tuple
(
range
(
transpose_axis_boundary
if
transpose_axis_boundary
>
0
else
transpose_axis_boundary
+
dz
.
ndim
)
),
keepdims
=
False
,
)
dbias
=
dbias
.
ravel
()
# C++ function returns an 1D array for dbias
return
casted_dz
,
cast_transposed_dz
,
dbias
,
updated_amax
class
TransposePrimitive
(
BasePrimitive
):
"""
Transpose Primitive
"""
name
=
"te_transpose"
multiple_results
=
False
impl_static_args
=
(
1
,
2
)
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
x_aval
,
*
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
_transpose abstract
"""
transposed_x_shape
=
multidim_transpose
(
x_aval
.
shape
,
static_axis_boundary
,
transpose_axis_boundary
)
xt_aval
=
x_aval
.
update
(
shape
=
transposed_x_shape
,
dtype
=
x_aval
.
dtype
)
return
xt_aval
@
staticmethod
def
lowering
(
ctx
,
x
,
*
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
_transpose cuda lowering
"""
x_aval
=
ctx
.
avals_in
[
0
]
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
,
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
,
]
if
is_ffi_enabled
():
name
=
"te_transpose_ffi"
out
=
ffi
.
ffi_lowering
(
name
)(
ctx
,
x
,
transpose_axis
=
transpose_axis_boundary
)
else
:
ir_x_type
=
ir
.
RankedTensorType
(
x
.
type
)
ir_x_shape
=
ir_x_type
.
shape
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
x_aval
.
dtype
)
if
static_axis_boundary
>=
0
:
for
i
in
range
(
static_axis_boundary
+
1
):
assert
ir_x_shape
[
i
]
==
1
transposed_x_shape
=
multidim_transpose
(
ir_x_shape
,
static_axis_boundary
,
transpose_axis_boundary
)
out_types
=
[
ir
.
RankedTensorType
.
get
(
transposed_x_shape
,
ir_out_dtype
)]
operands
=
[
x
]
operand_shapes
=
[
ir_x_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
te_dtype
=
jax_dtype_to_te_dtype
(
x_aval
.
dtype
)
contracted_x_shape
=
(
reduce
(
operator
.
mul
,
ir_x_shape
[:
transpose_axis_boundary
]),
reduce
(
operator
.
mul
,
ir_x_shape
[
transpose_axis_boundary
:]),
)
opaque
=
transformer_engine_jax
.
pack_common_descriptor
(
contracted_x_shape
,
te_dtype
,
te_dtype
)
out
=
custom_caller
(
TransposePrimitive
.
name
,
args
,
opaque
,
False
)
return
out
@
staticmethod
def
impl
(
x
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
tcast_transpose implementation
"""
assert
TransposePrimitive
.
inner_primitive
is
not
None
transposed_x
=
TransposePrimitive
.
inner_primitive
.
bind
(
x
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
return
transposed_x
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
static_axis_boundary
,
transpose_axis_boundary
):
check_valid_batch_dims
(
batch_dims
)
assert
TransposePrimitive
.
outer_primitive
is
not
None
assert
static_axis_boundary
<
0
(
x
,)
=
batched_args
(
x_bdim
,)
=
batch_dims
# Minus batch dim.
transpose_axis_boundary
=
normalize_axis_boundary
(
transpose_axis_boundary
,
x
.
ndim
-
1
)
transpose_axis_boundary
+=
1
# Plus batch dim
out_bdims
=
x_bdim
return
(
TransposePrimitive
.
outer_primitive
.
bind
(
x
,
static_axis_boundary
=
x_bdim
,
transpose_axis_boundary
=
transpose_axis_boundary
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
static_axis_boundary
,
transpose_axis_boundary
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
transpose_axis_boundary
)
transposed_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
return
transposed_x_sharding
@
staticmethod
def
partition
(
static_axis_boundary
,
transpose_axis_boundary
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
transpose_axis_boundary
)
transposed_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
transposed_x_sharding
impl
=
partial
(
TransposePrimitive
.
impl
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
return
mesh
,
impl
,
out_shardings
,
arg_shardings
register_primitive
(
TransposePrimitive
)
def
transpose
(
x
:
jnp
.
ndarray
,
static_axis_boundary
:
int
,
transpose_axis_boundary
:
int
)
->
jnp
.
ndarray
:
"""
transpose wrapper
"""
if
not
TransposePrimitive
.
enabled
():
return
_jax_transpose
(
x
,
static_axis_boundary
,
transpose_axis_boundary
)
return
TransposePrimitive
.
outer_primitive
.
bind
(
x
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
class
CastTransposePrimitive
(
BasePrimitive
):
"""
Cast Transpose Primitive
"""
name
=
"te_cast_transpose"
multiple_results
=
True
impl_static_args
=
(
4
,
5
,
6
)
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
,
*
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
te_cast_transpose_p abstract
"""
dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
transposed_x_shape
=
multidim_transpose
(
x_aval
.
shape
,
static_axis_boundary
,
transpose_axis_boundary
)
casted_x_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
casted_xt_aval
=
x_aval
.
update
(
shape
=
transposed_x_shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
amax_aval
.
update
(
shape
=
amax_aval
.
shape
,
dtype
=
amax_aval
.
dtype
)
return
casted_x_aval
,
casted_xt_aval
,
updated_amax_aval
@
staticmethod
def
lowering
(
ctx
,
x
,
amax
,
scale
,
scale_inv
,
*
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
te_cast_transpose_p lowering rules
"""
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
=
ctx
.
avals_in
assert
x_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
if
is_ffi_enabled
():
name
=
"te_cast_transpose_ffi"
out
=
ffi
.
ffi_lowering
(
name
,
operand_output_aliases
=
{
1
:
2
})(
ctx
,
x
,
amax
,
scale
,
scale_inv
,
transpose_axis
=
transpose_axis_boundary
)
else
:
ir_x_type
=
ir
.
RankedTensorType
(
x
.
type
)
ir_x_shape
=
ir_x_type
.
shape
if
static_axis_boundary
>=
0
:
for
i
in
range
(
static_axis_boundary
+
1
):
assert
ir_x_shape
[
i
]
==
1
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
out_dtype
)
ir_amax_type
=
ir
.
RankedTensorType
(
amax
.
type
)
ir_amax_dtype
=
ir_amax_type
.
element_type
ir_amax_shape
=
ir_amax_type
.
shape
ir_scale_shape
=
ir_amax_shape
ir_scale_inv_shape
=
ir_amax_shape
transposed_x_shape
=
multidim_transpose
(
ir_x_shape
,
static_axis_boundary
,
transpose_axis_boundary
)
out_types
=
[
ir
.
RankedTensorType
.
get
(
ir_x_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
transposed_x_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
ir_amax_shape
,
ir_amax_dtype
),
]
operands
=
[
x
,
amax
,
scale
,
scale_inv
]
operand_shapes
=
[
ir_x_shape
,
ir_amax_shape
,
ir_scale_shape
,
ir_scale_inv_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
contracted_x_shape
=
(
reduce
(
operator
.
mul
,
ir_x_shape
[:
transpose_axis_boundary
]),
reduce
(
operator
.
mul
,
ir_x_shape
[
transpose_axis_boundary
:]),
)
opaque
=
transformer_engine_jax
.
pack_common_descriptor
(
contracted_x_shape
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
)
out
=
custom_caller
(
CastTransposePrimitive
.
name
,
args
,
opaque
,
False
,
operand_output_aliases
=
{
1
:
2
}
)
return
out
@
staticmethod
def
impl
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
te_cast_transpose implementation
"""
assert
CastTransposePrimitive
.
inner_primitive
is
not
None
casted_x
,
casted_transposed_x
,
updated_amax
=
CastTransposePrimitive
.
inner_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
return
casted_x
,
casted_transposed_x
,
updated_amax
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
check_valid_batch_dims
(
batch_dims
)
assert
CastTransposePrimitive
.
outer_primitive
is
not
None
assert
static_axis_boundary
<
0
x
,
amax
,
scale
,
scale_inv
=
batched_args
x_bdim
,
amax_bdim
,
*
_
=
batch_dims
# Minus batch dim.
transpose_axis_boundary
=
normalize_axis_boundary
(
transpose_axis_boundary
,
x
.
ndim
-
1
)
transpose_axis_boundary
+=
1
# Plus batch dim
out_bdims
=
x_bdim
,
x_bdim
,
amax_bdim
return
(
CastTransposePrimitive
.
outer_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
x_bdim
,
transpose_axis_boundary
=
transpose_axis_boundary
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
,
mesh
,
arg_infos
,
result_infos
):
del
out_dtype
,
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
casted_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
transpose_axis_boundary
)
casted_transposed_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
return
(
casted_x_sharding
,
casted_transposed_x_sharding
,
amax_sharding
)
@
staticmethod
def
partition
(
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
casted_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
transpose_axis_boundary
)
casted_transposed_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
casted_x_sharding
,
casted_transposed_x_sharding
,
amax_sharding
)
def
sharded_impl
(
x
,
amax
,
scale
,
scale_inv
):
local_cx
,
local_cxt
,
local_updated_amax
=
CastTransposePrimitive
.
impl
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_updated_amax
,
mesh
)
return
local_cx
,
local_cxt
,
global_updated_amax
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
register_primitive
(
CastTransposePrimitive
)
def
cast_transpose
(
x
:
jnp
.
ndarray
,
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
out_dtype
:
jnp
.
dtype
,
static_axis_boundary
:
int
,
transpose_axis_boundary
:
int
,
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""
cast transpose wrapper
Return two tensors, FP8(inputs) and FP8(inputs.T), which are scaled by `scale`
"""
if
not
CastTransposePrimitive
.
enabled
():
return
_jax_cast_transpose
(
x
,
scale
,
amax
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
)
return
CastTransposePrimitive
.
outer_primitive
.
bind
(
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
class
DBiasCastTransposePrimitive
(
BasePrimitive
):
"""
DBias Cast Transpose Primitive
"""
name
=
"te_dbias_cast_transpose"
multiple_results
=
True
# out_dtype, static_axis_boundary, transpose_axis_boundary
impl_static_args
=
(
4
,
5
,
6
)
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
dz_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
,
*
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
te_dbias_cast_transpose_p abstract
"""
dtype
=
dtypes
.
canonicalize_dtype
(
dz_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
gi_hidden_size
=
reduce
(
operator
.
mul
,
dz_aval
.
shape
[
transpose_axis_boundary
:])
t_shape
=
multidim_transpose
(
dz_aval
.
shape
,
static_axis_boundary
,
transpose_axis_boundary
)
out
=
dz_aval
.
update
(
shape
=
dz_aval
.
shape
,
dtype
=
out_dtype
)
t_out
=
dz_aval
.
update
(
shape
=
t_shape
,
dtype
=
out_dtype
)
dbias_shape
=
(
*
dz_aval
.
shape
[:
static_axis_boundary
+
1
],
gi_hidden_size
)
dbias
=
dz_aval
.
update
(
shape
=
dbias_shape
,
dtype
=
dtype
)
updated_amax_aval
=
amax_aval
.
update
(
shape
=
amax_aval
.
shape
,
dtype
=
amax_aval
.
dtype
)
(
wkspace_info
,)
=
transformer_engine_jax
.
get_dbias_ct_workspace_sizes
(
dz_aval
.
size
//
gi_hidden_size
,
gi_hidden_size
,
jax_dtype_to_te_dtype
(
dz_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
)
wkspace_aval
=
dz_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
return
out
,
t_out
,
dbias
,
updated_amax_aval
,
wkspace_aval
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
te_dbias_cast_transpose_p outer abstract
"""
out
,
t_out
,
dbias
,
updated_amax_aval
,
_
=
DBiasCastTransposePrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
out
,
t_out
,
dbias
,
updated_amax_aval
@
staticmethod
def
lowering
(
ctx
,
dz
,
amax
,
scale
,
scale_inv
,
*
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
te_dbias_cast_transpose_p lowering rules
"""
dz_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
=
ctx
.
avals_in
assert
dz_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
if
is_ffi_enabled
():
name
=
"te_dbias_cast_transpose_ffi"
out
=
ffi
.
ffi_lowering
(
name
,
operand_output_aliases
=
{
1
:
3
})(
ctx
,
dz
,
amax
,
scale
,
scale_inv
,
transpose_axis
=
transpose_axis_boundary
)
else
:
ir_dz_type
=
ir
.
RankedTensorType
(
dz
.
type
)
ir_dz_shape
=
ir_dz_type
.
shape
batch_size
=
reduce
(
operator
.
mul
,
ir_dz_shape
[:
transpose_axis_boundary
])
ir_hidden_size
=
reduce
(
operator
.
mul
,
ir_dz_shape
[
transpose_axis_boundary
:])
contracted_dz_shape
=
(
batch_size
,
ir_hidden_size
)
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
out_dtype
)
ir_amax_type
=
ir
.
RankedTensorType
(
amax
.
type
)
ir_amax_dtype
=
ir_amax_type
.
element_type
ir_amax_shape
=
ir_amax_type
.
shape
ir_scale_shape
=
ir_amax_shape
ir_scale_inv_shape
=
ir_amax_shape
transposed_dz_shape
=
multidim_transpose
(
ir_dz_shape
,
static_axis_boundary
,
transpose_axis_boundary
)
dbias_shape
=
(
*
ir_dz_shape
[:
static_axis_boundary
+
1
],
ir_hidden_size
)
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
ir_dz_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
transposed_dz_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
dbias_shape
,
ir_dz_type
.
element_type
),
ir
.
RankedTensorType
.
get
(
ir_amax_shape
,
ir_amax_dtype
),
ir
.
RankedTensorType
.
get
(
wkspace_aval
.
shape
,
jax_dtype_to_ir_dtype
(
wkspace_aval
.
dtype
)
),
]
operands
=
[
dz
,
amax
,
scale
,
scale_inv
]
operand_shapes
=
[
ir_dz_shape
,
ir_amax_shape
,
ir_scale_shape
,
ir_scale_inv_shape
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
opaque
=
transformer_engine_jax
.
pack_common_wk_descriptor
(
contracted_dz_shape
,
wkspace_aval
.
shape
,
jax_dtype_to_te_dtype
(
dz_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
)
out
=
custom_caller
(
DBiasCastTransposePrimitive
.
name
,
args
,
opaque
,
False
,
operand_output_aliases
=
{
1
:
3
}
)
return
out
@
staticmethod
def
impl
(
dz
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
to describe implementation
"""
assert
DBiasCastTransposePrimitive
.
inner_primitive
is
not
None
out
,
t_out
,
dbias
,
updated_amax
,
_
=
DBiasCastTransposePrimitive
.
inner_primitive
.
bind
(
dz
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
return
out
,
t_out
,
dbias
,
updated_amax
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
):
"""
to describe batch rules for vmap
"""
del
static_axis_boundary
check_valid_batch_dims
(
batch_dims
)
assert
DBiasCastTransposePrimitive
.
outer_primitive
is
not
None
dz
,
amax
,
scale
,
scale_inv
=
batched_args
dz_bdim
,
amax_bdim
,
_
,
_
=
batch_dims
# Minus batch dim.
transpose_axis_boundary
=
normalize_axis_boundary
(
transpose_axis_boundary
,
dz
.
ndim
-
1
)
transpose_axis_boundary
+=
1
# Plus batch dim
out_bdims
=
dz_bdim
,
dz_bdim
,
dz_bdim
,
amax_bdim
return
(
DBiasCastTransposePrimitive
.
outer_primitive
.
bind
(
dz
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
dz_bdim
,
transpose_axis_boundary
=
transpose_axis_boundary
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
,
mesh
,
arg_infos
,
result_infos
):
del
out_dtype
,
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
transpose_axis_boundary
)
tranposed_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
dbias_shaprding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
static_axis_boundary
+
1
],
x_spec
[
-
1
])
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
return
(
out_sharding
,
tranposed_out_sharding
,
dbias_shaprding
,
amax_sharding
)
@
staticmethod
def
partition
(
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
casted_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
transpose_axis_boundary
)
casted_transposed_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
dbias_shaprding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
static_axis_boundary
+
1
],
x_spec
[
-
1
])
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
casted_x_sharding
,
casted_transposed_x_sharding
,
dbias_shaprding
,
amax_sharding
,
)
def
sharded_impl
(
dz
,
amax
,
scale
,
scale_inv
):
local_out
,
local_t_out
,
local_dbias
,
local_amax
=
DBiasCastTransposePrimitive
.
impl
(
dz
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
return
local_out
,
local_t_out
,
global_dbias
,
global_updated_amax
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
register_primitive
(
DBiasCastTransposePrimitive
)
def
dbias_cast_transpose
(
dz
:
jnp
.
ndarray
,
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
out_dtype
:
TEDType
,
static_axis_boundary
:
int
,
transpose_axis_boundary
:
int
=
-
1
,
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""
cast transpose dbias partial fusion wrapper
Return FP8(inputs), dbias
"""
if
static_axis_boundary
<
0
:
static_axis_boundary
=
-
1
# means no static axes
if
not
DBiasCastTransposePrimitive
.
enabled
():
return
_jax_dbias_cast_transpose
(
dz
,
amax
,
scale
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
)
return
DBiasCastTransposePrimitive
.
outer_primitive
.
bind
(
dz
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=
transpose_axis_boundary
,
)
class
DActLuDBiasCastTransposePrimitive
(
BasePrimitive
):
"""
DActLu DBias Cast Transpose Primitive
"""
name
=
"te_dact_lu_dbias_cast_transpose"
multiple_results
=
True
# out_dtype, static_axis_boundary, act_enum
impl_static_args
=
(
5
,
6
,
7
)
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
dz_aval
,
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
,
*
,
out_dtype
,
static_axis_boundary
,
act_enum
):
# pylint: disable=unused-argument
"""
te_dact_lu_dbais_cast_transpose_p abstract
"""
dtype
=
dtypes
.
canonicalize_dtype
(
dz_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
x_aval
.
dtype
==
dtype
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
ir_hidden_szie
=
dz_aval
.
shape
[
-
1
]
gi_hidden_size
=
x_aval
.
shape
[
-
1
]
assert
ir_hidden_szie
==
gi_hidden_size
t_shape
=
multidim_transpose
(
x_aval
.
shape
,
static_axis_boundary
,
-
2
)
out
=
dz_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
t_out
=
dz_aval
.
update
(
shape
=
t_shape
,
dtype
=
out_dtype
)
dbias_shape
=
(
*
x_aval
.
shape
[:
static_axis_boundary
+
1
],
gi_hidden_size
)
dbias
=
dz_aval
.
update
(
shape
=
dbias_shape
,
dtype
=
dtype
)
updated_amax_aval
=
amax_aval
.
update
(
shape
=
amax_aval
.
shape
,
dtype
=
amax_aval
.
dtype
)
(
wkspace_info
,)
=
transformer_engine_jax
.
get_dact_dbias_ct_workspace_sizes
(
x_aval
.
size
//
gi_hidden_size
,
gi_hidden_size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
)
wkspace_aval
=
x_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
return
out
,
t_out
,
dbias
,
updated_amax_aval
,
wkspace_aval
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
"""
te_dact_lu_dbais_cast_transpose_p outer abstract
"""
out
,
t_out
,
dbias
,
updated_amax_aval
,
_
=
DActLuDBiasCastTransposePrimitive
.
abstract
(
*
args
,
**
kwargs
)
return
out
,
t_out
,
dbias
,
updated_amax_aval
@
staticmethod
def
lowering
(
ctx
,
dz
,
x
,
amax
,
scale
,
scale_inv
,
*
,
out_dtype
,
static_axis_boundary
,
act_enum
):
"""
te_dgated_act_lu_cast_transpose_p lowering rules
"""
dz_aval
,
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
=
ctx
.
avals_in
assert
dz_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
x_aval
.
dtype
==
dz_aval
.
dtype
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
if
is_ffi_enabled
():
name
=
"te_dact_lu_dbias_cast_transpose_ffi"
out
=
ffi
.
ffi_lowering
(
name
,
operand_output_aliases
=
{
2
:
3
})(
ctx
,
dz
,
x
,
amax
,
scale
,
scale_inv
,
act_enum
=
int
(
act_enum
)
)
else
:
ir_dz_type
=
ir
.
RankedTensorType
(
dz
.
type
)
ir_dz_shape
=
ir_dz_type
.
shape
x_type
=
ir
.
RankedTensorType
(
x
.
type
)
x_shape
=
x_type
.
shape
dz_batch_szie
=
reduce
(
operator
.
mul
,
ir_dz_shape
[:
-
1
])
x_batch_size
=
reduce
(
operator
.
mul
,
x_shape
[:
-
2
])
assert
dz_batch_szie
==
x_batch_size
ir_hidden_szie
=
ir_dz_shape
[
-
1
]
contracted_x_shape
=
(
x_batch_size
,
ir_hidden_szie
)
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
out_dtype
)
ir_amax_type
=
ir
.
RankedTensorType
(
amax
.
type
)
ir_amax_dtype
=
ir_amax_type
.
element_type
ir_amax_shape
=
ir_amax_type
.
shape
ir_scale_shape
=
ir_amax_shape
ir_scale_inv_shape
=
ir_amax_shape
transposed_x_shape
=
multidim_transpose
(
x_shape
,
static_axis_boundary
,
-
2
)
dbias_shape
=
(
*
x_shape
[:
static_axis_boundary
+
1
],
ir_hidden_szie
)
wkspace_aval
=
ctx
.
avals_out
[
-
1
]
out_types
=
[
ir
.
RankedTensorType
.
get
(
x_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
transposed_x_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
dbias_shape
,
ir_dz_type
.
element_type
),
ir
.
RankedTensorType
.
get
(
ir_amax_shape
,
ir_amax_dtype
),
ir
.
RankedTensorType
.
get
(
wkspace_aval
.
shape
,
jax_dtype_to_ir_dtype
(
wkspace_aval
.
dtype
)
),
]
operands
=
[
dz
,
x
,
amax
,
scale
,
scale_inv
]
operand_shapes
=
[
ir_dz_shape
,
x_shape
,
ir_amax_shape
,
ir_scale_shape
,
ir_scale_inv_shape
,
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
opaque
=
transformer_engine_jax
.
pack_common_wk_descriptor
(
contracted_x_shape
,
wkspace_aval
.
shape
,
jax_dtype_to_te_dtype
(
dz_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
jax_dtype_to_te_dtype
(
wkspace_aval
.
dtype
),
act_enum
,
)
out
=
custom_caller
(
DActLuDBiasCastTransposePrimitive
.
name
,
args
,
opaque
,
False
,
operand_output_aliases
=
{
2
:
3
},
)
return
out
@
staticmethod
def
impl
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
act_enum
,
):
"""
to describe implementation
"""
assert
DActLuDBiasCastTransposePrimitive
.
inner_primitive
is
not
None
out
,
t_out
,
dbias
,
updated_amax
,
_
=
DActLuDBiasCastTransposePrimitive
.
inner_primitive
.
bind
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
act_enum
=
act_enum
,
)
return
out
,
t_out
,
dbias
,
updated_amax
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
static_axis_boundary
,
act_enum
):
"""
to describe batch rules for vmap
"""
del
static_axis_boundary
check_valid_batch_dims
(
batch_dims
)
assert
DActLuDBiasCastTransposePrimitive
.
outer_primitive
is
not
None
dz
,
x
,
amax
,
scale
,
scale_inv
=
batched_args
x_bdim
,
_
,
amax_bdim
,
_
,
_
=
batch_dims
out_bdims
=
x_bdim
,
x_bdim
,
x_bdim
,
amax_bdim
return
(
DActLuDBiasCastTransposePrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
x_bdim
,
act_enum
=
act_enum
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
static_axis_boundary
,
act_enum
,
mesh
,
arg_infos
,
result_infos
,
):
del
out_dtype
,
result_infos
,
act_enum
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
-
2
)
tranposed_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
dbias_shaprding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
static_axis_boundary
+
1
],
x_spec
[
-
1
])
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
2
])))
return
(
out_sharding
,
tranposed_out_sharding
,
dbias_shaprding
,
amax_sharding
)
@
staticmethod
def
partition
(
out_dtype
,
static_axis_boundary
,
act_enum
,
mesh
,
arg_infos
,
result_infos
,
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
casted_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
-
2
)
casted_transposed_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
dbias_shaprding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
static_axis_boundary
+
1
],
x_spec
[
-
1
])
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
2
])))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
casted_x_sharding
,
casted_transposed_x_sharding
,
dbias_shaprding
,
amax_sharding
,
)
def
sharded_impl
(
dz
,
x
,
amax
,
scale
,
scale_inv
):
local_out
,
local_t_out
,
local_dbias
,
local_amax
=
(
DActLuDBiasCastTransposePrimitive
.
impl
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
act_enum
=
act_enum
,
)
)
global_dbias
=
all_reduce_sum_along_dp_fsdp
(
local_dbias
,
mesh
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
return
local_out
,
local_t_out
,
global_dbias
,
global_updated_amax
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
register_primitive
(
DActLuDBiasCastTransposePrimitive
)
def
dact_lu_dbias_cast_transpose
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
out_dtype
:
TEDType
,
static_axis_boundary
:
int
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""
cast transpose dact_lu and dbias fusion wrapper
Return FP8(dact_lu(inputs)), dbias
ONLY support non-gated activation type
"""
if
static_axis_boundary
<
0
:
static_axis_boundary
=
-
1
# means no static axes
if
not
DActLuDBiasCastTransposePrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_act_lu
,
activation_type
=
activation_type
),
x
)
(
dx
,)
=
vjp_func
(
dz
)
transpose_axis_boundary
=
-
2
return
_jax_dbias_cast_transpose
(
dx
,
amax
,
scale
,
out_dtype
,
static_axis_boundary
,
transpose_axis_boundary
)
act_type_id
=
ActivationEnum
[
activation_type
]
return
DActLuDBiasCastTransposePrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
act_enum
=
act_type_id
,
)
class
DgatedActLuCastTransposePrimitive
(
BasePrimitive
):
"""
Dgated ActLu Cast Transpose Primitive
"""
name
=
"te_dgated_act_lu_cast_transpose"
multiple_results
=
True
impl_static_args
=
(
5
,
6
,
7
)
# out_dtype, static_axis_boundary, act_enum
inner_primitive
=
None
outer_primitive
=
None
@
staticmethod
def
abstract
(
dz_aval
,
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
,
*
,
out_dtype
,
static_axis_boundary
,
act_enum
):
# pylint: disable=unused-argument
"""
te_dgated_act_lu_cast_transpose_p abstract
"""
dtype
=
dtypes
.
canonicalize_dtype
(
dz_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
x_aval
.
dtype
==
dtype
assert
x_aval
.
shape
[
-
2
]
==
2
# Linear + GeLU
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
ir_hidden_szie
=
dz_aval
.
shape
[
-
1
]
gi_hidden_size
=
x_aval
.
shape
[
-
1
]
assert
ir_hidden_szie
==
gi_hidden_size
t_shape
=
multidim_transpose
(
x_aval
.
shape
,
static_axis_boundary
,
-
2
)
out
=
dz_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
t_out
=
dz_aval
.
update
(
shape
=
t_shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
amax_aval
.
update
(
shape
=
amax_aval
.
shape
,
dtype
=
amax_aval
.
dtype
)
return
out
,
t_out
,
updated_amax_aval
@
staticmethod
def
lowering
(
ctx
,
dz
,
x
,
amax
,
scale
,
scale_inv
,
*
,
out_dtype
,
static_axis_boundary
,
act_enum
):
"""
te_dgated_act_lu_cast_transpose_p lowering rules
"""
dz_aval
,
x_aval
,
amax_aval
,
scale_aval
,
scale_inv_aval
=
ctx
.
avals_in
assert
dz_aval
.
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
x_aval
.
dtype
==
dz_aval
.
dtype
assert
amax_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_inv_aval
.
dtype
==
jnp
.
float32
if
is_ffi_enabled
():
name
=
"te_dgated_act_lu_cast_transpose_ffi"
out
=
ffi
.
ffi_lowering
(
name
,
operand_output_aliases
=
{
2
:
2
})(
ctx
,
dz
,
x
,
amax
,
scale
,
scale_inv
,
act_enum
=
int
(
act_enum
)
)
else
:
ir_dz_type
=
ir
.
RankedTensorType
(
dz
.
type
)
ir_dz_shape
=
ir_dz_type
.
shape
x_type
=
ir
.
RankedTensorType
(
x
.
type
)
x_shape
=
x_type
.
shape
dz_batch_szie
=
reduce
(
operator
.
mul
,
ir_dz_shape
[:
-
1
])
x_batch_size
=
reduce
(
operator
.
mul
,
x_shape
[:
-
2
])
assert
dz_batch_szie
==
x_batch_size
assert
x_shape
[
-
2
]
==
2
# Linear + GeLU
ir_hidden_szie
=
ir_dz_shape
[
-
1
]
gi_hidden_size
=
x_shape
[
-
1
]
assert
ir_hidden_szie
==
gi_hidden_size
ir_out_dtype
=
jax_dtype_to_ir_dtype
(
out_dtype
)
ir_amax_type
=
ir
.
RankedTensorType
(
amax
.
type
)
ir_amax_dtype
=
ir_amax_type
.
element_type
ir_amax_shape
=
ir_amax_type
.
shape
ir_scale_shape
=
ir_amax_shape
ir_scale_inv_shape
=
ir_amax_shape
transposed_x_shape
=
multidim_transpose
(
x_shape
,
static_axis_boundary
,
-
2
)
out_types
=
[
ir
.
RankedTensorType
.
get
(
x_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
transposed_x_shape
,
ir_out_dtype
),
ir
.
RankedTensorType
.
get
(
ir_amax_shape
,
ir_amax_dtype
),
]
operands
=
[
dz
,
x
,
amax
,
scale
,
scale_inv
]
operand_shapes
=
[
ir_dz_shape
,
x_shape
,
ir_amax_shape
,
ir_scale_shape
,
ir_scale_inv_shape
,
]
args
=
CustomCallArgsWrapper
(
out_types
,
operands
,
operand_shapes
)
contracted_x_shape
=
(
x_batch_size
,
x_shape
[
-
1
])
opaque
=
transformer_engine_jax
.
pack_common_descriptor
(
contracted_x_shape
,
jax_dtype_to_te_dtype
(
dz_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
act_enum
,
)
out
=
custom_caller
(
DgatedActLuCastTransposePrimitive
.
name
,
args
,
opaque
,
False
,
operand_output_aliases
=
{
2
:
2
},
)
return
out
@
staticmethod
def
impl
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
,
static_axis_boundary
,
act_enum
):
"""
to describe implementation
"""
assert
DgatedActLuCastTransposePrimitive
.
inner_primitive
is
not
None
out
,
t_out
,
updated_amax
=
DgatedActLuCastTransposePrimitive
.
inner_primitive
.
bind
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
act_enum
=
act_enum
,
)
return
out
,
t_out
,
updated_amax
@
staticmethod
def
batcher
(
batched_args
,
batch_dims
,
*
,
out_dtype
,
static_axis_boundary
,
act_enum
):
"""
to describe batch rules for vmap
"""
del
static_axis_boundary
check_valid_batch_dims
(
batch_dims
)
assert
DgatedActLuCastTransposePrimitive
.
outer_primitive
is
not
None
dz
,
x
,
amax
,
scale
,
scale_inv
=
batched_args
x_bdim
,
_
,
amax_bdim
,
_
,
_
=
batch_dims
out_bdims
=
x_bdim
,
x_bdim
,
amax_bdim
return
(
DgatedActLuCastTransposePrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
x_bdim
,
act_enum
=
act_enum
,
),
out_bdims
,
)
@
staticmethod
def
infer_sharding_from_operands
(
out_dtype
,
static_axis_boundary
,
act_enum
,
mesh
,
arg_infos
,
result_infos
):
del
out_dtype
,
result_infos
,
act_enum
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
-
2
)
tranposed_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
2
])))
return
(
out_sharding
,
tranposed_out_sharding
,
amax_sharding
)
@
staticmethod
def
partition
(
out_dtype
,
static_axis_boundary
,
act_enum
,
mesh
,
arg_infos
,
result_infos
):
del
result_infos
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
casted_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
))
xt_spec
=
multidim_transpose
(
x_spec
,
static_axis_boundary
,
-
2
)
casted_transposed_x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
xt_spec
))
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
2
])))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
casted_x_sharding
,
casted_transposed_x_sharding
,
amax_sharding
)
def
sharded_impl
(
dz
,
x
,
amax
,
scale
,
scale_inv
):
local_out
,
local_t_out
,
local_amax
=
DgatedActLuCastTransposePrimitive
.
impl
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
act_enum
=
act_enum
,
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
return
local_out
,
local_t_out
,
global_updated_amax
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
register_primitive
(
DgatedActLuCastTransposePrimitive
)
def
dgated_act_lu_cast_transpose
(
dz
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
scale_inv
:
jnp
.
ndarray
,
out_dtype
:
TEDType
,
static_axis_boundary
:
int
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
)
->
Tuple
[
jnp
.
ndarray
,
jnp
.
ndarray
,
jnp
.
ndarray
]:
"""
cast transpose d_gated_act_lu fusion wrapper
Return FP8(dgated_act_lu(inputs))
"""
act_type_id
=
ActivationEnum
[
activation_type
]
if
not
DgatedActLuCastTransposePrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_act_lu
,
activation_type
=
activation_type
),
x
)
(
dx
,)
=
vjp_func
(
dz
)
return
_jax_cast_transpose
(
dx
,
scale
,
amax
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
transpose_axis_boundary
=-
2
,
)
return
DgatedActLuCastTransposePrimitive
.
outer_primitive
.
bind
(
dz
,
x
,
amax
,
scale
,
scale_inv
,
out_dtype
=
out_dtype
,
static_axis_boundary
=
static_axis_boundary
,
act_enum
=
act_type_id
,
)
transformer_engine/jax/csrc/extensions.h
View file @
a207db1d
...
...
@@ -13,6 +13,7 @@
#include <cudnn.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <transformer_engine/normalization.h>
#include <transformer_engine/transformer_engine.h>
#include <cassert>
...
...
@@ -33,226 +34,42 @@
namespace
transformer_engine
{
namespace
jax
{
// Phuong: These 3 functions need to stay in the header file for compilation purpose
// 1.
inline
bool
use_fp8
(
DType
type
)
{
return
type
==
DType
::
kFloat8E4M3
||
type
==
DType
::
kFloat8E5M2
;
}
// 2.
template
<
typename
T
>
pybind11
::
bytes
PackOpaque
(
const
T
&
descriptor
)
{
auto
str
=
std
::
string
(
reinterpret_cast
<
const
char
*>
(
&
descriptor
),
sizeof
(
T
));
return
pybind11
::
bytes
(
str
);
}
// 3.
template
<
typename
T
>
const
T
*
UnpackOpaque
(
const
char
*
opaque
,
size_t
opaque_len
)
{
if
(
opaque_len
!=
sizeof
(
T
))
{
throw
std
::
runtime_error
(
"Invalid opaque object size"
);
}
return
reinterpret_cast
<
const
T
*>
(
opaque
);
}
// Packing
struct
CustomCallCommonDescriptor
{
Shape
shape
;
DType
in_dtype
;
DType
out_dtype
;
size_t
act_enum
;
};
pybind11
::
bytes
PackCustomCallCommonDescriptor
(
const
std
::
vector
<
size_t
>
&
shape
,
DType
in_dtype
,
DType
out_dtype
,
size_t
act_enum
=
0
);
struct
CustomCallCommonWkDescriptor
{
Shape
shape
;
Shape
wkshape
;
DType
in_dtype
;
DType
out_dtype
;
DType
wk_dtype
;
size_t
act_enum
;
};
pybind11
::
bytes
PackCustomCallCommonWkDescriptor
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
size_t
>
&
wkshape
,
DType
in_dtype
,
DType
out_dtype
,
DType
wk_dtype
,
size_t
act_enum
=
0
);
struct
CustomCallNormDescriptor
{
size_t
batch_size
;
size_t
hidden_size
;
size_t
wkspace_size
;
DType
x_dtype
;
DType
w_dtype
;
DType
wkspace_dtype
;
bool
zero_centered_gamma
;
float
eps
;
int
sm_margin
;
};
pybind11
::
bytes
PackCustomCallNormDescriptor
(
size_t
batch_size
,
size_t
hidden_size
,
size_t
wkspace_size
,
DType
x_dtype
,
DType
w_dtype
,
DType
wkspace_dtype
,
bool
zero_centered_gamma
,
float
eps
,
int
sm_margin
);
struct
SoftmaxDescriptor
{
size_t
batch_size
;
size_t
padding_size
;
size_t
head_dim
;
size_t
q_seqlen
;
size_t
k_seqlen
;
DType
dtype
;
float
scale_factor
;
};
pybind11
::
bytes
PackCustomCallSoftmaxDescriptor
(
size_t
batch_size
,
size_t
padding_size
,
size_t
head_dim
,
size_t
q_seqlen
,
size_t
k_seqlen
,
DType
dtype
,
float
scale_factor
);
struct
CustomCallFusedAttnDescriptor
{
size_t
input_batch
;
size_t
bias_batch
;
size_t
q_max_seqlen
;
size_t
kv_max_seqlen
;
size_t
attn_heads
;
size_t
num_gqa_groups
;
size_t
bias_heads
;
size_t
head_dim
;
size_t
max_segments_per_seq
;
size_t
wkspace_size
;
float
scaling_factor
;
float
dropout_probability
;
NVTE_Bias_Type
bias_type
;
NVTE_Mask_Type
mask_type
;
NVTE_QKV_Layout
qkv_layout
;
DType
dtype
;
DType
wkspace_dtype
;
bool
is_training
;
bool
deterministic
;
int64_t
window_size_left
;
int64_t
window_size_right
;
};
pybind11
::
bytes
PackCustomCallFusedAttnDescriptor
(
size_t
input_batch
,
size_t
batch_size
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
head_dim
,
size_t
max_segments_per_seq
,
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
);
// Transpose
void
Transpose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
TransposeHandler
);
void
CastTranspose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
CastTransposeHandler
);
pybind11
::
tuple
GetDBiasCastTransposeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
);
void
DBiasCastTranspose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DBiasCastTransposeHandler
);
// Activation
size_t
get_activation_len
(
NVTE_Activation_Type
activation_enum
);
void
ActLu
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ActLuHandler
);
void
ActLuFP8
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ActLuFP8Handler
);
void
DActLu
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DActLuHandler
);
pybind11
::
tuple
GetDActDBiasCastTransposeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
);
void
DActLuDBiasCastTranspose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DActLuDBiasCastTransposeHandler
);
void
DGatedActLuCastTranspose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DGatedActLuCastTransposeHandler
);
// Normalization
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
NormForwardHandler
);
pybind11
::
tuple
GetLayerNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
DType
out_dtype
,
bool
is_layer_norm
,
bool
zero_centered_gamma
,
float
eps
,
int
sm_margin
);
void
LayerNormForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
LayerNormForwardHandler
);
void
LayerNormForwardFP8
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
LayerNormForwardFP8Handler
);
pybind11
::
tuple
GetLayerNormBackwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
bool
is_layer_norm
,
bool
zero_centered_gamma
,
float
eps
,
int
sm_margin
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
NormBackwardHandler
);
void
LayerNormBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
pybind11
::
tuple
GetNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
DType
out_dtype
,
NVTE_Norm_Type
norm_type
,
int
scaling_mode
,
bool
zero_centered_gamma
,
float
epsilon
,
int
sm_margin
,
bool
is_training
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
LayerNormBackwardHandler
);
void
RMSNormForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
RMSNormForwardHandler
);
void
RMSNormForwardFP8
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
RMSNormForwardFP8Handler
);
void
RMSNormBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
RMSNormBackwardHandler
);
pybind11
::
tuple
GetNormBackwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
w_dtype
,
NVTE_Norm_Type
norm_type
,
bool
zero_centered_gamma
,
int
sm_margin
);
// Quantization
void
Quantize
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
QuantizeHandler
);
void
Dequantize
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DBiasQuantizeHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DequantizeHandler
);
// Softmax
void
ScaledSoftmaxForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
std
::
size_t
opaque_len
);
void
ScaledSoftmaxBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
std
::
size_t
opaque_len
);
void
ScaledMaskedSoftmaxForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
std
::
size_t
opaque_len
);
void
ScaledMaskedSoftmaxBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
std
::
size_t
opaque_len
);
pybind11
::
tuple
GetDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
);
void
ScaledUpperTriangMaskedSoftmaxForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
std
::
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
DActLuDBiasQuantizeHandler
);
void
ScaledUpperTriangMaskedSoftmaxBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
std
::
size_t
opaque_len
);
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
,
int
scaling_mode
,
bool
is_2x
);
// Softmax
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ScaledSoftmaxForwardHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ScaledSoftmaxBackwardHandler
);
...
...
@@ -266,9 +83,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
ScaledUpperTriangMaskedSoftmaxBackwardHandler
);
// Attention
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnForwardHandler
);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
CudnnHandleInitHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnBackwardHandler
);
NVTE_Fused_Attn_Backend
GetFusedAttnBackend
(
DType
q_dtype
,
DType
kv_dtype
,
NVTE_QKV_Layout
qkv_layout
,
NVTE_Bias_Type
bias_type
,
...
...
@@ -285,10 +102,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
bool
is_training
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
void
FusedAttnForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnForwardHandler
);
pybind11
::
tuple
GetFusedAttnBackwardWorkspaceSizes
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
head_dim
,
...
...
@@ -297,9 +110,14 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes(
bool
deterministic
,
size_t
max_segments_per_seq
,
int64_t
window_size_left
,
int64_t
window_size_right
);
void
FusedAttnBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
);
// Grouped GEMM
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
GroupedGemmHandler
);
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
FusedAttnBackwardHandler
);
// Cudnn helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
CudnnHandleInitHandler
);
// CuBLAS helpers
XLA_FFI_DECLARE_HANDLER_SYMBOL
(
CublasHandleInitHandler
);
}
// namespace jax
}
// namespace transformer_engine
...
...
transformer_engine/jax/csrc/extensions/activation.cpp
View file @
a207db1d
...
...
@@ -5,328 +5,136 @@
************************************************************************/
#include "transformer_engine/activation.h"
#include <cuda_runtime.h>
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/transpose.h"
#include "xla/ffi/api/c_api.h"
namespace
transformer_engine
{
namespace
jax
{
// TODO: We won't need this function anymore when we move to the new XLA custom calls
size_t
get_activation_len
(
NVTE_Activation_Type
activation_enum
)
{
switch
(
activation_enum
)
{
case
NVTE_Activation_Type
::
GELU
:
return
1
;
case
NVTE_Activation_Type
::
GEGLU
:
return
2
;
case
NVTE_Activation_Type
::
SILU
:
return
1
;
case
NVTE_Activation_Type
::
SWIGLU
:
return
2
;
case
NVTE_Activation_Type
::
RELU
:
return
1
;
case
NVTE_Activation_Type
::
REGLU
:
return
2
;
case
NVTE_Activation_Type
::
QGELU
:
return
1
;
case
NVTE_Activation_Type
::
QGEGLU
:
return
2
;
case
NVTE_Activation_Type
::
SRELU
:
return
1
;
case
NVTE_Activation_Type
::
SREGLU
:
return
2
;
default:
NVTE_ERROR
(
"Unsupported ActivationEnum"
);
break
;
return
-
1
;
}
}
void
ActLuImpl
(
void
*
input
,
size_t
m
,
size_t
n
,
DType
in_dtype
,
DType
out_dtype
,
float
*
scale
,
cudaStream_t
stream
,
float
*
scale_inverse
,
float
*
amax
,
void
*
output
,
NVTE_Activation_Type
act_enum
,
size_t
act_len
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
static_cast
<
DType
>
(
in_dtype
));
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
static_cast
<
DType
>
(
out_dtype
),
amax
,
scale
,
scale_inverse
);
switch
(
act_enum
)
{
case
NVTE_Activation_Type
::
GELU
:
nvte_gelu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
GEGLU
:
nvte_geglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SILU
:
nvte_silu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SWIGLU
:
nvte_swiglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
RELU
:
nvte_relu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
REGLU
:
nvte_reglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGELU
:
nvte_qgelu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGEGLU
:
nvte_qgeglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SRELU
:
nvte_srelu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SREGLU
:
nvte_sreglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
default:
NVTE_ERROR
(
"Unsupported ActivationEnum"
);
break
;
}
}
void
ActLu
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
output
=
buffers
[
1
];
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonDescriptor
>
(
opaque
,
opaque_len
);
auto
m
=
desc
.
shape
.
dims
[
0
];
auto
n
=
desc
.
shape
.
dims
[
1
];
auto
act_enum
=
static_cast
<
NVTE_Activation_Type
>
(
desc
.
act_enum
);
auto
act_len
=
get_activation_len
(
act_enum
);
ActLuImpl
(
input
,
m
,
n
,
desc
.
in_dtype
,
desc
.
out_dtype
,
nullptr
,
stream
,
nullptr
,
nullptr
,
output
,
act_enum
,
act_len
);
}
Error_Type
ActLuFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Result_Type
output_buf
,
int64_t
act_enum
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
output
=
output_buf
->
untyped_data
();
auto
input_dims
=
input_buf
.
dimensions
();
auto
m
=
product
(
input_dims
,
0
,
input_dims
.
size
()
-
2
);
auto
n
=
input_dims
.
back
();
auto
act_len
=
input_dims
.
end
()[
-
2
];
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
ActLuImpl
(
input
,
m
,
n
,
in_dtype
,
out_dtype
,
nullptr
,
stream
,
nullptr
,
nullptr
,
output
,
act_type
,
act_len
);
return
ffi_with_cuda_error_check
();
namespace
{
bool
is_gated
(
NVTE_Activation_Type
act_type
)
{
return
act_type
==
NVTE_Activation_Type
::
GEGLU
||
act_type
==
NVTE_Activation_Type
::
SWIGLU
||
act_type
==
NVTE_Activation_Type
::
REGLU
||
act_type
==
NVTE_Activation_Type
::
QGEGLU
||
act_type
==
NVTE_Activation_Type
::
SREGLU
;
}
}
// namespace
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
ActLuHandler
,
ActLuFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Ret
<
Buffer_Type
>
()
// output
.
Attr
<
int64_t
>
(
"act_enum"
),
FFI_CudaGraph_Traits
);
void
ActLuFP8
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
float
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
1
]);
float
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
2
]);
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
auto
*
output
=
buffers
[
4
];
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
buffers
[
5
]);
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX ActLuFP8 primitive."
);
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonDescriptor
>
(
opaque
,
opaque_len
);
if
(
!
use_fp8
(
desc
.
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
m
=
desc
.
shape
.
dims
[
0
];
auto
n
=
desc
.
shape
.
dims
[
1
];
auto
act_enum
=
static_cast
<
NVTE_Activation_Type
>
(
desc
.
act_enum
);
auto
act_len
=
get_activation_len
(
act_enum
);
ActLuImpl
(
input
,
m
,
n
,
desc
.
in_dtype
,
desc
.
out_dtype
,
scale
,
stream
,
scale_inv
,
amax_out
,
output
,
act_enum
,
act_len
);
}
namespace
transformer_engine
{
namespace
jax
{
Error_Type
ActLuFP8FFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
,
Result_Type
amax_out_buf
,
int64_t
act_enum
)
{
Error_Type
ActLuFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
scale_buf
,
Result_Type
output_buf
,
Result_Type
colwise_output_buf
,
Result_Type
scale_inv_buf
,
Result_Type
colwise_scale_inv_buf
,
Result_Type
amax_buf
,
int64_t
act_enum
,
int64_t
scaling_mode_enum
,
bool
is_2x_int
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
.
untyped_data
());
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
scale_inv_buf
.
untyped_data
());
auto
*
output
=
output_buf
->
untyped_data
();
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX ActLuFP8 primitive."
);
if
(
!
use_fp8
(
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
*
colwise_output
=
colwise_output_buf
->
untyped_data
();
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
auto
input_dims
=
input_buf
.
dimensions
();
auto
m
=
product
(
input_dims
,
0
,
input_dims
.
size
()
-
2
);
auto
n
=
input_dims
.
back
();
auto
act_len
=
input_dims
.
end
()[
-
2
];
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
auto
act_len
=
input_dims
[
input_dims
.
size
()
-
2
];
auto
scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
is_2x
=
static_cast
<
bool
>
(
is_2x_int
);
ActLuImpl
(
input
,
m
,
n
,
in_dtype
,
out_dtype
,
scale
,
stream
,
scale_inv
,
amax_out
,
output
,
act_type
,
act_len
);
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
ActLuFP8Handler
,
ActLuFP8FFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
// amax
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// amax_out
.
Attr
<
int64_t
>
(
"act_enum"
),
FFI_CudaGraph_Traits
);
void
DActLu
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
act_input
=
buffers
[
1
];
auto
*
output
=
buffers
[
2
];
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonDescriptor
>
(
opaque
,
opaque_len
);
auto
m
=
desc
.
shape
.
dims
[
0
];
auto
n
=
desc
.
shape
.
dims
[
1
];
auto
act_enum
=
static_cast
<
NVTE_Activation_Type
>
(
desc
.
act_enum
);
auto
act_len
=
get_activation_len
(
act_enum
);
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
desc
.
in_dtype
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
desc
.
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
desc
.
out_dtype
);
switch
(
act_enum
)
{
case
NVTE_Activation_Type
::
GELU
:
nvte_dgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
GEGLU
:
nvte_dgeglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SILU
:
nvte_dsilu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SWIGLU
:
nvte_dswiglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
RELU
:
nvte_drelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
REGLU
:
nvte_dreglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGELU
:
nvte_dqgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGEGLU
:
nvte_dqgeglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SRELU
:
nvte_dsrelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SREGLU
:
nvte_dsreglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
default:
NVTE_ERROR
(
"Unsupported ActivationEnum"
);
break
;
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
act_len
*
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
static_cast
<
DType
>
(
in_dtype
));
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
}
}
Error_Type
DActLuFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
act_input_buf
,
Result_Type
output_buf
,
int64_t
act_enum
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
act_input
=
act_input_buf
.
untyped_data
();
auto
*
output
=
output_buf
->
untyped_data
();
auto
act_input_dims
=
act_input_buf
.
dimensions
();
auto
m
=
static_cast
<
size_t
>
(
product
(
act_input_dims
,
0
,
act_input_dims
.
size
()
-
2
));
auto
n
=
static_cast
<
size_t
>
(
act_input_dims
.
back
());
auto
act_len
=
act_input_dims
.
end
()[
-
2
];
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
act_len
};
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
static_cast
<
DType
>
(
in_dtype
));
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
static_cast
<
DType
>
(
in_dtype
));
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
static_cast
<
DType
>
(
out_dtype
));
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
colwise_output
,
static_cast
<
DType
>
(
out_dtype
),
output_shape
);
output_tensor
.
set_columnwise_scale_inv
(
colwise_scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
colwise_scale_inv_buf
->
dimensions
().
back
()});
}
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
switch
(
act_type
)
{
case
NVTE_Activation_Type
::
GELU
:
nvte_
d
gelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_gelu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
GEGLU
:
nvte_
d
geglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_geglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SILU
:
nvte_
d
silu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_silu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SWIGLU
:
nvte_
d
swiglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_swiglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
RELU
:
nvte_
d
relu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_relu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
REGLU
:
nvte_
d
reglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_reglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGELU
:
nvte_
d
qgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_qgelu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGEGLU
:
nvte_
d
qgeglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_qgeglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SRELU
:
nvte_
d
srelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_srelu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SREGLU
:
nvte_
d
sreglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_sreglu
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
default:
NVTE_ERROR
(
"Unsupported ActivationEnum"
);
break
;
}
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
D
ActLuHandler
,
D
ActLuFFI
,
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
ActLuHandler
,
ActLuFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
//
act_input
.
Arg
<
Buffer_Type
>
()
//
scale
.
Ret
<
Buffer_Type
>
()
// output
.
Attr
<
int64_t
>
(
"act_enum"
),
.
Ret
<
Buffer_Type
>
()
// colwise output
.
Ret
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// scale_inv colwise
.
Ret
<
Buffer_Type
>
()
// amax
.
Attr
<
int64_t
>
(
"act_enum"
)
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
bool
>
(
"is_2x"
),
FFI_CudaGraph_Traits
);
pybind11
::
tuple
GetDActDBiasCastTransposeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
)
{
pybind11
::
tuple
GetDActDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
,
int
scaling_mode
,
bool
is_2x
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
dact_input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
...
...
@@ -344,13 +152,34 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
auto
input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
input_shape
,
in_dtype
);
auto
dact_input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dact_input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
();
output_tensor
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_shape
);
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_trans_shape
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
static_cast
<
NVTEScalingMode
>
(
scaling_mode
));
output_tensor
.
set_rowwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_shape
);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
TensorWrapper
dummy_workspace
;
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_trans_shape
);
// Only the pointers will be checked for scale_inv, thus the shapes do not matter
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_columnwise_scale_inv
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
}
if
(
is_fp8_dtype
(
out_dtype
)
&&
scaling_mode
==
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
)
{
output_tensor
.
set_amax
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_scale
(
reinterpret_cast
<
void
*>
(
&
temp
),
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
TensorWrapper
dummy_workspace
;
// For now, all dbias_dact(-s) have the same workspace size
nvte_quantize_dbias_dgelu
(
input_tensor
.
data
(),
dact_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
dummy_workspace
.
data
(),
nullptr
);
...
...
@@ -359,101 +188,26 @@ pybind11::tuple GetDActDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_
return
pybind11
::
make_tuple
(
std
::
make_pair
(
work_shape
,
dummy_workspace
.
dtype
()));
}
void
DActLuDBiasCastTranspose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
act_input
=
buffers
[
1
];
float
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
2
]);
float
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
4
]);
auto
*
output
=
buffers
[
5
];
auto
*
output_trans
=
buffers
[
6
];
auto
*
dbias
=
buffers
[
7
];
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
buffers
[
8
]);
void
*
workspace_ptr
=
buffers
[
9
];
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonWkDescriptor
>
(
opaque
,
opaque_len
);
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive."
);
if
(
!
use_fp8
(
desc
.
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
m
=
desc
.
shape
.
dims
[
0
];
auto
n
=
desc
.
shape
.
dims
[
1
];
auto
act_enum
=
static_cast
<
NVTE_Activation_Type
>
(
desc
.
act_enum
);
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
desc
.
in_dtype
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
desc
.
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
desc
.
out_dtype
,
amax_out
,
scale
,
scale_inv
);
output_tensor
.
set_columnwise_data
(
output_trans
,
desc
.
out_dtype
,
output_trans_shape
);
output_tensor
.
set_columnwise_scale_inv
(
scale_inv
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
desc
.
in_dtype
);
auto
workspace
=
TensorWrapper
(
workspace_ptr
,
desc
.
wkshape
.
to_vector
(),
desc
.
wk_dtype
);
switch
(
act_enum
)
{
case
NVTE_Activation_Type
::
GELU
:
nvte_quantize_dbias_dgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SILU
:
nvte_quantize_dbias_dsilu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
RELU
:
nvte_quantize_dbias_drelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGELU
:
nvte_quantize_dbias_dqgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SRELU
:
nvte_quantize_dbias_dsrelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
stream
);
break
;
default:
NVTE_ERROR
(
"Unsupported ActivationEnum"
);
break
;
}
}
Error_Type
DActLuDBiasCastTransposeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
act_input_buf
,
Buffer_Type
amax_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Error_Type
DActLuDBiasQuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
act_input_buf
,
Buffer_Type
scale_buf
,
Result_Type
output_buf
,
Result_Type
output_trans_buf
,
Result_Type
dbias_buf
,
Result_Type
amax_out_buf
,
Result_Type
workspace_buf
,
int64_t
act_enum
)
{
Result_Type
scale_inv_buf
,
Result_Type
trans_scale_inv_buf
,
Result_Type
amax_out_buf
,
Result_Type
dbias_buf
,
Result_Type
workspace_buf
,
int64_t
scaling_mode_enum
,
bool
is_2x
,
bool
is_dbias
,
int64_t
act_enum
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
act_input
=
act_input_buf
.
untyped_data
();
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
.
untyped_data
());
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
()
);
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
scale_inv_buf
.
untyped_data
());
auto
scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans_buf
->
untyped_data
();
auto
*
dbias
=
dbias_buf
->
untyped_data
();
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
void
*
workspace
=
workspace_buf
->
untyped_data
();
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX DActLuDBiasCastTranspose primitive."
);
if
(
!
use_fp8
(
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
input_dims
=
input_buf
.
dimensions
();
auto
act_input_dims
=
act_input_buf
.
dimensions
();
...
...
@@ -461,212 +215,156 @@ Error_Type DActLuDBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_bu
// m = x_batch_size = reduce(operator.mul, x_shape[:-2]), x_shape == act_input_dims
// n = ir_dz_shape[-1], ir_dz_shape == input_dims
auto
input_ranks
=
input_dims
.
size
();
auto
m
=
product
(
act_input_dims
,
0
,
act_input_dims
.
size
()
-
2
);
auto
n
=
product
(
input_dims
,
input_ranks
-
1
,
input_ranks
);
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
act_input_ranks
=
act_input_dims
.
size
();
auto
m
=
product
(
act_input_dims
,
0
,
act_input_dims
.
size
()
-
1
);
// 'n' will be 2x the size of input_dims.back() if the dactivation is dgated
auto
n
=
act_input_dims
.
back
();
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
input_dims
.
back
()};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
};
std
::
vector
<
size_t
>
workspace_shape
(
workspace_dims
.
begin
(),
workspace_dims
.
end
());
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
out_dtype
,
amax_out
,
scale
,
scale_inv
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax_out
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
cudaMemsetAsync
(
amax_out
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_amax
(
amax_out
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
}
if
(
is_2x
)
{
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
output_tensor
.
set_columnwise_scale_inv
(
scale_inv
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
if
(
is_fp8_dtype
(
out_dtype
))
{
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
colwise_scale_inv_buf
=
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
trans_scale_inv_buf
;
output_tensor
.
set_columnwise_scale_inv
(
colwise_scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
colwise_scale_inv_buf
->
dimensions
().
back
()});
}
}
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
workspace_dtype
);
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
// fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead
NVTE_CHECK
(
!
(
is_gated
(
act_type
)
&&
is_dbias
),
"Unsupported DGatedActedDBias Fusion!"
);
NVTE_CHECK
(
!
(
scaling_mode
==
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
&&
is_2x
&&
is_gated
(
act_type
)),
"TE/common does not support delayed scaling for 2x with gated activations."
);
if
(
is_dbias
)
{
switch
(
act_type
)
{
case
NVTE_Activation_Type
::
GELU
:
nvte_quantize_dbias_dgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
nvte_quantize_dbias_dgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SILU
:
nvte_quantize_dbias_dsilu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
nvte_quantize_dbias_dsilu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
RELU
:
nvte_quantize_dbias_drelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
nvte_quantize_dbias_drelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGELU
:
nvte_quantize_dbias_dqgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
nvte_quantize_dbias_dqgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SRELU
:
nvte_quantize_dbias_dsrelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
nvte_quantize_dbias_dsrelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
break
;
default:
NVTE_ERROR
(
"Unsupported ActivationEnum"
);
NVTE_ERROR
(
"Unsupported ActivationEnum
= "
,
act_enum
,
"with dbias = True
"
);
break
;
}
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
DActLuDBiasCastTransposeHandler
,
DActLuDBiasCastTransposeFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
// act_input
.
Arg
<
Buffer_Type
>
()
// amax
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// output_trans
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// amax_out
.
Ret
<
Buffer_Type
>
()
// workspace
.
Attr
<
int64_t
>
(
"act_enum"
),
FFI_CudaGraph_Traits
);
void
DGatedActLuCastTranspose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
act_input
=
buffers
[
1
];
float
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
2
]);
float
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
4
]);
auto
*
output
=
buffers
[
5
];
auto
*
output_trans
=
buffers
[
6
];
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
buffers
[
7
]);
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonDescriptor
>
(
opaque
,
opaque_len
);
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive."
);
if
(
!
use_fp8
(
desc
.
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
m
=
desc
.
shape
.
dims
[
0
];
auto
n
=
desc
.
shape
.
dims
[
1
];
auto
act_enum
=
static_cast
<
NVTE_Activation_Type
>
(
desc
.
act_enum
);
auto
input_shape
=
desc
.
shape
.
to_vector
();
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
2
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
2
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
*
2
,
m
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
desc
.
in_dtype
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
desc
.
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
desc
.
out_dtype
,
amax_out
,
scale
,
scale_inv
);
output_tensor
.
set_columnwise_data
(
output_trans
,
desc
.
out_dtype
,
output_trans_shape
);
output_tensor
.
set_columnwise_scale_inv
(
scale_inv
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
switch
(
act_enum
)
{
case
NVTE_Activation_Type
::
GEGLU
:
nvte_dgeglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SWIGLU
:
nvte_dswiglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
}
else
{
switch
(
act_type
)
{
case
NVTE_Activation_Type
::
GELU
:
nvte_dgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
REGLU
:
nvte_dreglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
case
NVTE_Activation_Type
::
SILU
:
nvte_dsilu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGEGLU
:
nvte_dqgeglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
case
NVTE_Activation_Type
::
RELU
:
nvte_drelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SREGLU
:
nvte_dsreglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
case
NVTE_Activation_Type
::
QGELU
:
nvte_dqgelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
default
:
NVTE_ERROR
(
"Unsupported ActivationEnum"
);
case
NVTE_Activation_Type
::
SRELU
:
nvte_dsrelu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
}
}
Error_Type
DGatedActLuCastTransposeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
act_input_buf
,
Buffer_Type
amax_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
,
Result_Type
output_trans_buf
,
Result_Type
amax_out_buf
,
int64_t
act_enum
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
act_input
=
act_input_buf
.
untyped_data
();
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
.
untyped_data
());
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
scale_inv_buf
.
untyped_data
());
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans_buf
->
untyped_data
();
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX DGatedActLuCastTranspose primitive."
);
if
(
!
use_fp8
(
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
input_dims
=
input_buf
.
dimensions
();
auto
act_input_dims
=
act_input_buf
.
dimensions
();
auto
act_input_ranks
=
act_input_dims
.
size
();
auto
m
=
product
(
act_input_dims
,
0
,
act_input_ranks
-
2
);
auto
n
=
product
(
act_input_dims
,
act_input_ranks
-
1
,
act_input_ranks
);
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
act_input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
2
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
*
2
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
*
2
,
m
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
act_input_tensor
=
TensorWrapper
(
act_input
,
act_input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
out_dtype
,
amax_out
,
scale
,
scale_inv
);
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
output_tensor
.
set_columnwise_scale_inv
(
scale_inv
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
auto
act_type
=
static_cast
<
NVTE_Activation_Type
>
(
act_enum
);
switch
(
act_type
)
{
case
NVTE_Activation_Type
::
GEGLU
:
nvte_dgeglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_dgeglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SWIGLU
:
nvte_dswiglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_dswiglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
REGLU
:
nvte_dreglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_dreglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
QGEGLU
:
nvte_dqgeglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_dqgeglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
case
NVTE_Activation_Type
::
SREGLU
:
nvte_dsreglu_cast_transpose
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
nvte_dsreglu
(
input_tensor
.
data
(),
act_input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
break
;
default:
NVTE_ERROR
(
"Unsupported ActivationEnum"
);
break
;
}
}
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
D
GatedActLuCastTransposeHandler
,
DGatedActLuCastTranspos
eFFI
,
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
D
ActLuDBiasQuantizeHandler
,
DActLuDBiasQuantiz
eFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
// act_input
.
Arg
<
Buffer_Type
>
()
// amax
.
Arg
<
Buffer_Type
>
()
// act input
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// output_trans
.
Ret
<
Buffer_Type
>
()
// amax_out
.
Ret
<
Buffer_Type
>
()
// colwise output
.
Ret
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// scale_inv colwise
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
bool
>
(
"is_2x"
)
.
Attr
<
bool
>
(
"is_dbias"
)
.
Attr
<
int64_t
>
(
"act_enum"
),
FFI_CudaGraph_Traits
);
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/attention.cpp
View file @
a207db1d
...
...
@@ -301,39 +301,6 @@ static void FusedAttnForwardImpl(
nvte_tensor_pack_destroy
(
&
aux_output_tensors
);
}
void
FusedAttnForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
const
CustomCallFusedAttnDescriptor
&
descriptor
=
*
UnpackOpaque
<
CustomCallFusedAttnDescriptor
>
(
opaque
,
opaque_len
);
auto
is_ragged
=
nvte_get_qkv_format
(
descriptor
.
qkv_layout
)
==
NVTE_QKV_Format
::
NVTE_THD
;
/* Input buffers from XLA */
void
*
q
=
buffers
[
0
];
void
*
k
=
buffers
[
1
];
void
*
v
=
buffers
[
2
];
void
*
bias
=
buffers
[
3
];
void
*
seed
=
buffers
[
4
];
void
*
q_cu_seqlens
=
buffers
[
5
];
void
*
kv_cu_seqlens
=
buffers
[
6
];
void
*
q_seq_offsets
=
is_ragged
?
buffers
[
7
]
:
nullptr
;
void
*
k_seq_offsets
=
is_ragged
?
buffers
[
8
]
:
nullptr
;
/* Output buffer from XLA */
void
*
output
=
buffers
[
9
];
void
*
softmax_aux
=
buffers
[
10
];
void
*
rng_state
=
buffers
[
11
];
void
*
workspace
=
buffers
[
12
];
FusedAttnForwardImpl
(
stream
,
q
,
k
,
v
,
bias
,
seed
,
q_cu_seqlens
,
kv_cu_seqlens
,
q_seq_offsets
,
k_seq_offsets
,
output
,
softmax_aux
,
rng_state
,
workspace
,
descriptor
.
input_batch
,
descriptor
.
bias_batch
,
descriptor
.
q_max_seqlen
,
descriptor
.
kv_max_seqlen
,
descriptor
.
attn_heads
,
descriptor
.
num_gqa_groups
,
descriptor
.
bias_heads
,
descriptor
.
head_dim
,
descriptor
.
max_segments_per_seq
,
descriptor
.
wkspace_size
,
descriptor
.
scaling_factor
,
descriptor
.
dropout_probability
,
descriptor
.
bias_type
,
descriptor
.
mask_type
,
descriptor
.
qkv_layout
,
descriptor
.
dtype
,
descriptor
.
wkspace_dtype
,
descriptor
.
is_training
,
descriptor
.
deterministic
,
descriptor
.
window_size_left
,
descriptor
.
window_size_right
);
}
#define FUSED_ATTN_FFI_GET_ATTRS \
size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \
...
...
@@ -608,45 +575,6 @@ static void FusedAttnBackwardImpl(
nvte_tensor_pack_destroy
(
&
aux_input_tensors
);
}
void
FusedAttnBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
const
CustomCallFusedAttnDescriptor
&
descriptor
=
*
UnpackOpaque
<
CustomCallFusedAttnDescriptor
>
(
opaque
,
opaque_len
);
auto
qkv_layout
=
descriptor
.
qkv_layout
;
auto
is_ragged
=
nvte_get_qkv_format
(
qkv_layout
)
==
NVTE_QKV_Format
::
NVTE_THD
;
/* Input buffers from XLA */
void
*
q
=
buffers
[
0
];
void
*
k
=
buffers
[
1
];
void
*
v
=
buffers
[
2
];
void
*
bias
=
buffers
[
3
];
void
*
softmax_aux
=
buffers
[
4
];
void
*
rng_state
=
buffers
[
5
];
void
*
output
=
buffers
[
6
];
void
*
doutput
=
buffers
[
7
];
void
*
q_cu_seqlens
=
buffers
[
8
];
void
*
kv_cu_seqlens
=
buffers
[
9
];
void
*
q_seq_offsets
=
is_ragged
?
buffers
[
10
]
:
nullptr
;
void
*
k_seq_offsets
=
is_ragged
?
buffers
[
11
]
:
nullptr
;
/* Output buffer from XLA */
void
*
dq
=
buffers
[
12
];
void
*
dk
=
buffers
[
13
];
void
*
dv
=
buffers
[
14
];
void
*
dbias
=
buffers
[
15
];
void
*
workspace
=
buffers
[
16
];
FusedAttnBackwardImpl
(
stream
,
q
,
k
,
v
,
bias
,
softmax_aux
,
rng_state
,
output
,
doutput
,
q_cu_seqlens
,
kv_cu_seqlens
,
q_seq_offsets
,
k_seq_offsets
,
dq
,
dk
,
dv
,
dbias
,
workspace
,
descriptor
.
input_batch
,
descriptor
.
bias_batch
,
descriptor
.
q_max_seqlen
,
descriptor
.
kv_max_seqlen
,
descriptor
.
attn_heads
,
descriptor
.
num_gqa_groups
,
descriptor
.
bias_heads
,
descriptor
.
head_dim
,
descriptor
.
max_segments_per_seq
,
descriptor
.
wkspace_size
,
descriptor
.
scaling_factor
,
descriptor
.
dropout_probability
,
descriptor
.
bias_type
,
descriptor
.
mask_type
,
descriptor
.
qkv_layout
,
descriptor
.
dtype
,
descriptor
.
wkspace_dtype
,
descriptor
.
is_training
,
descriptor
.
deterministic
,
descriptor
.
window_size_left
,
descriptor
.
window_size_right
);
}
Error_Type
FusedAttnBackwardFFI
(
cudaStream_t
stream
,
Buffer_Type
q_buf
,
Buffer_Type
k_buf
,
Buffer_Type
v_buf
,
Buffer_Type
bias_buf
,
Buffer_Type
softmax_aux_buf
,
Buffer_Type
rng_state_buf
,
...
...
transformer_engine/jax/csrc/extensions/cublas.cpp
0 → 100644
View file @
a207db1d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
#include "transformer_engine/gemm.h"
#include "xla/ffi/api/c_api.h"
namespace
transformer_engine
{
namespace
jax
{
Error_Type
CublasHandleInitFFI
(
Variadic_Buffer_Type
args
,
Variadic_Result_Type
rets
,
Dictionary
attrs
)
{
nvte_cublas_handle_init
();
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
CublasHandleInitHandler
,
CublasHandleInitFFI
,
FFI
::
Bind
<
FFI_Prepare
>
().
RemainingArgs
().
RemainingRets
().
Attrs
());
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/ffi.cpp
View file @
a207db1d
...
...
@@ -13,8 +13,9 @@ namespace jax {
// For XLA_FFI_DataType Enum Reference: https://github.com/openxla/xla/blob/d054e8366c4e8807726961feeb28b1cdba681888/xla/ffi/api/c_api.h#L163-L186
DType
convert_ffi_datatype_to_te_dtype
(
const
xla
::
ffi
::
DataType
&
type
)
{
switch
(
type
)
{
// Using this for E8M0
case
xla
::
ffi
::
DataType
::
U8
:
return
DType
::
k
Byte
;
return
DType
::
k
Float8E8M0
;
break
;
case
xla
::
ffi
::
DataType
::
S32
:
return
DType
::
kInt32
;
...
...
@@ -37,8 +38,12 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) {
case
xla
::
ffi
::
DataType
::
F8E4M3FN
:
return
DType
::
kFloat8E4M3
;
break
;
// case xla::ffi::DataType::F8E8M0FNU:
// return DType::kFloat8E8M0;
// break;
default:
auto
type_num
=
static_cast
<
XLA_FFI_DataType
>
(
type
);
if
(
type_num
==
33
)
return
DType
::
kFloat8E8M0
;
NVTE_ERROR
(
"TE does not support conversion of XLA_FFI_DataType %d"
,
static_cast
<
int
>
(
type_num
));
break
;
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment