Commit 2b05e121 authored by yuguo's avatar yuguo
Browse files

Merge commit 'a69692ac' of...

Merge commit 'a69692ac' of https://github.com/NVIDIA/TransformerEngine
parents 0fd441c2 a69692ac
...@@ -153,28 +153,28 @@ def _dense_bwd_rule( ...@@ -153,28 +153,28 @@ def _dense_bwd_rule(
# GEMM NT # GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim = tuple( g_contracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
) )
# k_non_contracting_dims # k_non_contracting_dims
k_constracting_dim = tuple( k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad = tex.gemm( dgrad = tex.gemm(
casted_grad.get_rowwise_tensor(), casted_grad.get_rowwise_tensor(),
rowwise_casted_kernel, rowwise_casted_kernel,
(g_constracting_dim, k_constracting_dim), (g_contracting_dim, k_contracting_dim),
) )
dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes)
# GEMM TN # GEMM TN
# x_non_contracting_dims # x_non_contracting_dims
g_constracting_dim = x_constracting_dim = tuple( g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims)) range(0, len(x_shape) - len(fwd_x_contracting_dims))
) )
wgrad = tex.gemm( wgrad = tex.gemm(
colwise_casted_x, casted_grad.get_colwise_tensor(), (x_constracting_dim, g_constracting_dim) colwise_casted_x, casted_grad.get_colwise_tensor(), (x_contracting_dim, g_contracting_dim)
) )
wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes)
...@@ -184,135 +184,240 @@ def _dense_bwd_rule( ...@@ -184,135 +184,240 @@ def _dense_bwd_rule(
_dense.defvjp(_dense_fwd_rule, _dense_bwd_rule) _dense.defvjp(_dense_fwd_rule, _dense_bwd_rule)
"""
def grouped_dense( def grouped_dense(
x_list, x: jnp.ndarray,
kernel_list, kernel: jnp.ndarray,
bias_list, group_sizes: jnp.ndarray,
contracting_dims_list, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)),
quantizer_set_list=None, bias: jnp.ndarray = None,
precision: jax.lax.Precision = jax.lax.Precision.DEFAULT,
preferred_element_type: jnp.dtype = None,
group_offset: jnp.array = None,
quantizer_set: QuantizerSet = noop_quantizer_set,
): ):
# Perform grouped_dense layer transformation with optional quantization. """
Perform grouped dense (linear) layer transformation with optional quantization.
output_list = _grouped_dense( Args:
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list x: Input tensor of shape (M, K)
kernel: Weight matrix of shape (G, K, N)
group_sizes: 1D array of shape (G,) specifying the size of each group
contracting_dims: Tuple of sequences specifying which dimensions to contract
(currently only supports ((1,), (1,)))
bias: Bias tensor of shape (G, N)
precision: JAX precision for the GEMM operation
preferred_element_type: Preferred data type for the output tensor
group_offset: 1D array containing offsets for each group (not yet implemented)
quantizer_set: Set of quantizers for FP8 quantization of the input and output
Returns:
A jnp.ndarray containing the result of the grouped linear operation
"""
output = _grouped_dense(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
) )
return output_list return output
@partial(jax.custom_vjp, nondiff_argnums=(3,)) @partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7))
def _grouped_dense(x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list): def _grouped_dense(
output_list, _ = _grouped_dense_fwd_rule( x,
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
):
output, _ = _grouped_dense_fwd_rule(
x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
) )
return output_list return output
def _grouped_dense_fwd_rule( def _grouped_dense_fwd_rule(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list x,
kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
quantizer_set,
): ):
use_bias = bias_list is not None use_bias = bias is not None
output_list = [] is_noop_quantizer_set = quantizer_set == noop_quantizer_set
x_rowwise_list = []
x_colwise_list = [] if is_noop_quantizer_set:
kernel_colwise_list = [] grouped_gemm_x = x
kernel_rowwise_list = [] grouped_gemm_kernel = kernel
x_shape_list = [] ctx_x = x
kernel_shape_list = [] ctx_kernel = kernel
if quantizer_set_list is None: flatten_axis_k = None
x_rowwise_list = x_list
x_colwise_list = x_list
kernel_colwise_list = kernel_list
kernel_rowwise_list = kernel_list
x_shape_list = [x.shape for x in x_list]
kernel_shape_list = [kernel.shape for kernel in kernel_list]
else: else:
for i in range(len(x_list)): # pylint: disable=consider-using-enumerate x_contracting_dims, k_contracting_dims = contracting_dims
q_x = tex.quantize(x_list[i], quantizer_set_list[i].x) flatten_axis_x = -len(x_contracting_dims)
q_kernel = tex.quantize(kernel_list[i], quantizer_set_list[i].kernel) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis
x_rowwise_list.append(q_x.get_rowwise_tensor())
x_colwise_list.append(q_x.get_colwise_tensor()) assert x.ndim == 2, "Grouped dense expects a 2D input tensor of shape (M, K)"
kernel_colwise_list.append(q_kernel.get_colwise_tensor()) assert kernel.ndim == 3, "Grouped dense expects a 3D kernel tensor of shape (G, K, N)"
kernel_rowwise_list.append(q_kernel.get_rowwise_tensor()) # Expected k_contracting_dims == (1,), need to tweak it for grouped_gemm FP8 extra transpose
x_shape_list.append(x_rowwise_list[-1].data.shape) # TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
kernel_shape_list.append(kernel_rowwise_list[-1].data.shape) assert x_contracting_dims == (1,) and k_contracting_dims == (1,), (
"grouped_dense for FP8 can only handle x_contracting_dims=(1,) "
output_list = tex.grouped_gemm( "and k_contracting_dims=(1,) for now, "
x_rowwise_list, kernel_colwise_list, contracting_dims_list, bias_list f"got {x_contracting_dims=} and {k_contracting_dims=}"
)
k_contracting_dims = (0,)
casted_x = tex.grouped_quantize(
x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x
)
casted_kernel = tex.grouped_quantize(
kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k
)
contracting_dims = (x_contracting_dims, k_contracting_dims)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we should have
# rowwise_casted_x.original_shape == (M, K)
# colwise_casted_kernel.original_shape == (G, N, K)
grouped_gemm_x = casted_x.get_rowwise_tensor()
grouped_gemm_kernel = casted_kernel.get_colwise_tensor()
# TODO(Hua): Shall we give warning/error if not quantizer_set.x.is_2x2x()?
ctx_x = casted_x.get_colwise_tensor() if quantizer_set.x.is_2x2x() else None
ctx_kernel = casted_kernel.get_rowwise_tensor() if quantizer_set.kernel.is_2x2x() else None
output = tex.grouped_gemm(
grouped_gemm_x,
grouped_gemm_kernel,
group_sizes,
contracting_dims,
bias,
precision,
preferred_element_type,
group_offset,
) )
ctx = ( ctx = (
x_colwise_list, group_sizes,
kernel_rowwise_list, ctx_x,
x_shape_list, ctx_kernel,
kernel_shape_list, x.shape,
kernel.shape,
use_bias, use_bias,
quantizer_set_list, is_noop_quantizer_set,
quantizer_set,
flatten_axis_k,
) )
return output_list, ctx return output, ctx
def _grouped_dense_bwd_rule(
contracting_dims, precision, preferred_element_type, group_offset, ctx, grad
):
fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims
def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
( (
colwise_x_list, group_sizes,
rowwise_kernel_list, ctx_x,
x_shape_list, ctx_kernel,
kernel_shape_list, x_shape,
kernel_shape,
use_bias, use_bias,
quantizer_set_list, is_noop_quantizer_set,
quantizer_set,
flatten_axis_k,
) = ctx ) = ctx
group_size = len(grad_list) if is_noop_quantizer_set:
dbias_list = [] # The 1 in range is for excluding the group dimension (shall we use the hardcoded results below?)
grad_rowwise_list = [] # g_contracting_dim = (1, )
grad_colwise_list = [] # k_contracting_dim = (2, )
dgrad_contracting_dims_list = []
wgrad_contracting_dims_list = []
for i in range(group_size):
grad = grad_list[i]
x_shape = x_shape_list[i]
kernel_shape = kernel_shape_list[i]
fwd_contracting_dims = contracting_dims_list[i]
if quantizer_set_list is None:
casted_grad = grad
dbias = tex.quantization._jax_dbias(grad)
grad_rowwise_list.append(grad)
grad_colwise_list.append(grad)
else:
quantizer_set = quantizer_set_list[i]
casted_grad, dbias = tex.quantize_dbias(
grad, is_dbias=use_bias, quantizer=quantizer_set.dgrad
)
grad_rowwise_list.append(casted_grad.get_rowwise_tensor())
grad_colwise_list.append(casted_grad.get_colwise_tensor())
dbias_list.append(dbias)
# GEMM NT
fwd_x_contracting_dims, fwd_k_contracting_dims = fwd_contracting_dims
g_contracting_dim = tuple( g_contracting_dim = tuple(
range(grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim) range(1 + grad.ndim - len(kernel_shape) + len(fwd_k_contracting_dims), grad.ndim)
) )
k_contracting_dim = tuple( k_contracting_dim = tuple(
dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims dim for dim in range(1, len(kernel_shape)) if dim not in fwd_k_contracting_dims
) )
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim) dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_contracting_dims_list.append(dgrad_contracting_dims) dgrad_grad = grad
dgrad_kernel_T = ctx_kernel
# GEMM TN # g_contracting_dim = (0, )
# x_contracting_dim = (0, )
g_contracting_dim = x_contracting_dim = tuple( g_contracting_dim = x_contracting_dim = tuple(
range(0, len(x_shape) - len(fwd_x_contracting_dims)) range(0, len(x_shape) - len(fwd_x_contracting_dims))
) )
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim) wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_contracting_dims_list.append(wgrad_contracting_dims) wgrad_x_T = ctx_x
wgrad_grad = grad
else:
casted_grad = tex.grouped_quantize(
grad, quantizer_set.dgrad, group_sizes, flatten_axis=flatten_axis_k
)
# For x_contracting_dims == (1,) and k_contracting_dims == (1,), we need to use
# g_contracting_dim = (1,) and k_contracting_dim = (2,) to make it work after the
# extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (1,)
k_contracting_dim = (2,)
dgrad_contracting_dims = (g_contracting_dim, k_contracting_dim)
dgrad_grad = casted_grad.get_rowwise_tensor()
dgrad_kernel_T = ctx_kernel
# We need to use g_contracting_dim = (0,) and x_contracting_dim = (1,) to make it work
# after the extra transpose for FP8 in grouped_gemm
# TODO(Hua): Do we have a better way for this? What if is_gemm_with_all_layouts_supported()?
g_contracting_dim = (0,)
x_contracting_dim = (0,)
wgrad_contracting_dims = (x_contracting_dim, g_contracting_dim)
wgrad_x_T = ctx_x
wgrad_grad = casted_grad.get_colwise_tensor()
dgrad = tex.grouped_gemm(
dgrad_grad,
dgrad_kernel_T,
group_sizes,
dgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
)
dgrad_list = tex.grouped_gemm( wgrad = tex.grouped_gemm(
grad_rowwise_list, rowwise_kernel_list, dgrad_contracting_dims_list wgrad_x_T,
wgrad_grad,
group_sizes,
wgrad_contracting_dims,
precision=precision,
preferred_element_type=preferred_element_type,
group_offset=group_offset,
) )
wgrad_list = tex.grouped_gemm(colwise_x_list, grad_colwise_list, wgrad_contracting_dims_list)
return dgrad_list, wgrad_list, dbias_list, quantizer_set_list group_sizes_grad = None
dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None
return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
...@@ -594,8 +594,16 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -594,8 +594,16 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
seqlen_kv = seqlen_q seqlen_kv = seqlen_q
else: else:
seqlen_kv = key.shape[sequence_dim] seqlen_kv = key.shape[sequence_dim]
if qkv_layout.is_separate():
head_dim_qk = query.shape[-1]
head_dim_v = value.shape[-1]
else:
head_dim_qk = self.head_dim
head_dim_v = self.head_dim
has_fused_attn_kernel = is_fused_attn_kernel_available( has_fused_attn_kernel = is_fused_attn_kernel_available(
# This needs to be fixed: TE-Jax has historically correlated training mode with deterministic mode.
not deterministic,
self.dtype, self.dtype,
self.dtype, self.dtype,
qkv_layout, qkv_layout,
...@@ -606,7 +614,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -606,7 +614,8 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
self.num_gqa_groups, self.num_gqa_groups,
seqlen_q, seqlen_q,
seqlen_kv, seqlen_kv,
self.head_dim, head_dim_qk,
head_dim_v,
self.window_size, self.window_size,
) )
...@@ -619,7 +628,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -619,7 +628,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
"Please try to update the cuDNN and TE to the latest version.\n" "Please try to update the cuDNN and TE to the latest version.\n"
f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n" f"{self.dtype=}\n{qkv_layout=}\n{attn_bias_type=}\n{attn_mask_type=}\n"
f"{self.attention_dropout=}\n{self.num_attention_heads=}\n" f"{self.attention_dropout=}\n{self.num_attention_heads=}\n"
f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{self.head_dim=}\n" f"{self.num_gqa_groups=}\n{seqlen_q=}\n{seqlen_kv=}\n{head_dim_qk=}\n{head_dim_v=}\n"
) )
dropout_rng = None dropout_rng = None
...@@ -627,7 +636,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods ...@@ -627,7 +636,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
dropout_rng = self.make_rng(self.dropout_rng_name) dropout_rng = self.make_rng(self.dropout_rng_name)
if self.scale_factor is None: if self.scale_factor is None:
scale_factor = 1.0 / sqrt(self.head_dim) scale_factor = 1.0 / sqrt(head_dim_qk)
else: else:
scale_factor = self.scale_factor scale_factor = self.scale_factor
del self.scale_factor del self.scale_factor
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[build-system]
requires = ["setuptools>=61.0", "pybind11[global]", "pip", "jax[cuda12]", "flax>=0.7.1"]
# Use legacy backend to import local packages in setup.py
build-backend = "setuptools.build_meta:__legacy__"
...@@ -7,24 +7,54 @@ Dequantization utilities for TE/JAX. ...@@ -7,24 +7,54 @@ Dequantization utilities for TE/JAX.
This module provides utilities for dequantizing tensors that have been quantized This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling. using various scaling modes, including delayed scaling and block scaling.
""" """
import math
from dataclasses import dataclass
from abc import ABC, abstractmethod
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
__all__ = ["Dequantizer"] __all__ = ["ScalingModeToDequantizerMap"]
@dataclass
class Dequantizer(ABC):
"""
Base Dequantizer Class
"""
@staticmethod
@abstractmethod
def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
pass
@staticmethod
@abstractmethod
def dequantize(scaled_tensor):
"""Dequantizing given tensor to higher precision."""
class Dequantizer: class TensorScaleDequantizer(Dequantizer):
"""Encapsulation class for dequantization helpers. """
TensorScaling Dequantizer Class
This class provides static methods for dequantizing tensors that have been This class provides static methods for dequantizing tensors that have been
quantized using different scaling modes. It supports both delayed scaling quantized using different tensor scaling modes. It supports both delayed scaling
and block scaling modes. and current scaling modes.
""" """
@staticmethod @staticmethod
def _dq_func_tensor_scaling(scaled_tensor): def _dequantize_func(data, scale_inv, dq_dtype, **kwargs):
del kwargs
return jnp.asarray(
data.astype(jnp.float32) * scale_inv.astype(jnp.float32),
dq_dtype,
)
@staticmethod
def dequantize(scaled_tensor):
"""Dequantize a tensor using delayed scaling. """Dequantize a tensor using delayed scaling.
This function dequantizes a tensor that was quantized using delayed scaling This function dequantizes a tensor that was quantized using delayed scaling
...@@ -36,36 +66,48 @@ class Dequantizer: ...@@ -36,36 +66,48 @@ class Dequantizer:
Returns: Returns:
The dequantized tensor in the specified data type The dequantized tensor in the specified data type
""" """
return jnp.asarray( return TensorScaleDequantizer._dequantize_func(
scaled_tensor.data.astype(jnp.float32) * scaled_tensor.scale_inv.astype(jnp.float32), scaled_tensor.data, scaled_tensor.scale_inv, scaled_tensor.dq_dtype
scaled_tensor.dq_dtype,
) )
class BlockScaleDequantizer(Dequantizer):
"""BlockScaling Dequantizer Class.
This class provides static methods for dequantizing tensors that have been
quantized using block scaling modes.
"""
@staticmethod @staticmethod
def _dq_func_block_scaling(scaled_tensor): def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatten_axis):
"""Dequantize a tensor using block scaling. """Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args: Args:
scaled_tensor: The quantized tensor to dequantize data: The quantized tensor data
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns: Returns:
The dequantized tensor in the specified data type The dequantized tensor
""" """
data = scaled_tensor.data.astype(jnp.float32)
data = data.astype(jnp.float32)
scale_inv = scale_inv.view(jnp.uint8).astype(jnp.float32)
data_shape = data.shape data_shape = data.shape
scale = scaled_tensor.scale_inv.view(jnp.uint8).astype(jnp.float32)
flatten_axis = scaled_tensor.flatten_axis
flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert ( assert (
0 < flatten_axis < len(data_shape) 0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
scale_shape = scaled_tensor.scaling_mode.get_scale_shape( scale_shape = scaling_mode.get_scale_shape(
data_shape, scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis
) )
scale = jax.lax.slice(scale, [0] * len(scale_shape), scale_shape) # slice out the padding scale_inv = jax.lax.slice(
scale_inv, [0] * len(scale_shape), scale_shape
) # slice out the padding
data = data.reshape( data = data.reshape(
*data_shape[: flatten_axis - 1], *data_shape[: flatten_axis - 1],
...@@ -76,31 +118,106 @@ class Dequantizer: ...@@ -76,31 +118,106 @@ class Dequantizer:
int(data_shape[-1] / scale_shape[-1]), int(data_shape[-1] / scale_shape[-1]),
) )
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers. scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1))
scale = jnp.expand_dims(scale, axis=(flatten_axis + 2 - 2, -1))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return jnp.asarray(data * jnp.power(2, scale - 127), scaled_tensor.dq_dtype).reshape(
data_shape
)
funcs = { # E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
ScalingMode.DELAYED_TENSOR_SCALING: _dq_func_tensor_scaling, return jnp.asarray(data * jnp.power(2, scale_inv - 127), dq_dtype).reshape(data_shape)
ScalingMode.CURRENT_TENSOR_SCALING: _dq_func_tensor_scaling,
ScalingMode.MXFP8_1D_SCALING: _dq_func_block_scaling,
}
@staticmethod @staticmethod
def dequantize(scaled_tensor): def dequantize(scaled_tensor):
"""Dequantize a scaled tensor using the appropriate scaling mode. """Dequantize a tensor using block scaling.
Args:
data: The quantized tensor data
scale_inv: The inverse scaling factors
dq_dtype: The data type for dequantized values
scaling_mode: The scaling mode used for quantization
is_colwise: Whether the scaling is column-wise
flatten_axis: The axis along which the tensor could be flattened to 2D
Returns:
The dequantized tensor
"""
return BlockScaleDequantizer._dequantize_func(
scaled_tensor.data,
scaled_tensor.scale_inv,
scaled_tensor.dq_dtype,
scaled_tensor.scaling_mode,
scaled_tensor.is_colwise,
scaled_tensor.flatten_axis,
)
ScalingModeToDequantizerMap = {
ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer,
ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer,
}
This method selects the appropriate dequantization function based on the
scaling mode used for quantization and applies it to the tensor. @staticmethod
def _grouped_dequantize(grouped_scaled_tensor):
"""Dequantize a grouped tensor.
Args: Args:
scaled_tensor: The quantized tensor to dequantize grouped_scaled_tensor: The grouped scaled tensor to dequantize
Returns: Returns:
The dequantized tensor in the specified data type List of dequantized tensors for each group
""" """
dq_func = Dequantizer.funcs[scaled_tensor.scaling_mode] data = grouped_scaled_tensor.data
return dq_func(scaled_tensor) scale_inv = grouped_scaled_tensor.scale_inv
group_sizes = grouped_scaled_tensor.group_sizes
flatten_axis = grouped_scaled_tensor.flatten_axis
scaling_mode = grouped_scaled_tensor.scaling_mode
original_shape = grouped_scaled_tensor.original_shape
group_axis = grouped_scaled_tensor.group_axis
flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
output = []
non_group_shape = tuple(
original_shape[i] for i in range(len(original_shape)) if i != group_axis
)
matrix_sizes = group_sizes * math.prod(non_group_shape)
data = jnp.split(data, jnp.cumulative_sum(matrix_sizes)[:-1])
scale_inv_ptr = 0
for i, data_i in enumerate(data):
data_shape_i = (
*original_shape[:group_axis],
group_sizes[i],
*original_shape[group_axis + 1 :],
)
assert math.prod(data_shape_i) == data_i.size, (
f"math.prod({data_shape_i}) = {math.prod(data_shape_i)} which is not equal to"
f" {data_i.size}"
)
scale_shape_i = scaling_mode.get_scale_shape(
data_shape_i,
grouped_scaled_tensor.is_colwise,
is_padded=True,
flatten_axis=flatten_axis,
)
scale_shape_i_size = math.prod(scale_shape_i)
scale_inv_i = scale_inv[scale_inv_ptr : scale_inv_ptr + scale_shape_i_size]
dequantizer_type = ScalingModeToDequantizerMap.get(grouped_scaled_tensor.scaling_mode)
if len(data_i) == 0:
out_i = []
else:
out_i = dequantizer_type._dequantize_func(
data_i.reshape(data_shape_i),
scale_inv_i.reshape(scale_shape_i),
grouped_scaled_tensor.dq_dtype,
scaling_mode=grouped_scaled_tensor.scaling_mode,
is_colwise=grouped_scaled_tensor.is_colwise,
flatten_axis=grouped_scaled_tensor.flatten_axis,
)
output.append(out_i)
scale_inv_ptr += scale_shape_i_size
return output
Dequantizer.grouped_dequantize = _grouped_dequantize
...@@ -9,7 +9,8 @@ This module provides classes and utilities for quantizing tensors in JAX. ...@@ -9,7 +9,8 @@ This module provides classes and utilities for quantizing tensors in JAX.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial
from typing import Union, Optional from typing import Union, Optional, Tuple
import warnings
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
...@@ -17,7 +18,7 @@ from jax.tree_util import register_pytree_node_class ...@@ -17,7 +18,7 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .tensor import ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory
from .helper import ( from .helper import (
QuantizeConfig, QuantizeConfig,
AmaxComputeAlgo, AmaxComputeAlgo,
...@@ -30,6 +31,7 @@ __all__ = [ ...@@ -30,6 +31,7 @@ __all__ = [
"CurrentScaleQuantizer", "CurrentScaleQuantizer",
"DelayedScaleQuantizer", "DelayedScaleQuantizer",
"BlockScaleQuantizer", "BlockScaleQuantizer",
"GroupedQuantizer",
"QuantizerFactory", "QuantizerFactory",
"noop_quantizer_set", "noop_quantizer_set",
"compute_scale_from_amax", "compute_scale_from_amax",
...@@ -74,6 +76,7 @@ class Quantizer(ABC): ...@@ -74,6 +76,7 @@ class Quantizer(ABC):
q_dtype: jnp.dtype q_dtype: jnp.dtype
scaling_mode: ScalingMode scaling_mode: ScalingMode
q_layout: QuantizeLayout q_layout: QuantizeLayout
data_layout: str
def tree_flatten(self): def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations. """Flatten the quantizer for JAX tree operations.
...@@ -82,7 +85,7 @@ class Quantizer(ABC): ...@@ -82,7 +85,7 @@ class Quantizer(ABC):
Tuple of (children, aux_data) for tree operations Tuple of (children, aux_data) for tree operations
""" """
children = () children = ()
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout)
return (children, aux_data) return (children, aux_data)
@classmethod @classmethod
...@@ -110,13 +113,22 @@ class Quantizer(ABC): ...@@ -110,13 +113,22 @@ class Quantizer(ABC):
""" """
return self.q_layout == QuantizeLayout.ROWWISE_COLWISE return self.q_layout == QuantizeLayout.ROWWISE_COLWISE
@abstractmethod
def get_data_layout(self) -> str: def get_data_layout(self) -> str:
"""Get the data data_layout. """Get the data data_layout string.
Returns: Returns:
Data data_layout in string format Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
""" """
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return self.data_layout
if self.q_layout == QuantizeLayout.ROWWISE:
return self.data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE:
return self.data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}")
@abstractmethod @abstractmethod
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
...@@ -132,7 +144,9 @@ class Quantizer(ABC): ...@@ -132,7 +144,9 @@ class Quantizer(ABC):
A ScaledTensor1x containing the quantized data A ScaledTensor1x containing the quantized data
""" """
def quantize(self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1): def quantize(
self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs
) -> ScaledTensor:
"""Quantize a tensor using the internal _quantize_func(). """Quantize a tensor using the internal _quantize_func().
Args: Args:
...@@ -145,6 +159,7 @@ class Quantizer(ABC): ...@@ -145,6 +159,7 @@ class Quantizer(ABC):
Returns: Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data A ScaledTensor1x or ScaledTensor2x containing the quantized data
""" """
del kwargs
if (is_rowwise and is_colwise) or self.is_2x2x(): if (is_rowwise and is_colwise) or self.is_2x2x():
rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
colwise_tensor = self._quantize_func( colwise_tensor = self._quantize_func(
...@@ -159,7 +174,7 @@ class Quantizer(ABC): ...@@ -159,7 +174,7 @@ class Quantizer(ABC):
return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) return self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis)
def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1): def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, **kwargs):
"""Get shapes for scale tensors. """Get shapes for scale tensors.
Args: Args:
...@@ -169,6 +184,7 @@ class Quantizer(ABC): ...@@ -169,6 +184,7 @@ class Quantizer(ABC):
Returns: Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape) Tuple of (rowwise_scale_shape, colwise_scale_shape)
""" """
del kwargs
return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis) return self.scaling_mode.get_scale_shape_2x(data_shape, is_padded, flatten_axis)
def get_scale_dtype(self): def get_scale_dtype(self):
...@@ -194,24 +210,7 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -194,24 +210,7 @@ class CurrentScaleQuantizer(Quantizer):
scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
data_layout: str = "NT"
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data data_layout in string format
Raises:
ValueError: If quantization axis is invalid
"""
data_layout = "NT"
if self.q_layout == QuantizeLayout.ROWWISE_COLWISE:
return data_layout
if self.q_layout == QuantizeLayout.ROWWISE:
return data_layout[0]
if self.q_layout == QuantizeLayout.COLWISE:
return data_layout[1]
raise ValueError(f"Invalid q_layout: {self.q_layout}")
def _quantize_func( def _quantize_func(
self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1
...@@ -230,16 +229,11 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -230,16 +229,11 @@ class CurrentScaleQuantizer(Quantizer):
compute_dtype = jnp.float32 compute_dtype = jnp.float32
dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
amax = jnp.max(jnp.abs(x)).reshape((1,)).astype(compute_dtype) amax = jnp.max(jnp.abs(x)).reshape((1,))
fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32)
scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN)
scaled_x = x.astype(compute_dtype) * scale scaled_x = x.astype(compute_dtype) * scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype)
scale_inv = 1.0 / scale scale_inv = 1.0 / scale
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
...@@ -295,6 +289,7 @@ class CurrentScaleQuantizer(Quantizer): ...@@ -295,6 +289,7 @@ class CurrentScaleQuantizer(Quantizer):
data_layout="T", data_layout="T",
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
) )
if is_colwise and is_rowwise: if is_colwise and is_rowwise:
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
if is_colwise: if is_colwise:
...@@ -332,7 +327,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): ...@@ -332,7 +327,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer):
Tuple of (children, aux_data) for tree operations Tuple of (children, aux_data) for tree operations
""" """
children = (self.scale, self.amax_history) children = (self.scale, self.amax_history)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout) aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout)
return (children, aux_data) return (children, aux_data)
def _quantize_func( def _quantize_func(
...@@ -447,16 +442,7 @@ class BlockScaleQuantizer(Quantizer): ...@@ -447,16 +442,7 @@ class BlockScaleQuantizer(Quantizer):
scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING scaling_mode: ScalingMode = ScalingMode.MXFP8_1D_SCALING
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE
data_layout: str = "NN"
def get_data_layout(self) -> str:
"""Get the data data_layout string.
Returns:
Data data_layout in string format
"""
if self.is_2x2x():
return "NN"
return "N"
def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x:
"""Quantize function helper for block scaling FP8. """Quantize function helper for block scaling FP8.
...@@ -591,6 +577,189 @@ class QuantizerSet: ...@@ -591,6 +577,189 @@ class QuantizerSet:
return cls(*aux_data, *children) return cls(*aux_data, *children)
@register_pytree_node_class
@dataclass
class GroupedQuantizer(Quantizer):
"""Quantizer for grouped arrays.
This class extends Quantizer to support quantization of arrays in grouped manner,
where elements are grouped along a specified axis then quantized separately.
Attributes:
data_layout: The data layout specification
n_groups: Number of groups for quantization
quantizers: Tuple of quantizers for each group
"""
data_layout: str = None
n_groups: int = 1
quantizers: Tuple[Quantizer] = field(default_factory=lambda: (None,))
def tree_flatten(self):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children = (self.quantizers,)
aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.n_groups)
return (children, aux_data)
def __post_init__(self):
if self.quantizers[0] is None:
self.quantizers = QuantizerFactory.create(
self.n_groups, self.scaling_mode, self.q_dtype, self.q_layout
)
self.data_layout = self.quantizers[0].data_layout
def _create_grouped_tensor_from_tensor_list(
self, tensor_list, group_sizes, original_shape, group_axis, mode
):
# mode 0 = concate, mode 1 = add
# TODO(Ming Huang): Consider to apply Enum for mode.
assert mode in [0, 1]
grouped_data = (
[] if mode == 0 else jnp.zeros(tensor_list[0].data.shape, tensor_list[0].data.dtype)
)
grouped_scale_inv = []
for tensor in tensor_list:
if mode == 0:
grouped_data.append(tensor.data.flatten())
else:
grouped_data += tensor.data
grouped_scale_inv.append(tensor.scale_inv.flatten())
grouped_data = jnp.concatenate(grouped_data) if mode == 0 else grouped_data.flatten()
grouped_scale_inv = jnp.concatenate(grouped_scale_inv)
return ScaledTensorFactory.create_1x(
grouped_data,
grouped_scale_inv,
self.scaling_mode,
tensor_list[0].dq_dtype,
tensor_list[0].is_colwise,
tensor_list[0].data_layout,
tensor_list[0].flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
def _quantize_func(self, *args, **kwargs):
pass
def quantize(
self,
x,
is_rowwise: bool = None,
is_colwise: bool = None,
dq_dtype=None,
flatten_axis=-1,
group_sizes=None,
group_axis=0,
):
"""Quantize a tensor in grouped manner.
Expected input shape: [M, K] or [G, K, N]
Split to x.shape[group_axis] number of groups if group_sizes is not given
Args:
x: Input tensor to quantize
is_rowwise: Whether to use row-wise quantization
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
group_sizes: Array of ints containing the size of each group (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns:
A ScaledTensor1x or ScaledTensor2x containing the quantized data
"""
assert group_axis == 0, "Only group_axis == 0 is supported now!"
dq_dtype = dq_dtype if dq_dtype is not None else x.dtype
if flatten_axis < 0:
flatten_axis += x.ndim
assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!"
is_rowwise = (
is_rowwise
if is_rowwise is not None
else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x())
)
is_colwise = (
is_colwise
if is_colwise is not None
else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x())
)
assert is_rowwise or is_colwise, "No quantization layout is specified"
original_shape = x.shape
if group_sizes is not None:
assert not is_colwise, "Not yet implememted!"
assert group_sizes.ndim == 1, (
"GroupedQuantizer only support 1D group_sizes, got group_sizes.ndim ="
f" {group_sizes.ndim}"
)
_zeros = partial(jax.lax.full_like, fill_value=0)
x_iota = jax.lax.broadcasted_iota(group_sizes.dtype, x.shape, 0)
group_ends = jnp.cumulative_sum(group_sizes)
group_starts = jax.lax.concatenate(
[_zeros(group_sizes)[:1], group_ends[:-1]],
dimension=0,
)
x_zero = _zeros(x)
tensor_list = []
for i in range(len(group_sizes)):
mask = jax.lax.bitwise_and(group_starts[i] <= x_iota, x_iota < group_ends[i])
x_selected = jax.lax.select(mask, x, x_zero)
tensor = self.quantizers[i].quantize(
x_selected, is_rowwise, is_colwise, dq_dtype, flatten_axis
)
tensor_list.append(tensor)
combine_mode = 1 # Add
else:
group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32)
x = jnp.split(x, x.shape[group_axis], axis=group_axis)
tensor_list = []
for i in range(len(group_sizes)):
tensor = self.quantizers[i].quantize(
x[i], is_rowwise, is_colwise, dq_dtype, flatten_axis
)
tensor_list.append(tensor)
combine_mode = 0 # Concate
grouped_rowwise_tensor = grouped_colwise_tensor = None
if is_rowwise:
rowwise_tensor_list = [tensor.get_rowwise_tensor() for tensor in tensor_list]
grouped_rowwise_tensor = self._create_grouped_tensor_from_tensor_list(
rowwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode
)
if is_colwise:
colwise_tensor_list = [tensor.get_colwise_tensor() for tensor in tensor_list]
grouped_colwise_tensor = self._create_grouped_tensor_from_tensor_list(
colwise_tensor_list, group_sizes, original_shape, group_axis, combine_mode
)
if is_colwise and is_rowwise:
return ScaledTensor2x(grouped_rowwise_tensor, grouped_colwise_tensor)
if is_colwise:
return grouped_colwise_tensor
return grouped_rowwise_tensor
def get_scale_shapes(self, data_shape, is_padded=True, flatten_axis=-1, group_sizes=None):
assert group_sizes, "Empty group_sizes was given!"
return self.scaling_mode.get_grouped_scale_shape_2x(
data_shape, group_sizes, is_padded, flatten_axis
)
@dataclass @dataclass
class QuantizerFactory: class QuantizerFactory:
"""Factory class for creating quantizers. """Factory class for creating quantizers.
...@@ -611,6 +780,7 @@ class QuantizerFactory: ...@@ -611,6 +780,7 @@ class QuantizerFactory:
scaling_mode: ScalingMode = None, scaling_mode: ScalingMode = None,
q_dtype: jnp.dtype = None, q_dtype: jnp.dtype = None,
q_layout: QuantizeLayout = None, q_layout: QuantizeLayout = None,
n_groups: int = None,
**kwargs, **kwargs,
) -> Quantizer: ) -> Quantizer:
"""Create one or more quantizers with specified parameters. """Create one or more quantizers with specified parameters.
...@@ -621,6 +791,7 @@ class QuantizerFactory: ...@@ -621,6 +791,7 @@ class QuantizerFactory:
q_dtype: Quantization data type q_dtype: Quantization data type
q_layout: Quantization axis q_layout: Quantization axis
flatten_axis: The quantization axis for the tensor flatten_axis: The quantization axis for the tensor
n_groups: Number of quantizers if GroupedQuantizer
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
Returns: Returns:
...@@ -628,13 +799,21 @@ class QuantizerFactory: ...@@ -628,13 +799,21 @@ class QuantizerFactory:
""" """
# (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted
assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type"
# import pdb; pdb.set_trace() if n_groups:
if n_quantizers != 1:
warnings.warn(
"Using more than one GroupedQuantizer for a grouped input is not recommended"
)
quantizer_type = GroupedQuantizer
kwargs["n_groups"] = n_groups
else:
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
if scaling_mode == ScalingMode.NO_SCALING: if scaling_mode == ScalingMode.NO_SCALING:
quantizers = [None] * n_quantizers quantizers = [None] * n_quantizers
else: else:
quantizers = [] quantizers = []
for _ in range(n_quantizers): for _ in range(n_quantizers):
quantizer_type = QuantizerFactory.quantizer_type_map.get(scaling_mode)
quantizers.append( quantizers.append(
quantizer_type( quantizer_type(
q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs
...@@ -643,7 +822,9 @@ class QuantizerFactory: ...@@ -643,7 +822,9 @@ class QuantizerFactory:
return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) return quantizers[0] if len(quantizers) == 1 else tuple(quantizers)
@staticmethod @staticmethod
def _create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) -> QuantizerSet: def _create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
) -> QuantizerSet:
"""Create a set of quantizers for forward and backward passes. """Create a set of quantizers for forward and backward passes.
Args: Args:
...@@ -651,6 +832,7 @@ class QuantizerFactory: ...@@ -651,6 +832,7 @@ class QuantizerFactory:
fwd_dtype: Data type for forward pass fwd_dtype: Data type for forward pass
bwd_dtype: Data type for backward pass bwd_dtype: Data type for backward pass
is_2x2x: Whether to use 2x2x quantization is_2x2x: Whether to use 2x2x quantization
n_groups
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
Returns: Returns:
...@@ -680,11 +862,13 @@ class QuantizerFactory: ...@@ -680,11 +862,13 @@ class QuantizerFactory:
else: else:
args_x = args_kernel = args_grad = {} args_x = args_kernel = args_grad = {}
q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, **args_x) q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x)
q_kernel = QuantizerFactory.create( q_kernel = QuantizerFactory.create(
1, scaling_mode, fwd_dtype, q_layout_kernel, **args_kernel 1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel
)
q_dgrad = QuantizerFactory.create(
1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad
) )
q_dgrad = QuantizerFactory.create(1, scaling_mode, bwd_dtype, q_layout_dgrad, **args_grad)
return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad)
@staticmethod @staticmethod
...@@ -694,6 +878,7 @@ class QuantizerFactory: ...@@ -694,6 +878,7 @@ class QuantizerFactory:
fwd_dtype: jnp.dtype = None, fwd_dtype: jnp.dtype = None,
bwd_dtype: jnp.dtype = None, bwd_dtype: jnp.dtype = None,
is_2x2x: bool = None, is_2x2x: bool = None,
n_groups: int = None,
**kwargs, **kwargs,
) -> tuple[Union[tuple[Quantizer], None]]: ) -> tuple[Union[tuple[Quantizer], None]]:
"""Create one or more sets of quantizers. """Create one or more sets of quantizers.
...@@ -704,6 +889,7 @@ class QuantizerFactory: ...@@ -704,6 +889,7 @@ class QuantizerFactory:
fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE
bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE
is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X
n_groups:
**kwargs: Additional arguments for quantizer initialization **kwargs: Additional arguments for quantizer initialization
Returns: Returns:
...@@ -717,7 +903,9 @@ class QuantizerFactory: ...@@ -717,7 +903,9 @@ class QuantizerFactory:
q_set = [] q_set = []
for _ in range(n_quantizer_sets): for _ in range(n_quantizer_sets):
q_set.append( q_set.append(
QuantizerFactory._create_set(scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, **kwargs) QuantizerFactory._create_set(
scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs
)
) )
return q_set[0] if len(q_set) == 1 else tuple(q_set) return q_set[0] if len(q_set) == 1 else tuple(q_set)
......
...@@ -15,6 +15,7 @@ from enum import Enum ...@@ -15,6 +15,7 @@ from enum import Enum
from typing import Tuple, Dict from typing import Tuple, Dict
from functools import reduce from functools import reduce
import operator import operator
import numpy as np
from jax.experimental.custom_partitioning import CompoundFactor from jax.experimental.custom_partitioning import CompoundFactor
from jax.tree_util import register_pytree_node_class from jax.tree_util import register_pytree_node_class
...@@ -26,6 +27,11 @@ from transformer_engine_jax import JAXX_Scaling_Mode ...@@ -26,6 +27,11 @@ from transformer_engine_jax import JAXX_Scaling_Mode
__all__ = ["QuantizeShardyRules", "ScalingMode"] __all__ = ["QuantizeShardyRules", "ScalingMode"]
def DIVUP(a, b):
"Divide a by b and then round up"
return -(a // -b)
@dataclass @dataclass
class QuantizeShardyRules: class QuantizeShardyRules:
"""Information necessary to shard scale tensors with Shardy. """Information necessary to shard scale tensors with Shardy.
...@@ -74,7 +80,26 @@ class ScalingModeMetadataImpl(ABC): ...@@ -74,7 +80,26 @@ class ScalingModeMetadataImpl(ABC):
data_shape: The shape of the tensor being quantized data_shape: The shape of the tensor being quantized
is_colwise: Whether the scaling is column-wise is_colwise: Whether the scaling is column-wise
is_padded: Whether to return padded shape is_padded: Whether to return padded shape
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
The shape for scale tensors
"""
@abstractmethod
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
n_groups: Number of groups in grouped quantization
group_axis: The axis along which grouping is performed
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns: Returns:
The shape for scale tensors The shape for scale tensors
""" """
...@@ -127,9 +152,29 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -127,9 +152,29 @@ class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl):
Returns: Returns:
The shape for scale tensors - (1,) The shape for scale tensors - (1,)
""" """
del data_shape, is_colwise del is_colwise
if np.prod(data_shape) == 0:
return (0,)
return (1,) return (1,)
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for scale tensors in this mode.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
del data_shape, group_axis, is_colwise
assert isinstance(n_groups, int)
return (n_groups,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
...@@ -276,6 +321,77 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): ...@@ -276,6 +321,77 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl):
return (*first_dim_scale_shape, *last_dim_scale_shape) return (*first_dim_scale_shape, *last_dim_scale_shape)
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[int]:
"""Get the shape for grouped scale tensors in this mode.
If padded: The estimiated maximal possible shape for grouped scale tensor is return instead.
Args:
data_shape: Original shape of the data tensor
is_colwise: Whether to use column-wise scaling
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
The shape for scale tensors
"""
assert isinstance(n_groups, int)
block_alignment = self._block_alignment if is_padded else (1, 1)
if is_colwise:
block_y, block_x = self._block_dims
alignment_y, alignment_x = block_alignment
else:
block_x, block_y = self._block_dims
alignment_x, alignment_y = block_alignment
if flatten_axis < 0:
flatten_axis = len(data_shape) + flatten_axis
assert (
0 < flatten_axis < len(data_shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}"
assert data_shape[flatten_axis - 1] % block_x == 0, (
f"Data shape {data_shape} should be divisible by block_x {block_x} in axis"
f" {flatten_axis - 1}"
)
assert (
data_shape[-1] % block_y == 0
), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1"
flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1)
flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1)
assert flattened_first_dim % block_x == 0, (
f"Flattened first dim - mutiplication of axes={tuple(range(0, flatten_axis))} of shape"
f" {data_shape} - should be divisible by block_x {block_x}"
)
assert flattened_last_dim % block_y == 0, (
"Flattened last dim - mutiplication of"
f" axes={tuple(range(flatten_axis, len(data_shape)))} of shape {data_shape} - should be"
f" divisible by block_y {block_y}"
)
n_block_x = int(flattened_first_dim // block_x)
n_block_y = int(flattened_last_dim // block_y)
"""
Given the scale shape of [M, N], and G groups, and padding alignment (128, 4),
The worst scenario is when we have (G-1) groups with 1 rows and 1 group with (M-G+1) rows.
Then:
max_padded_rows = (G-1) * 128 + DIVUP(M-G+1, 128) * 128
max_padded_cols = DIVUP(N, 4) * 4
max_scale_size = max_padded_rows * max_padded_cols
"""
if is_padded:
n_block_x = (n_groups - 1) * alignment_x + DIVUP(
n_block_x - n_groups + 1, alignment_x
) * alignment_x
n_block_y = DIVUP(n_block_y, alignment_y) * alignment_y
return (n_block_x * n_block_y,)
def get_shardy_sharding_rules( def get_shardy_sharding_rules(
self, input_rank, unique_var, flatten_axis self, input_rank, unique_var, flatten_axis
) -> QuantizeShardyRules: ) -> QuantizeShardyRules:
...@@ -404,6 +520,61 @@ class ScalingMode(Enum): ...@@ -404,6 +520,61 @@ class ScalingMode(Enum):
""" """
return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis)
def get_grouped_scale_shape_2x(
self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1
) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
n_groups: Number of groups for grouped quantization
group_axis: The axis along which grouping is performed
is_padded: Whether to use padded shapes
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
rowwise_scale_shape = self.get_grouped_scale_shape(
data_shape,
n_groups,
group_axis,
is_colwise=False,
is_padded=is_padded,
flatten_axis=flatten_axis,
)
colwise_scale_shape = self.get_grouped_scale_shape(
data_shape,
n_groups,
group_axis,
is_colwise=True,
is_padded=is_padded,
flatten_axis=flatten_axis,
)
return (rowwise_scale_shape, colwise_scale_shape)
def get_grouped_scale_shape(
self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1
) -> Tuple[Tuple[int]]:
"""Get shapes for both row-wise and column-wise scaling.
Args:
data_shape: Shape of the data tensor
is_padded: Whether to use padded shapes
flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1.
Returns:
Tuple of (rowwise_scale_shape, colwise_scale_shape)
"""
return self._get_impl().get_grouped_scale_shape(
data_shape,
n_groups,
group_axis,
is_colwise=is_colwise,
is_padded=is_padded,
flatten_axis=flatten_axis,
)
def is_tensor_scaling(self) -> bool: def is_tensor_scaling(self) -> bool:
"""Check if this scaling mode is per-tensor scaling. """Check if this scaling mode is per-tensor scaling.
......
...@@ -18,7 +18,7 @@ from jax.tree_util import register_pytree_node_class ...@@ -18,7 +18,7 @@ from jax.tree_util import register_pytree_node_class
from transformer_engine_jax import QuantizeLayout from transformer_engine_jax import QuantizeLayout
from .scaling_modes import ScalingMode from .scaling_modes import ScalingMode
from .dequantizer import Dequantizer from .dequantizer import ScalingModeToDequantizerMap
from ..sharding import ( from ..sharding import (
with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes,
) )
...@@ -27,6 +27,7 @@ __all__ = [ ...@@ -27,6 +27,7 @@ __all__ = [
"ScaledTensor", "ScaledTensor",
"ScaledTensor1x", "ScaledTensor1x",
"ScaledTensor2x", "ScaledTensor2x",
"GroupedScaledTensor1x",
"ScaledTensorFactory", "ScaledTensorFactory",
"with_sharding_constraint_by_logical_axes", "with_sharding_constraint_by_logical_axes",
] ]
...@@ -122,7 +123,7 @@ class ScaledTensor1x(ScaledTensor): ...@@ -122,7 +123,7 @@ class ScaledTensor1x(ScaledTensor):
_dq_func: Callable _dq_func: Callable
is_colwise: bool is_colwise: bool
data_layout: str data_layout: str
flatten_axis: int = -1 flatten_axis: int
def __post_init__(self): def __post_init__(self):
"""Validates and adjusts the scale_inv shape after initialization. """Validates and adjusts the scale_inv shape after initialization.
...@@ -130,22 +131,16 @@ class ScaledTensor1x(ScaledTensor): ...@@ -130,22 +131,16 @@ class ScaledTensor1x(ScaledTensor):
Ensures the scale_inv shape matches the expected shape based on the scaling mode Ensures the scale_inv shape matches the expected shape based on the scaling mode
and quantization direction. Pads the scale_inv if necessary. and quantization direction. Pads the scale_inv if necessary.
""" """
flatten_axis = ( assert self.flatten_axis > 0
len(self.data.shape) + self.flatten_axis if self.flatten_axis < 0 else self.flatten_axis
)
assert ( assert (
0 < flatten_axis < len(self.data.shape) 0 < self.flatten_axis < len(self.data.shape)
), f"flatten_axis {flatten_axis} is out of bounds for shape {self.data.shape}" ), f"flatten_axis {self.flatten_axis} is out of bounds for shape {self.data.shape}"
if self.data_layout == "T":
flatten_axis = self.data.ndim - flatten_axis
self.flatten_axis = flatten_axis
expected_scale_shape = self.scaling_mode.get_scale_shape( expected_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=True, flatten_axis=flatten_axis self.data.shape, self.is_colwise, is_padded=True, flatten_axis=self.flatten_axis
) )
expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape( expected_unpadded_scale_shape = self.scaling_mode.get_scale_shape(
self.data.shape, self.is_colwise, is_padded=False, flatten_axis=flatten_axis self.data.shape, self.is_colwise, is_padded=False, flatten_axis=self.flatten_axis
) )
if self.scale_inv.shape != expected_scale_shape: if self.scale_inv.shape != expected_scale_shape:
assert self.scale_inv.shape == expected_unpadded_scale_shape, ( assert self.scale_inv.shape == expected_unpadded_scale_shape, (
...@@ -229,8 +224,12 @@ class ScaledTensor1x(ScaledTensor): ...@@ -229,8 +224,12 @@ class ScaledTensor1x(ScaledTensor):
# axis_names were given for N layout, so needs to be transpose for T layout # axis_names were given for N layout, so needs to be transpose for T layout
if self.data_layout == "T": if self.data_layout == "T":
assert self.flatten_axis > 0 assert self.flatten_axis > 0
flatten_axis = -self.flatten_axis assert len(logical_axis_names) == self.data.ndim
axis_names = (*logical_axis_names[flatten_axis:], *logical_axis_names[:flatten_axis]) flatten_axis = self.data.ndim - self.flatten_axis
axis_names = (
*logical_axis_names[flatten_axis:],
*logical_axis_names[:flatten_axis],
)
else: else:
axis_names = logical_axis_names axis_names = logical_axis_names
...@@ -254,6 +253,98 @@ class ScaledTensor1x(ScaledTensor): ...@@ -254,6 +253,98 @@ class ScaledTensor1x(ScaledTensor):
) )
@register_pytree_node_class
@dataclass
class GroupedScaledTensor1x(ScaledTensor1x):
"""Grouped Quantizer for an array.
This class extends ScaledTensor1x to support quantization of an array in grouped manner,
where elements are grouped along a specified axis.
Attributes:
group_sizes: Array containing the size of each group
original_shape: The original shape of the tensor before grouping
group_axis: The axis along which grouping is performed (default: 0)
"""
group_sizes: jnp.ndarray
original_shape: Tuple
group_axis: int
def __init__(
self,
data,
scale_inv,
group_sizes,
scaling_mode,
dq_dtype,
_dq_func,
is_colwise,
data_layout,
flatten_axis,
original_shape,
group_axis=0,
):
self.flatten_axis = flatten_axis
self.group_sizes = group_sizes
self.original_shape = original_shape
self.group_axis = group_axis
super().__init__(
data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis
)
def __post_init__(self):
assert self.scale_inv.ndim == 1, "Only support flattened scale_inv"
assert self.data.ndim == 1, "Only support flattened data"
assert self.group_axis >= 0
assert self.flatten_axis > 0
data_ndim = len(self.original_shape)
assert (
0 < self.flatten_axis < data_ndim
), f"flatten_axis {self.flatten_axis} is out of bounds for data.ndim = {data_ndim}"
assert (
0 <= self.group_axis < data_ndim
), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}"
expected_scale_shape = self.scaling_mode.get_grouped_scale_shape(
self.original_shape,
self.group_sizes.size,
self.group_axis,
self.is_colwise,
is_padded=True,
flatten_axis=self.flatten_axis,
)
assert self.scale_inv.shape == expected_scale_shape, (
f"Unexpected scale_inv shape! \nExpect {expected_scale_shape} for padded"
f" scale_inv, got {self.scale_inv.shape}"
)
def tree_flatten(self):
"""Flattens the tensor for JAX tree operations.
Returns:
A tuple containing (children, aux_data) for tree operations
"""
children = (self.data, self.scale_inv, self.group_sizes)
aux_data = (
self.scaling_mode,
self.dq_dtype,
self._dq_func,
self.is_colwise,
self.data_layout,
self.flatten_axis,
self.original_shape,
self.group_axis,
)
return (children, aux_data)
def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]):
raise NotImplementedError
@register_pytree_node_class @register_pytree_node_class
@dataclass @dataclass
class ScaledTensor2x(ScaledTensor): class ScaledTensor2x(ScaledTensor):
...@@ -342,6 +433,9 @@ class ScaledTensorFactory: ...@@ -342,6 +433,9 @@ class ScaledTensorFactory:
is_colwise=False, is_colwise=False,
data_layout="N", data_layout="N",
flatten_axis=-1, flatten_axis=-1,
group_sizes=None,
original_shape=None,
group_axis=0,
): ):
"""Creates a single-scale quantized tensor. """Creates a single-scale quantized tensor.
...@@ -353,13 +447,67 @@ class ScaledTensorFactory: ...@@ -353,13 +447,67 @@ class ScaledTensorFactory:
is_colwise: Whether to use column-wise quantization (default: False) is_colwise: Whether to use column-wise quantization (default: False)
data_layout: The data_layout specification (default: "N") data_layout: The data_layout specification (default: "N")
flatten_axis: The quantization axis for the tensor flatten_axis: The quantization axis for the tensor
group_sizes: Arra of ints containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns: Returns:
A ScaledTensor1x instance A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided
""" """
dq_func = Dequantizer.funcs.get(scaling_mode) dequantizer = ScalingModeToDequantizerMap.get(scaling_mode)
if group_sizes is not None:
flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis
assert (
original_shape is not None
), "original_shape is not given for GroupedScaledTensor1x"
# Handling attrs of transposed tensors
group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis
if data_layout == "T":
if original_shape[0] == group_sizes.size:
original_shape = (
original_shape[0],
*original_shape[flatten_axis:],
*original_shape[1:flatten_axis],
)
flatten_axis = len(original_shape) - flatten_axis + 1
else:
original_shape = (
*original_shape[flatten_axis:],
*original_shape[:flatten_axis],
)
group_axis = flatten_axis
flatten_axis = len(original_shape) - flatten_axis
return GroupedScaledTensor1x(
data=data,
scale_inv=scale_inv,
scaling_mode=scaling_mode,
dq_dtype=dq_dtype,
_dq_func=dequantizer.grouped_dequantize,
is_colwise=is_colwise,
data_layout=data_layout,
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
# Handling attrs of transposed tensors
flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis
if data_layout == "T":
flatten_axis = data.ndim - flatten_axis
return ScaledTensor1x( return ScaledTensor1x(
data, scale_inv, scaling_mode, dq_dtype, dq_func, is_colwise, data_layout, flatten_axis data,
scale_inv,
scaling_mode,
dq_dtype,
dequantizer.dequantize,
is_colwise,
data_layout,
flatten_axis,
) )
@staticmethod @staticmethod
...@@ -372,6 +520,9 @@ class ScaledTensorFactory: ...@@ -372,6 +520,9 @@ class ScaledTensorFactory:
dq_dtype=jnp.bfloat16, dq_dtype=jnp.bfloat16,
data_layout="NN", data_layout="NN",
flatten_axis=-1, flatten_axis=-1,
group_sizes=None,
original_shape=None,
group_axis=0,
): ):
"""Creates a double-scale quantized tensor. """Creates a double-scale quantized tensor.
...@@ -384,30 +535,37 @@ class ScaledTensorFactory: ...@@ -384,30 +535,37 @@ class ScaledTensorFactory:
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN") data_layout: The data_layout specification (default: "NN")
flatten_axis: The quantization axis for the tensor flatten_axis: The quantization axis for the tensor
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns: Returns:
A ScaledTensor2x instance A ScaledTensor2x instance
""" """
dq_func = Dequantizer.funcs.get(scaling_mode) assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}"
rowwise_tensor = ScaledTensor1x( rowwise_tensor = ScaledTensorFactory.create_1x(
data, data,
scale_inv, scale_inv,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
dq_func,
is_colwise=False, is_colwise=False,
data_layout=data_layout[0], data_layout=data_layout[0],
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
) )
colwise_tensor = ScaledTensor1x( colwise_tensor = ScaledTensorFactory.create_1x(
colwise_data, colwise_data,
colwise_scale_inv, colwise_scale_inv,
scaling_mode, scaling_mode,
dq_dtype, dq_dtype,
dq_func,
is_colwise=True, is_colwise=True,
data_layout=data_layout[1], data_layout=data_layout[1],
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
) )
return ScaledTensor2x(rowwise_tensor, colwise_tensor) return ScaledTensor2x(rowwise_tensor, colwise_tensor)
...@@ -422,6 +580,9 @@ class ScaledTensorFactory: ...@@ -422,6 +580,9 @@ class ScaledTensorFactory:
data_layout: str = "NN", data_layout: str = "NN",
q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, q_layout: QuantizeLayout = QuantizeLayout.ROWWISE,
flatten_axis: int = -1, flatten_axis: int = -1,
group_sizes: jnp.ndarray = None,
original_shape: Tuple[int] = None,
group_axis: int = 0,
): ):
"""Creates a scaled tensor based on the quantization axis. """Creates a scaled tensor based on the quantization axis.
...@@ -434,6 +595,10 @@ class ScaledTensorFactory: ...@@ -434,6 +595,10 @@ class ScaledTensorFactory:
dq_dtype: The data type for dequantized values (default: bfloat16) dq_dtype: The data type for dequantized values (default: bfloat16)
data_layout: The data_layout specification (default: "NN") data_layout: The data_layout specification (default: "NN")
q_layout: The quantization axis (default: ROWWISE) q_layout: The quantization axis (default: ROWWISE)
flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1)
group_sizes: Array containing the size of each group (default: None)
original_shape: The original shape of the tensor before grouping (default: None)
group_axis: The axis along which grouping is performed (default: 0)
Returns: Returns:
Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout
...@@ -448,9 +613,26 @@ class ScaledTensorFactory: ...@@ -448,9 +613,26 @@ class ScaledTensorFactory:
dq_dtype, dq_dtype,
data_layout=data_layout, data_layout=data_layout,
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
) )
is_colwise = q_layout == QuantizeLayout.COLWISE is_colwise = q_layout == QuantizeLayout.COLWISE
if is_colwise:
return ScaledTensorFactory.create_1x(
colwise_data,
colwise_scale_inv,
scaling_mode,
dq_dtype,
is_colwise=is_colwise,
data_layout=data_layout[0],
flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
)
return ScaledTensorFactory.create_1x( return ScaledTensorFactory.create_1x(
data, data,
scale_inv, scale_inv,
...@@ -459,6 +641,9 @@ class ScaledTensorFactory: ...@@ -459,6 +641,9 @@ class ScaledTensorFactory:
is_colwise=is_colwise, is_colwise=is_colwise,
data_layout=data_layout[0], data_layout=data_layout[0],
flatten_axis=flatten_axis, flatten_axis=flatten_axis,
group_sizes=group_sizes,
original_shape=original_shape,
group_axis=group_axis,
) )
...@@ -472,6 +657,9 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . ...@@ -472,6 +657,9 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, .
Returns: Returns:
The tensor with applied sharding constraints The tensor with applied sharding constraints
""" """
if isinstance(x, GroupedScaledTensor1x):
raise NotImplementedError
if isinstance(x, ScaledTensor): if isinstance(x, ScaledTensor):
return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) return x.apply_sharding_constraint_by_logical_axes(logical_axis_names)
......
...@@ -44,11 +44,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_ ...@@ -44,11 +44,10 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers, install_and_import from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension from build_tools.jax import setup_jax_extension, install_requirements, test_requirements
install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension from pybind11.setup_helpers import build_ext as BuildExtension
os.environ["NVTE_PROJECT_BUILDING"] = "1" os.environ["NVTE_PROJECT_BUILDING"] = "1"
...@@ -101,19 +100,8 @@ if __name__ == "__main__": ...@@ -101,19 +100,8 @@ if __name__ == "__main__":
description="Transformer acceleration library - Jax Lib", description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension}, cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=[ install_requires=install_requirements(),
"jax[cuda12]", tests_require=test_requirements(),
"flax>=0.7.1",
"nvidia-cuda-runtime-cu12",
"nvidia-cublas-cu12",
"nvidia-cudnn-cu12",
"nvidia-cuda-cccl-cu12",
"nvidia-cuda-nvcc-cu12",
"nvidia-nvtx-cu12",
"nvidia-cuda-nvrtc-cu12",
],
install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy"],
) )
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir) shutil.rmtree(common_headers_dir)
......
...@@ -18,6 +18,7 @@ from jax.interpreters import pxla ...@@ -18,6 +18,7 @@ from jax.interpreters import pxla
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from jax.sharding import PartitionSpec from jax.sharding import PartitionSpec
import numpy as np
_PXLA_THREAD_RESOURCES = pxla.thread_resources _PXLA_THREAD_RESOURCES = pxla.thread_resources
...@@ -201,6 +202,31 @@ def get_mesh_axis_rank(axis: str, mesh=None): ...@@ -201,6 +202,31 @@ def get_mesh_axis_rank(axis: str, mesh=None):
return jax.lax.axis_index(axis_name) return jax.lax.axis_index(axis_name)
def get_mesh_axis_rank_host(axis, mesh) -> int:
"""
Same as get_mesh_axis_rank(), but return a host value instead of a
traced device value.
"""
if axis not in mesh.axis_names:
raise ValueError(f"Axis {axis} not found in mesh axis names: {mesh.axis_names}")
axis_index = mesh.axis_names.index(axis)
# Convert mesh.devices (ndarray of Device objects) to flat list
devices = mesh.devices
local_device = jax.devices()[jax.process_index()] # Pick one device on this host
# Find index of local_device in mesh.devices
coords = np.argwhere(devices == local_device)
if coords.size == 0:
raise ValueError(f"Local device {local_device} not found in mesh.devices.")
coords = tuple(coords[0]) # Coordinates in the mesh array
# Get the mesh rank along the specified axis
rank = coords[axis_index]
return int(rank)
@dataclass @dataclass
class MeshResource: class MeshResource:
"""A data container for managing mesh resources in distributed training. """A data container for managing mesh resources in distributed training.
......
...@@ -217,7 +217,12 @@ class UnfusedDotProductAttention(torch.nn.Module): ...@@ -217,7 +217,12 @@ class UnfusedDotProductAttention(torch.nn.Module):
if "padding" in attn_mask_type and attention_mask is None: if "padding" in attn_mask_type and attention_mask is None:
attention_mask = dpa_utils.get_padding_mask( attention_mask = dpa_utils.get_padding_mask(
batch_size, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv batch_size,
cu_seqlens_q,
cu_seqlens_kv,
max_seqlen_q,
max_seqlen_kv,
self.attention_type,
) )
attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = ( attn_mask_type, attention_mask, actual_seqlens_q, actual_seqlens_kv = (
dpa_utils.get_full_mask( dpa_utils.get_full_mask(
......
...@@ -461,6 +461,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -461,6 +461,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
): ):
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
enable_mla = k.shape[-1] != v.shape[-1]
if softmax_scale is None: if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5) softmax_scale = q.shape[-1] ** (-0.5)
...@@ -498,6 +499,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -498,6 +499,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
cu_seqlens_q_half, cu_seqlens_kv_half = None, None cu_seqlens_q_half, cu_seqlens_kv_half = None, None
if qkv_format in ["bshd", "sbhd"]: if qkv_format in ["bshd", "sbhd"]:
seq_dim = qkv_format.index("s") seq_dim = qkv_format.index("s")
if enable_mla:
qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format
else:
qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:]
cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None
if use_fused_attention: if use_fused_attention:
...@@ -676,9 +680,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -676,9 +680,16 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fwd_results_correction_done = torch.cuda.Event() fwd_results_correction_done = torch.cuda.Event()
p2p_comm_buffers = [None for _ in range(cp_size)] p2p_comm_buffers = [None for _ in range(cp_size)]
if qkv_format in ["bshd", "sbhd"]: if enable_mla:
# If MLA, the shape of k and v does not match, so we flatten them
# and split them after receiving them.
k_shape = k.shape
k_numel = k.numel()
v_shape = v.shape
p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1)
elif qkv_format in ["bshd", "sbhd"]:
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3)
else: else: # qkv_format == "thd"
p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0)
send_recv_reqs = [[], []] send_recv_reqs = [[], []]
...@@ -707,6 +718,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -707,6 +718,10 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
# KV exchange is in BF16/FP16, cast received KV in each step # KV exchange is in BF16/FP16, cast received KV in each step
kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data
if enable_mla:
# If MLA, k and v are flattened, so split them after receiving.
k_part = kv_inputs[i % 2][:k_numel].view(*k_shape)
v_part = kv_inputs[i % 2][k_numel:].view(*v_shape)
if causal: if causal:
if i == 0: if i == 0:
if pad_between_seqs: if pad_between_seqs:
...@@ -725,6 +740,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -725,6 +740,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:] k.shape[0], -1, 2, *k.shape[-2:]
...@@ -732,6 +752,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -732,6 +752,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[2:])
v_part = v_part.view(-1, *v_part.shape[2:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:] -1, k.shape[2], 2, *k.shape[-2:]
...@@ -750,6 +775,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -750,6 +775,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous() ).contiguous()
q_part = q_inputs[i % 2] q_part = q_inputs[i % 2]
if not enable_mla:
# If MHA, then split the KV into k_part and v_part.
# Otherwise (MHA), k_part and v_part have already been split.
k_part = ( k_part = (
kv_inputs[i % 2][..., 0, :, :] kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
...@@ -810,6 +838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -810,6 +838,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
) )
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( (
...@@ -858,26 +887,50 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -858,26 +887,50 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:])
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn]
k_part = k_part[:, 0, ...]
v_part = v_part[:, 0, ...]
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...]
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) q_inputs[i % 2] = q.view(-1, *q.shape[-3:])
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk//2, b, np, hn]
k_part = k_part[0]
v_part = v_part[0]
else:
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2][0] kv_inputs[i % 2] = kv_inputs[i % 2][0]
elif qkv_format == "thd": elif qkv_format == "thd":
q_inputs[i % 2] = q q_inputs[i % 2] = q
if enable_mla:
# [t, np, hn] -> [t/2, np, hn]
k_part = tex.thd_read_half_tensor(
k_part, cu_seqlens_kv_padded, 0
)
v_part = tex.thd_read_half_tensor(
v_part, cu_seqlens_kv_padded, 0
)
else:
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_inputs[i % 2] = tex.thd_read_half_tensor( kv_inputs[i % 2] = tex.thd_read_half_tensor(
kv_inputs[i % 2], cu_seqlens_kv_padded, 0 kv_inputs[i % 2], cu_seqlens_kv_padded, 0
) )
if use_fused_attention: if use_fused_attention:
if enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
else:
kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() kv_inputs[i % 2] = kv_inputs[i % 2].contiguous()
if attn_bias is not None: if attn_bias is not None:
idx = (rank - i) % cp_size idx = (rank - i) % cp_size
attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous()
q_part = q_inputs[i % 2] q_part = q_inputs[i % 2]
if not enable_mla:
k_part = ( k_part = (
kv_inputs[i % 2][..., 0, :, :] kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
...@@ -948,6 +1001,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -948,6 +1001,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1 fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( (
...@@ -996,6 +1050,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -996,6 +1050,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if qkv_format == "bshd": if qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_inputs[i % 2] = q[:, 1, ...] q_inputs[i % 2] = q[:, 1, ...]
if enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
k.shape[0], -1, 2, *k.shape[-2:] k.shape[0], -1, 2, *k.shape[-2:]
...@@ -1003,6 +1062,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1003,6 +1062,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif qkv_format == "sbhd": elif qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_inputs[i % 2] = q[1] q_inputs[i % 2] = q[1]
if enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[2:])
v_part = v_part.view(-1, *v_part.shape[2:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_inputs[i % 2] = kv_inputs[i % 2].view( kv_inputs[i % 2] = kv_inputs[i % 2].view(
-1, k.shape[2], 2, *k.shape[-2:] -1, k.shape[2], 2, *k.shape[-2:]
...@@ -1025,6 +1089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1025,6 +1089,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous() ).contiguous()
q_part = q_inputs[i % 2] q_part = q_inputs[i % 2]
if not enable_mla:
k_part = ( k_part = (
kv_inputs[i % 2][..., 0, :, :] kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
...@@ -1095,6 +1160,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1095,6 +1160,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif fa_utils.v2_7_0_plus: elif fa_utils.v2_7_0_plus:
fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_left"] = -1
fa_forward_kwargs["window_size_right"] = -1 fa_forward_kwargs["window_size_right"] = -1
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q_inputs[i % 2], q_inputs[i % 2],
( (
...@@ -1152,6 +1218,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1152,6 +1218,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
).contiguous() ).contiguous()
q_part = q q_part = q
if not enable_mla:
k_part = ( k_part = (
kv_inputs[i % 2][..., 0, :, :] kv_inputs[i % 2][..., 0, :, :]
if qkv_format in ["bshd", "sbhd"] if qkv_format in ["bshd", "sbhd"]
...@@ -1211,6 +1278,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1211,6 +1278,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
) )
# Need to add MLA support once Flash Attention supports MLA
fa_outputs = flash_attn_fwd( fa_outputs = flash_attn_fwd(
q, q,
( (
...@@ -1257,7 +1325,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1257,7 +1325,15 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if i == 1: if i == 1:
softmax_lse = torch.clone(softmax_lse_per_step[0]) softmax_lse = torch.clone(softmax_lse_per_step[0])
if qkv_format == "thd": if qkv_format == "thd":
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(q.shape) if enable_mla:
out = torch.zeros_like(v if not fp8 else out_per_step[0]).view(
v_shape
)
else:
# MHA or GQA
out = torch.zeros_like(q if not fp8 else out_per_step[0]).view(
q.shape
)
elif (i - 1) <= rank or not causal: elif (i - 1) <= rank or not causal:
flash_attn_fwd_softmax_lse_correction( flash_attn_fwd_softmax_lse_correction(
softmax_lse, softmax_lse_per_step[i - 1] softmax_lse, softmax_lse_per_step[i - 1]
...@@ -1295,6 +1371,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1295,6 +1371,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
softmax_lse_per_step[0], softmax_lse_per_step[0],
seq_dim, seq_dim,
) )
if enable_mla:
out = out.view(v_shape)
else:
out = out.view(q.shape) out = out.view(q.shape)
else: else:
flash_attn_fwd_out_correction( flash_attn_fwd_out_correction(
...@@ -1417,6 +1496,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1417,6 +1496,12 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
ctx.is_output_fp8 = is_output_fp8 ctx.is_output_fp8 = is_output_fp8
ctx.use_flash_attn_3 = use_flash_attn_3 ctx.use_flash_attn_3 = use_flash_attn_3
ctx.enable_mla = enable_mla
if enable_mla:
ctx.k_numel = k_numel
ctx.k_shape = k_shape
ctx.v_shape = v_shape
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
ctx.dQKV_quantizer = dQKV_quantizer ctx.dQKV_quantizer = dQKV_quantizer
ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dQKV_CP_quantizer = dQKV_CP_quantizer
...@@ -1466,6 +1551,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1466,6 +1551,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
seq_dim = None seq_dim = None
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
seq_dim = ctx.qkv_format.index("s") seq_dim = ctx.qkv_format.index("s")
if ctx.enable_mla:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:]
else: else:
qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format
...@@ -1595,6 +1683,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1595,6 +1683,11 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
) )
dout = dout.dequantize(dtype=dout_dtype) dout = dout.dequantize(dtype=dout_dtype)
if ctx.enable_mla:
out = out.view(*ctx.v_shape)
dout = dout.view(*ctx.v_shape)
else:
# MHA or GQA
out = out.view(*q.shape) out = out.view(*q.shape)
dout = dout.view(*q.shape) dout = dout.view(*q.shape)
send_recv_reqs = [] send_recv_reqs = []
...@@ -1672,6 +1765,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1672,6 +1765,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
kv = p2p_comm_buffers[i % 2][0] kv = p2p_comm_buffers[i % 2][0]
q_, kv_, out_, dout_ = None, None, None, None q_, kv_, out_, dout_ = None, None, None, None
dq_, dk_, dv_ = None, None, None dq_, dk_, dv_ = None, None, None
if ctx.enable_mla:
k_part = kv[: ctx.k_numel].view(*ctx.k_shape)
v_part = kv[ctx.k_numel :].view(*ctx.v_shape)
# In reversed order of fwd # In reversed order of fwd
if causal: if causal:
if i == (cp_size - 1): if i == (cp_size - 1):
...@@ -1680,11 +1776,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1680,11 +1776,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_, out_, dout_ = [ q_, out_, dout_ = [
x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
] ]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[-3:])
v_part = v_part.view(-1, *v_part.shape[-3:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:]) kv_ = kv.view(-1, *kv.shape[-4:])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
...@@ -1701,8 +1807,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1701,8 +1807,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_ q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] if not ctx.enable_mla:
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_ out_part = out_
dout_part = dout_ dout_part = dout_
...@@ -1784,6 +1895,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1784,6 +1895,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = 0 fa_backward_kwargs["window_size_right"] = 0
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
...@@ -1801,18 +1913,37 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1801,18 +1913,37 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
q_, out_, dout_ = [ q_, out_, dout_ = [
x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout]
] ]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part[:, 0]
v_part = v_part[:, 0]
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn]
kv_ = kv[:, 0] kv_ = kv[:, 0]
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part[0]
v_part = v_part[0]
else:
# [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn]
kv_ = kv[0] kv_ = kv[0]
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
q_, out_, dout_ = q, out, dout q_, out_, dout_ = q, out, dout
if ctx.enable_mla:
# [t, np, hn] -> [t/2, np, hn]
k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0)
v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0)
else:
# [2, t, np, hn] -> [2, t/2, np, hn] # [2, t, np, hn] -> [2, t/2, np, hn]
kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0)
if ctx.use_fused_attention: if ctx.use_fused_attention:
if ctx.enable_mla:
k_part = k_part.contiguous()
v_part = v_part.contiguous()
else:
kv_ = kv_.contiguous() kv_ = kv_.contiguous()
if ctx.fp8: if ctx.fp8:
aux_ctx_tensors = [ aux_ctx_tensors = [
...@@ -1825,8 +1956,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1825,8 +1956,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_ q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] if not ctx.enable_mla:
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_ out_part = out_
dout_part = dout_ dout_part = dout_
...@@ -1910,6 +2046,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1910,6 +2046,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
...@@ -1925,11 +2062,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1925,11 +2062,21 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn]
q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1]
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:])
v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:])
else:
# [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn]
kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn]
q_, out_, dout_ = q[1], out[1], dout[1] q_, out_, dout_ = q[1], out[1], dout[1]
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
k_part = k_part.view(-1, *k_part.shape[-3:])
v_part = v_part.view(-1, *v_part.shape[-3:])
else:
# [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn]
kv_ = kv.view(-1, *kv.shape[-4:]) kv_ = kv.view(-1, *kv.shape[-4:])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
...@@ -1953,8 +2100,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -1953,8 +2100,13 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q_ q_part = q_
k_part = kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] if not ctx.enable_mla:
v_part = kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] k_part = (
kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0]
)
v_part = (
kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1]
)
out_part = out_ out_part = out_
dout_part = dout_ dout_part = dout_
...@@ -2038,6 +2190,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2038,6 +2190,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout_, dout_,
q_, q_,
...@@ -2058,6 +2211,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2058,6 +2211,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if attn_dbias is not None: if attn_dbias is not None:
aux_ctx_tensors += [attn_biases[cp_size - i - 1]] aux_ctx_tensors += [attn_biases[cp_size - i - 1]]
q_part = q q_part = q
if not ctx.enable_mla:
k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0]
v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1]
out_part = out out_part = out
...@@ -2133,6 +2287,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2133,6 +2287,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
fa_backward_kwargs["window_size_right"] = -1 fa_backward_kwargs["window_size_right"] = -1
if not ctx.use_flash_attn_3: if not ctx.use_flash_attn_3:
fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1]
# Need to add MLA support once Flash Attention supports MLA
flash_attn_bwd( flash_attn_bwd(
dout, dout,
q, q,
...@@ -2225,15 +2380,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2225,15 +2380,18 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dkv = p2p_comm_buffers[(i + 1) % 2][1] dkv = p2p_comm_buffers[(i + 1) % 2][1]
if ctx.use_fused_attention: if ctx.use_fused_attention:
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.enable_mla:
dkv_ = None
elif ctx.qkv_format in ["bshd", "sbhd"]:
dkv_ = combine_tensors([dk_, dv_], -2) dkv_ = combine_tensors([dk_, dv_], -2)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
dkv_ = torch.cat( dkv_ = torch.cat(
(dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0
) # pylint: disable=used-before-assignment ) # pylint: disable=used-before-assignment
if ctx.qkv_format in ["bshd", "sbhd"]: if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]:
# [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or
# [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn]
# dkv is a buffer, so we do not need to transpose it, but only need to reshape it.
dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:])
dkv_ = dkv_.movedim(-3, 0) dkv_ = dkv_.movedim(-3, 0)
if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
...@@ -2241,7 +2399,101 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2241,7 +2399,101 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
# [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn]
dkv_ = dkv_.view(*dkv.shape) dkv_ = dkv_.view(*dkv.shape)
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] or
# [2, sk//2, b, np, hn]
dk = dkv[: ctx.k_numel].view(*ctx.k_shape)
dv = dkv[ctx.k_numel :].view(*ctx.v_shape)
if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)):
dk_ = dk_.view(*ctx.k_shape)
dv_ = dv_.view(*ctx.v_shape)
if ctx.fp8:
# enable_mla and fp8
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd":
dk[:, 0, ...].copy_(dk_)
dk[:, 1, ...].fill_(0)
dv[:, 0, ...].copy_(dv_)
dv[:, 1, ...].fill_(0)
elif ctx.qkv_format == "sbhd":
dk[0].copy_(dk_)
dk[1].fill_(0)
dv[0].copy_(dv_)
dv[1].fill_(0)
else:
dk.copy_(dk_)
dv.copy_(dv_)
elif causal:
# enable_mla and not fp8 and causal
if i == (cp_size - 1):
if rank == 0:
if ctx.qkv_format == "bshd":
dk[:, 0, ...].add_(dk_[:, 0, ...])
dk[:, 1, ...].copy_(dk_[:, 1, ...])
dv[:, 0, ...].add_(dv_[:, 0, ...])
dv[:, 1, ...].copy_(dv_[:, 1, ...])
elif ctx.qkv_format == "sbhd":
dk[0, ...].add_(dk_[0, ...])
dk[1, ...].copy_(dk_[1, ...])
dv[0, ...].add_(dv_[0, ...])
dv[1, ...].copy_(dv_[1, ...])
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "add", "copy"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "add", "copy"
)
else:
dk.add_(dk_)
dv.add_(dv_)
elif i >= (cp_size - rank - 1):
if i == 0 and rank == (cp_size - 1):
if ctx.qkv_format == "bshd":
dk[:, 0, ...].copy_(dk_)
dv[:, 0, ...].copy_(dv_)
elif ctx.qkv_format == "sbhd":
dk[0, ...].copy_(dk_)
dv[0, ...].copy_(dv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "copy", "none"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "copy", "none"
)
else:
if ctx.qkv_format == "bshd":
dk[:, 0, ...].add_(dk_)
dv[:, 0, ...].add_(dv_)
elif ctx.qkv_format == "sbhd":
dk[0, ...].add_(dk_)
dv[0, ...].add_(dv_)
elif ctx.qkv_format == "thd":
tex.thd_grad_correction(
dk, dk_, cu_seqlens_kv_padded, "add", "none"
)
tex.thd_grad_correction(
dv, dv_, cu_seqlens_kv_padded, "add", "none"
)
elif i > 0:
dk.add_(dk_)
dv.add_(dv_)
else: # i == 0
dk.copy_(dk_)
dv.copy_(dv_)
else:
# enable_mla and not fp8 and not causal
if i == 0:
dk.copy_(dk_)
dv.copy_(dv_)
else: # i > 0
dk.add_(dk_)
dv.add_(dv_)
else:
if ctx.fp8: if ctx.fp8:
# fp8
if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].copy_(dkv_) dkv[:, :, 0, ...].copy_(dkv_)
...@@ -2252,6 +2504,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2252,6 +2504,7 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
else: else:
dkv.copy_(dkv_) dkv.copy_(dkv_)
elif causal: elif causal:
# not fp8 and causal
if i == (cp_size - 1): if i == (cp_size - 1):
if rank == 0: if rank == 0:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
...@@ -2261,7 +2514,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2261,7 +2514,9 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
dkv[:, 0, ...].add_(dkv_[:, 0, ...]) dkv[:, 0, ...].add_(dkv_[:, 0, ...])
dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) dkv[:, 1, ...].copy_(dkv_[:, 1, ...])
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "copy") tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "add", "copy"
)
else: else:
dkv.add_(dkv_) dkv.add_(dkv_)
elif i >= (cp_size - rank - 1): elif i >= (cp_size - rank - 1):
...@@ -2271,35 +2526,54 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2271,35 +2526,54 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].copy_(dkv_) dkv[:, 0, ...].copy_(dkv_)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "copy", "none") tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "copy", "none"
)
else: else:
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
dkv[:, :, 0, ...].add_(dkv_) dkv[:, :, 0, ...].add_(dkv_)
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
dkv[:, 0, ...].add_(dkv_) dkv[:, 0, ...].add_(dkv_)
elif ctx.qkv_format == "thd": elif ctx.qkv_format == "thd":
tex.thd_grad_correction(dkv, dkv_, cu_seqlens_kv_padded, "add", "none") tex.thd_grad_correction(
dkv, dkv_, cu_seqlens_kv_padded, "add", "none"
)
elif i > 0: elif i > 0:
dkv.add_(dkv_) dkv.add_(dkv_)
else: else: # i == 0
dkv.copy_(dkv_) dkv.copy_(dkv_)
else: else:
# not fp8 and not causal
if i == 0: if i == 0:
dkv.copy_(dkv_) dkv.copy_(dkv_)
else: else: # i > 0
dkv.add_(dkv_) dkv.add_(dkv_)
if ctx.fp8 and ctx.use_fused_attention: if ctx.fp8 and ctx.use_fused_attention:
amax_cp_bwd = amax_per_step.amax(dim=1) amax_cp_bwd = amax_per_step.amax(dim=1)
ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0])
ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1])
dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dq_fp8, fake_dtype=torch.float32, internal=True
)
if ctx.enable_mla:
# [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn]
dk_fp8 = dkv_fp8[: ctx.k_numel].view(cp_size, *ctx.k_shape)
dv_fp8 = dkv_fp8[ctx.k_numel :].view(cp_size, *ctx.v_shape)
dk = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dk_fp8, fake_dtype=torch.float32, internal=True
)
dv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dv_fp8, fake_dtype=torch.float32, internal=True
)
dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]]
dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]]
else:
if ctx.qkv_format in ["bshd", "sbhd"]: if ctx.qkv_format in ["bshd", "sbhd"]:
# [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or
# [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn]
dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:])
dq = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dq_fp8, fake_dtype=torch.float32, internal=True
)
dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data(
dkv_fp8, fake_dtype=torch.float32, internal=True dkv_fp8, fake_dtype=torch.float32, internal=True
) )
...@@ -2310,21 +2584,39 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function): ...@@ -2310,21 +2584,39 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
if ctx.qkv_format == "bshd": if ctx.qkv_format == "bshd":
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] # [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) dq = dq.view(dq.shape[0], -1, *dq.shape[-2:])
if ctx.enable_mla:
# [b, 2, sk//2, np, hn] -> [b, sk, np, hn]
dk = dk.view(*dk.shape[0], -1, *dk.shape[-2:])
dv = dv.view(*dv.shape[0], -1, *dv.shape[-2:])
else:
# [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn]
dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:])
elif ctx.qkv_format == "sbhd": elif ctx.qkv_format == "sbhd":
# [2, sq//2, b, np, hn] -> [sq, b, np, hn] # [2, sq//2, b, np, hn] -> [sq, b, np, hn]
dq = dq.view(-1, *dq.shape[-3:]) dq = dq.view(-1, *dq.shape[-3:])
if ctx.enable_mla:
# [2, sk//2, b, np, hn] -> [sk, b, np, hn]
dk = dk.view(-1, *dk.shape[-3:])
dv = dv.view(-1, *dv.shape[-3:])
else:
# [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn]
dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:])
if ctx.qkv_format == "thd" and not ctx.use_fused_attention: if ctx.qkv_format == "thd" and not ctx.use_fused_attention:
dq[cu_seqlens_q_padded[-1] :].fill_(0) dq[cu_seqlens_q_padded[-1] :].fill_(0)
if ctx.enable_mla:
dk[cu_seqlens_kv_padded[-1] :].fill_(0)
dv[cu_seqlens_kv_padded[-1] :].fill_(0)
else:
dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0)
if ctx.fp8 and ctx.is_input_fp8: if ctx.fp8 and ctx.is_input_fp8:
assert torch.uint8 not in [dq.dtype, dkv.dtype] assert torch.uint8 not in [dq.dtype, dkv.dtype]
if ctx.enable_mla:
dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]]
else:
dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]]
if not ctx.enable_mla:
dk, dv = dkv[0], dkv[1] dk, dv = dkv[0], dkv[1]
if cp_size_a2a > 1: if cp_size_a2a > 1:
...@@ -3484,7 +3776,64 @@ def attn_forward_func_with_cp( ...@@ -3484,7 +3776,64 @@ def attn_forward_func_with_cp(
use_flash_attn_3=False, use_flash_attn_3=False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Attention implementation with context parallelism. Attention implementation with context parallelism (CP). CP partitions tensors along the sequence
dimension, and by reducing the memory and computational pressure on each GPU, it enables long-context
LLMs in a distributed fashion. Transformer Engine's PyTorch CP implementation currently utilizes
the DualChunkSwap strategy to ensure load balancing across CP ranks. It is applied to all `attn_mask_type`s
and all `qkv_format`s, and it requires sequence lengths to be, or are padded to be, divisible by
(cp_size * 2). It also requires tokens to be re-ordered before entering this function.
For qkv_format = {'bshd', 'sbhd'}, the token re-ordering is illustrated as below, for an example
use case of s = 12, attn_mask_type = 'causal', and cp_size = 2. seq_pos indicates each token's position
in their corresponding sequence.
GPU0 | GPU1 GPU0 | GPU1
seq_pos | 0 1 2 3 4 5 | 6 7 8 9 10 11 seq_pos | 0 1 2 9 10 11 | 3 4 5 6 7 8
---------------------------|----------------- ---------------------------|------------------
0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
U 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 9 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 1, 1,
0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 10 | 1, 1, 1, 1, 1, 0,| 1, 1, 1, 1, 1, 1,
5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1,
---------------------------|----------------- ---------------------------|------------------
6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 3 | 1, 1, 1, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
G 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 4 | 1, 1, 1, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
P 8 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 0, 0, 0, P 5 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
U 9 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 0, 0, U 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
1 10 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 0, 1 7 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 0,
11 | 1, 1, 1, 1, 1, 1,| 1, 1, 1, 1, 1, 1, 8 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 1, 1,
For qkv_format = 'thd', multiple sequences may be packed into the batch, and they may be of different
lengths. DualChunkSwap divides each sequence into (cp_size * 2) chunks and distributes 2 chunks of
every sequence onto a CP rank. The token matrix transformation is shown as follows, for an example of
batch_size = 2, seq_ids = [0, 1], seq_lens = [8, 4], t = 12, attn_mask_type = 'padding_causal', and
cp_size = 2.
GPU0 | GPU1 GPU0 | GPU1
seq_id | 0 0 0 0 0 0 | 0 0 1 1 1 1 seq_id | 0 0 0 0 1 1 | 0 0 0 0 1 1
seq_pos | 0 1 2 3 4 5 | 6 7 0 1 2 3 seq_pos | 0 1 6 7 0 3 | 2 3 4 5 1 2
---------------------------|----------------- ---------------------------|------------------
0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 0 0 | 1, 0, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0 G 0 1 | 1, 1, 0, 0, 0, 0,| 0, 0, 0, 0, 0, 0,
P 0 2 | 1, 1, 1, 0, 0, 0,| 0, 0, 0, 0, 0, 0 P 0 6 | 1, 1, 1, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
U 0 3 | 1, 1, 1, 1, 0, 0,| 0, 0, 0, 0, 0, 0 U 0 7 | 1, 1, 1, 1, 0, 0,| 1, 1, 1, 1, 0, 0,
0 0 4 | 1, 1, 1, 1, 1, 0,| 0, 0, 0, 0, 0, 0 -> 0 1 0 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 0, 0,
0 5 | 1, 1, 1, 1, 1, 1,| 0, 0, 0, 0, 0, 0 1 3 | 0, 0, 0, 0, 2, 2,| 0, 0, 0, 0, 2, 2,
---------------------------|----------------- ---------------------------|------------------
0 6 | 1, 1, 1, 1, 1, 1,| 1, 0, 0, 0, 0, 0 0 2 | 1, 1, 0, 0, 0, 0,| 1, 0, 0, 0, 0, 0,
G 0 7 | 1, 1, 1, 1, 1, 1,| 1, 1, 0, 0, 0, 0 G 0 3 | 1, 1, 0, 0, 0, 0,| 1, 1, 0, 0, 0, 0,
P 1 0 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 0, 0, 0 P 0 4 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 0, 0, 0,
U 1 1 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 0, 0 U 0 5 | 1, 1, 0, 0, 0, 0,| 1, 1, 1, 1, 0, 0,
1 1 2 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 0 1 1 1 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 0,
1 3 | 0, 0, 0, 0, 0, 0,| 0, 0, 2, 2, 2, 2 1 2 | 0, 0, 0, 0, 2, 0,| 0, 0, 0, 0, 2, 2,
When all transformer layers in a model share the same CP configuration, i.e. cp_group, cp_global_ranks,
cp_comm_type and cp_stream, token re-ordering can take place in the dataloader, i.e. only once for
all the layers. An example of the re-ordering code is `get_batch_on_this_cp_rank
<https://github.com/NVIDIA/Megatron-LM/blob/d6eb60b5ea1efca47401c0be97f456fbe3a55bcd/megatron/core/utils.py#L1725>`_
in Megatron-LM.
""" """
if cp_comm_type == "a2a+p2p": if cp_comm_type == "a2a+p2p":
...@@ -3527,6 +3876,12 @@ def attn_forward_func_with_cp( ...@@ -3527,6 +3876,12 @@ def attn_forward_func_with_cp(
"all_gather", "all_gather",
], "The context parallel running configs cannot support sliding window attetnion!" ], "The context parallel running configs cannot support sliding window attetnion!"
enable_mla = k.shape[-1] != v.shape[-1]
assert not enable_mla or cp_comm_type in [
"p2p",
"a2a+p2p",
], "The context parallel running configs cannot support MLA!"
args = [ args = [
is_training, is_training,
q, q,
......
...@@ -624,11 +624,6 @@ def get_attention_backend( ...@@ -624,11 +624,6 @@ def get_attention_backend(
" bias for THD format" " bias for THD format"
) )
use_fused_attention = False use_fused_attention = False
elif head_dim_qk != head_dim_v:
logger.debug(
"Disabling FusedAttention as it does not support context parallelism with MLA"
)
use_fused_attention = False
# Filter: Attention mask # Filter: Attention mask
# attn_mask_type | attention_mask | supported backends # attn_mask_type | attention_mask | supported backends
...@@ -782,6 +777,7 @@ def get_attention_backend( ...@@ -782,6 +777,7 @@ def get_attention_backend(
q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) q_type = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True)
kv_type = q_type kv_type = q_type
fused_attention_backend = tex.get_fused_attn_backend( fused_attention_backend = tex.get_fused_attn_backend(
is_training,
q_type, q_type,
kv_type, kv_type,
QKVLayout[qkv_layout], QKVLayout[qkv_layout],
...@@ -962,15 +958,23 @@ def get_attention_backend( ...@@ -962,15 +958,23 @@ def get_attention_backend(
@torch.no_grad() @torch.no_grad()
def get_padding_mask( def get_padding_mask(
batch_size: int, batch_size: int,
cu_seqlens_q: torch.Tensor, cu_seqlens_q: torch.Tensor = None,
cu_seqlens_kv: torch.Tensor, cu_seqlens_kv: torch.Tensor = None,
max_seqlen_q: int, max_seqlen_q: int = None,
max_seqlen_kv: int, max_seqlen_kv: int = None,
attention_type: str = "self",
): ):
"""Convert cu_seqlens to attention_mask""" """Convert cu_seqlens to attention_mask"""
assert (
cu_seqlens_q is not None and max_seqlen_q is not None
), "cu_seqlens_q and max_seqlen_q are required for self-attention and cross-attention"
seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] seqlens_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool) attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
if attention_type == "cross":
assert (
cu_seqlens_kv is not None and max_seqlen_kv is not None
), "cu_seqlens_kv and max_seqlen_kv are required for cross-attention"
seqlens_kv = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool) attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(batch_size): for i in range(batch_size):
attention_mask_q = torch.cat( attention_mask_q = torch.cat(
...@@ -984,6 +988,7 @@ def get_padding_mask( ...@@ -984,6 +988,7 @@ def get_padding_mask(
], ],
dim=0, dim=0,
) )
if attention_type == "cross":
attention_mask_kv = torch.cat( attention_mask_kv = torch.cat(
[ [
attention_mask_kv, attention_mask_kv,
...@@ -995,8 +1000,12 @@ def get_padding_mask( ...@@ -995,8 +1000,12 @@ def get_padding_mask(
], ],
dim=0, dim=0,
) )
attention_mask_q = attention_mask_q.to(device="cuda")
if attention_type == "self":
attention_mask = attention_mask_q
else:
attention_mask = ( attention_mask = (
attention_mask_q.to(device="cuda"), attention_mask_q,
attention_mask_kv.to(device="cuda"), attention_mask_kv.to(device="cuda"),
) )
return attention_mask return attention_mask
......
...@@ -12,6 +12,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager ...@@ -12,6 +12,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
from transformer_engine.pytorch.module import LayerNormLinear, Linear from transformer_engine.pytorch.module import LayerNormLinear, Linear
from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization
from transformer_engine.pytorch.utils import ( from transformer_engine.pytorch.utils import (
SplitAlongDim, SplitAlongDim,
divide, divide,
...@@ -174,6 +175,22 @@ class MultiheadAttention(torch.nn.Module): ...@@ -174,6 +175,22 @@ class MultiheadAttention(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`. `fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
seq_length: Optional[int], default = `None`
sequence length of input samples. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are used for
forward propagation and activation recompute phase.
micro_batch_size: Optional[int], default = `None`
batch size per training step. Needed for JIT Warmup, a technique where jit
fused functions are warmed up before training to ensure same kernels are
used for forward propagation and activation recompute phase.
""" """
def __init__( def __init__(
...@@ -214,6 +231,10 @@ class MultiheadAttention(torch.nn.Module): ...@@ -214,6 +231,10 @@ class MultiheadAttention(torch.nn.Module):
device: Union[torch.device, str] = "cuda", device: Union[torch.device, str] = "cuda",
qkv_format: str = "sbhd", qkv_format: str = "sbhd",
name: str = None, name: str = None,
use_qk_norm: bool = False,
qk_norm_eps: float = 1e-6,
seq_length: Optional[int] = None,
micro_batch_size: Optional[int] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -267,6 +288,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -267,6 +288,7 @@ class MultiheadAttention(torch.nn.Module):
self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups
self.name = name self.name = name
self.use_qk_norm = use_qk_norm
common_gemm_kwargs = { common_gemm_kwargs = {
"fuse_wgrad_accumulation": fuse_wgrad_accumulation, "fuse_wgrad_accumulation": fuse_wgrad_accumulation,
...@@ -278,6 +300,14 @@ class MultiheadAttention(torch.nn.Module): ...@@ -278,6 +300,14 @@ class MultiheadAttention(torch.nn.Module):
"device": device, "device": device,
} }
# Initialize L2 normalization modules for query and key if enabled
if self.use_qk_norm:
self.qk_norm = L2Normalization(
eps=qk_norm_eps,
seq_length=seq_length,
micro_batch_size=micro_batch_size,
)
qkv_parallel_mode = "column" if set_parallel_mode else None qkv_parallel_mode = "column" if set_parallel_mode else None
if self.attention_type == "self": if self.attention_type == "self":
...@@ -482,6 +512,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -482,6 +512,8 @@ class MultiheadAttention(torch.nn.Module):
alibi_slopes: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_kv: Optional[torch.Tensor] = None, cu_seqlens_kv: Optional[torch.Tensor] = None,
cu_seqlens_q_padded: Optional[torch.Tensor] = None,
cu_seqlens_kv_padded: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None, max_seqlen_q: Optional[int] = None,
max_seqlen_kv: Optional[int] = None, max_seqlen_kv: Optional[int] = None,
fast_zero_fill: bool = True, fast_zero_fill: bool = True,
...@@ -556,6 +588,12 @@ class MultiheadAttention(torch.nn.Module): ...@@ -556,6 +588,12 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens_kv: Optional[torch.Tensor], default = `None` cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer` Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32. and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
max_seqlen_q: Optional[int], default = `None` max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`. Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided. Calculated from `cu_seqlens_q` if not provided.
...@@ -714,6 +752,18 @@ class MultiheadAttention(torch.nn.Module): ...@@ -714,6 +752,18 @@ class MultiheadAttention(torch.nn.Module):
for x in (key_layer, value_layer) for x in (key_layer, value_layer)
) )
if self.qkv_format == "thd":
key_layer, value_layer = (
x.reshape(x.size(0), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)
else:
# key, value: -> [sq, b, ng, hn]
key_layer, value_layer = (
x.reshape(x.size(0), x.size(1), -1, self.hidden_size_per_attention_head)
for x in (key_layer, value_layer)
)
# Attention head [sq, b, h] --> [sq, b, hp] # Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm: if self.input_layernorm:
layernorm_query_outputs = self.layernorm_query( layernorm_query_outputs = self.layernorm_query(
...@@ -792,6 +842,14 @@ class MultiheadAttention(torch.nn.Module): ...@@ -792,6 +842,14 @@ class MultiheadAttention(torch.nn.Module):
interleaved=self.rotary_pos_interleaved, interleaved=self.rotary_pos_interleaved,
) )
# ===========================
# Apply L2 normalization to query and key tensors
# ===========================
if self.use_qk_norm:
query_layer = self.qk_norm(query_layer)
key_layer = self.qk_norm(key_layer)
# =========================== # ===========================
# Core attention computation # Core attention computation
# =========================== # ===========================
...@@ -803,6 +861,8 @@ class MultiheadAttention(torch.nn.Module): ...@@ -803,6 +861,8 @@ class MultiheadAttention(torch.nn.Module):
qkv_format=self.qkv_format, qkv_format=self.qkv_format,
cu_seqlens_q=cu_seqlens_q, cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
max_seqlen_q=max_seqlen_q, max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv, max_seqlen_kv=max_seqlen_kv,
attention_mask=attention_mask, attention_mask=attention_mask,
......
...@@ -140,6 +140,14 @@ def general_gemm( ...@@ -140,6 +140,14 @@ def general_gemm(
# There is not use_split_accumulator == False # There is not use_split_accumulator == False
# implementation for Float8BlockwiseQTensorBase GEMM # implementation for Float8BlockwiseQTensorBase GEMM
use_split_accumulator = True use_split_accumulator = True
# Check that data format is supported
if (
A._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
or B._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
args = ( args = (
A, A,
transa, # transa transa, # transa
......
...@@ -253,13 +253,21 @@ class SynchronizedGroupOffloadHandler(OffloadHandler): ...@@ -253,13 +253,21 @@ class SynchronizedGroupOffloadHandler(OffloadHandler):
return state return state
@staticmethod @staticmethod
def reload(state, non_blocking=None): def reload(state, non_blocking=None, copy_buffer=None):
"""Reload.""" """Reload."""
dev, cpu_backup = state dev, cpu_backup = state
if non_blocking is None: if non_blocking is None:
non_blocking = cpu_backup.is_pinned() non_blocking = cpu_backup.is_pinned()
if copy_buffer is None:
return cpu_backup.to(dev, non_blocking=non_blocking) return cpu_backup.to(dev, non_blocking=non_blocking)
assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!"
copy_buffer.copy_(cpu_backup, non_blocking=non_blocking)
return copy_buffer
def tensor_push(self, tensor: torch.Tensor, **kwargs): def tensor_push(self, tensor: torch.Tensor, **kwargs):
"""Tensor push.""" """Tensor push."""
# obtain a unique tensor tag # obtain a unique tensor tag
...@@ -300,6 +308,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -300,6 +308,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
num_offload_group, # must be <= actual number of groups (number of commits) num_offload_group, # must be <= actual number of groups (number of commits)
num_model_group, num_model_group,
tensor_need_offloading_checker=(lambda t: True), tensor_need_offloading_checker=(lambda t: True),
double_buffering=False,
debug=False, debug=False,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -314,11 +323,17 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -314,11 +323,17 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Data structure to hold the FP8/MXFP8 tensor objects # Data structure to hold the FP8/MXFP8 tensor objects
self.fp8_tensor_object_map = {} self.fp8_tensor_object_map = {}
self.float8_transpose_cache_valid = {} self.float8_transpose_cache_valid = {}
self.dereferencing_list = []
# Tracking the number of layers offloaded # Tracking the number of layers offloaded
self.offloaded_group_count = 0 self.offloaded_group_count = 0
# Core data structure that decides the window for offloading # Core data structure that decides the window for offloading
self.layer_window_map = {} self.layer_window_map = {}
# Data structures fo double buffered reloading
self.double_buffering = double_buffering
self.reload_double_buffer = [[], []]
self.double_buffer_created = False
# Logic to make offloading load balance across computation # Logic to make offloading load balance across computation
# for optimal CPU/GPU interconnect usage # for optimal CPU/GPU interconnect usage
constant = 0 constant = 0
...@@ -360,6 +375,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -360,6 +375,12 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.tensor_tag_to_state[tensor_tag] = [] self.tensor_tag_to_state[tensor_tag] = []
self.tensor_tag_to_buf[tensor_tag] = [] self.tensor_tag_to_buf[tensor_tag] = []
# Added support for de-duplicating FP8 param tensors
for _, value in self.fp8_tensor_object_map.items():
if tensor is value:
self.dereferencing_list.append(tensor_tag)
break
self.fp8_tensor_object_map[tensor_tag] = tensor self.fp8_tensor_object_map[tensor_tag] = tensor
if isinstance(tensor, Float8Tensor): if isinstance(tensor, Float8Tensor):
self.float8_transpose_cache_valid[tensor_tag] = getattr( self.float8_transpose_cache_valid[tensor_tag] = getattr(
...@@ -398,11 +419,18 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -398,11 +419,18 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Handling the quantized tensor case specially here # Handling the quantized tensor case specially here
if isinstance(tensor, list): if isinstance(tensor, list):
# If it's a duplicated tensor, we don't need to locally
# write back a tensor as it would already be written
if tensor_tag in self.dereferencing_list:
self.dereferencing_list.remove(tensor_tag)
else:
self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor)
tensor = self.fp8_tensor_object_map.pop(tensor_tag) tensor = self.fp8_tensor_object_map.pop(tensor_tag)
self.tensor_tag_to_buf.pop(tensor_tag, None) if self.double_buffering:
tensor.do_not_clear = True
self.tensor_tag_to_buf.pop(tensor_tag, None)
# the tensor should have been copied back in on_group_commit_backward() # the tensor should have been copied back in on_group_commit_backward()
# which invokes bulk_reload_group. # which invokes bulk_reload_group.
assert not isinstance(tensor, tuple) assert not isinstance(tensor, tuple)
...@@ -454,6 +482,20 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -454,6 +482,20 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# the first compute completion # the first compute completion
if current_group == 0: if current_group == 0:
self.d2h_stream.wait_stream(torch.cuda.current_stream()) self.d2h_stream.wait_stream(torch.cuda.current_stream())
if not self.double_buffer_created:
# Creating the first copy of double buffer for tensors that are offloaded
for tensor_tag, buf in self.tensor_tag_to_buf.items():
if isinstance(buf, list):
for b in buf:
self.reload_double_buffer[0].append(
torch.empty_like(b) if self.double_buffering else None
)
else:
self.reload_double_buffer[0].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.bulk_offload_group(current_group) self.bulk_offload_group(current_group)
# Window map data structure helps us synchronize based on number # Window map data structure helps us synchronize based on number
...@@ -483,6 +525,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -483,6 +525,15 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Increment the offload group count to keep track # Increment the offload group count to keep track
self.offloaded_group_count += 1 self.offloaded_group_count += 1
if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1):
for buf in self.reload_double_buffer[0]:
self.reload_double_buffer[1].append(
torch.empty_like(buf) if self.double_buffering else None
)
self.double_buffer_created = True
def on_group_commit_forward(self): def on_group_commit_forward(self):
"""This function will cause host device synchronization""" """This function will cause host device synchronization"""
# handle synchronization events # handle synchronization events
...@@ -494,28 +545,49 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -494,28 +545,49 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
"""Bulk reload group.""" """Bulk reload group."""
assert group_to_reload < self.num_offload_group assert group_to_reload < self.num_offload_group
buffer_idx = 0
double_buffer_idx = group_to_reload % 2
with torch.cuda.stream(self.h2d_stream): with torch.cuda.stream(self.h2d_stream):
# move back tensors # move back tensors
for tensor_label, state in self.tensor_tag_to_state.items(): for tensor_label, state in self.tensor_tag_to_state.items():
group_id, _ = tensor_label group_id, _ = tensor_label
if group_id == group_to_reload: if group_id == group_to_reload:
if isinstance(state, tuple): if isinstance(state, tuple):
recovered_tensor = SynchronizedGroupOffloadHandler.reload(state) recovered_tensor = SynchronizedGroupOffloadHandler.reload(
state, True, self.reload_double_buffer[double_buffer_idx][buffer_idx]
)
buffer_idx = buffer_idx + 1
self.tensor_tag_to_state[tensor_label] = recovered_tensor self.tensor_tag_to_state[tensor_label] = recovered_tensor
elif isinstance(state, list): elif isinstance(state, list):
tensor_list = [] tensor_list = []
for state_tuple in state: for state_tuple in state:
if isinstance(state_tuple, tuple): if isinstance(state_tuple, tuple):
tensor_list.append( tensor_list.append(
SynchronizedGroupOffloadHandler.reload(state_tuple) SynchronizedGroupOffloadHandler.reload(
state_tuple,
True,
self.reload_double_buffer[double_buffer_idx][buffer_idx],
) )
)
buffer_idx = buffer_idx + 1
else: else:
tensor_list.append(state_tuple) tensor_list.append(state_tuple)
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(tensor_list)
# No need to write back the duplicated tensor againn
# to the same location, this check ensures that
if tensor_label in self.dereferencing_list:
self.dereferencing_list.remove(tensor_label)
else:
_ = self.fp8_tensor_object_map[tensor_label].restore_from_saved(
tensor_list
)
if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor):
self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( self.fp8_tensor_object_map[tensor_label]._transpose_invalid = (
self.float8_transpose_cache_valid.pop(tensor_label) self.float8_transpose_cache_valid.pop(tensor_label)
) )
self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop(
tensor_label tensor_label
) )
...@@ -552,6 +624,7 @@ def get_cpu_offload_context( ...@@ -552,6 +624,7 @@ def get_cpu_offload_context(
model_layers: int = 1, model_layers: int = 1,
offload_activations: bool = True, offload_activations: bool = True,
offload_weights: bool = False, offload_weights: bool = False,
double_buffering: bool = False,
): ):
""" """
This function returns the CPU Offload context and the synchronizer function that needs to be This function returns the CPU Offload context and the synchronizer function that needs to be
...@@ -580,6 +653,8 @@ def get_cpu_offload_context( ...@@ -580,6 +653,8 @@ def get_cpu_offload_context(
When set to `True`, offloads the activations for the TE layer. When set to `True`, offloads the activations for the TE layer.
offload_weights: bool, default = `True` offload_weights: bool, default = `True`
When set to `True`, offloads the weights for the TE layer. When set to `True`, offloads the weights for the TE layer.
double_buffering: bool, default = `False`
When set to `True`, uses double buffering for offloading.
""" """
...@@ -611,6 +686,7 @@ def get_cpu_offload_context( ...@@ -611,6 +686,7 @@ def get_cpu_offload_context(
num_offload_group=num_layers, num_offload_group=num_layers,
num_model_group=model_layers, num_model_group=model_layers,
tensor_need_offloading_checker=tensor_need_offloading_checker, tensor_need_offloading_checker=tensor_need_offloading_checker,
double_buffering=double_buffering,
) )
def group_prefetch_offload_commit_async(tensor): def group_prefetch_offload_commit_async(tensor):
......
...@@ -20,6 +20,20 @@ std::vector<size_t> getTensorShape(at::Tensor t) { ...@@ -20,6 +20,20 @@ std::vector<size_t> getTensorShape(at::Tensor t) {
return shape; return shape;
} }
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) {
NVTEShape ret;
ret.ndim = torch_shape.size();
constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t);
NVTE_CHECK(ret.ndim < max_dimensions,
"Torch tensor has too many dimensions. Max supported: ", max_dimensions, " and got ",
ret.ndim, ".");
for (size_t i = 0; i < ret.ndim; ++i) {
const auto& v = torch_shape[i];
ret.data[i] = static_cast<size_t>(v);
}
return ret;
}
std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer) { std::unique_ptr<Quantizer> convert_quantizer(py::handle quantizer) {
init_extension(); init_extension();
if (quantizer.is_none()) { if (quantizer.is_none()) {
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include <transformer_engine/fused_attn.h> #include <transformer_engine/fused_attn.h>
#include <transformer_engine/fused_rope.h> #include <transformer_engine/fused_rope.h>
#include <transformer_engine/gemm.h> #include <transformer_engine/gemm.h>
#include <transformer_engine/multi_stream.h>
#include <transformer_engine/multi_tensor.h> #include <transformer_engine/multi_tensor.h>
#include <transformer_engine/normalization.h> #include <transformer_engine/normalization.h>
#include <transformer_engine/padding.h> #include <transformer_engine/padding.h>
...@@ -177,6 +178,8 @@ class Float8BlockQuantizer : public Quantizer { ...@@ -177,6 +178,8 @@ class Float8BlockQuantizer : public Quantizer {
bool force_pow_2_scales = false; bool force_pow_2_scales = false;
// Amax within quantization tile has a floor of epsilon. // Amax within quantization tile has a floor of epsilon.
float amax_epsilon = 0.0; float amax_epsilon = 0.0;
// Whether quantized tensor will be used in an all-gather
bool all_gather_usage = false;
private: private:
int block_scaling_dim = 2; int block_scaling_dim = 2;
...@@ -222,21 +225,23 @@ std::vector<size_t> getTensorShape(at::Tensor t); ...@@ -222,21 +225,23 @@ std::vector<size_t> getTensorShape(at::Tensor t);
transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
const std::string& fp8_recipe); const std::string& fp8_recipe);
inline size_t typeToSize(transformer_engine::DType t) { inline size_t typeToNumBits(transformer_engine::DType t) {
switch (t) { switch (t) {
case transformer_engine::DType::kInt64: case transformer_engine::DType::kInt64:
return 8; return 64;
case transformer_engine::DType::kInt32: case transformer_engine::DType::kInt32:
case transformer_engine::DType::kFloat32: case transformer_engine::DType::kFloat32:
return 4; return 32;
case transformer_engine::DType::kInt16: case transformer_engine::DType::kInt16:
case transformer_engine::DType::kFloat16: case transformer_engine::DType::kFloat16:
case transformer_engine::DType::kBFloat16: case transformer_engine::DType::kBFloat16:
return 2; return 16;
case transformer_engine::DType::kByte: case transformer_engine::DType::kByte:
case transformer_engine::DType::kFloat8E4M3: case transformer_engine::DType::kFloat8E4M3:
case transformer_engine::DType::kFloat8E5M2: case transformer_engine::DType::kFloat8E5M2:
return 1; return 8;
case transformer_engine::DType::kFloat4E2M1:
return 4;
default: default:
NVTE_ERROR("Invalid type"); NVTE_ERROR("Invalid type");
} }
...@@ -355,6 +360,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape); ...@@ -355,6 +360,7 @@ std::vector<size_t> convertShape(const NVTEShape& shape);
int roundup(const int value, const int multiple); int roundup(const int value, const int multiple);
NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape);
} // namespace transformer_engine::pytorch } // namespace transformer_engine::pytorch
namespace std { namespace std {
......
...@@ -35,13 +35,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T ...@@ -35,13 +35,11 @@ std::tuple<at::Tensor, at::Tensor> moe_unpermute_bwd(at::Tensor input_bwd, at::T
* Attention * Attention
**************************************************************************************************/ **************************************************************************************************/
NVTE_Fused_Attn_Backend get_fused_attn_backend(const DType q_dtype, const DType kv_dtype, NVTE_Fused_Attn_Backend get_fused_attn_backend(
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout,
NVTE_Mask_Type attn_mask_type, float p_dropout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads,
size_t num_attn_heads, size_t num_gqa_groups, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk,
size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right);
size_t head_dim_qk, size_t head_dim_v,
int64_t window_size_left, int64_t window_size_right);
std::vector<py::object> fused_attn_fwd( std::vector<py::object> fused_attn_fwd(
size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout,
...@@ -450,6 +448,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ...@@ -450,6 +448,8 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
at::Tensor get_buffer(bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt); std::optional<std::vector<int64_t>> shape = std::nullopt);
at::Stream get_communication_stream();
}; // CommOverlap }; // CommOverlap
class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase { class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
...@@ -469,6 +469,8 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ...@@ -469,6 +469,8 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm
at::Tensor get_buffer(bool local_chunk = false, at::Tensor get_buffer(bool local_chunk = false,
std::optional<std::vector<int64_t>> shape = std::nullopt); std::optional<std::vector<int64_t>> shape = std::nullopt);
at::Stream get_communication_stream();
}; // CommOverlapP2P }; // CommOverlapP2P
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_ #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "extensions.h"
#include "pybind.h" #include "pybind.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h"
#include "common.h" #include "common.h"
#include "extensions.h"
namespace transformer_engine::pytorch { namespace transformer_engine::pytorch {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment