Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -14,8 +14,10 @@
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
#include "./transpose.h"
namespace transformer_engine {
namespace detail {
namespace {
......@@ -203,7 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match.");
NVTE_CHECK(input.data.dtype == output.data.dtype, "Input (dtype=", to_string(input.data.dtype),
") and output (dtype=", to_string(output.data.dtype), ") do not match.");
if (noop.data.dptr != nullptr) {
NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), ".");
......@@ -289,19 +292,20 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
}); // NOLINT(*)
}
} // namespace detail
} // namespace transformer_engine
void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine;
auto noop = Tensor();
transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
detail::transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
}
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine;
transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
convertNVTETensor(output), stream);
detail::transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
convertNVTETensor(output), stream);
}
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
#include "../common.h"
namespace transformer_engine {
namespace detail {
void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream);
} // namespace detail
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -392,8 +392,6 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace->data.dtype);
const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -18,6 +18,8 @@ def online_softmax_kernel(
m_d_X_y_stride,
rank,
n_cols,
ignore_idx,
n_non_ignore,
BLOCK_SIZE: tl.constexpr,
):
"""
......@@ -32,6 +34,8 @@ def online_softmax_kernel(
m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor.
ignore_idx (int): The index to ignore for loss calculation.
n_non_ignore: The number of non-ignored elements in the batch.
BLOCK_SIZE (int): The block size for Triton operations.
"""
......@@ -44,6 +48,9 @@ def online_softmax_kernel(
Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr)
if y != ignore_idx:
tl.atomic_add(n_non_ignore, 1)
vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx:
......@@ -89,6 +96,7 @@ def cross_entropy_kernel(
world_size,
ignore_idx,
n_cols,
n_rows,
n_non_ignore,
reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr,
......@@ -110,12 +118,14 @@ def cross_entropy_kernel(
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing.
n_non_ignore: The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id = tl.program_id(0).to(tl.int64)
n_non_ignore = tl.load(n_non_ignore)
# locate the start index
X_ptr += program_id * X_stride
......@@ -140,7 +150,7 @@ def cross_entropy_kernel(
ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride))
for i in range(1, world_size):
offset = i * 3 * n_non_ignore * m_d_X_y_stride
offset = i * 3 * n_rows * m_d_X_y_stride
access_ptr = m_d_X_y_ptr + offset
m_new = tl.load(access_ptr)
d_new = tl.load(access_ptr + m_d_X_y_stride)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -81,10 +81,8 @@ def _argsort(x, indices, n_dims: tl.constexpr):
@triton.jit
def _row_id_map_pass_1_kernel(
# pointers
# input pointers
routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes
num_tokens,
# strides
......@@ -92,6 +90,9 @@ def _row_id_map_pass_1_kernel(
stride_routing_map_expert,
stride_row_id_map_token,
stride_row_id_map_expert,
# output pointers
row_id_map_ptr,
workspace_ptr,
# metas
BLOCK_SIZE: tl.constexpr,
):
......@@ -155,12 +156,11 @@ def _row_id_map_pass_2_kernel(
def _row_id_map_pass_3_kernel(
# pointers
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
# metas
num_experts: tl.constexpr,
LOAD_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
......@@ -194,18 +194,22 @@ def _row_id_map_pass_3_kernel(
@triton.jit
def _permute_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr,
pad_offsets_ptr,
# Pre-allocated output buffers for JAX input_output_aliases.
# These are aliased to output_ptr/permuted_probs_ptr in JAX, so they point to the same memory.
# In PyTorch, pass the same tensors as output_ptr/permuted_probs_ptr.
output_buf_ptr, # pylint: disable=unused-argument
permuted_probs_buf_ptr, # pylint: disable=unused-argument
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim,
num_tokens, # pylint: disable=unused-argument
num_out_tokens, # pylint: disable=unused-argument
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -220,15 +224,28 @@ def _permute_kernel(
stride_permuted_probs_token,
stride_permuted_scale_token,
stride_permuted_scale_hidden,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr,
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
# Note: When FUSION_PAD=True, output buffers should be pre-zeroed by the caller
# to ensure padding positions contain zeros.
# PyTorch: Use torch.zeros() for output buffer allocation
# JAX: Pre-zeroed buffers should be passed (when input_output_aliases works)
expert_idx = 0
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size
src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
......@@ -245,6 +262,15 @@ def _permute_kernel(
dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
if FUSION_PAD or PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_PAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE:
permuted_scale_off = (
......@@ -252,11 +278,6 @@ def _permute_kernel(
)
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off)
if pid_h == 0:
......@@ -291,16 +312,16 @@ except RuntimeError:
@triton.jit
def _unpermute_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
merging_probs_ptr,
permuted_probs_ptr,
unpermuted_probs_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
pad_offsets_ptr,
# Dummy parameters for JAX input_output_aliases compatibility (matches _permute_kernel signature pattern)
# These are unused in the unpermute kernel but maintain consistency with the permute kernel.
output_buf_ptr, # pylint: disable=unused-argument
unpermuted_probs_buf_ptr, # pylint: disable=unused-argument
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -313,14 +334,21 @@ def _unpermute_kernel(
stride_permuted_probs_token,
stride_unpermuted_probs_token,
stride_unpermuted_probs_expert,
# output pointers
output_ptr,
unpermuted_probs_ptr,
# metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = input_ptr.dtype.element_ty
compute_type = tl.float32
expert_idx = 0
pid_t = tl.program_id(0)
pid_h = tl.program_id(1)
......@@ -347,15 +375,19 @@ def _unpermute_kernel(
src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
if FUSION_UNPAD or WITH_MERGING_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
src_row = src_row + pad_off
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
)
......@@ -401,16 +433,12 @@ except RuntimeError:
@triton.jit
def _unpermute_bwd_with_merging_probs_kernel(
# pointers
# input pointers
fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr,
merging_probs_ptr,
merging_probs_grad_ptr,
row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
pad_offsets_ptr,
# strides
stride_row_id_map_token,
stride_row_id_map_expert,
......@@ -424,8 +452,14 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_expert,
stride_merging_probs_grad_token,
stride_merging_probs_grad_expert,
# output pointers
fwd_input_grad_ptr,
merging_probs_grad_ptr,
# metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
data_type = fwd_output_grad_ptr.dtype.element_ty
......@@ -449,6 +483,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0
while current_start < hidden_size:
......@@ -546,14 +583,10 @@ def _make_chunk_sort_map_kernel(
@triton.jit
def _sort_chunks_by_map_kernel(
# pointers
# input pointers
input_ptr,
output_ptr,
row_id_map_ptr,
probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size: tl.constexpr,
# strides
stride_input_token,
stride_input_hidden,
......@@ -561,7 +594,11 @@ def _sort_chunks_by_map_kernel(
stride_output_hidden,
stride_probs_token,
stride_permuted_probs_token,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr,
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment