Unverified Commit 27612051 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Support logical partitioning axes in TE Flax modules (#1772)



* [JAX] Update flax module param initialization to support logical partitioning axes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix ffn1 intermediate result being replicated
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add documentation and assert when logical_axes=None
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix bias in LayerNormMLP flax module
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix layer tests to not use nn_partitioning and instead use nn.with_logical_axes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 7c474236
...@@ -609,7 +609,7 @@ class TestEncoder(unittest.TestCase): ...@@ -609,7 +609,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling") result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.505 and result[1] > 0.753
@unittest.skipIf( @unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8" not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
......
...@@ -13,7 +13,6 @@ import jax ...@@ -13,7 +13,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks from flax.linen.attention import combine_masks
from jax import lax, vmap from jax import lax, vmap
from jax import nn as jax_nn from jax import nn as jax_nn
...@@ -316,16 +315,22 @@ class DenseGeneral(nn.Module): ...@@ -316,16 +315,22 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features)) kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes( kernel = self.param(
"kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes "kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_param_shape,
self.dtype,
) )
kernel = jnp.asarray(kernel, input_dtype) kernel = jnp.asarray(kernel, input_dtype)
kernel = jnp.reshape(kernel, kernel_shape) kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = self.param(
"bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes "bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
self.features,
self.dtype,
) )
bias = bias.astype(input_dtype) bias = bias.astype(input_dtype)
else: else:
...@@ -422,9 +427,9 @@ class MlpBlock(nn.Module): ...@@ -422,9 +427,9 @@ class MlpBlock(nn.Module):
) # Broadcast along length. ) # Broadcast along length.
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp")) x = nn.with_logical_constraint(x, ("length", "batch", "mlp"))
else: else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp")) x = nn.with_logical_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral( output = DenseGeneral(
inputs.shape[-1], inputs.shape[-1],
dtype=self.dtype, dtype=self.dtype,
...@@ -688,21 +693,13 @@ class MultiHeadAttention(nn.Module): ...@@ -688,21 +693,13 @@ class MultiHeadAttention(nn.Module):
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim)) value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint( query = nn.with_logical_constraint(query, ("length", "batch", "heads", "kv"))
query, ("length", "batch", "heads", "kv") key = nn.with_logical_constraint(key, ("length", "batch", "heads", "kv"))
) value = nn.with_logical_constraint(value, ("length", "batch", "heads", "kv"))
key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("length", "batch", "heads", "kv")
)
else: else:
query = nn_partitioning.with_sharding_constraint( query = nn.with_logical_constraint(query, ("batch", "length", "heads", "kv"))
query, ("batch", "length", "heads", "kv") key = nn.with_logical_constraint(key, ("batch", "length", "heads", "kv"))
) value = nn.with_logical_constraint(value, ("batch", "length", "heads", "kv"))
key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("batch", "length", "heads", "kv")
)
if decode: if decode:
# Detect if we're initializing by absence of existing cache data. # Detect if we're initializing by absence of existing cache data.
...@@ -809,9 +806,9 @@ class MultiHeadAttention(nn.Module): ...@@ -809,9 +806,9 @@ class MultiHeadAttention(nn.Module):
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence: if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv")) x = nn.with_logical_constraint(x, ("length", "batch", "joined_kv"))
else: else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv")) x = nn.with_logical_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions. # Back to the original inputs dimensions.
...@@ -857,8 +854,11 @@ class LayerNorm(nn.Module): ...@@ -857,8 +854,11 @@ class LayerNorm(nn.Module):
input_dtype = x.dtype input_dtype = x.dtype
features = x.shape[-1] features = x.shape[-1]
scale = nn_partitioning.param_with_axes( scale = self.param(
"scale", self.scale_init, (features,), self.dtype, axes=("embed",) "scale",
nn.with_logical_partitioning(self.scale_init, ("embed",)),
(features,),
self.dtype,
) )
x_ = x.astype(jnp.float32) x_ = x.astype(jnp.float32)
if self.layernorm_type == "layernorm": if self.layernorm_type == "layernorm":
...@@ -866,8 +866,11 @@ class LayerNorm(nn.Module): ...@@ -866,8 +866,11 @@ class LayerNorm(nn.Module):
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
y = (x_ - mean) * lax.rsqrt(var + self.epsilon) y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes( bias = self.param(
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",) "ln_bias",
nn.with_logical_partitioning(self.bias_init, ("embed",)),
(features,),
self.dtype,
) )
bias = jnp.asarray(bias, input_dtype) bias = jnp.asarray(bias, input_dtype)
...@@ -976,12 +979,11 @@ class RelativePositionBiases(nn.Module): ...@@ -976,12 +979,11 @@ class RelativePositionBiases(nn.Module):
num_buckets=self.num_buckets, num_buckets=self.num_buckets,
max_distance=self.max_distance, max_distance=self.max_distance,
) )
relative_attention_bias = nn_partitioning.param_with_axes( relative_attention_bias = self.param(
"rel_embedding", "rel_embedding",
self.embedding_init, nn.with_logical_partitioning(self.embedding_init, ("heads", "relpos_buckets")),
(self.num_heads, self.num_buckets), (self.num_heads, self.num_buckets),
jnp.float32, jnp.float32,
axes=("heads", "relpos_buckets"),
) )
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
...@@ -1559,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"): ...@@ -1559,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"):
""" """
src_values = {} src_values = {}
for key, value in jax.tree_util.tree_leaves_with_path(src): for key, value in jax.tree_util.tree_leaves_with_path(src):
normalized_key = sep.join(x.key for x in key) # Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
src_values[normalized_key] = value src_values[normalized_key] = value
flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst) flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
synced_dst_values = [] synced_dst_values = []
for key, value in flatten_dst: for key, value in flatten_dst:
normalized_key = sep.join(x.key for x in key) # Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
if normalized_key in transformations: if normalized_key in transformations:
corresponding_src_key = transformations[normalized_key] corresponding_src_key = transformations[normalized_key]
else: else:
......
...@@ -49,6 +49,7 @@ def dense( ...@@ -49,6 +49,7 @@ def dense(
""" """
# Remove when tex.quantize() can handle quantizer=None # Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set: if quantizer_set == noop_quantizer_set:
x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims) output = tex.gemm(x, kernel, contracting_dims)
if bias is not None: if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
......
...@@ -11,7 +11,6 @@ from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union ...@@ -11,7 +11,6 @@ from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import numpy as np import numpy as np
import jax.numpy as jnp import jax.numpy as jnp
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax from jax import lax
from jax import random as jax_random from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name from jax.ad_checkpoint import checkpoint_name
...@@ -65,6 +64,7 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga ...@@ -65,6 +64,7 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters( def _create_layernorm_parameters(
module,
norm_type, norm_type,
shape, shape,
scale_init, scale_init,
...@@ -74,13 +74,21 @@ def _create_layernorm_parameters( ...@@ -74,13 +74,21 @@ def _create_layernorm_parameters(
input_dtype, input_dtype,
dtype, dtype,
): ):
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes) scale = module.param(
scale = scale.astype(input_dtype) "scale",
nn.with_logical_partitioning(scale_init, scale_axes),
shape,
dtype,
).astype(input_dtype)
norm_type = canonicalize_norm_type(norm_type) norm_type = canonicalize_norm_type(norm_type)
if norm_type == "layernorm": if norm_type == "layernorm":
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes) bias = module.param(
bias = jnp.asarray(bias, input_dtype) "ln_bias",
nn.with_logical_partitioning(bias_init, bias_axes),
shape,
dtype,
).astype(input_dtype)
else: else:
assert norm_type == "rmsnorm" assert norm_type == "rmsnorm"
bias = None bias = None
...@@ -308,6 +316,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ...@@ -308,6 +316,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
features = x.shape[-1] features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters( scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type, self.layernorm_type,
(features,), (features,),
self.scale_init, self.scale_init,
...@@ -467,16 +476,22 @@ class DenseGeneral(TransformerEngineBase): ...@@ -467,16 +476,22 @@ class DenseGeneral(TransformerEngineBase):
"Expected len(kernel_shape) to match len(kernel_axes)," "Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}" f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
) )
kernel = nn_partitioning.param_with_axes( kernel = self.param(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_shape,
self.dtype,
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = self.param(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
features,
self.dtype,
).astype(input_dtype) ).astype(input_dtype)
else: else:
bias = None bias = None
...@@ -499,25 +514,21 @@ class DenseGeneral(TransformerEngineBase): ...@@ -499,25 +514,21 @@ class DenseGeneral(TransformerEngineBase):
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape) lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes( lora_a_kernel = self.param(
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
lora_a_kernel_shape, lora_a_kernel_shape,
self.dtype, self.dtype,
axes=lora_a_kernel_axes, ).astype(input_dtype)
)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes( lora_b_kernel = self.param(
"lora_b_kernel", "lora_b_kernel",
nn.initializers.zeros, nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
lora_b_kernel_shape, lora_b_kernel_shape,
self.dtype, self.dtype,
axes=lora_b_kernel_axes, ).astype(input_dtype)
)
lora_b_kernel = lora_b_kernel.astype(input_dtype)
y += _apply_low_rank_adaptation( y += _apply_low_rank_adaptation(
inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
...@@ -695,6 +706,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -695,6 +706,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes) inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1] features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters( scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type, self.layernorm_type,
(features,), (features,),
self.scale_init, self.scale_init,
...@@ -730,8 +742,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -730,8 +742,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes( kernel = self.param(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes "kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_shape,
self.dtype,
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype) kernel = kernel.astype(input_dtype)
...@@ -770,25 +785,21 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -770,25 +785,21 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape) lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes( lora_a_kernel = self.param(
"lora_a_kernel", "lora_a_kernel",
self.kernel_init, nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
lora_a_kernel_shape, lora_a_kernel_shape,
self.dtype, self.dtype,
axes=lora_a_kernel_axes, ).astype(input_dtype)
)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1]) lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape) lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes( lora_b_kernel = self.param(
"lora_b_kernel", "lora_b_kernel",
nn.initializers.zeros, nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
lora_b_kernel_shape, lora_b_kernel_shape,
self.dtype, self.dtype,
axes=lora_b_kernel_axes, ).astype(input_dtype)
)
lora_b_kernel = lora_b_kernel.astype(input_dtype)
z += _apply_low_rank_adaptation( z += _apply_low_rank_adaptation(
y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
...@@ -796,8 +807,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): ...@@ -796,8 +807,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None bias = None
if self.use_bias: if self.use_bias:
bias = nn_partitioning.param_with_axes( bias = self.param(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes "bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
features,
self.dtype,
).astype(input_dtype) ).astype(input_dtype)
if bias is not None: if bias is not None:
...@@ -1028,6 +1042,7 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1028,6 +1042,7 @@ class LayerNormMLP(TransformerEngineBase):
features = inputs.shape[-1] features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters( scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type, self.layernorm_type,
(features,), (features,),
self.scale_init, self.scale_init,
...@@ -1067,14 +1082,13 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1067,14 +1082,13 @@ class LayerNormMLP(TransformerEngineBase):
axis = _canonicalize_tuple(self.axis) axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim) axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim) kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes( kernel_1 = self.param(
"wi_kernel", "wi_kernel",
kernel_1_init, nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
num_activations, num_activations,
-2, -2,
kernel_1_each_shape, kernel_1_each_shape,
self.dtype, self.dtype,
axes=self.kernel_axes_1,
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
...@@ -1083,12 +1097,11 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1083,12 +1097,11 @@ class LayerNormMLP(TransformerEngineBase):
hidden_size = inputs.shape[-1] hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size) hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
kernel_2 = nn_partitioning.param_with_axes( kernel_2 = self.param(
"wo_kernel", "wo_kernel",
self.kernel_init, nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
kernel_2_shape, kernel_2_shape,
self.dtype, self.dtype,
axes=self.kernel_axes_2,
) )
if not QuantizeConfig.is_fp8_enabled(): if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype)
...@@ -1097,21 +1110,19 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1097,21 +1110,19 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias: if self.use_bias:
bias_1_shape = (num_activations, self.intermediate_dim) bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes( bias_1 = self.param(
"wi_bias", "wi_bias",
self.bias_init, nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
bias_1_shape, bias_1_shape,
self.dtype, self.dtype,
axes=self.bias_axes_1,
).astype(input_dtype) ).astype(input_dtype)
bias_2_shape = (hidden_size,) bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes( bias_2 = self.param(
"wo_bias", "wo_bias",
self.bias_init, nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
bias_2_shape, bias_2_shape,
self.dtype, self.dtype,
axes=self.bias_axes_2,
).astype(input_dtype) ).astype(input_dtype)
else: else:
bias_1 = None bias_1 = None
...@@ -1168,9 +1179,13 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1168,9 +1179,13 @@ class LayerNormMLP(TransformerEngineBase):
kernel_axes=self.kernel_axes_1, kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set, quantizer_set=ffn1_quantizer_set,
) )
if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
dot_1_output_axes = ( dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis), *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind), *get_non_contracting_logical_axes(
kernel_1.ndim, self.kernel_axes_1, contract_ind
),
) )
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes) x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
...@@ -1180,16 +1195,14 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1180,16 +1195,14 @@ class LayerNormMLP(TransformerEngineBase):
self.low_rank_adaptation_dim, self.low_rank_adaptation_dim,
) )
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1) wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
wi_lora_a_kernel = nn_partitioning.param_with_axes( wi_lora_a_kernel = self.param(
"wi_lora_a_kernel", "wi_lora_a_kernel",
kernel_1_init, nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
num_activations, num_activations,
-2, -2,
wi_lora_a_kernel_each_shape, wi_lora_a_kernel_each_shape,
self.dtype, self.dtype,
axes=wi_lora_a_kernel_axes, ).astype(input_dtype)
)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
wi_lora_b_kernel_shape = ( wi_lora_b_kernel_shape = (
num_activations, num_activations,
...@@ -1197,14 +1210,12 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1197,14 +1210,12 @@ class LayerNormMLP(TransformerEngineBase):
self.intermediate_dim, self.intermediate_dim,
) )
wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape) wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
wi_lora_b_kernel = nn_partitioning.param_with_axes( wi_lora_b_kernel = self.param(
"wi_lora_b_kernel", "wi_lora_b_kernel",
nn.initializers.zeros, nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
wi_lora_b_kernel_shape, wi_lora_b_kernel_shape,
self.dtype, self.dtype,
axes=wi_lora_b_kernel_axes, ).astype(input_dtype)
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype)
x += _apply_low_rank_adaptation( x += _apply_low_rank_adaptation(
y, y,
...@@ -1253,25 +1264,21 @@ class LayerNormMLP(TransformerEngineBase): ...@@ -1253,25 +1264,21 @@ class LayerNormMLP(TransformerEngineBase):
if self.enable_low_rank_adaptation: if self.enable_low_rank_adaptation:
wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim) wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape) wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
wo_lora_a_kernel = nn_partitioning.param_with_axes( wo_lora_a_kernel = self.param(
"wo_lora_a_kernel", "wo_lora_a_kernel",
self.kernel_init, nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
wo_lora_a_kernel_shape, wo_lora_a_kernel_shape,
self.dtype, self.dtype,
axes=wo_lora_a_kernel_axes, ).astype(input_dtype)
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype)
wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size) wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape) wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
wo_lora_b_kernel = nn_partitioning.param_with_axes( wo_lora_b_kernel = self.param(
"wo_lora_b_kernel", "wo_lora_b_kernel",
nn.initializers.zeros, nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
wo_lora_b_kernel_shape, wo_lora_b_kernel_shape,
self.dtype, self.dtype,
axes=wo_lora_b_kernel_axes, ).astype(input_dtype)
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype)
out += _apply_low_rank_adaptation( out += _apply_low_rank_adaptation(
z, z,
......
...@@ -15,7 +15,6 @@ import jax ...@@ -15,7 +15,6 @@ import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np import numpy as np
from flax import linen as nn from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks from flax.linen.attention import combine_masks
from jax import nn as jax_nn from jax import nn as jax_nn
from jax import random as jax_random from jax import random as jax_random
...@@ -1503,12 +1502,11 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho ...@@ -1503,12 +1502,11 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large) rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)
# Compute relative attention bias # Compute relative attention bias
relative_attention_bias = nn_partitioning.param_with_axes( relative_attention_bias = self.param(
"rel_embedding", "rel_embedding",
self.embedding_init, nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
(self.num_attention_heads, self.num_buckets), (self.num_attention_heads, self.num_buckets),
self.dtype, self.dtype,
axes=self.embedding_axes,
) )
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype) relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
......
...@@ -275,6 +275,7 @@ def _layernorm_mlp_fwd_rule( ...@@ -275,6 +275,7 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
if dot_1_input_axes is not None and kernel_1_axes is not None:
dot_1_output_axes = ( dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
...@@ -303,12 +304,6 @@ def _layernorm_mlp_fwd_rule( ...@@ -303,12 +304,6 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims), (x_contracting_dims, k_contracting_dims),
) )
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2: if use_bias_2:
bias_2_shape = bias_2.shape bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
......
...@@ -13,7 +13,7 @@ import os ...@@ -13,7 +13,7 @@ import os
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable from typing import Callable, Optional
from jax.interpreters import pxla from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -112,9 +112,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): ...@@ -112,9 +112,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
return jax.lax.with_sharding_constraint(x, pspec) return jax.lax.with_sharding_constraint(x, pspec)
def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list): def with_sharding_constraint_by_logical_axes(
x: jnp.array, logical_axis_names: Optional[tuple | list]
):
""" """
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes. A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
If logical_axis_names = None, this means no sharding constraint is applied.
If logical_axis_names = (None, None, ...), this means a sharding constraint is applied and the tensor is replicated across all devices.
Args:
x: Input tensor to apply sharding constraint
logical_axis_names: Logical axis names to apply sharding constraint
Returns:
Tensor with sharding constraint applied, or the original tensor if no logical axes are provided.
""" """
if not logical_axis_names: if not logical_axis_names:
return x return x
...@@ -321,7 +334,9 @@ class ShardingType(Enum): ...@@ -321,7 +334,9 @@ class ShardingType(Enum):
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims): def get_non_contracting_logical_axes(
ndim, logical_axes: tuple[Optional[str]], contracting_dims
) -> tuple[Optional[str]]:
"""Get logical axes for non-contracting dimensions. """Get logical axes for non-contracting dimensions.
Args: Args:
...@@ -332,11 +347,8 @@ def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims): ...@@ -332,11 +347,8 @@ def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
Returns: Returns:
Tuple of logical axes for non-contracting dimensions. Tuple of logical axes for non-contracting dimensions.
""" """
if not logical_axes: assert logical_axes is not None, "Logical axes must be a tuple and cannot be None."
logical_axes = (None,) * ndim assert len(logical_axes) == ndim, "Logical axes must match the number of dimensions."
elif len(logical_axes) < ndim:
logical_axes = logical_axes + (None,) * (ndim - len(logical_axes))
assert len(logical_axes) == ndim
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims] non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims) non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment