Unverified Commit ed42b5ac authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Remove `dot_1_output_axes` usage in LayerNormMLP (#2029)



* remove dot1_output_axes
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 7101f4be
...@@ -31,7 +31,6 @@ from ..cpp_extensions import ( ...@@ -31,7 +31,6 @@ from ..cpp_extensions import (
jax_scaled_upper_triang_masked_softmax, jax_scaled_upper_triang_masked_softmax,
) )
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..sharding import get_non_contracting_logical_axes
PRNGKey = Any PRNGKey = Any
Shape = Tuple[int, ...] Shape = Tuple[int, ...]
...@@ -1206,15 +1205,6 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1206,15 +1205,6 @@ class LayerNormMLP(TransformerEngineBase):
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(
kernel_1.ndim, self.kernel_axes_1, contract_ind
),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
wi_lora_a_kernel_each_shape = ( wi_lora_a_kernel_each_shape = (
kernel_1_each_shape[: len(axis)], kernel_1_each_shape[: len(axis)],
......
...@@ -30,7 +30,6 @@ from .quantize import ( ...@@ -30,7 +30,6 @@ from .quantize import (
TensorUsage, TensorUsage,
) )
from .sharding import ( from .sharding import (
get_non_contracting_logical_axes,
get_sequence_parallel_dim, get_sequence_parallel_dim,
) )
...@@ -259,7 +258,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -259,7 +258,7 @@ def _layernorm_mlp_fwd_rule(
Returns: Returns:
Tuple of (output, context) for automatic differentiation Tuple of (output, context) for automatic differentiation
""" """
del kernel_2_axes del kernel_1_axes, kernel_2_axes
ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets
...@@ -318,13 +317,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -318,13 +317,6 @@ def _layernorm_mlp_fwd_rule(
fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False,
) )
if dot_1_input_axes is not None and kernel_1_axes is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
)
dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes)
if use_bias_1 and tex.gemm_uses_jax_dot(): if use_bias_1 and tex.gemm_uses_jax_dot():
bias_1_shape = bias_1.shape bias_1_shape = bias_1.shape
bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape
......
...@@ -427,24 +427,3 @@ class ShardingType(Enum): ...@@ -427,24 +427,3 @@ class ShardingType(Enum):
TP_ROW = (MajorShardingType.TP, "tp_row") TP_ROW = (MajorShardingType.TP, "tp_row")
DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col")
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(
ndim, logical_axes: tuple[Optional[str]], contracting_dims
) -> tuple[Optional[str]]:
"""Get logical axes for non-contracting dimensions.
Args:
ndim: Number of dimensions in the tensor.
logical_axes: Tuple of logical axes for each dimension.
contracting_dims: Set of dimensions that are being contracted.
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
assert logical_axes is not None, "Logical axes must be a tuple and cannot be None."
assert len(logical_axes) == ndim, "Logical axes must match the number of dimensions."
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
return non_contracting_logical_axes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment