Commit c1a1c04e authored by wenjh's avatar wenjh
Browse files

Merge nv_main(2.10) to main


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents e698a0a7 66aed3ae
......@@ -13,13 +13,12 @@ import warnings
import torch
# import transformer_engine_torch as tex
import transformer_engine_torch as tex
from transformer_engine_torch import DType as TE_DType
from ..quantized_tensor import QuantizedTensorStorage
from ...quantized_tensor import QuantizedTensorStorage, Quantizer
# from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ..quantized_tensor import Quantizer
from ...constants import TE_DType as torch_to_transformer_engine_dtype
from ...utils import _empty_tensor
......@@ -46,34 +45,7 @@ class _FromNVFP4Func(torch.autograd.Function):
# Dequantize row-wise data
if tensor._rowwise_data is not None:
### TODO(tmoon): Debug dequantize kernel and remove unfused impl
# return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
# Tensor properties
shape = list(tensor._rowwise_data.size())
shape[-1] *= 2
device = tensor._rowwise_data.device
# Convert FP4E2M1 values to FP32
data = tensor._rowwise_data.view(torch.uint8).to(torch.int32)
data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape)
data = _fp4_e2m1_vals(device, dtype=torch.float32)[data]
data = data.to(torch.float32).contiguous()
# Convert FP8E4M3 block scales to FP32
block_scales = tensor._rowwise_scale_inv
block_scales = block_scales.reshape(-1, block_scales.size(-1))
block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16]
block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32)
# Convert amax to FP32 tensor scale
tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max
# Apply scales
block_data = data.view(-1, 16)
block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1)
return data.to(dtype)
return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype])
if tensor._columnwise_data is not None:
raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!")
......
......@@ -4,18 +4,18 @@
"""Helper functions for using fp8 tensors as weights"""
import os
from typing import Optional, Union
from typing import Optional, Union, List
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
import transformer_engine_torch as tex
from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage
from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer
from ..optimizers.multi_tensor_apply import multi_tensor_applier
from ..utils import is_non_tn_fp8_gemm_supported
def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
......@@ -48,7 +48,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
def cast_master_weights_to_fp8(
model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None
model_weights,
master_weights,
start_offsets,
group,
fsdp_shard_model_weights=None,
manual_post_all_gather_processing=False,
):
r"""Helper function to cast master weights to FP8 primary weights.
......@@ -69,6 +74,11 @@ def cast_master_weights_to_fp8(
fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
not sharded. Otherwise, it means that the model weights are sharded and we get
target model weights data storage using the FSDP shard model weights.
manual_post_all_gather_processing: bool, default = `False`.
If False, post processing will be automatically triggered during next forward.
If True, the timing of calling post_all_gather_processing is left to the user.
Note that users must call `post_all_gather_processing` if it's set to True,
otherwise the weights won't be updated correctly.
"""
......@@ -129,21 +139,18 @@ def cast_master_weights_to_fp8(
f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet"
)
extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing]
if len(delayed_scaling_params) > 0:
_cast_master_weights_to_fp8_delayed_scaling(
delayed_scaling_params, group, use_fsdp_shard_model_weights
)
_cast_master_weights_to_fp8_delayed_scaling(delayed_scaling_params, *extra_args)
if len(current_scaling_params) > 0:
_cast_master_weights_to_fp8_current_scaling(
current_scaling_params, group, use_fsdp_shard_model_weights
)
_cast_master_weights_to_fp8_current_scaling(current_scaling_params, *extra_args)
if len(blockwise_scaling_params) > 0:
_cast_master_weights_to_fp8_blockwise_scaling(
blockwise_scaling_params, group, use_fsdp_shard_model_weights
)
_cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args)
def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False):
def _cast_master_weights_to_fp8_delayed_scaling(
params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
):
r"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
Parameters
......@@ -160,11 +167,12 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
amaxes, scales, scale_invs = [], [], []
for model_weight, master_weight, start_offset, shard_model_weight_raw in params:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight._reset_caches()
if not manual_post_all_gather_processing:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated currently.
model_weight._reset_caches()
quantizer = model_weight._get_quantizer()
......@@ -225,7 +233,9 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo
)
def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False):
def _cast_master_weights_to_fp8_current_scaling(
params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
):
r"""Helper function to cast master weights to FP8 primary weights for current scaling.
Parameters
......@@ -305,11 +315,12 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales
):
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# currently.
model_weight._reset_caches()
if not manual_post_all_gather_processing:
# Reset transpose cache for all model weights.
# We cannot create transpose cache here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated currently.
model_weight._reset_caches()
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
......@@ -336,7 +347,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
def _cast_master_weights_to_fp8_blockwise_scaling(
params, group, use_fsdp_shard_model_weights=False
params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False
):
r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling.
......@@ -437,11 +448,12 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
params, scales
):
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False)
if not manual_post_all_gather_processing:
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to
# overlap the all-gather of model weights and forward process, so the model weight is
# not updated at this moment.
model_weight.update_usage(rowwise_usage=True, columnwise_usage=False)
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
......@@ -459,18 +471,35 @@ def _cast_master_weights_to_fp8_blockwise_scaling(
)
def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an environment or object is using experimental Kitchen middleware.
def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]):
"""
Post-processing after all-gather for weights in distributed optimizer.
- Float8Tensor: may need to create a transposed view to match backend GEMM.
- Float8BlockwiseQTensor: create column-wise storage.
- Plain pytorch tensor: noop.
"""
if not isinstance(model_weights, list):
model_weights = [model_weights]
for model_weight in model_weights:
if isinstance(model_weight, Float8Tensor):
# Delayed scaling and per-tensor current scaling: if backend does not support
# non-transposed FP8 GEMM, pre-create the transpose.
if not is_non_tn_fp8_gemm_supported():
model_weight._create_transpose()
elif isinstance(model_weight, Float8BlockwiseQTensor):
# Blockwise scaling: create column-wise storage.
model_weight._create_columnwise()
elif isinstance(model_weight, QuantizedTensor):
raise ValueError(f"post_processing for {type(model_weight)} is not supported")
def is_custom(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool:
"""Check if an object is custom.
Returns False if x is a torch.Tensor.
"""
# Detect if the environment is experimental
if x is None:
return int(os.getenv("QAT_PARAMS", "0")) > 0
# Detect if the object is experimental
if isinstance(x, torch.Tensor):
if x is None or isinstance(x, torch.Tensor):
return False
if not isinstance(x, (Quantizer, QuantizedTensorStorage)):
raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance")
return hasattr(x, "experimental") and x.experimental
return hasattr(x, "custom") and x.custom
......@@ -176,7 +176,12 @@ class TransformerLayer(torch.nn.Module):
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
'silu', 'swiglu', and 'clamped_swiglu'.
activation_params : Optional[dict], default = `None`
Additional parameters for the activation function.
At the moment, only used for 'clamped_swiglu' activation which
supports 'limit' and 'alpha' parameters. You can set these as
`activation_params={'limit': 7.0, 'alpha': 1.702}`.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
......@@ -310,6 +315,7 @@ class TransformerLayer(torch.nn.Module):
ub_bulk_wgrad: bool = True,
bias: bool = True,
activation: str = "gelu",
activation_params: Optional[dict] = None,
normalization: str = "LayerNorm",
device: Union[torch.device, str] = "cuda",
attn_input_format: str = "sbhd",
......@@ -475,6 +481,7 @@ class TransformerLayer(torch.nn.Module):
ub_overlap_rs=ub_overlap_rs,
ub_overlap_ag=ub_overlap_ag,
activation=activation,
activation_params=activation_params,
normalization=normalization,
device=device,
name=name + ".layernorm_mlp" if name is not None else None,
......
......@@ -2,4 +2,4 @@
#
# See LICENSE for license information.
"""Kernels written with OpenAI Triton."""
"""PyTorch wrappers for Triton kernels."""
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Efficient Cross Entropy kernels written with OpenAI Triton."""
"""PyTorch wrapper functions for Cross Entropy Triton kernels."""
from typing import Union
from functools import reduce
......@@ -13,257 +13,17 @@ import torch
import torch.distributed as dist
import triton
import triton.language as tl
@triton.jit
def online_softmax_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
rank,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes the m/d components on this TP rank for the online softmax.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
# locate the start index
X_ptr += program_id * X_stride
# Load Y_ptr
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32)
else:
X_y = float("-inf")
else:
X_y = float("-inf")
m_d_X_y_ptr += program_id * m_d_X_y_stride * 3
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to(
tl.float32
)
block_max = tl.max(X_block)
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
tl.store(m_d_X_y_ptr, m)
tl.store(m_d_X_y_ptr + m_d_X_y_stride, d)
tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y)
@triton.jit
def cross_entropy_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
loss_ptr,
loss_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
rank,
world_size,
ignore_idx,
n_cols,
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
loss_stride (int): The stride of the loss tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
# locate the start index
X_ptr += program_id * X_stride
# Load Y_ptr
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
if y == ignore_idx:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
return
loss_ptr += program_id * loss_stride
m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride
# Need to reduce the m/d/X_y values from other TP ranks
m = tl.load(m_d_X_y_ptr)
d = tl.load(m_d_X_y_ptr + m_d_X_y_stride)
ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride))
for i in range(1, world_size):
offset = i * 3 * n_non_ignore * m_d_X_y_stride
access_ptr = m_d_X_y_ptr + offset
m_new = tl.load(access_ptr)
d_new = tl.load(access_ptr + m_d_X_y_stride)
X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride))
d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new))
m = tl.maximum(m, m_new)
ori_X_y = tl.maximum(ori_X_y, X_y_new)
# Label smoothing is a general case of normal cross entropy
scaled_x_sum = 0.0
eps = label_smoothing / (n_cols * world_size)
# 4. [Online softmax] second pass: calculate the gradients
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
grad_dtype = X_block.dtype
X_block = X_block.to(tl.float32)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written
tl.debug_barrier()
# 5. Calculate the loss
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
loss = -(ori_X_y - m - tl.log(d))
# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx)
# Apply the same conditional scaling logic for the target token
if reduce_loss:
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
tl.store(X_ptr + y - vocab_start_idx, X_y)
tl.store(loss_ptr, loss)
from transformer_engine.common.triton.cross_entropy import (
online_softmax_kernel,
cross_entropy_kernel,
element_mul_kernel,
)
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2
@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
grad_output_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)
# Locate the start index
X_ptr += program_id * X_stride
# Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
def cross_entropy_forward(
_input: torch.Tensor,
target: torch.Tensor,
......
......@@ -2,63 +2,12 @@
#
# See LICENSE for license information.
"""NVFP4 padding kernels
TODO(ksivamani): Documentation
"""
"""PyTorch wrapper functions for padding Triton kernels."""
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1),
],
key=["out_dim0", "out_dim1"],
)
@triton.jit
def zero_pad_kernel(
inp_ptr,
out_ptr,
in_dim0: tl.constexpr,
in_dim1: tl.constexpr,
out_dim0: tl.constexpr,
out_dim1: tl.constexpr,
in_s0,
in_s1,
out_s0,
out_s1,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Pads a tensor assuming it's a columnwise scaling inverse."""
# tile over OUTPUT coordinates
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols
om = offs_m[:, None]
on = offs_n[None, :]
# edge masking for output
out_mask = (om < out_dim0) & (on < out_dim1)
# valid input region is simply top-left (no offsets)
in_mask = (om < in_dim0) & (on < in_dim1)
# load valid input, else zero (masked load touches memory only where True)
x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0)
# store to output (only within bounds of the output tile)
tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask)
from transformer_engine.common.triton.pad import zero_pad_kernel
def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor:
......
......@@ -2,192 +2,23 @@
#
# See LICENSE for license information.
"""Permutation kernels written with OpenAI Triton."""
"""PyTorch wrapper functions for Permutation Triton kernels."""
from typing import Union
import torch
import triton
import triton.language as tl
from triton.language import core
from triton.language.standard import _log2
# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698
@triton.jit
def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)]
y = tl.reshape(x, shape)
z = tl.reshape(indices, shape)
mask = tl.arange(0, 2)[None, :, None]
l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to(
x.dtype
)
r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to(
x.dtype
)
l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)
idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
il_value = l_value.to(idtype, bitcast=True)
ir_value = r_value.to(idtype, bitcast=True)
ix = x.to(idtype, bitcast=True)
flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix))
ret = ix ^ flag1
flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix))
ind = indices ^ flag2
return ret.to(x.dtype, bitcast=True), ind
@triton.jit
def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
tl.static_assert(stage <= n_dims)
"""
order_type 0 == ascending
order_type 1 == descending
order_type 2 == alternating
"""
if order == 2:
shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage]
flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
else:
flip = tl.full(x.shape, value=order, dtype=tl.int32)
for i in tl.static_range(stage):
x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims)
return x, indices
@triton.jit
def _argsort(x, indices, n_dims: tl.constexpr):
for i in tl.static_range(1, n_dims + 1):
x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims)
return x, indices
@triton.jit
def _row_id_map_pass_1_kernel(
# pointers
routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_routing_map_token,
stride_routing_map_expert,
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
expert_token_mask = tl.load(
routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
mask=(offset < num_tokens),
other=0,
).to(tl.int32)
row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
tl.store(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id_within_token_block,
mask=offset < num_tokens,
)
n_tokens_per_block = tl.sum(expert_token_mask)
tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)
@triton.jit
def _row_id_map_pass_2_kernel(
# pointers
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
WORKSPACE_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
row_id_within_token_block = tl.load(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
mask=(offset < num_tokens),
other=0,
)
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
row_id = tl.where(
row_id_within_token_block == 0,
-1,
row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
)
tl.store(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id,
mask=(offset < num_tokens),
)
@triton.jit
def _row_id_map_pass_3_kernel(
# pointers
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
LOAD_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_dims: tl.constexpr = _log2(LOAD_SIZE)
off = tl.arange(0, LOAD_SIZE)
row_id_map = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off,
mask=off < num_experts,
other=-1,
)
n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0))
indices = off
sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert,
sorted_map,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + off) * stride_row_id_map_expert,
indices,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert,
n_routed,
)
from transformer_engine.common.triton.permutation import (
_row_id_map_pass_1_kernel,
_row_id_map_pass_2_kernel,
_row_id_map_pass_3_kernel,
_permute_kernel,
_unpermute_kernel,
_unpermute_bwd_with_merging_probs_kernel,
_make_chunk_sort_map_kernel,
_sort_chunks_by_map_kernel,
)
def make_row_id_map(
......@@ -287,103 +118,6 @@ def make_row_id_map(
return row_id_map
@triton.jit
def _permute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
stride_scale_token,
stride_scale_hidden,
stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# metas
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
if prob == 0.0:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl.store(output_ptr + output_off, 0.0, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
try:
_permute_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_permute_kernel)
except RuntimeError:
pass
def permute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -463,116 +197,6 @@ def permute_with_mask_map(
return output, permuted_scale, permuted_probs
@triton.jit
def _unpermute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_merging_probs_token,
stride_merging_probs_expert,
stride_permuted_probs_token,
stride_unpermuted_probs_token,
stride_unpermuted_probs_expert,
# metas
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
if PERMUTE_PROBS:
# write 0.0 to probs_grad that are not routed
if pid_h == 0:
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
unpermuted_prob_off = (
pid_t * stride_unpermuted_probs_token
+ stride_unpermuted_probs_expert * map_load_off
)
tl.store(
unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts
)
accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
inp *= merging_prob
accumulator += inp
if PERMUTE_PROBS:
if pid_h == 0:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
unpermuted_prob_off = (
pid_t * stride_unpermuted_probs_token
+ expert_idx * stride_unpermuted_probs_expert
)
permuted_prob_off = src_row * stride_permuted_probs_token
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
accumulator = accumulator.to(data_type)
dst_row = pid_t.to(tl.int64)
output_off = dst_row * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
try:
_unpermute_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_unpermute_kernel)
except RuntimeError:
pass
def unpermute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -639,110 +263,6 @@ def unpermute_with_mask_map(
return output, unpermuted_probs
@triton.jit
def _unpermute_bwd_with_merging_probs_kernel(
# pointers
fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr,
merging_probs_ptr,
merging_probs_grad_ptr,
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_fwd_output_grad_token,
stride_fwd_output_grad_hidden,
stride_fwd_input_grad_token,
stride_fwd_input_grad_hidden,
stride_fwd_input_token,
stride_fwd_input_hidden,
stride_merging_probs_token,
stride_merging_probs_expert,
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# metas
PROBS_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
token_probs_grad_off = (
pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off
)
tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts)
n_routed = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
expert_idx = tl.load(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
src_row = pid.to(tl.int64)
input_off = (
src_row * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob
output = output.to(data_type)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
)
tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
fwd_input_off = (
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
probs_grad_off = (
pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
try:
_unpermute_bwd_with_merging_probs_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_unpermute_bwd_with_merging_probs_kernel)
except RuntimeError:
pass
def unpermute_with_mask_map_bwd_with_merging_probs(
fwd_output_grad: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -808,47 +328,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
return act_grad, merging_probs_grad
@triton.jit
def _make_chunk_sort_map_kernel(
# pointers
split_sizes_ptr,
sorted_indices_ptr,
dst_rows_ptr,
# sizes
num_splits: tl.constexpr,
# metas
IDX_LOAD_WIDTH: tl.constexpr,
):
pid = tl.program_id(0)
load_split_offset = tl.arange(0, IDX_LOAD_WIDTH)
sorted_indices = tl.load(
sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits
)
# get chunk idx of the current token in the input tensor
input_split_sizes = tl.load(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
in_chunk_offset = pid - input_split_sizes_presum
# get chunk idx of the current token in the output tensor
output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0)
output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1)
# make row_id_map
output_split_sizes = tl.load(
split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
tl.store(dst_rows_ptr + pid, dst_row)
def make_chunk_sort_map(
split_sizes: torch.Tensor,
sorted_indices: torch.Tensor,
......@@ -881,67 +360,6 @@ def make_chunk_sort_map(
return row_id_map
@triton.jit
def _sort_chunks_by_map_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size: tl.constexpr,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# metas
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr,
):
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
if FORWARD:
src_row = pid_t.to(tl.int64)
dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
else:
src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
dst_row = pid_t.to(tl.int64)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
if PERMUTE_PROBS:
if pid_h == 0:
prob_off = src_row * stride_probs_token
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
try:
_sort_chunks_by_map_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_sort_chunks_by_map_kernel)
except RuntimeError:
pass
def sort_chunks_by_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......
......@@ -12,7 +12,7 @@ import numpy as np
import torch
from . import torch_version
from .tensor.quantized_tensor import Quantizer
from .quantized_tensor import Quantizer
from torch.utils.cpp_extension import IS_HIP_EXTENSION
__all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment