Unverified Commit 5ea83432 authored by Teddy Do's avatar Teddy Do Committed by GitHub
Browse files

Move Triton to common (#2359)



* move triton to common and change paths
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

* Formatting
Signed-off-by: default avatartdophung <tdophung@nvidia.com>

---------
Signed-off-by: default avatartdophung <tdophung@nvidia.com>
parent 3454f84d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Kernels written with OpenAI Triton."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Efficient Cross Entropy kernels written with OpenAI Triton."""
import triton
import triton.language as tl
@triton.jit
def online_softmax_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
rank,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes the m/d components on this TP rank for the online softmax.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
# locate the start index
X_ptr += program_id * X_stride
# Load Y_ptr
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32)
else:
X_y = float("-inf")
else:
X_y = float("-inf")
m_d_X_y_ptr += program_id * m_d_X_y_stride * 3
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to(
tl.float32
)
block_max = tl.max(X_block)
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
tl.store(m_d_X_y_ptr, m)
tl.store(m_d_X_y_ptr + m_d_X_y_stride, d)
tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y)
@triton.jit
def cross_entropy_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
loss_ptr,
loss_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
rank,
world_size,
ignore_idx,
n_cols,
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
loss_stride (int): The stride of the loss tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
# locate the start index
X_ptr += program_id * X_stride
# Load Y_ptr
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
if y == ignore_idx:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
return
loss_ptr += program_id * loss_stride
m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride
# Need to reduce the m/d/X_y values from other TP ranks
m = tl.load(m_d_X_y_ptr)
d = tl.load(m_d_X_y_ptr + m_d_X_y_stride)
ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride))
for i in range(1, world_size):
offset = i * 3 * n_non_ignore * m_d_X_y_stride
access_ptr = m_d_X_y_ptr + offset
m_new = tl.load(access_ptr)
d_new = tl.load(access_ptr + m_d_X_y_stride)
X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride))
d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new))
m = tl.maximum(m, m_new)
ori_X_y = tl.maximum(ori_X_y, X_y_new)
# Label smoothing is a general case of normal cross entropy
scaled_x_sum = 0.0
eps = label_smoothing / (n_cols * world_size)
# 4. [Online softmax] second pass: calculate the gradients
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
grad_dtype = X_block.dtype
X_block = X_block.to(tl.float32)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written
tl.debug_barrier()
# 5. Calculate the loss
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
loss = -(ori_X_y - m - tl.log(d))
# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx)
# Apply the same conditional scaling logic for the target token
if reduce_loss:
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
tl.store(X_ptr + y - vocab_start_idx, X_y)
tl.store(loss_ptr, loss)
@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
grad_output_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)
# Locate the start index
X_ptr += program_id * X_stride
# Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Efficient NVFP4 padding kernels written with OpenAI Triton .
TODO(ksivamani): Documentation
"""
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1),
],
key=["out_dim0", "out_dim1"],
)
@triton.jit
def zero_pad_kernel(
inp_ptr,
out_ptr,
in_dim0: tl.constexpr,
in_dim1: tl.constexpr,
out_dim0: tl.constexpr,
out_dim1: tl.constexpr,
in_s0,
in_s1,
out_s0,
out_s1,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Pads a tensor assuming it's a columnwise scaling inverse."""
# tile over OUTPUT coordinates
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols
om = offs_m[:, None]
on = offs_n[None, :]
# edge masking for output
out_mask = (om < out_dim0) & (on < out_dim1)
# valid input region is simply top-left (no offsets)
in_mask = (om < in_dim0) & (on < in_dim1)
# load valid input, else zero (masked load touches memory only where True)
x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0)
# store to output (only within bounds of the output tile)
tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Efficient Permutation kernels written with OpenAI Triton."""
import triton
import triton.language as tl
from triton.language import core
from triton.language.standard import _log2
from packaging import version
# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698
get_int_dtype = core.get_int_dtype
if version.parse(triton.__version__) >= version.parse("3.5.0"):
get_int_dtype = triton.constexpr_function(get_int_dtype)
@triton.jit
def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)]
y = tl.reshape(x, shape)
z = tl.reshape(indices, shape)
mask = tl.arange(0, 2)[None, :, None]
l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to(
x.dtype
)
r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to(
x.dtype
)
l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)
idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
il_value = l_value.to(idtype, bitcast=True)
ir_value = r_value.to(idtype, bitcast=True)
ix = x.to(idtype, bitcast=True)
flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix))
ret = ix ^ flag1
flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix))
ind = indices ^ flag2
return ret.to(x.dtype, bitcast=True), ind
@triton.jit
def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
tl.static_assert(stage <= n_dims)
"""
order_type 0 == ascending
order_type 1 == descending
order_type 2 == alternating
"""
if order == 2:
shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage]
flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
else:
flip = tl.full(x.shape, value=order, dtype=tl.int32)
for i in tl.static_range(stage):
x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims)
return x, indices
@triton.jit
def _argsort(x, indices, n_dims: tl.constexpr):
for i in tl.static_range(1, n_dims + 1):
x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims)
return x, indices
@triton.jit
def _row_id_map_pass_1_kernel(
# pointers
routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_routing_map_token,
stride_routing_map_expert,
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
expert_token_mask = tl.load(
routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
mask=(offset < num_tokens),
other=0,
).to(tl.int32)
row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
tl.store(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id_within_token_block,
mask=offset < num_tokens,
)
n_tokens_per_block = tl.sum(expert_token_mask)
tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)
@triton.jit
def _row_id_map_pass_2_kernel(
# pointers
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
WORKSPACE_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
row_id_within_token_block = tl.load(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
mask=(offset < num_tokens),
other=0,
)
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
row_id = tl.where(
row_id_within_token_block == 0,
-1,
row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
)
tl.store(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id,
mask=(offset < num_tokens),
)
@triton.jit
def _row_id_map_pass_3_kernel(
# pointers
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
LOAD_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_dims: tl.constexpr = _log2(LOAD_SIZE)
off = tl.arange(0, LOAD_SIZE)
row_id_map = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off,
mask=off < num_experts,
other=-1,
)
n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0))
indices = off
sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert,
sorted_map,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + off) * stride_row_id_map_expert,
indices,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert,
n_routed,
)
@triton.jit
def _permute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
stride_scale_token,
stride_scale_hidden,
stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# metas
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
if prob == 0.0:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl.store(output_ptr + output_off, 0.0, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
try:
_permute_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_permute_kernel)
except RuntimeError:
pass
@triton.jit
def _unpermute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_merging_probs_token,
stride_merging_probs_expert,
stride_permuted_probs_token,
stride_unpermuted_probs_token,
stride_unpermuted_probs_expert,
# metas
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
if PERMUTE_PROBS:
# write 0.0 to probs_grad that are not routed
if pid_h == 0:
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
unpermuted_prob_off = (
pid_t * stride_unpermuted_probs_token
+ stride_unpermuted_probs_expert * map_load_off
)
tl.store(
unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts
)
accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
inp *= merging_prob
accumulator += inp
if PERMUTE_PROBS:
if pid_h == 0:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
unpermuted_prob_off = (
pid_t * stride_unpermuted_probs_token
+ expert_idx * stride_unpermuted_probs_expert
)
permuted_prob_off = src_row * stride_permuted_probs_token
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
accumulator = accumulator.to(data_type)
dst_row = pid_t.to(tl.int64)
output_off = dst_row * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
try:
_unpermute_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_unpermute_kernel)
except RuntimeError:
pass
@triton.jit
def _unpermute_bwd_with_merging_probs_kernel(
# pointers
fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr,
merging_probs_ptr,
merging_probs_grad_ptr,
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_fwd_output_grad_token,
stride_fwd_output_grad_hidden,
stride_fwd_input_grad_token,
stride_fwd_input_grad_hidden,
stride_fwd_input_token,
stride_fwd_input_hidden,
stride_merging_probs_token,
stride_merging_probs_expert,
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# metas
PROBS_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
token_probs_grad_off = (
pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off
)
tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts)
n_routed = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
expert_idx = tl.load(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
src_row = pid.to(tl.int64)
input_off = (
src_row * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob
output = output.to(data_type)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
)
tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
fwd_input_off = (
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
probs_grad_off = (
pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
try:
_unpermute_bwd_with_merging_probs_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_unpermute_bwd_with_merging_probs_kernel)
except RuntimeError:
pass
@triton.jit
def _make_chunk_sort_map_kernel(
# pointers
split_sizes_ptr,
sorted_indices_ptr,
dst_rows_ptr,
# sizes
num_splits: tl.constexpr,
# metas
IDX_LOAD_WIDTH: tl.constexpr,
):
pid = tl.program_id(0)
load_split_offset = tl.arange(0, IDX_LOAD_WIDTH)
sorted_indices = tl.load(
sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits
)
# get chunk idx of the current token in the input tensor
input_split_sizes = tl.load(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
in_chunk_offset = pid - input_split_sizes_presum
# get chunk idx of the current token in the output tensor
output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0)
output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1)
# make row_id_map
output_split_sizes = tl.load(
split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
tl.store(dst_rows_ptr + pid, dst_row)
@triton.jit
def _sort_chunks_by_map_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size: tl.constexpr,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# metas
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr,
):
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
if FORWARD:
src_row = pid_t.to(tl.int64)
dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
else:
src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
dst_row = pid_t.to(tl.int64)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
if PERMUTE_PROBS:
if pid_h == 0:
prob_off = src_row * stride_probs_token
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
try:
_sort_chunks_by_map_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_sort_chunks_by_map_kernel)
except RuntimeError:
pass
......@@ -29,12 +29,14 @@ except ImportError:
import transformer_engine_torch as tex
from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv
from . import torch_version
from .utils import (
is_non_tn_fp8_gemm_supported,
safely_set_viewless_tensor_data,
needs_quantized_gemm,
)
from .constants import dist_group_type
from .quantization import FP8GlobalStateManager, autocast
from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer
......@@ -46,7 +48,6 @@ from .tensor.storage.float8_tensor_storage import Float8TensorStorage
from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from .triton.pad import pad_columnwise_scale_inv
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer
......
......@@ -2,4 +2,4 @@
#
# See LICENSE for license information.
"""Kernels written with OpenAI Triton."""
"""PyTorch wrappers for Triton kernels."""
......@@ -2,7 +2,7 @@
#
# See LICENSE for license information.
"""Efficient Cross Entropy kernels written with OpenAI Triton."""
"""PyTorch wrapper functions for Cross Entropy Triton kernels."""
from typing import Union
from functools import reduce
......@@ -12,257 +12,17 @@ import torch
import torch.distributed as dist
import triton
import triton.language as tl
@triton.jit
def online_softmax_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
rank,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes the m/d components on this TP rank for the online softmax.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
# locate the start index
X_ptr += program_id * X_stride
# Load Y_ptr
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32)
else:
X_y = float("-inf")
else:
X_y = float("-inf")
m_d_X_y_ptr += program_id * m_d_X_y_stride * 3
# 3. [Online softmax] first pass: find max + sum
m = float("-inf") # m is the max value. use the notation from the paper
d = 0.0 # d is the sum. use the notation from the paper
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to(
tl.float32
)
block_max = tl.max(X_block)
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new
tl.store(m_d_X_y_ptr, m)
tl.store(m_d_X_y_ptr + m_d_X_y_stride, d)
tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y)
@triton.jit
def cross_entropy_kernel(
X_ptr,
X_stride,
Y_ptr,
Y_stride,
loss_ptr,
loss_stride,
m_d_X_y_ptr,
m_d_X_y_stride,
rank,
world_size,
ignore_idx,
n_cols,
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
This kernel computes both cross entropy loss and the gradient of the input.
Parameters:
X_ptr: Pointer to input tensor.
X_stride (int): The stride of the input tensor.
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
loss_stride (int): The stride of the loss tensor.
m_d_X_y_ptr: Pointer to m/d/X_y tensor.
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
# locate the start index
X_ptr += program_id * X_stride
# Load Y_ptr
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
if y == ignore_idx:
# set all X_ptr as 0
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols)
return
loss_ptr += program_id * loss_stride
m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride
# Need to reduce the m/d/X_y values from other TP ranks
m = tl.load(m_d_X_y_ptr)
d = tl.load(m_d_X_y_ptr + m_d_X_y_stride)
ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride))
for i in range(1, world_size):
offset = i * 3 * n_non_ignore * m_d_X_y_stride
access_ptr = m_d_X_y_ptr + offset
m_new = tl.load(access_ptr)
d_new = tl.load(access_ptr + m_d_X_y_stride)
X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride))
d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new))
m = tl.maximum(m, m_new)
ori_X_y = tl.maximum(ori_X_y, X_y_new)
# Label smoothing is a general case of normal cross entropy
scaled_x_sum = 0.0
eps = label_smoothing / (n_cols * world_size)
# 4. [Online softmax] second pass: calculate the gradients
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf"))
grad_dtype = X_block.dtype
X_block = X_block.to(tl.float32)
if label_smoothing > 0:
# scale X beforehand to avoid overflow
scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0))
# Scale gradients based on reduction mode
# For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore
# For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here
if reduce_loss:
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
else:
X_block = tl.exp(X_block - m) / d - eps
tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols)
# We need tl.debug_barrier() to ensure the new result of X_ptr is written
tl.debug_barrier()
# 5. Calculate the loss
# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
loss = -(ori_X_y - m - tl.log(d))
# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
# = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i))
# By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as:
# = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd))
# Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
loss = loss * (1 - label_smoothing) + smooth_loss
# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
if y < vocab_end_idx:
X_y = tl.load(X_ptr + y - vocab_start_idx)
# Apply the same conditional scaling logic for the target token
if reduce_loss:
X_y += -(1 - label_smoothing) / (n_non_ignore)
else:
X_y += -(1 - label_smoothing)
tl.store(X_ptr + y - vocab_start_idx, X_y)
tl.store(loss_ptr, loss)
from transformer_engine.common.triton.cross_entropy import (
online_softmax_kernel,
cross_entropy_kernel,
element_mul_kernel,
)
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
MAX_FUSED_SIZE = 65536 // 2
@triton.jit
def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
grad_output_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
"""
This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr.
The multiplication is performed in-place on the tensor pointed by X_ptr.
Parameters:
X_ptr: Pointer to the input tensor.
X_stride (int): The stride of the input tensor.
grad_output_ptr: Pointer to the gradient output value.
n_cols (int): The number of columns in the input tensor.
BLOCK_SIZE (int): The block size for Triton operations.
"""
# Get the program ID and convert it to int64 to avoid overflow
program_id = tl.program_id(0).to(tl.int64)
# Locate the start index
X_ptr += program_id * X_stride
# Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols)
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)
def cross_entropy_forward(
_input: torch.Tensor,
target: torch.Tensor,
......
......@@ -2,63 +2,12 @@
#
# See LICENSE for license information.
"""NVFP4 padding kernels
TODO(ksivamani): Documentation
"""
"""PyTorch wrapper functions for padding Triton kernels."""
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1),
],
key=["out_dim0", "out_dim1"],
)
@triton.jit
def zero_pad_kernel(
inp_ptr,
out_ptr,
in_dim0: tl.constexpr,
in_dim1: tl.constexpr,
out_dim0: tl.constexpr,
out_dim1: tl.constexpr,
in_s0,
in_s1,
out_s0,
out_s1,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""Pads a tensor assuming it's a columnwise scaling inverse."""
# tile over OUTPUT coordinates
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols
om = offs_m[:, None]
on = offs_n[None, :]
# edge masking for output
out_mask = (om < out_dim0) & (on < out_dim1)
# valid input region is simply top-left (no offsets)
in_mask = (om < in_dim0) & (on < in_dim1)
# load valid input, else zero (masked load touches memory only where True)
x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0)
# store to output (only within bounds of the output tile)
tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask)
from transformer_engine.common.triton.pad import zero_pad_kernel
def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor:
......
......@@ -2,197 +2,23 @@
#
# See LICENSE for license information.
"""Permutation kernels written with OpenAI Triton."""
"""PyTorch wrapper functions for Permutation Triton kernels."""
from typing import Union
import torch
import triton
import triton.language as tl
from triton.language import core
from triton.language.standard import _log2
from packaging import version
# The following three argsort related kernels are adapted from
# the issue https://github.com/triton-lang/triton/issues/3698
get_int_dtype = core.get_int_dtype
if version.parse(triton.__version__) >= version.parse("3.5.0"):
get_int_dtype = triton.constexpr_function(get_int_dtype)
@triton.jit
def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)]
y = tl.reshape(x, shape)
z = tl.reshape(indices, shape)
mask = tl.arange(0, 2)[None, :, None]
l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to(
x.dtype
)
r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to(
x.dtype
)
l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape)
r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape)
idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
il_value = l_value.to(idtype, bitcast=True)
ir_value = r_value.to(idtype, bitcast=True)
ix = x.to(idtype, bitcast=True)
flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix))
ret = ix ^ flag1
flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix))
ind = indices ^ flag2
return ret.to(x.dtype, bitcast=True), ind
@triton.jit
def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr):
n_outer: tl.constexpr = x.numel >> n_dims
tl.static_assert(stage <= n_dims)
"""
order_type 0 == ascending
order_type 1 == descending
order_type 2 == alternating
"""
if order == 2:
shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage]
flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
else:
flip = tl.full(x.shape, value=order, dtype=tl.int32)
for i in tl.static_range(stage):
x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims)
return x, indices
@triton.jit
def _argsort(x, indices, n_dims: tl.constexpr):
for i in tl.static_range(1, n_dims + 1):
x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims)
return x, indices
@triton.jit
def _row_id_map_pass_1_kernel(
# pointers
routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_routing_map_token,
stride_routing_map_expert,
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
expert_token_mask = tl.load(
routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token,
mask=(offset < num_tokens),
other=0,
).to(tl.int32)
row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask
tl.store(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id_within_token_block,
mask=offset < num_tokens,
)
n_tokens_per_block = tl.sum(expert_token_mask)
tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block)
@triton.jit
def _row_id_map_pass_2_kernel(
# pointers
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
WORKSPACE_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n
offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
row_id_within_token_block = tl.load(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
mask=(offset < num_tokens),
other=0,
)
workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH)
n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx)
row_id = tl.where(
row_id_within_token_block == 0,
-1,
row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1,
)
tl.store(
row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token,
row_id,
mask=(offset < num_tokens),
)
@triton.jit
def _row_id_map_pass_3_kernel(
# pointers
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
LOAD_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_dims: tl.constexpr = _log2(LOAD_SIZE)
off = tl.arange(0, LOAD_SIZE)
row_id_map = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off,
mask=off < num_experts,
other=-1,
)
n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0))
indices = off
sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert,
sorted_map,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + off) * stride_row_id_map_expert,
indices,
mask=off < n_routed,
)
tl.store(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert,
n_routed,
)
from transformer_engine.common.triton.permutation import (
_row_id_map_pass_1_kernel,
_row_id_map_pass_2_kernel,
_row_id_map_pass_3_kernel,
_permute_kernel,
_unpermute_kernel,
_unpermute_bwd_with_merging_probs_kernel,
_make_chunk_sort_map_kernel,
_sort_chunks_by_map_kernel,
)
def make_row_id_map(
......@@ -292,103 +118,6 @@ def make_row_id_map(
return row_id_map
@triton.jit
def _permute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_probs_expert,
stride_scale_token,
stride_scale_hidden,
stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# metas
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
if PERMUTE_SCALE:
mask_scale = cur_off < scale_hidden_dim
scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden
scale = tl.load(scale_ptr + scale_off, mask=mask_scale)
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
if prob == 0.0:
# for routing_map padding
# dst_row != -1 and prob == 0.0 means that this slot is padded
tl.store(output_ptr + output_off, 0.0, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
else:
tl.store(output_ptr + output_off, inp, mask=mask)
try:
_permute_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_permute_kernel)
except RuntimeError:
pass
def permute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -468,116 +197,6 @@ def permute_with_mask_map(
return output, permuted_scale, permuted_probs
@triton.jit
def _unpermute_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_merging_probs_token,
stride_merging_probs_expert,
stride_permuted_probs_token,
stride_unpermuted_probs_token,
stride_unpermuted_probs_expert,
# metas
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
if PERMUTE_PROBS:
# write 0.0 to probs_grad that are not routed
if pid_h == 0:
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
unpermuted_prob_off = (
pid_t * stride_unpermuted_probs_token
+ stride_unpermuted_probs_expert * map_load_off
)
tl.store(
unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts
)
accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
n_routed = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
inp *= merging_prob
accumulator += inp
if PERMUTE_PROBS:
if pid_h == 0:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
unpermuted_prob_off = (
pid_t * stride_unpermuted_probs_token
+ expert_idx * stride_unpermuted_probs_expert
)
permuted_prob_off = src_row * stride_permuted_probs_token
prob = tl.load(permuted_probs_ptr + permuted_prob_off)
tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob)
accumulator = accumulator.to(data_type)
dst_row = pid_t.to(tl.int64)
output_off = dst_row * stride_output_token + current_offset * stride_output_hidden
tl.store(output_ptr + output_off, accumulator, mask=mask)
try:
_unpermute_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_unpermute_kernel)
except RuntimeError:
pass
def unpermute_with_mask_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -644,110 +263,6 @@ def unpermute_with_mask_map(
return output, unpermuted_probs
@triton.jit
def _unpermute_bwd_with_merging_probs_kernel(
# pointers
fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr,
merging_probs_ptr,
merging_probs_grad_ptr,
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
stride_fwd_output_grad_token,
stride_fwd_output_grad_hidden,
stride_fwd_input_grad_token,
stride_fwd_input_grad_hidden,
stride_fwd_input_token,
stride_fwd_input_hidden,
stride_merging_probs_token,
stride_merging_probs_expert,
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# metas
PROBS_LOAD_WIDTH: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
compute_type = tl.float32
pid = tl.program_id(0)
map_load_off = tl.arange(0, PROBS_LOAD_WIDTH)
token_probs_grad_off = (
pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off
)
tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts)
n_routed = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert
)
for idx in tl.range(n_routed):
dst_row = tl.load(
row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
expert_idx = tl.load(
row_id_map_ptr
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
current_offset = current_start + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
src_row = pid.to(tl.int64)
input_off = (
src_row * stride_fwd_output_grad_token
+ current_offset * stride_fwd_output_grad_hidden
)
inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
merging_prob_off = (
pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type)
output = inp * merging_prob
output = output.to(data_type)
output_off = (
dst_row * stride_fwd_input_grad_token
+ current_offset * stride_fwd_input_grad_hidden
)
tl.store(fwd_input_grad_ptr + output_off, output, mask=mask)
fwd_input_off = (
dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden
)
fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask)
prob_grad_accum += fwd_input.to(compute_type) * inp
current_start += BLOCK_SIZE
probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty)
probs_grad_off = (
pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert
)
tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad)
try:
_unpermute_bwd_with_merging_probs_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_unpermute_bwd_with_merging_probs_kernel)
except RuntimeError:
pass
def unpermute_with_mask_map_bwd_with_merging_probs(
fwd_output_grad: torch.Tensor,
row_id_map: torch.Tensor,
......@@ -813,47 +328,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs(
return act_grad, merging_probs_grad
@triton.jit
def _make_chunk_sort_map_kernel(
# pointers
split_sizes_ptr,
sorted_indices_ptr,
dst_rows_ptr,
# sizes
num_splits: tl.constexpr,
# metas
IDX_LOAD_WIDTH: tl.constexpr,
):
pid = tl.program_id(0)
load_split_offset = tl.arange(0, IDX_LOAD_WIDTH)
sorted_indices = tl.load(
sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits
)
# get chunk idx of the current token in the input tensor
input_split_sizes = tl.load(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes)
input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
in_chunk_offset = pid - input_split_sizes_presum
# get chunk idx of the current token in the output tensor
output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0)
output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1)
# make row_id_map
output_split_sizes = tl.load(
split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits
).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset
tl.store(dst_rows_ptr + pid, dst_row)
def make_chunk_sort_map(
split_sizes: torch.Tensor,
sorted_indices: torch.Tensor,
......@@ -886,67 +360,6 @@ def make_chunk_sort_map(
return row_id_map
@triton.jit
def _sort_chunks_by_map_kernel(
# pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size: tl.constexpr,
# strides
stride_input_token,
stride_input_hidden,
stride_output_token,
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# metas
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr,
):
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
if FORWARD:
src_row = pid_t.to(tl.int64)
dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
else:
src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64)
dst_row = pid_t.to(tl.int64)
current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = current_offset < hidden_size
input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden
output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden
inp = tl.load(input_ptr + input_offsets, mask=mask)
tl.store(output_ptr + output_offsets, inp, mask=mask)
if PERMUTE_PROBS:
if pid_h == 0:
prob_off = src_row * stride_probs_token
prob = tl.load(probs_ptr + prob_off)
permuted_prob_off = dst_row * stride_permuted_probs_token
tl.store(permuted_probs_ptr + permuted_prob_off, prob)
try:
_sort_chunks_by_map_kernel = triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 64}),
triton.Config({"BLOCK_SIZE": 128}),
triton.Config({"BLOCK_SIZE": 256}),
triton.Config({"BLOCK_SIZE": 512}),
triton.Config({"BLOCK_SIZE": 1024}),
triton.Config({"BLOCK_SIZE": 2048}),
triton.Config({"BLOCK_SIZE": 4096}),
],
key=["hidden_size"],
)(_sort_chunks_by_map_kernel)
except RuntimeError:
pass
def sort_chunks_by_map(
inp: torch.Tensor,
row_id_map: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment