Unverified Commit 0db0f4d2 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fix for GEMM + fuse bias + AllReduce (#2230)



* not fuse bias for output all reduction case + unit tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* norm to reduce dgamma along tpsp as well
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix test_distributed_layernorm byte counts
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* increase tols for jax_gemm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 7e45be73
......@@ -17,14 +17,6 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2")
)
if is_devices_enough(4):
configs.append(
pytest.param(
......@@ -32,10 +24,17 @@ def generate_configs():
(2, 2),
("dp", "tpsp"),
MeshResource(dp_resource="dp", tpsp_resource="tpsp"),
id=f"n4_dp2_tp2",
id="n4_dp2_tp2",
)
)
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"),
)
return configs
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from functools import partial
from distributed_test_base import generate_configs
from utils import assert_allclose, pytest_parametrize_wrapper
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.dense import dense
DTYPES = [jnp.bfloat16]
GEMM_INPUT_SHAPES = [[256, 128, 256]] # [batch, seq_len, hidden_in]
WEIGHT_SHAPES = [[256, 256]] # [hidden_in, hidden_out]
def _generate_inputs(input_shape, weight_shape, dtype):
"""Generate test inputs for GEMM operations"""
_, _, hidden_in = input_shape
hidden_in_w, hidden_out = weight_shape
assert hidden_in == hidden_in_w, f"Dimension mismatch: {hidden_in} != {hidden_in_w}"
bias_shape = (hidden_out,)
# Generate random inputs
x = random.normal(random.PRNGKey(1124), input_shape, dtype=dtype)
weight = random.normal(random.PRNGKey(2248), weight_shape, dtype=dtype) / jnp.sqrt(hidden_in_w)
bias = random.normal(random.PRNGKey(3372), bias_shape, dtype=dtype) / jnp.sqrt(hidden_out)
return x, weight, bias
def _get_sharding_for_gemm(mesh, mesh_resource, partition_layout="rowwise"):
"""Get sharding patterns for GEMM inputs and outputs"""
dp_axis = mesh_resource.dp_resource
tp_axis = mesh_resource.tpsp_resource
if partition_layout == "colwise":
x_spec = PartitionSpec(dp_axis, None, None)
weight_spec = PartitionSpec(None, tp_axis)
bias_spec = PartitionSpec(tp_axis)
output_spec = PartitionSpec(dp_axis, None, tp_axis)
elif partition_layout == "rowwise":
x_spec = PartitionSpec(dp_axis, None, tp_axis)
weight_spec = PartitionSpec(tp_axis, None)
bias_spec = PartitionSpec(None)
output_spec = PartitionSpec(dp_axis, None, None)
else:
raise ValueError(f"Invalid partition: {partition_layout}")
x_sharding = NamedSharding(mesh, x_spec)
weight_sharding = NamedSharding(mesh, weight_spec)
bias_sharding = NamedSharding(mesh, bias_spec)
output_sharding = NamedSharding(mesh, output_spec)
return x_sharding, weight_sharding, bias_sharding, output_sharding
@partial(jax.jit, static_argnames=("contracting_dims", "output_sharding"))
def _jitted_gemm(x, weight, bias, contracting_dims, output_sharding):
output = tex.gemm(
x,
weight,
bias=bias,
contracting_dims=contracting_dims,
fuse_bias=True,
)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return output
# TODO(Phuong):
# 1. Add supported recipes after FP4 is added
# 2. Add communication type/byte checks
class TestDistributedDense:
"""Test distributed GEMM without collective operations vs JAX dot"""
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_configs(),
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES)
@pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES)
@pytest_parametrize_wrapper("partition", ["rowwise", "colwise"])
def test_distributed_gemm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
dtype,
input_shape,
weight_shape,
partition,
):
"""Test TE GEMM against JAX dot with bf16 dtype"""
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
# Generate inputs
x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype)
# Get sharding patterns
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm(
mesh, mesh_resource, partition_layout=partition
)
# Shard inputs
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
# TE GEMM result
te_result = _jitted_gemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=contracting_dims,
output_sharding=output_sharding,
)
# JAX dot reference result
jax_result = (
jax.lax.dot_general(
x_sharded, weight_sharded, dimension_numbers=(contracting_dims, ((), ()))
)
+ bias_sharded
)
assert te_result.sharding == jax_result.sharding
# Ensure computation is complete
jax.block_until_ready(te_result)
jax.block_until_ready(jax_result)
# Gather results for comparison
gathered_te = jax.lax.with_sharding_constraint(
te_result, NamedSharding(mesh, PartitionSpec(None))
)
gathered_jax = jax.lax.with_sharding_constraint(
jax_result, NamedSharding(mesh, PartitionSpec(None))
)
# Compare results
assert_allclose(gathered_te, gathered_jax, dtype=dtype)
def _te_sum_dense(self, x, weight, bias, contracting_dims):
"""TE GEMM function for gradient testing"""
return jnp.sum(dense(x, weight, bias=bias, contracting_dims=contracting_dims))
def _jax_sum_dense(self, x, weight, bias, contracting_dims):
"""JAX dot function for gradient testing"""
result = (
jax.lax.dot_general(x, weight, dimension_numbers=(contracting_dims, ((), ()))) + bias
)
return jnp.sum(result)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_configs(),
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES)
@pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES)
@pytest_parametrize_wrapper("partition", ["rowwise", "colwise"])
def test_te_distributed_dense_grad(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
dtype,
input_shape,
weight_shape,
partition,
):
"""Test TE GEMM gradients against JAX dot gradients"""
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
# Generate inputs
x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype)
# Get sharding patterns
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm(
mesh, mesh_resource, partition_layout=partition
)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
contracting_dims = ((2,), (0,))
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
# Test gradients w.r.t. all inputs
te_grad_func = jax.jit(
jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)),
static_argnames=("contracting_dims",),
)
jax_grad_func = jax.jit(
jax.value_and_grad(self._jax_sum_dense, argnums=(0, 1, 2)),
static_argnames=("contracting_dims",),
)
te_val, te_grads = te_grad_func(
x_sharded, weight_sharded, bias_sharded, contracting_dims
)
jax_val, jax_grads = jax_grad_func(
x_sharded, weight_sharded, bias_sharded, contracting_dims
)
# Compare forward pass
assert_allclose(te_val, jax_val, dtype=dtype)
# Compare gradients
for i, (te_grad, jax_grad) in enumerate(zip(te_grads, jax_grads)):
te_grad_spec = tuple(i for i in te_grad.sharding.spec if i is not None)
jax_grad_spec = tuple(i for i in jax_grad.sharding.spec if i is not None)
assert te_grad_spec == jax_grad_spec, f"Gradient sharding mismatch at te_grads[{i}]"
gathered_te_grad = jax.lax.with_sharding_constraint(
te_grad, NamedSharding(mesh, PartitionSpec(None))
)
gathered_jax_grad = jax.lax.with_sharding_constraint(
jax_grad, NamedSharding(mesh, PartitionSpec(None))
)
assert_allclose(
gathered_te_grad,
gathered_jax_grad,
dtype=dtype,
err_msg=f"Gradient mismatch for argument {i}",
)
if __name__ == "__main__":
unittest.main()
......@@ -66,18 +66,19 @@ class TestDistributedLayernorm:
self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
# TODO(Phuong) is_dp_enabled = dp mesh axis size > 1
is_dp_enabled = mesh_resource.dp_resource is not None
is_tpsp_enabled = mesh_resource.tpsp_resource is not None
assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp
weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1
allreduce_total_bytes = (
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
# loss, 1 FP32
allreduce_total_bytes = 4 if is_dp_enabled else 0
# dgamma and dbeta
weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes += weight_count * shape[-1] * jax_dtype.itemsize
return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes
allreduce=allreduce_total_bytes * int(is_dp_enabled or is_tpsp_enabled),
allgather=0,
other=0,
)
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
......
......@@ -48,7 +48,7 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in]
INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
......@@ -59,19 +59,47 @@ LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64
INTERMEDIATE = 256
# Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs():
configs = []
if is_devices_enough(4):
configs.append(
pytest.param(
[
4,
(2, 2),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp2_tpsp2",
)
)
if is_devices_enough(2):
configs.append(
[2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
pytest.param(
[
2,
(1, 2),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp1_tpsp2",
)
)
if is_devices_enough(4):
configs.append(
[4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")]
pytest.param(
[
2,
(2, 1),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp2_tpsp1",
),
)
return configs
......@@ -229,10 +257,7 @@ class TestDistributedLayernormMLP:
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
if fwd_test_type == jnp.float16 and use_bias:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5)
else:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)):
if multi_grads[i] is not None:
......@@ -381,6 +406,7 @@ class TestDistributedLayernormMLP:
assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
# TODO(Phuong): check if these tols updates are still needed
atol = None
rtol = None
l40_tolerance_update = (
......@@ -404,9 +430,10 @@ class TestDistributedLayernormMLP:
# within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update = (
with_jax_gemm
and isinstance(fp8_recipe, recipe.Float8CurrentScaling)
and dtype == jnp.bfloat16
and activation_type == ("gelu", "linear")
and fp8_recipe is not None
and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling())
and dtype in (jnp.bfloat16, jnp.float16)
and activation_type == ("gelu", "linear"),
)
if jax_triton_gemm_precision_tolerance_update:
atol = 0.08
......
......@@ -451,23 +451,19 @@ class GemmPrimitive(BasePrimitive):
output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype)
# Validate bias
bias_shape = (0,)
bias_dtype = out_dtype
if fuse_bias:
expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape)
if not grad:
assert bias.size == expected_bias_size, (
"cuBLAS GEMM bias tensor has incorrect shape, "
f"expected ({expected_bias_size}, ) but found {bias.shape}."
)
assert bias.dtype == out_dtype, (
"cuBLAS GEMM bias tensor has incorrect data type, "
f"expected {bias_dtype} but found {bias.dtype}."
)
bias_shape = bias.shape
else:
bias_shape = rhs_non_contracting_shape
bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype)
assert bias.shape == tuple(rhs_non_contracting_shape), (
"cuBLAS GEMM bias tensor has incorrect shape, "
f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}."
)
assert bias.dtype == out_dtype, (
"cuBLAS GEMM bias tensor has incorrect data type, "
f"expected {out_dtype} but found {bias.dtype}."
)
# WAR: allocate dbias regardless of fuse_bias so that the sharding propagation works as we
# change the fuse_bias value in the sharded_impl
dbias_shape = bias.shape if grad else (0,)
bias_grad = jax.core.ShapedArray(shape=dbias_shape, dtype=bias.dtype)
# Validate pre-GeLU
pre_gelu_shape = (0,)
......@@ -548,7 +544,7 @@ class GemmPrimitive(BasePrimitive):
}
operand_output_aliases = {}
if fuse_bias and not grad:
if grad:
operand_output_aliases.update({4: 1}) # bias <-> bias_grad
if fuse_gelu and grad:
operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out
......@@ -927,7 +923,6 @@ class GemmPrimitive(BasePrimitive):
del (
out_dtype,
scaling_mode,
grad,
use_split_accumulator,
result_infos,
is_outer,
......@@ -941,8 +936,8 @@ class GemmPrimitive(BasePrimitive):
)
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
# Discard bias gradient spec if there is no bias fusion
if not fuse_bias:
# Discard dbias gradient spec if there is no bias and grad fusion
if not (fuse_bias and grad):
dbias_specs = (None,)
dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs))
......@@ -1008,8 +1003,8 @@ class GemmPrimitive(BasePrimitive):
# Assemble output shardings
out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))]
# Discard bias gradient spec if there is no bias fusion
if not fuse_bias:
# Discard bias gradient spec if there is no bias and grad fusion
if not (fuse_bias and grad):
dbias_specs = (None,)
out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs)))
......@@ -1019,6 +1014,8 @@ class GemmPrimitive(BasePrimitive):
out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)))
def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input):
# We should not fuse bias in the output reduction case
sharded_fuse_bias = fuse_bias and reduce_spec is None
outputs = GemmPrimitive.impl(
lhs,
lhs_scale_inv,
......@@ -1029,7 +1026,7 @@ class GemmPrimitive(BasePrimitive):
out_dtype=out_dtype,
contracting_dims=contracting_dims,
scaling_mode=scaling_mode,
fuse_bias=fuse_bias,
fuse_bias=sharded_fuse_bias,
fuse_gelu=fuse_gelu,
grad=grad,
use_split_accumulator=use_split_accumulator,
......@@ -1039,13 +1036,17 @@ class GemmPrimitive(BasePrimitive):
collective_op=collective_op,
)
if reduce_spec is not None and not collective_op.is_reduce_scatter:
if is_all_reduce_in_float32(): # For unittest only
outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype(
out_dtype
)
else:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
if reduce_spec is not None:
if not collective_op.is_reduce_scatter:
if is_all_reduce_in_float32(): # For unittest only
outputs[0] = jax.lax.psum(
outputs[0].astype(jnp.float32), reduce_spec
).astype(out_dtype)
else:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
if fuse_bias: # TODO(Phuong): rename fuse_bias to has_bias
outputs[0] += bias
return outputs
......@@ -1068,7 +1069,7 @@ class GemmPrimitive(BasePrimitive):
operand_types,
result_types,
):
del out_dtype, grad, use_split_accumulator
del out_dtype, use_split_accumulator
del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer
if not collective_op.is_none:
......@@ -1079,12 +1080,6 @@ class GemmPrimitive(BasePrimitive):
prefix = "Gemm_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def _generate_operand_rules(name, ndim, cdims):
specs = []
ldims = tuple(i for i in range(ndim) if i not in cdims)
......@@ -1118,7 +1113,8 @@ class GemmPrimitive(BasePrimitive):
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
out_spec = (*lhs_non_cspec, *rhs_non_cspec)
bias_spec = rhs_non_cspec if fuse_bias else ("…4",)
gelu_spec = out_spec if fuse_gelu else ("…5",)
dbias_spec = bias_spec if grad else ("…5")
gelu_spec = out_spec if fuse_gelu else ("…6",)
return SdyShardingRule(
operand_mappings=(
......@@ -1131,7 +1127,7 @@ class GemmPrimitive(BasePrimitive):
),
result_mappings=(
out_spec,
bias_spec,
dbias_spec,
gelu_spec,
),
)
......@@ -1161,6 +1157,13 @@ def _te_gemm(
collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]:
if grad or fuse_gelu:
warnings.warn(
"GEMM + fused grad or fused gelu is not well tested and will be deprecated in the"
" future",
DeprecationWarning,
)
# Prepare non-quantized GEMM operands
lhs_data = lhs
rhs_data = rhs
......@@ -1228,7 +1231,7 @@ def _te_gemm(
grad=grad,
use_split_accumulator=use_split_accumulator,
transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=-1,
sequence_dim=-1, # Dummy value and will be set in the primitive
is_outer=True,
collective_op=collective_op,
)
......@@ -1618,6 +1621,7 @@ def gemm(
rhs_quantizer = quantizer_set.kernel
# Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled
# TODO(Phuong): fuse_bias -> has_bias and has_bias = bias is not None
fuse_bias = kwargs.get("fuse_bias", False)
fuse_gelu = kwargs.get("fuse_gelu", False)
if not GemmPrimitive.enabled():
......
......@@ -28,7 +28,7 @@ from .misc import (
get_cudnn_version,
)
from .quantization import _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import (
Quantizer,
......@@ -801,9 +801,9 @@ class NormBwdPrimitive(BasePrimitive):
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
)
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh)
global_dgamma = all_reduce_sum_along_dp_fsdp_tpsp(local_dgamma, mesh)
if norm_type == NVTE_Norm_Type.LayerNorm:
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh)
global_dbeta = all_reduce_sum_along_dp_fsdp_tpsp(local_dbeta, mesh)
else:
global_dbeta = local_dbeta
return local_dx, global_dgamma, global_dbeta
......
......@@ -158,18 +158,18 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
// Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr;
std::vector<size_t> bias_shape = {0};
size_t bias_size = 0;
DType bias_dtype = out_dtype;
if (fuse_bias) {
if (!grad) {
if (grad) {
NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad");
}
bias_ptr = bias_grad->untyped_data();
bias_shape.at(0) = bias_grad->dimensions().front();
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type());
bias_ptr = bias.untyped_data();
bias_size = product(bias.dimensions());
bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());
}
auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype);
auto bias_ = TensorWrapper(bias_ptr, std::vector<size_t>{bias_size}, bias_dtype);
// Pre-GeLU output from forward pass or input to backward pass
void *pre_gelu_ptr = nullptr;
......@@ -202,6 +202,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ",
to_string_like(out_shape), " but got ", output->element_count(), " elements ",
to_string_like(output->dimensions()));
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
......@@ -220,6 +222,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
buffer_shape[1] = out_shape[1];
out_shape[0] = out_shape[0] / comm_handler.tp_size;
}
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]);
auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor(
buffer_shape, buffer_dtype, collective_op);
if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
......
......@@ -365,6 +365,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_sum_along_dp_fsdp_tpsp(x: jnp.array, mesh: jax.sharding.Mesh):
"""Perform all-reduce sum operation along data parallelism and sequence parallelism axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
"""
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().tpsp_resource, mesh)
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
"""Perform all-reduce max operation along all axes except pipeline parallelism.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment