"docs/vscode:/vscode.git/clone" did not exist on "d06980dfa7e8dcc1738656beb46d3735c86faa21"
Unverified Commit 0db0f4d2 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fix for GEMM + fuse bias + AllReduce (#2230)



* not fuse bias for output all reduction case + unit tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* norm to reduce dgamma along tpsp as well
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up tests
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix test_distributed_layernorm byte counts
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* increase tols for jax_gemm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 7e45be73
...@@ -17,14 +17,6 @@ from utils import assert_allclose, is_devices_enough ...@@ -17,14 +17,6 @@ from utils import assert_allclose, is_devices_enough
def generate_configs(): def generate_configs():
configs = [] configs = []
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2")
)
if is_devices_enough(4): if is_devices_enough(4):
configs.append( configs.append(
pytest.param( pytest.param(
...@@ -32,10 +24,17 @@ def generate_configs(): ...@@ -32,10 +24,17 @@ def generate_configs():
(2, 2), (2, 2),
("dp", "tpsp"), ("dp", "tpsp"),
MeshResource(dp_resource="dp", tpsp_resource="tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp"),
id=f"n4_dp2_tp2", id="n4_dp2_tp2",
) )
) )
if is_devices_enough(2):
configs.append(
pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1")
)
configs.append(
pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"),
)
return configs return configs
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import unittest
import jax
import jax.numpy as jnp
import numpy as np
from jax import random
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from functools import partial
from distributed_test_base import generate_configs
from utils import assert_allclose, pytest_parametrize_wrapper
import transformer_engine.jax.cpp_extensions as tex
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.dense import dense
DTYPES = [jnp.bfloat16]
GEMM_INPUT_SHAPES = [[256, 128, 256]] # [batch, seq_len, hidden_in]
WEIGHT_SHAPES = [[256, 256]] # [hidden_in, hidden_out]
def _generate_inputs(input_shape, weight_shape, dtype):
"""Generate test inputs for GEMM operations"""
_, _, hidden_in = input_shape
hidden_in_w, hidden_out = weight_shape
assert hidden_in == hidden_in_w, f"Dimension mismatch: {hidden_in} != {hidden_in_w}"
bias_shape = (hidden_out,)
# Generate random inputs
x = random.normal(random.PRNGKey(1124), input_shape, dtype=dtype)
weight = random.normal(random.PRNGKey(2248), weight_shape, dtype=dtype) / jnp.sqrt(hidden_in_w)
bias = random.normal(random.PRNGKey(3372), bias_shape, dtype=dtype) / jnp.sqrt(hidden_out)
return x, weight, bias
def _get_sharding_for_gemm(mesh, mesh_resource, partition_layout="rowwise"):
"""Get sharding patterns for GEMM inputs and outputs"""
dp_axis = mesh_resource.dp_resource
tp_axis = mesh_resource.tpsp_resource
if partition_layout == "colwise":
x_spec = PartitionSpec(dp_axis, None, None)
weight_spec = PartitionSpec(None, tp_axis)
bias_spec = PartitionSpec(tp_axis)
output_spec = PartitionSpec(dp_axis, None, tp_axis)
elif partition_layout == "rowwise":
x_spec = PartitionSpec(dp_axis, None, tp_axis)
weight_spec = PartitionSpec(tp_axis, None)
bias_spec = PartitionSpec(None)
output_spec = PartitionSpec(dp_axis, None, None)
else:
raise ValueError(f"Invalid partition: {partition_layout}")
x_sharding = NamedSharding(mesh, x_spec)
weight_sharding = NamedSharding(mesh, weight_spec)
bias_sharding = NamedSharding(mesh, bias_spec)
output_sharding = NamedSharding(mesh, output_spec)
return x_sharding, weight_sharding, bias_sharding, output_sharding
@partial(jax.jit, static_argnames=("contracting_dims", "output_sharding"))
def _jitted_gemm(x, weight, bias, contracting_dims, output_sharding):
output = tex.gemm(
x,
weight,
bias=bias,
contracting_dims=contracting_dims,
fuse_bias=True,
)
if output_sharding is not None:
output = jax.lax.with_sharding_constraint(output, output_sharding)
return output
# TODO(Phuong):
# 1. Add supported recipes after FP4 is added
# 2. Add communication type/byte checks
class TestDistributedDense:
"""Test distributed GEMM without collective operations vs JAX dot"""
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_configs(),
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES)
@pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES)
@pytest_parametrize_wrapper("partition", ["rowwise", "colwise"])
def test_distributed_gemm(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
dtype,
input_shape,
weight_shape,
partition,
):
"""Test TE GEMM against JAX dot with bf16 dtype"""
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
# Generate inputs
x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype)
# Get sharding patterns
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm(
mesh, mesh_resource, partition_layout=partition
)
# Shard inputs
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
# TE GEMM result
te_result = _jitted_gemm(
x_sharded,
weight_sharded,
bias_sharded,
contracting_dims=contracting_dims,
output_sharding=output_sharding,
)
# JAX dot reference result
jax_result = (
jax.lax.dot_general(
x_sharded, weight_sharded, dimension_numbers=(contracting_dims, ((), ()))
)
+ bias_sharded
)
assert te_result.sharding == jax_result.sharding
# Ensure computation is complete
jax.block_until_ready(te_result)
jax.block_until_ready(jax_result)
# Gather results for comparison
gathered_te = jax.lax.with_sharding_constraint(
te_result, NamedSharding(mesh, PartitionSpec(None))
)
gathered_jax = jax.lax.with_sharding_constraint(
jax_result, NamedSharding(mesh, PartitionSpec(None))
)
# Compare results
assert_allclose(gathered_te, gathered_jax, dtype=dtype)
def _te_sum_dense(self, x, weight, bias, contracting_dims):
"""TE GEMM function for gradient testing"""
return jnp.sum(dense(x, weight, bias=bias, contracting_dims=contracting_dims))
def _jax_sum_dense(self, x, weight, bias, contracting_dims):
"""JAX dot function for gradient testing"""
result = (
jax.lax.dot_general(x, weight, dimension_numbers=(contracting_dims, ((), ()))) + bias
)
return jnp.sum(result)
@pytest_parametrize_wrapper(
"device_count,mesh_shape,mesh_axes,mesh_resource",
generate_configs(),
)
@pytest_parametrize_wrapper("dtype", DTYPES)
@pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES)
@pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES)
@pytest_parametrize_wrapper("partition", ["rowwise", "colwise"])
def test_te_distributed_dense_grad(
self,
device_count,
mesh_shape,
mesh_axes,
mesh_resource,
dtype,
input_shape,
weight_shape,
partition,
):
"""Test TE GEMM gradients against JAX dot gradients"""
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
# Generate inputs
x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype)
# Get sharding patterns
x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm(
mesh, mesh_resource, partition_layout=partition
)
x_sharded = jax.device_put(x, x_sharding)
weight_sharded = jax.device_put(weight, weight_sharding)
bias_sharded = jax.device_put(bias, bias_sharding)
contracting_dims = ((2,), (0,))
with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource):
# Test gradients w.r.t. all inputs
te_grad_func = jax.jit(
jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)),
static_argnames=("contracting_dims",),
)
jax_grad_func = jax.jit(
jax.value_and_grad(self._jax_sum_dense, argnums=(0, 1, 2)),
static_argnames=("contracting_dims",),
)
te_val, te_grads = te_grad_func(
x_sharded, weight_sharded, bias_sharded, contracting_dims
)
jax_val, jax_grads = jax_grad_func(
x_sharded, weight_sharded, bias_sharded, contracting_dims
)
# Compare forward pass
assert_allclose(te_val, jax_val, dtype=dtype)
# Compare gradients
for i, (te_grad, jax_grad) in enumerate(zip(te_grads, jax_grads)):
te_grad_spec = tuple(i for i in te_grad.sharding.spec if i is not None)
jax_grad_spec = tuple(i for i in jax_grad.sharding.spec if i is not None)
assert te_grad_spec == jax_grad_spec, f"Gradient sharding mismatch at te_grads[{i}]"
gathered_te_grad = jax.lax.with_sharding_constraint(
te_grad, NamedSharding(mesh, PartitionSpec(None))
)
gathered_jax_grad = jax.lax.with_sharding_constraint(
jax_grad, NamedSharding(mesh, PartitionSpec(None))
)
assert_allclose(
gathered_te_grad,
gathered_jax_grad,
dtype=dtype,
err_msg=f"Gradient mismatch for argument {i}",
)
if __name__ == "__main__":
unittest.main()
...@@ -66,18 +66,19 @@ class TestDistributedLayernorm: ...@@ -66,18 +66,19 @@ class TestDistributedLayernorm:
self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe
): ):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype) jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
# TODO(Phuong) is_dp_enabled = dp mesh axis size > 1
is_dp_enabled = mesh_resource.dp_resource is not None is_dp_enabled = mesh_resource.dp_resource is not None
is_tpsp_enabled = mesh_resource.tpsp_resource is not None
assert ln_type in ["layernorm", "rmsnorm"] assert ln_type in ["layernorm", "rmsnorm"]
all_reduce_loss_bytes = 4 # 1 * FP32 # loss, 1 FP32
# for loss, dgamma and dbeta allreduce_total_bytes = 4 if is_dp_enabled else 0
# TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp # dgamma and dbeta
weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1 weight_count = 2 if ln_type == "layernorm" else 1
allreduce_total_bytes = ( allreduce_total_bytes += weight_count * shape[-1] * jax_dtype.itemsize
all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
)
other_bytes = 0
return generate_collectives_count( return generate_collectives_count(
allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes allreduce=allreduce_total_bytes * int(is_dp_enabled or is_tpsp_enabled),
allgather=0,
other=0,
) )
@pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
......
...@@ -48,7 +48,7 @@ if is_mxfp8_supported: ...@@ -48,7 +48,7 @@ if is_mxfp8_supported:
SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling"))
DTYPES = [jnp.bfloat16, jnp.float16] DTYPES = [jnp.bfloat16, jnp.float16]
INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
...@@ -59,19 +59,47 @@ LN_SCALE_AXES = (W_NO_SHARD_AXES,) ...@@ -59,19 +59,47 @@ LN_SCALE_AXES = (W_NO_SHARD_AXES,)
LN_BIAS_AXES = (W_NO_SHARD_AXES,) LN_BIAS_AXES = (W_NO_SHARD_AXES,)
BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES)
BIAS_2_AXES = (W_NO_SHARD_AXES,) BIAS_2_AXES = (W_NO_SHARD_AXES,)
INTERMEDIATE = 64 INTERMEDIATE = 256
# Only test with FSDP and TPSP as DP is not used # Only test with FSDP and TPSP as DP is not used
def generate_fsdp_and_tpsp_configs(): def generate_fsdp_and_tpsp_configs():
configs = [] configs = []
if is_devices_enough(4):
configs.append(
pytest.param(
[
4,
(2, 2),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp2_tpsp2",
)
)
if is_devices_enough(2): if is_devices_enough(2):
configs.append( configs.append(
[2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] pytest.param(
[
2,
(1, 2),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp1_tpsp2",
)
) )
if is_devices_enough(4):
configs.append( configs.append(
[4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] pytest.param(
[
2,
(2, 1),
("fsdp", "tpsp"),
MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"),
],
id="fsdp2_tpsp1",
),
) )
return configs return configs
...@@ -229,9 +257,6 @@ class TestDistributedLayernormMLP: ...@@ -229,9 +257,6 @@ class TestDistributedLayernormMLP:
fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn
bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2
if fwd_test_type == jnp.float16 and use_bias:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5)
else:
assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type)
for i in range(len(inputs)): for i in range(len(inputs)):
...@@ -381,6 +406,7 @@ class TestDistributedLayernormMLP: ...@@ -381,6 +406,7 @@ class TestDistributedLayernormMLP:
assert_tree_like_allclose(params_sharded["params"], params_single["params"]) assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
# TODO(Phuong): check if these tols updates are still needed
atol = None atol = None
rtol = None rtol = None
l40_tolerance_update = ( l40_tolerance_update = (
...@@ -404,9 +430,10 @@ class TestDistributedLayernormMLP: ...@@ -404,9 +430,10 @@ class TestDistributedLayernormMLP:
# within tolerance to the float32 ground truth. # within tolerance to the float32 ground truth.
jax_triton_gemm_precision_tolerance_update = ( jax_triton_gemm_precision_tolerance_update = (
with_jax_gemm with_jax_gemm
and isinstance(fp8_recipe, recipe.Float8CurrentScaling) and fp8_recipe is not None
and dtype == jnp.bfloat16 and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling())
and activation_type == ("gelu", "linear") and dtype in (jnp.bfloat16, jnp.float16)
and activation_type == ("gelu", "linear"),
) )
if jax_triton_gemm_precision_tolerance_update: if jax_triton_gemm_precision_tolerance_update:
atol = 0.08 atol = 0.08
......
...@@ -451,23 +451,19 @@ class GemmPrimitive(BasePrimitive): ...@@ -451,23 +451,19 @@ class GemmPrimitive(BasePrimitive):
output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype)
# Validate bias # Validate bias
bias_shape = (0,)
bias_dtype = out_dtype
if fuse_bias: if fuse_bias:
expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) assert bias.shape == tuple(rhs_non_contracting_shape), (
if not grad:
assert bias.size == expected_bias_size, (
"cuBLAS GEMM bias tensor has incorrect shape, " "cuBLAS GEMM bias tensor has incorrect shape, "
f"expected ({expected_bias_size}, ) but found {bias.shape}." f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}."
) )
assert bias.dtype == out_dtype, ( assert bias.dtype == out_dtype, (
"cuBLAS GEMM bias tensor has incorrect data type, " "cuBLAS GEMM bias tensor has incorrect data type, "
f"expected {bias_dtype} but found {bias.dtype}." f"expected {out_dtype} but found {bias.dtype}."
) )
bias_shape = bias.shape # WAR: allocate dbias regardless of fuse_bias so that the sharding propagation works as we
else: # change the fuse_bias value in the sharded_impl
bias_shape = rhs_non_contracting_shape dbias_shape = bias.shape if grad else (0,)
bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) bias_grad = jax.core.ShapedArray(shape=dbias_shape, dtype=bias.dtype)
# Validate pre-GeLU # Validate pre-GeLU
pre_gelu_shape = (0,) pre_gelu_shape = (0,)
...@@ -548,7 +544,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -548,7 +544,7 @@ class GemmPrimitive(BasePrimitive):
} }
operand_output_aliases = {} operand_output_aliases = {}
if fuse_bias and not grad: if grad:
operand_output_aliases.update({4: 1}) # bias <-> bias_grad operand_output_aliases.update({4: 1}) # bias <-> bias_grad
if fuse_gelu and grad: if fuse_gelu and grad:
operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out
...@@ -927,7 +923,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -927,7 +923,6 @@ class GemmPrimitive(BasePrimitive):
del ( del (
out_dtype, out_dtype,
scaling_mode, scaling_mode,
grad,
use_split_accumulator, use_split_accumulator,
result_infos, result_infos,
is_outer, is_outer,
...@@ -941,8 +936,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -941,8 +936,8 @@ class GemmPrimitive(BasePrimitive):
) )
out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs))
# Discard bias gradient spec if there is no bias fusion # Discard dbias gradient spec if there is no bias and grad fusion
if not fuse_bias: if not (fuse_bias and grad):
dbias_specs = (None,) dbias_specs = (None,)
dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs))
...@@ -1008,8 +1003,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -1008,8 +1003,8 @@ class GemmPrimitive(BasePrimitive):
# Assemble output shardings # Assemble output shardings
out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))]
# Discard bias gradient spec if there is no bias fusion # Discard bias gradient spec if there is no bias and grad fusion
if not fuse_bias: if not (fuse_bias and grad):
dbias_specs = (None,) dbias_specs = (None,)
out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs)))
...@@ -1019,6 +1014,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -1019,6 +1014,8 @@ class GemmPrimitive(BasePrimitive):
out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs)))
def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input):
# We should not fuse bias in the output reduction case
sharded_fuse_bias = fuse_bias and reduce_spec is None
outputs = GemmPrimitive.impl( outputs = GemmPrimitive.impl(
lhs, lhs,
lhs_scale_inv, lhs_scale_inv,
...@@ -1029,7 +1026,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -1029,7 +1026,7 @@ class GemmPrimitive(BasePrimitive):
out_dtype=out_dtype, out_dtype=out_dtype,
contracting_dims=contracting_dims, contracting_dims=contracting_dims,
scaling_mode=scaling_mode, scaling_mode=scaling_mode,
fuse_bias=fuse_bias, fuse_bias=sharded_fuse_bias,
fuse_gelu=fuse_gelu, fuse_gelu=fuse_gelu,
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
...@@ -1039,14 +1036,18 @@ class GemmPrimitive(BasePrimitive): ...@@ -1039,14 +1036,18 @@ class GemmPrimitive(BasePrimitive):
collective_op=collective_op, collective_op=collective_op,
) )
if reduce_spec is not None and not collective_op.is_reduce_scatter: if reduce_spec is not None:
if not collective_op.is_reduce_scatter:
if is_all_reduce_in_float32(): # For unittest only if is_all_reduce_in_float32(): # For unittest only
outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype( outputs[0] = jax.lax.psum(
out_dtype outputs[0].astype(jnp.float32), reduce_spec
) ).astype(out_dtype)
else: else:
outputs[0] = jax.lax.psum(outputs[0], reduce_spec) outputs[0] = jax.lax.psum(outputs[0], reduce_spec)
if fuse_bias: # TODO(Phuong): rename fuse_bias to has_bias
outputs[0] += bias
return outputs return outputs
return mesh, _sharded_impl, out_shardings, arg_shardings return mesh, _sharded_impl, out_shardings, arg_shardings
...@@ -1068,7 +1069,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -1068,7 +1069,7 @@ class GemmPrimitive(BasePrimitive):
operand_types, operand_types,
result_types, result_types,
): ):
del out_dtype, grad, use_split_accumulator del out_dtype, use_split_accumulator
del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer
if not collective_op.is_none: if not collective_op.is_none:
...@@ -1079,12 +1080,6 @@ class GemmPrimitive(BasePrimitive): ...@@ -1079,12 +1080,6 @@ class GemmPrimitive(BasePrimitive):
prefix = "Gemm_" prefix = "Gemm_"
warnings.warn(
"Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now,"
" please turn off Shardy by exporting the environment variable"
" 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems."
)
def _generate_operand_rules(name, ndim, cdims): def _generate_operand_rules(name, ndim, cdims):
specs = [] specs = []
ldims = tuple(i for i in range(ndim) if i not in cdims) ldims = tuple(i for i in range(ndim) if i not in cdims)
...@@ -1118,7 +1113,8 @@ class GemmPrimitive(BasePrimitive): ...@@ -1118,7 +1113,8 @@ class GemmPrimitive(BasePrimitive):
rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims)
out_spec = (*lhs_non_cspec, *rhs_non_cspec) out_spec = (*lhs_non_cspec, *rhs_non_cspec)
bias_spec = rhs_non_cspec if fuse_bias else ("…4",) bias_spec = rhs_non_cspec if fuse_bias else ("…4",)
gelu_spec = out_spec if fuse_gelu else ("…5",) dbias_spec = bias_spec if grad else ("…5")
gelu_spec = out_spec if fuse_gelu else ("…6",)
return SdyShardingRule( return SdyShardingRule(
operand_mappings=( operand_mappings=(
...@@ -1131,7 +1127,7 @@ class GemmPrimitive(BasePrimitive): ...@@ -1131,7 +1127,7 @@ class GemmPrimitive(BasePrimitive):
), ),
result_mappings=( result_mappings=(
out_spec, out_spec,
bias_spec, dbias_spec,
gelu_spec, gelu_spec,
), ),
) )
...@@ -1161,6 +1157,13 @@ def _te_gemm( ...@@ -1161,6 +1157,13 @@ def _te_gemm(
collective_op: CollectiveOp = CollectiveOp.NONE, collective_op: CollectiveOp = CollectiveOp.NONE,
) -> Tuple[jax.Array, ...]: ) -> Tuple[jax.Array, ...]:
if grad or fuse_gelu:
warnings.warn(
"GEMM + fused grad or fused gelu is not well tested and will be deprecated in the"
" future",
DeprecationWarning,
)
# Prepare non-quantized GEMM operands # Prepare non-quantized GEMM operands
lhs_data = lhs lhs_data = lhs
rhs_data = rhs rhs_data = rhs
...@@ -1228,7 +1231,7 @@ def _te_gemm( ...@@ -1228,7 +1231,7 @@ def _te_gemm(
grad=grad, grad=grad,
use_split_accumulator=use_split_accumulator, use_split_accumulator=use_split_accumulator,
transpose_batch_sequence=transpose_batch_sequence, transpose_batch_sequence=transpose_batch_sequence,
sequence_dim=-1, sequence_dim=-1, # Dummy value and will be set in the primitive
is_outer=True, is_outer=True,
collective_op=collective_op, collective_op=collective_op,
) )
...@@ -1618,6 +1621,7 @@ def gemm( ...@@ -1618,6 +1621,7 @@ def gemm(
rhs_quantizer = quantizer_set.kernel rhs_quantizer = quantizer_set.kernel
# Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled
# TODO(Phuong): fuse_bias -> has_bias and has_bias = bias is not None
fuse_bias = kwargs.get("fuse_bias", False) fuse_bias = kwargs.get("fuse_bias", False)
fuse_gelu = kwargs.get("fuse_gelu", False) fuse_gelu = kwargs.get("fuse_gelu", False)
if not GemmPrimitive.enabled(): if not GemmPrimitive.enabled():
......
...@@ -28,7 +28,7 @@ from .misc import ( ...@@ -28,7 +28,7 @@ from .misc import (
get_cudnn_version, get_cudnn_version,
) )
from .quantization import _quantize_dbias_impl, AmaxScope from .quantization import _quantize_dbias_impl, AmaxScope
from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp
from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor
from ..quantize import ( from ..quantize import (
Quantizer, Quantizer,
...@@ -801,9 +801,9 @@ class NormBwdPrimitive(BasePrimitive): ...@@ -801,9 +801,9 @@ class NormBwdPrimitive(BasePrimitive):
norm_type=norm_type, norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma, zero_centered_gamma=zero_centered_gamma,
) )
global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) global_dgamma = all_reduce_sum_along_dp_fsdp_tpsp(local_dgamma, mesh)
if norm_type == NVTE_Norm_Type.LayerNorm: if norm_type == NVTE_Norm_Type.LayerNorm:
global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh) global_dbeta = all_reduce_sum_along_dp_fsdp_tpsp(local_dbeta, mesh)
else: else:
global_dbeta = local_dbeta global_dbeta = local_dbeta
return local_dx, global_dgamma, global_dbeta return local_dx, global_dgamma, global_dbeta
......
...@@ -158,18 +158,18 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -158,18 +158,18 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
// Bias input to forward pass or bias gradient output from backward pass // Bias input to forward pass or bias gradient output from backward pass
void *bias_ptr = nullptr; void *bias_ptr = nullptr;
std::vector<size_t> bias_shape = {0}; size_t bias_size = 0;
DType bias_dtype = out_dtype; DType bias_dtype = out_dtype;
if (fuse_bias) { if (fuse_bias) {
if (!grad) { if (grad) {
NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(),
"Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad");
} }
bias_ptr = bias_grad->untyped_data(); bias_ptr = bias.untyped_data();
bias_shape.at(0) = bias_grad->dimensions().front(); bias_size = product(bias.dimensions());
bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type());
} }
auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); auto bias_ = TensorWrapper(bias_ptr, std::vector<size_t>{bias_size}, bias_dtype);
// Pre-GeLU output from forward pass or input to backward pass // Pre-GeLU output from forward pass or input to backward pass
void *pre_gelu_ptr = nullptr; void *pre_gelu_ptr = nullptr;
...@@ -202,6 +202,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -202,6 +202,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
"cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ",
to_string_like(out_shape), " but got ", output->element_count(), " elements ", to_string_like(out_shape), " but got ", output->element_count(), " elements ",
to_string_like(output->dimensions())); to_string_like(output->dimensions()));
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]);
nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(),
rhs_transposed, lhs_transposed, grad, workspace_.data(), false, rhs_transposed, lhs_transposed, grad, workspace_.data(), false,
...@@ -220,6 +222,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i ...@@ -220,6 +222,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i
buffer_shape[1] = out_shape[1]; buffer_shape[1] = out_shape[1];
out_shape[0] = out_shape[0] / comm_handler.tp_size; out_shape[0] = out_shape[0] / comm_handler.tp_size;
} }
NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size,
", out_shape[1]=", out_shape[1]);
auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor(
buffer_shape, buffer_dtype, collective_op); buffer_shape, buffer_dtype, collective_op);
if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) {
......
...@@ -365,6 +365,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): ...@@ -365,6 +365,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh):
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_sum_along_dp_fsdp_tpsp(x: jnp.array, mesh: jax.sharding.Mesh):
"""Perform all-reduce sum operation along data parallelism and sequence parallelism axes.
Args:
x: Input tensor to reduce
mesh: JAX mesh for distributed computation
Returns:
Reduced tensor
"""
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().tpsp_resource, mesh)
x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh)
return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh)
def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh):
"""Perform all-reduce max operation along all axes except pipeline parallelism. """Perform all-reduce max operation along all axes except pipeline parallelism.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment