Unverified Commit 27612051 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Support logical partitioning axes in TE Flax modules (#1772)



* [JAX] Update flax module param initialization to support logical partitioning axes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix ffn1 intermediate result being replicated
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add documentation and assert when logical_axes=None
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix bias in LayerNormMLP flax module
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix layer tests to not use nn_partitioning and instead use nn.with_logical_axes
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 7c474236
......@@ -609,7 +609,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.754
assert result[0] < 0.505 and result[1] > 0.753
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
......
......@@ -13,7 +13,6 @@ import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import lax, vmap
from jax import nn as jax_nn
......@@ -316,16 +315,22 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_param_shape, self.dtype, axes=self.kernel_axes
kernel = self.param(
"kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_param_shape,
self.dtype,
)
kernel = jnp.asarray(kernel, input_dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, self.features, self.dtype, axes=self.bias_axes
bias = self.param(
"bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
self.features,
self.dtype,
)
bias = bias.astype(input_dtype)
else:
......@@ -422,9 +427,9 @@ class MlpBlock(nn.Module):
) # Broadcast along length.
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
x = nn.with_logical_constraint(x, ("length", "batch", "mlp"))
else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
x = nn.with_logical_constraint(x, ("batch", "length", "mlp"))
output = DenseGeneral(
inputs.shape[-1],
dtype=self.dtype,
......@@ -688,21 +693,13 @@ class MultiHeadAttention(nn.Module):
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence:
query = nn_partitioning.with_sharding_constraint(
query, ("length", "batch", "heads", "kv")
)
key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("length", "batch", "heads", "kv")
)
query = nn.with_logical_constraint(query, ("length", "batch", "heads", "kv"))
key = nn.with_logical_constraint(key, ("length", "batch", "heads", "kv"))
value = nn.with_logical_constraint(value, ("length", "batch", "heads", "kv"))
else:
query = nn_partitioning.with_sharding_constraint(
query, ("batch", "length", "heads", "kv")
)
key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
value = nn_partitioning.with_sharding_constraint(
value, ("batch", "length", "heads", "kv")
)
query = nn.with_logical_constraint(query, ("batch", "length", "heads", "kv"))
key = nn.with_logical_constraint(key, ("batch", "length", "heads", "kv"))
value = nn.with_logical_constraint(value, ("batch", "length", "heads", "kv"))
if decode:
# Detect if we're initializing by absence of existing cache data.
......@@ -809,9 +806,9 @@ class MultiHeadAttention(nn.Module):
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence:
x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
x = nn.with_logical_constraint(x, ("length", "batch", "joined_kv"))
else:
x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
x = nn.with_logical_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions.
......@@ -857,8 +854,11 @@ class LayerNorm(nn.Module):
input_dtype = x.dtype
features = x.shape[-1]
scale = nn_partitioning.param_with_axes(
"scale", self.scale_init, (features,), self.dtype, axes=("embed",)
scale = self.param(
"scale",
nn.with_logical_partitioning(self.scale_init, ("embed",)),
(features,),
self.dtype,
)
x_ = x.astype(jnp.float32)
if self.layernorm_type == "layernorm":
......@@ -866,8 +866,11 @@ class LayerNorm(nn.Module):
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
y = (x_ - mean) * lax.rsqrt(var + self.epsilon)
bias = nn_partitioning.param_with_axes(
"ln_bias", self.bias_init, (features,), self.dtype, axes=("embed",)
bias = self.param(
"ln_bias",
nn.with_logical_partitioning(self.bias_init, ("embed",)),
(features,),
self.dtype,
)
bias = jnp.asarray(bias, input_dtype)
......@@ -976,12 +979,11 @@ class RelativePositionBiases(nn.Module):
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
relative_attention_bias = nn_partitioning.param_with_axes(
relative_attention_bias = self.param(
"rel_embedding",
self.embedding_init,
nn.with_logical_partitioning(self.embedding_init, ("heads", "relpos_buckets")),
(self.num_heads, self.num_buckets),
jnp.float32,
axes=("heads", "relpos_buckets"),
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
......@@ -1559,14 +1561,16 @@ def sync_params_values(dst, src, transformations, sep="/"):
"""
src_values = {}
for key, value in jax.tree_util.tree_leaves_with_path(src):
normalized_key = sep.join(x.key for x in key)
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
src_values[normalized_key] = value
flatten_dst, dst_tree_def = jax.tree_util.tree_flatten_with_path(dst)
synced_dst_values = []
for key, value in flatten_dst:
normalized_key = sep.join(x.key for x in key)
# Only select DictKey(key="...") entries, skip GetAttr(name="...") entries at the end of the tree path
normalized_key = sep.join(x.key for x in key if hasattr(x, "key"))
if normalized_key in transformations:
corresponding_src_key = transformations[normalized_key]
else:
......
......@@ -49,6 +49,7 @@ def dense(
"""
# Remove when tex.quantize() can handle quantizer=None
if quantizer_set == noop_quantizer_set:
x = with_sharding_constraint_by_logical_axes(x, input_axes)
output = tex.gemm(x, kernel, contracting_dims)
if bias is not None:
bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape
......
......@@ -11,7 +11,6 @@ from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import numpy as np
import jax.numpy as jnp
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from jax import lax
from jax import random as jax_random
from jax.ad_checkpoint import checkpoint_name
......@@ -65,6 +64,7 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def _create_layernorm_parameters(
module,
norm_type,
shape,
scale_init,
......@@ -74,13 +74,21 @@ def _create_layernorm_parameters(
input_dtype,
dtype,
):
scale = nn_partitioning.param_with_axes("scale", scale_init, shape, dtype, axes=scale_axes)
scale = scale.astype(input_dtype)
scale = module.param(
"scale",
nn.with_logical_partitioning(scale_init, scale_axes),
shape,
dtype,
).astype(input_dtype)
norm_type = canonicalize_norm_type(norm_type)
if norm_type == "layernorm":
bias = nn_partitioning.param_with_axes("ln_bias", bias_init, shape, dtype, axes=bias_axes)
bias = jnp.asarray(bias, input_dtype)
bias = module.param(
"ln_bias",
nn.with_logical_partitioning(bias_init, bias_axes),
shape,
dtype,
).astype(input_dtype)
else:
assert norm_type == "rmsnorm"
bias = None
......@@ -308,6 +316,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
features = x.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type,
(features,),
self.scale_init,
......@@ -467,16 +476,22 @@ class DenseGeneral(TransformerEngineBase):
"Expected len(kernel_shape) to match len(kernel_axes),"
f"got kernel_shape {kernel_shape} and kernel_axes {self.kernel_axes}"
)
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
kernel = self.param(
"kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
bias = self.param(
"bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
features,
self.dtype,
).astype(input_dtype)
else:
bias = None
......@@ -499,25 +514,21 @@ class DenseGeneral(TransformerEngineBase):
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
lora_a_kernel = self.param(
"lora_a_kernel",
self.kernel_init,
nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
).astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes(
lora_b_kernel = self.param(
"lora_b_kernel",
nn.initializers.zeros,
nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
lora_b_kernel_shape,
self.dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(input_dtype)
).astype(input_dtype)
y += _apply_low_rank_adaptation(
inputs, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
......@@ -695,6 +706,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
inputs = with_sharding_constraint_by_logical_axes(inputs, self.layernorm_input_axes)
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type,
(features,),
self.scale_init,
......@@ -730,8 +742,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis = _normalize_axes(axis, y.ndim)
kernel_shape = (np.prod([inputs.shape[ax] for ax in axis]),) + features
kernel = nn_partitioning.param_with_axes(
"kernel", self.kernel_init, kernel_shape, self.dtype, axes=self.kernel_axes
kernel = self.param(
"kernel",
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes),
kernel_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
kernel = kernel.astype(input_dtype)
......@@ -770,25 +785,21 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self.low_rank_adaptation_dim,
)
lora_a_kernel_axes = (None,) * len(lora_a_kernel_shape)
lora_a_kernel = nn_partitioning.param_with_axes(
lora_a_kernel = self.param(
"lora_a_kernel",
self.kernel_init,
nn.with_logical_partitioning(self.kernel_init, lora_a_kernel_axes),
lora_a_kernel_shape,
self.dtype,
axes=lora_a_kernel_axes,
)
lora_a_kernel = lora_a_kernel.astype(input_dtype)
).astype(input_dtype)
lora_b_kernel_shape = (*features[:-1], self.low_rank_adaptation_dim, features[-1])
lora_b_kernel_axes = (None,) * len(lora_b_kernel_shape)
lora_b_kernel = nn_partitioning.param_with_axes(
lora_b_kernel = self.param(
"lora_b_kernel",
nn.initializers.zeros,
nn.with_logical_partitioning(nn.initializers.zeros, lora_b_kernel_axes),
lora_b_kernel_shape,
self.dtype,
axes=lora_b_kernel_axes,
)
lora_b_kernel = lora_b_kernel.astype(input_dtype)
).astype(input_dtype)
z += _apply_low_rank_adaptation(
y, axis, features, lora_a_kernel, lora_b_kernel, self.low_rank_adaptation_alpha
......@@ -796,8 +807,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias = None
if self.use_bias:
bias = nn_partitioning.param_with_axes(
"bias", self.bias_init, features, self.dtype, axes=self.bias_axes
bias = self.param(
"bias",
nn.with_logical_partitioning(self.bias_init, self.bias_axes),
features,
self.dtype,
).astype(input_dtype)
if bias is not None:
......@@ -1028,6 +1042,7 @@ class LayerNormMLP(TransformerEngineBase):
features = inputs.shape[-1]
scale, ln_bias = _create_layernorm_parameters(
self,
self.layernorm_type,
(features,),
self.scale_init,
......@@ -1067,14 +1082,13 @@ class LayerNormMLP(TransformerEngineBase):
axis = _canonicalize_tuple(self.axis)
axis = _normalize_axes(axis, y.ndim)
kernel_1_each_shape = (np.prod([inputs.shape[ax] for ax in axis]), self.intermediate_dim)
kernel_1 = nn_partitioning.param_with_axes(
kernel_1 = self.param(
"wi_kernel",
kernel_1_init,
nn.with_logical_partitioning(kernel_1_init, self.kernel_axes_1),
num_activations,
-2,
kernel_1_each_shape,
self.dtype,
axes=self.kernel_axes_1,
)
if not QuantizeConfig.is_fp8_enabled():
......@@ -1083,12 +1097,11 @@ class LayerNormMLP(TransformerEngineBase):
hidden_size = inputs.shape[-1]
hidden_size_tuple = _canonicalize_tuple(hidden_size)
kernel_2_shape = (self.intermediate_dim,) + hidden_size_tuple
kernel_2 = nn_partitioning.param_with_axes(
kernel_2 = self.param(
"wo_kernel",
self.kernel_init,
nn.with_logical_partitioning(self.kernel_init, self.kernel_axes_2),
kernel_2_shape,
self.dtype,
axes=self.kernel_axes_2,
)
if not QuantizeConfig.is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)
......@@ -1097,21 +1110,19 @@ class LayerNormMLP(TransformerEngineBase):
if self.use_bias:
bias_1_shape = (num_activations, self.intermediate_dim)
bias_1 = nn_partitioning.param_with_axes(
bias_1 = self.param(
"wi_bias",
self.bias_init,
nn.with_logical_partitioning(self.bias_init, self.bias_axes_1),
bias_1_shape,
self.dtype,
axes=self.bias_axes_1,
).astype(input_dtype)
bias_2_shape = (hidden_size,)
bias_2 = nn_partitioning.param_with_axes(
bias_2 = self.param(
"wo_bias",
self.bias_init,
nn.with_logical_partitioning(self.bias_init, self.bias_axes_2),
bias_2_shape,
self.dtype,
axes=self.bias_axes_2,
).astype(input_dtype)
else:
bias_1 = None
......@@ -1168,9 +1179,13 @@ class LayerNormMLP(TransformerEngineBase):
kernel_axes=self.kernel_axes_1,
quantizer_set=ffn1_quantizer_set,
)
if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis),
*get_non_contracting_logical_axes(kernel_1.ndim, self.kernel_axes_1, contract_ind),
*get_non_contracting_logical_axes(
kernel_1.ndim, self.kernel_axes_1, contract_ind
),
)
x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes)
......@@ -1180,16 +1195,14 @@ class LayerNormMLP(TransformerEngineBase):
self.low_rank_adaptation_dim,
)
wi_lora_a_kernel_axes = (None,) * len(wi_lora_a_kernel_each_shape + 1)
wi_lora_a_kernel = nn_partitioning.param_with_axes(
wi_lora_a_kernel = self.param(
"wi_lora_a_kernel",
kernel_1_init,
nn.with_logical_partitioning(kernel_1_init, wi_lora_a_kernel_axes),
num_activations,
-2,
wi_lora_a_kernel_each_shape,
self.dtype,
axes=wi_lora_a_kernel_axes,
)
wi_lora_a_kernel = wi_lora_a_kernel.astype(input_dtype)
).astype(input_dtype)
wi_lora_b_kernel_shape = (
num_activations,
......@@ -1197,14 +1210,12 @@ class LayerNormMLP(TransformerEngineBase):
self.intermediate_dim,
)
wi_lora_b_kernel_axes = (None,) * len(wi_lora_b_kernel_shape)
wi_lora_b_kernel = nn_partitioning.param_with_axes(
wi_lora_b_kernel = self.param(
"wi_lora_b_kernel",
nn.initializers.zeros,
nn.with_logical_partitioning(nn.initializers.zeros, wi_lora_b_kernel_axes),
wi_lora_b_kernel_shape,
self.dtype,
axes=wi_lora_b_kernel_axes,
)
wi_lora_b_kernel = wi_lora_b_kernel.astype(input_dtype)
).astype(input_dtype)
x += _apply_low_rank_adaptation(
y,
......@@ -1253,25 +1264,21 @@ class LayerNormMLP(TransformerEngineBase):
if self.enable_low_rank_adaptation:
wo_lora_a_kernel_shape = (self.intermediate_dim, self.low_rank_adaptation_dim)
wo_lora_a_kernel_axes = (None,) * len(wo_lora_a_kernel_shape)
wo_lora_a_kernel = nn_partitioning.param_with_axes(
wo_lora_a_kernel = self.param(
"wo_lora_a_kernel",
self.kernel_init,
nn.with_logical_partitioning(self.kernel_init, wo_lora_a_kernel_axes),
wo_lora_a_kernel_shape,
self.dtype,
axes=wo_lora_a_kernel_axes,
)
wo_lora_a_kernel = wo_lora_a_kernel.astype(input_dtype)
).astype(input_dtype)
wo_lora_b_kernel_shape = (self.low_rank_adaptation_dim, hidden_size)
wo_lora_b_kernel_axes = (None,) * len(wo_lora_b_kernel_shape)
wo_lora_b_kernel = nn_partitioning.param_with_axes(
wo_lora_b_kernel = self.param(
"wo_lora_b_kernel",
nn.initializers.zeros,
nn.with_logical_partitioning(nn.initializers.zeros, wo_lora_b_kernel_axes),
wo_lora_b_kernel_shape,
self.dtype,
axes=wo_lora_b_kernel_axes,
)
wo_lora_b_kernel = wo_lora_b_kernel.astype(input_dtype)
).astype(input_dtype)
out += _apply_low_rank_adaptation(
z,
......
......@@ -15,7 +15,6 @@ import jax
import jax.numpy as jnp
import numpy as np
from flax import linen as nn
from flax.linen import partitioning as nn_partitioning
from flax.linen.attention import combine_masks
from jax import nn as jax_nn
from jax import random as jax_random
......@@ -1503,12 +1502,11 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
rp_bucket += np.where(rpb_is_small, negative_rp, rpb_val_if_large)
# Compute relative attention bias
relative_attention_bias = nn_partitioning.param_with_axes(
relative_attention_bias = self.param(
"rel_embedding",
self.embedding_init,
nn.with_logical_partitioning(self.embedding_init, self.embedding_axes),
(self.num_attention_heads, self.num_buckets),
self.dtype,
axes=self.embedding_axes,
)
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
......
......@@ -275,6 +275,7 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims),
)
if dot_1_input_axes is not None and kernel_1_axes is not None:
dot_1_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims),
......@@ -303,12 +304,6 @@ def _layernorm_mlp_fwd_rule(
(x_contracting_dims, k_contracting_dims),
)
dot_2_output_axes = (
*get_non_contracting_logical_axes(x.ndim, dot_2_input_axes, x_contracting_dims),
*get_non_contracting_logical_axes(kernel_2.ndim, None, k_contracting_dims),
)
dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_2_output_axes)
if use_bias_2:
bias_2_shape = bias_2.shape
bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape
......
......@@ -13,7 +13,7 @@ import os
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum
from typing import Callable
from typing import Callable, Optional
from jax.interpreters import pxla
import jax
import jax.numpy as jnp
......@@ -112,9 +112,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
return jax.lax.with_sharding_constraint(x, pspec)
def with_sharding_constraint_by_logical_axes(x: jnp.array, logical_axis_names: tuple | list):
def with_sharding_constraint_by_logical_axes(
x: jnp.array, logical_axis_names: Optional[tuple | list]
):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
If logical_axis_names = None, this means no sharding constraint is applied.
If logical_axis_names = (None, None, ...), this means a sharding constraint is applied and the tensor is replicated across all devices.
Args:
x: Input tensor to apply sharding constraint
logical_axis_names: Logical axis names to apply sharding constraint
Returns:
Tensor with sharding constraint applied, or the original tensor if no logical axes are provided.
"""
if not logical_axis_names:
return x
......@@ -321,7 +334,9 @@ class ShardingType(Enum):
DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row")
def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
def get_non_contracting_logical_axes(
ndim, logical_axes: tuple[Optional[str]], contracting_dims
) -> tuple[Optional[str]]:
"""Get logical axes for non-contracting dimensions.
Args:
......@@ -332,11 +347,8 @@ def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if not logical_axes:
logical_axes = (None,) * ndim
elif len(logical_axes) < ndim:
logical_axes = logical_axes + (None,) * (ndim - len(logical_axes))
assert len(logical_axes) == ndim
assert logical_axes is not None, "Logical axes must be a tuple and cannot be None."
assert len(logical_axes) == ndim, "Logical axes must match the number of dimensions."
non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims]
non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment