Unverified Commit 62a57dd4 authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)



* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent 8dba2963
...@@ -9,3 +9,4 @@ set -xe ...@@ -9,3 +9,4 @@ set -xe
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#!/bin/bash
SCRIPT_NAME="${SCRIPT_NAME:-test.py}"
XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true
--xla_gpu_enable_command_buffer=''"
export XLA_FLAGS="${XLA_BASE_FLAGS}"
NUM_RUNS=$(nvidia-smi --query-gpu=count --format=csv,noheader)
for ((i=1; i<NUM_RUNS; i++))
do
CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_PROC > /dev/null 2>&1 &
done
CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_PROC
wait
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from functools import partial
import jax
import jax.numpy as jnp
from transformer_engine.jax.dense import grouped_dense as te_grouped_dense
from transformer_engine.jax.quantize import (
QuantizerFactory,
ScalingMode,
)
from utils import assert_allclose
N_GROUP = 8
MESH_AXIS_NAME = "fsdp"
def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis):
assert kernel_fsdp_axis in [1, 2]
x_shape, w_shape = data_shapes
x_sharding = NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None, None, None))
w_sharding = (
NamedSharding(mesh, PartitionSpec(None, None, MESH_AXIS_NAME))
if kernel_fsdp_axis == 2
else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None))
)
w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None))
def init_data():
x_key = jax.random.PRNGKey(0)
w_key = jax.random.PRNGKey(1)
x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16)
w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16)
w_amax = jnp.max(jnp.abs(w), axis=range(1, w.ndim))
return x, w, w, w_amax
def test_func(outter_x, outter_w, outter_w_amax):
in_specs = (x_sharding.spec, w_sharding.spec, None)
out_specs = x_sharding.spec
@partial(
shard_map.shard_map,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
def sharded_group_gemm(x, w, w_amax):
group_size = x.shape[0]
x_reshaped = x.reshape(-1, x.shape[-1])
n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=True,
n_groups=group_size,
)
output = te_grouped_dense(
x_reshaped,
w,
n_groups,
kernel_amax=w_amax,
quantizer_set=quantizer_set,
kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis),
)
output = output.reshape(*x.shape[:-1], -1)
return output
def run(x, w, w_amax):
output = sharded_group_gemm(x, w, w_amax)
return output
output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_w_amax)
dx, dw, _ = vjp_fn(output)
return output, dx, dw
def ref_func(outter_x, outter_w):
in_specs = (x_sharding.spec, w_no_sharding.spec)
out_specs = x_sharding.spec
@partial(
shard_map.shard_map,
mesh=mesh,
in_specs=in_specs,
out_specs=out_specs,
check_rep=False,
)
def sharded_group_gemm(x, w):
group_size = x.shape[0]
x_reshaped = x.reshape(-1, x.shape[-1])
n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size)
quantizer_set = QuantizerFactory.create_set(
scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING,
fwd_dtype=jnp.float8_e4m3fn,
bwd_dtype=jnp.float8_e5m2,
is_2x2x=True,
n_groups=group_size,
)
output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set)
output = output.reshape(*x.shape[:-1], -1)
return output
def run(x, w):
output = sharded_group_gemm(x, w)
return output
output, vjp_fn = jax.vjp(run, outter_x, outter_w)
dx, dw = vjp_fn(output)
return output, dx, dw
init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding, None))
x, w, w_global, w_amax = init_func()
o_sharding = x_sharding
test_func_jitted = jax.jit(
test_func,
in_shardings=(x_sharding, w_sharding, None),
out_shardings=(o_sharding, x_sharding, w_sharding),
)
ref_func_jitted = jax.jit(
ref_func,
in_shardings=(x_sharding, w_no_sharding),
out_shardings=(o_sharding, x_sharding, w_no_sharding),
)
out, dx, dw = test_func_jitted(x, w, w_amax)
ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global)
assert_allclose(out, ref_out, dtype=jnp.float8_e4m3fn)
assert_allclose(dx, ref_dx, dtype=jnp.float8_e5m2)
assert_allclose(dw, ref_dw, dtype=jnp.float8_e5m2)
if __name__ == "__main__":
from jax.sharding import NamedSharding, PartitionSpec
from jax.experimental import shard_map
import sys
coord_addr = sys.argv[1]
proc_id = int(sys.argv[2])
num_procs = int(sys.argv[3])
jax.distributed.initialize(
coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id
)
mesh = jax.make_mesh((num_procs,), (MESH_AXIS_NAME,))
with mesh:
data_shapes = [((4, 16, 128, 7168), (7168, 2048))]
for data_shape in data_shapes:
for kernel_fsdp_axis in [1, 2]:
test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis)
...@@ -931,6 +931,7 @@ def grouped_quantize( ...@@ -931,6 +931,7 @@ def grouped_quantize(
x: jnp.ndarray, x: jnp.ndarray,
quantizer: GroupedQuantizer, quantizer: GroupedQuantizer,
group_sizes: jnp.ndarray = None, group_sizes: jnp.ndarray = None,
amax: jnp.ndarray = None,
flatten_axis: int = -1, flatten_axis: int = -1,
) -> GroupedScaledTensor1x: ) -> GroupedScaledTensor1x:
"""Quantize a tensor in grouped manner. """Quantize a tensor in grouped manner.
...@@ -943,6 +944,7 @@ def grouped_quantize( ...@@ -943,6 +944,7 @@ def grouped_quantize(
x: Input tensor to quantize x: Input tensor to quantize
quantizer: The quantizer to use for quantization quantizer: The quantizer to use for quantization
group_sizes: Array of ints containing the size of each group (default: None) group_sizes: Array of ints containing the size of each group (default: None)
amax: The amax of x; if None, it is auto-generated. (default: None)
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns: Returns:
...@@ -985,6 +987,9 @@ def grouped_quantize( ...@@ -985,6 +987,9 @@ def grouped_quantize(
scale = scale.at[i].set(quantizer_i.scale[0]) scale = scale.at[i].set(quantizer_i.scale[0])
if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING:
if amax is not None:
row_amax = amax
else:
row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim))
segment_ids = jnp.repeat( segment_ids = jnp.repeat(
jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis]
......
...@@ -285,18 +285,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type ...@@ -285,18 +285,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type
size_t out_dtype_bytes = te_dtype_bytes(out_dtype); size_t out_dtype_bytes = te_dtype_bytes(out_dtype);
if (is_tensor_scaling) { if (is_tensor_scaling) {
cudaStream_t stream_0 = nvte_get_compute_stream(0);
size_t dpitch = tensor_scaling_sinv_aligment; size_t dpitch = tensor_scaling_sinv_aligment;
size_t spitch = lhs_sinv_dtype_bytes; size_t spitch = lhs_sinv_dtype_bytes;
size_t width = lhs_sinv_dtype_bytes; size_t width = lhs_sinv_dtype_bytes;
size_t height = lhs_sinv_size; size_t height = lhs_sinv_size;
cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height,
cudaMemcpyDeviceToDevice, stream_0); cudaMemcpyDeviceToDevice, stream);
spitch = rhs_sinv_dtype_bytes; spitch = rhs_sinv_dtype_bytes;
width = rhs_sinv_dtype_bytes; width = rhs_sinv_dtype_bytes;
height = rhs_sinv_size; height = rhs_sinv_size;
cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height,
cudaMemcpyDeviceToDevice, stream_0); cudaMemcpyDeviceToDevice, stream);
lhs_sinv_ptr = lhs_scatter_aligned_ptr; lhs_sinv_ptr = lhs_scatter_aligned_ptr;
rhs_sinv_ptr = rhs_scatter_aligned_ptr; rhs_sinv_ptr = rhs_scatter_aligned_ptr;
} }
......
...@@ -16,13 +16,45 @@ import jax.numpy as jnp ...@@ -16,13 +16,45 @@ import jax.numpy as jnp
from . import cpp_extensions as tex from . import cpp_extensions as tex
from .quantize import ( from .quantize import (
ScaledTensorFactory,
ScalingMode,
QuantizeLayout,
QuantizerSet, QuantizerSet,
noop_quantizer_set, noop_quantizer_set,
with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes,
is_fp8_gemm_with_all_layouts_supported,
TensorUsage, TensorUsage,
) )
def _all_gather_kernel(kernel, mesh_axis, axis_idx):
assert mesh_axis is not None
assert 0 < axis_idx < len(kernel.shape)
# TODO(Ming Hunag): Add a condition branch for with/without shmap.
kernel_shape = kernel.shape
kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :])
global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx)
global_kernel = global_kernel.reshape(*kernel_whole_shape)
return global_kernel
def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx):
assert mesh_axis is not None
assert 0 < axis_idx < len(scattered_kernel_shape)
# TODO(Ming Hunag): Add a condition branch for with/without shmap.
kernel = kernel.reshape(
*scattered_kernel_shape[:axis_idx],
-1,
scattered_kernel_shape[axis_idx],
*scattered_kernel_shape[axis_idx + 1 :],
)
kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx)
kernel = kernel.reshape(scattered_kernel_shape)
return kernel
def dense( def dense(
x: jnp.ndarray, x: jnp.ndarray,
kernel: jnp.ndarray, kernel: jnp.ndarray,
...@@ -253,10 +285,12 @@ def grouped_dense( ...@@ -253,10 +285,12 @@ def grouped_dense(
group_sizes: jnp.ndarray, group_sizes: jnp.ndarray,
contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
bias: jnp.ndarray = None, bias: jnp.ndarray = None,
kernel_amax: jnp.ndarray = None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None, preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None, group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set, quantizer_set: QuantizerSet = noop_quantizer_set,
kernel_fsdp_info: Tuple[str, int] = (None, -1),
): ):
""" """
Perform grouped dense (linear) layer transformation with optional quantization. Perform grouped dense (linear) layer transformation with optional quantization.
...@@ -268,10 +302,15 @@ def grouped_dense( ...@@ -268,10 +302,15 @@ def grouped_dense(
contracting_dims: Tuple of sequences specifying which dimensions to contract contracting_dims: Tuple of sequences specifying which dimensions to contract
(currently only supports ((1,), (1,))) (currently only supports ((1,), (1,)))
bias: Bias tensor of shape (G, N) bias: Bias tensor of shape (G, N)
kernel_amax: The amax values of weight matrix of shape (G,)
precision: JAX precision for the GEMM operation precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented) group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output quantizer_set: Set of quantizers for FP8 quantization of the input and output
kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix
represented in the format (str, int). The first element is the
FSDP mesh axis, and the second element is the dimension along
which the weight is sharded.
Returns: Returns:
A jnp.ndarray containing the result of the grouped linear operation A jnp.ndarray containing the result of the grouped linear operation
...@@ -282,25 +321,29 @@ def grouped_dense( ...@@ -282,25 +321,29 @@ def grouped_dense(
group_sizes, group_sizes,
contracting_dims, contracting_dims,
bias, bias,
kernel_amax,
precision, precision,
preferred_element_type, preferred_element_type,
group_offset, group_offset,
quantizer_set, quantizer_set,
kernel_fsdp_info,
) )
return output return output
@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) @partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10))
def _grouped_dense( def _grouped_dense(
x, x,
kernel, kernel,
group_sizes, group_sizes,
contracting_dims, contracting_dims,
bias, bias,
kernel_amax,
precision, precision,
preferred_element_type, preferred_element_type,
group_offset, group_offset,
quantizer_set, quantizer_set,
kernel_fsdp_info,
): ):
output, _ = _grouped_dense_fwd_rule( output, _ = _grouped_dense_fwd_rule(
x, x,
...@@ -308,10 +351,12 @@ def _grouped_dense( ...@@ -308,10 +351,12 @@ def _grouped_dense(
group_sizes, group_sizes,
contracting_dims, contracting_dims,
bias, bias,
kernel_amax,
precision, precision,
preferred_element_type, preferred_element_type,
group_offset, group_offset,
quantizer_set, quantizer_set,
kernel_fsdp_info,
) )
return output return output
...@@ -322,21 +367,31 @@ def _grouped_dense_fwd_rule( ...@@ -322,21 +367,31 @@ def _grouped_dense_fwd_rule(
group_sizes, group_sizes,
contracting_dims, contracting_dims,
bias, bias,
kernel_amax,
precision, precision,
preferred_element_type, preferred_element_type,
group_offset, group_offset,
quantizer_set, quantizer_set,
kernel_fsdp_info,
): ):
use_bias = bias is not None use_bias = bias is not None
is_noop_quantizer_set = quantizer_set == noop_quantizer_set is_noop_quantizer_set = quantizer_set == noop_quantizer_set
kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None
if is_noop_quantizer_set: if is_noop_quantizer_set:
grouped_gemm_x = x grouped_gemm_x = x
grouped_gemm_kernel = kernel grouped_gemm_kernel = kernel
ctx_x = x ctx_x = x
ctx_kernel = kernel ctx_kernel = kernel
flatten_axis_k = None flatten_axis_k = None
if kernel_fsdp_enabled:
kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx)
else: else:
original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout
x_contracting_dims, k_contracting_dims = contracting_dims x_contracting_dims, k_contracting_dims = contracting_dims
flatten_axis_x = -len(x_contracting_dims) flatten_axis_x = -len(x_contracting_dims)
flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis
...@@ -352,10 +407,24 @@ def _grouped_dense_fwd_rule( ...@@ -352,10 +407,24 @@ def _grouped_dense_fwd_rule(
) )
casted_x = tex.grouped_quantize( casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x x,
quantizer_set.x,
group_sizes,
flatten_axis=flatten_axis_x,
) )
ctx_kernel_usage = TensorUsage.RHS_TRANS
if kernel_fsdp_enabled:
assert quantizer_set.kernel.scaling_mode in [
ScalingMode.CURRENT_TENSOR_SCALING,
ScalingMode.DELAYED_TENSOR_SCALING,
]
# Perform `cast` only
ctx_kernel_usage = TensorUsage.LHS
quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE
casted_kernel = tex.grouped_quantize( casted_kernel = tex.grouped_quantize(
kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k
) )
contracting_dims = (x_contracting_dims, k_contracting_dims) contracting_dims = (x_contracting_dims, k_contracting_dims)
...@@ -363,9 +432,51 @@ def _grouped_dense_fwd_rule( ...@@ -363,9 +432,51 @@ def _grouped_dense_fwd_rule(
# rowwise_casted_x.original_shape == (M, K) # rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K) # colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS)
grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS)
ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage)
if kernel_fsdp_enabled:
ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape)
global_ctx_kernel_data = _all_gather_kernel(
ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
)
kernel_shape = global_ctx_kernel_data.shape
ctx_kernel = ScaledTensorFactory.create_1x(
global_ctx_kernel_data.reshape(-1),
ctx_kernel.scale_inv,
ctx_kernel.scaling_mode,
dq_dtype=ctx_kernel.dq_dtype,
is_colwise=False,
data_layout="N",
flatten_axis=ctx_kernel.flatten_axis,
group_sizes=ctx_kernel.group_sizes,
original_shape=kernel_shape,
group_axis=ctx_kernel.group_axis,
)
if is_fp8_gemm_with_all_layouts_supported():
grouped_gemm_kernel = ctx_kernel
else:
grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1)
grouped_gemm_kernel = ScaledTensorFactory.create_1x(
grouped_gemm_kernel_data.reshape(-1),
ctx_kernel.scale_inv,
ctx_kernel.scaling_mode,
dq_dtype=ctx_kernel.dq_dtype,
is_colwise=True,
data_layout="T",
flatten_axis=ctx_kernel.flatten_axis,
group_sizes=ctx_kernel.group_sizes,
original_shape=kernel_shape,
group_axis=ctx_kernel.group_axis,
)
else:
grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS)
# Reset quantizer_set.kernel.q_layout to align the PyTree as the given one.
# This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled.
quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout
output = tex.grouped_gemm( output = tex.grouped_gemm(
grouped_gemm_x, grouped_gemm_x,
...@@ -393,7 +504,7 @@ def _grouped_dense_fwd_rule( ...@@ -393,7 +504,7 @@ def _grouped_dense_fwd_rule(
def _grouped_dense_bwd_rule( def _grouped_dense_bwd_rule(
contracting_dims, precision, preferred_element_type, group_offset, ctx, grad contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad
): ):
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
...@@ -474,11 +585,17 @@ def _grouped_dense_bwd_rule( ...@@ -474,11 +585,17 @@ def _grouped_dense_bwd_rule(
preferred_element_type=preferred_element_type, preferred_element_type=preferred_element_type,
group_offset=group_offset, group_offset=group_offset,
) )
kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info
if kernel_fsdp_mesh_axis is not None:
wgrad = _psum_scatter_kernel(
wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx
)
group_sizes_grad = None group_sizes_grad = None
dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
dkernel_amax = None
return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment