Unverified Commit 221fedc2 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Exclude GroupedGemm APIs for TE 2.3 (#1737)



* exclude GroupedGemm APIs
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent dac098d8
...@@ -39,7 +39,7 @@ from transformer_engine.jax.quantize import ( ...@@ -39,7 +39,7 @@ from transformer_engine.jax.quantize import (
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.dense import dense
from transformer_engine.jax.layernorm_dense import layernorm_dense from transformer_engine.jax.layernorm_dense import layernorm_dense
from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x from transformer_engine.jax.quantize import ScaledTensor1x, ScaledTensor2x
...@@ -1138,7 +1138,7 @@ fwd_bwd_dtypes = [ ...@@ -1138,7 +1138,7 @@ fwd_bwd_dtypes = [
[jnp.float8_e5m2, jnp.float8_e4m3fn], [jnp.float8_e5m2, jnp.float8_e4m3fn],
] ]
"""
@pytest_parametrize_wrapper( @pytest_parametrize_wrapper(
"shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]] "shape_list", [[(512, 128, 256), (256, 128, 256), (256, 128, 128), (512, 256, 128)]]
) )
...@@ -1353,3 +1353,4 @@ class TestGroupedDense: ...@@ -1353,3 +1353,4 @@ class TestGroupedDense:
assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype) assert_allclose(primitive_dgrad_list[i], ref_dgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype) assert_allclose(primitive_wgrad_list[i], ref_wgrad_list[i], dtype=allclose_dtype)
assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype) assert_allclose(primitive_dbias_list[i], ref_dbias_list[i], dtype=allclose_dtype)
"""
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""JAX te modules""" """JAX te modules"""
from typing import Tuple, Sequence, Union, Dict, List from typing import Tuple, Sequence, Union, Dict
from functools import partial, reduce from functools import partial, reduce
import operator import operator
import jax import jax
...@@ -21,7 +21,7 @@ from ..quantize import ( ...@@ -21,7 +21,7 @@ from ..quantize import (
) )
__all__ = ["gemm", "grouped_gemm"] __all__ = ["gemm"]
num_cublas_streams = 4 num_cublas_streams = 4
...@@ -338,8 +338,9 @@ def gemm( ...@@ -338,8 +338,9 @@ def gemm(
return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set) return _jax_gemm(lhs, rhs, contracting_dims, quantizer_set)
"""
def swizzled_scale(scales): def swizzled_scale(scales):
"""Swizzle the scale tensor for FP8 GEMM""" # Swizzle the scale tensor for FP8 GEMM
assert scales.ndim == 2 assert scales.ndim == 2
rows, cols = scales.shape rows, cols = scales.shape
scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4) scales = scales.reshape(rows // 128, 4, 32, cols // 4, 4)
...@@ -354,7 +355,7 @@ def grouped_gemm( ...@@ -354,7 +355,7 @@ def grouped_gemm(
contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]], contracting_dims_list: List[Tuple[Sequence[int], Sequence[int]]],
bias_list: List[jnp.ndarray] = None, bias_list: List[jnp.ndarray] = None,
) -> List[jnp.ndarray]: ) -> List[jnp.ndarray]:
"""Grouped GEMM for multiple pairs of tensors.""" # Grouped GEMM for multiple pairs of tensors.
assert ( assert (
len(lhs_list) == len(rhs_list) == len(contracting_dims_list) len(lhs_list) == len(rhs_list) == len(contracting_dims_list)
), "lhs_list, rhs_list, contracting_dims_list must have the same length" ), "lhs_list, rhs_list, contracting_dims_list must have the same length"
...@@ -463,3 +464,4 @@ def grouped_gemm( ...@@ -463,3 +464,4 @@ def grouped_gemm(
) )
return out_list return out_list
"""
...@@ -183,6 +183,7 @@ def _dense_bwd_rule( ...@@ -183,6 +183,7 @@ def _dense_bwd_rule(
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
"""
def grouped_dense( def grouped_dense(
x_list, x_list,
kernel_list, kernel_list,
...@@ -190,10 +191,8 @@ def grouped_dense( ...@@ -190,10 +191,8 @@ def grouped_dense(
contracting_dims_list, contracting_dims_list,
quantizer_set_list=None, quantizer_set_list=None,
): ):
""" # Perform grouped_dense layer transformation with optional quantization.
Perform grouped_dense layer transformation with optional quantization.
"""
output_list = _grouped_dense( output_list = _grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
) )
...@@ -315,3 +314,4 @@ def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list): ...@@ -315,3 +314,4 @@ def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment