Unverified Commit 221fedc2 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Exclude GroupedGemm APIs for TE 2.3 (#1737)



* exclude GroupedGemm APIs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dac098d8
......@@ -39,7 +39,7 @@ from transformer_engine.jax.quantize import (
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.dense import dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
......@@ -1138,7 +1138,7 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn],
]
"""
@pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
)
......@@ -1353,3 +1353,4 @@ class TestGroupedDense:
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""JAX te modules"""
from typing import Tuple, Sequence, Union, Dict, List
from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce
import operator
import jax
......@@ -21,7 +21,7 @@ from ..quantize import (
)
__all__ = ["gemm", "grouped_gemm"]
__all__ = ["gemm"]
num_cublas_streams = 4
......@@ -338,8 +338,9 @@ def gemm(
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
"""
def swizzled_scale(scales):
"""Swizzle the scale tensor for FP8 GEMM"""
# Swizzle the scale tensor for FP8 GEMM
assert scales.ndim == 2
rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
......@@ -354,7 +355,7 @@ def grouped_gemm(
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
bias_list: List[jnp.ndarray] = None,
) -> List[jnp.ndarray]:
"""Grouped GEMM for multiple pairs of tensors."""
# Grouped GEMM for multiple pairs of tensors.
assert (
len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length"
......@@ -463,3 +464,4 @@ def grouped_gemm(
)
return out_list
"""
......@@ -183,6 +183,7 @@ def _dense_bwd_rule(
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
"""
def grouped_dense(
x_list,
kernel_list,
......@@ -190,10 +191,8 @@ def grouped_dense(
contracting_dims_list,
quantizer_set_list=None,
):
"""
Perform grouped_dense layer transformation with optional quantization.
# Perform grouped_dense layer transformation with optional quantization.
"""
output_list = _grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
......@@ -315,3 +314,4 @@ def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment