Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -14,8 +14,10 @@ ...@@ -14,8 +14,10 @@
#include "../util/rtc.h" #include "../util/rtc.h"
#include "../util/string.h" #include "../util/string.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "./transpose.h"
namespace transformer_engine { namespace transformer_engine {
namespace detail {
namespace { namespace {
...@@ -203,7 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr ...@@ -203,7 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
NVTE_CHECK(input.data.dtype == output.data.dtype, "Input and output type must match."); NVTE_CHECK(input.data.dtype == output.data.dtype, "Input (dtype=", to_string(input.data.dtype),
") and output (dtype=", to_string(output.data.dtype), ") do not match.");
if (noop.data.dptr != nullptr) { if (noop.data.dptr != nullptr) {
NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), "."); NVTE_CHECK(noop.numel() == 1, "Expected 1 element, ", "but found ", noop.numel(), ".");
...@@ -289,19 +292,20 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr ...@@ -289,19 +292,20 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
}); // NOLINT(*) }); // NOLINT(*)
} }
} // namespace detail
} // namespace transformer_engine } // namespace transformer_engine
void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_transpose(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose); NVTE_API_CALL(nvte_transpose);
using namespace transformer_engine; using namespace transformer_engine;
auto noop = Tensor(); auto noop = Tensor();
transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream); detail::transpose(*convertNVTETensorCheck(input), noop, convertNVTETensor(output), stream);
} }
void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output, void nvte_transpose_with_noop(const NVTETensor input, const NVTETensor noop, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_transpose_with_noop); NVTE_API_CALL(nvte_transpose_with_noop);
using namespace transformer_engine; using namespace transformer_engine;
transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop), detail::transpose(*convertNVTETensorCheck(input), *convertNVTETensorCheck(noop),
convertNVTETensor(output), stream); convertNVTETensor(output), stream);
} }
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
#include "../common.h"
namespace transformer_engine {
namespace detail {
void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStream_t stream);
} // namespace detail
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -392,8 +392,6 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/ ...@@ -392,8 +392,6 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace->data.dtype); workspace->data.dtype);
const size_t required_size = const size_t required_size =
get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32); get_buffer_size_bytes(num_rows_partial_dbias, row_length, DType::kFloat32);
NVTE_CHECK(!workspace->data.shape.empty(), "Invalid workspace dims (expected (",
num_rows_partial_dbias, ",", row_length, "), found ())");
NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(", NVTE_CHECK(workspace_size >= required_size, "Invalid workspace (expected dims=(",
num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32), num_rows_partial_dbias, ",", row_length, "), dtype=", to_string(DType::kFloat32),
"; found dims=", workspace->data.shape, "; found dims=", workspace->data.shape,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -18,6 +18,8 @@ def online_softmax_kernel( ...@@ -18,6 +18,8 @@ def online_softmax_kernel(
m_d_X_y_stride, m_d_X_y_stride,
rank, rank,
n_cols, n_cols,
ignore_idx,
n_non_ignore,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
""" """
...@@ -32,6 +34,8 @@ def online_softmax_kernel( ...@@ -32,6 +34,8 @@ def online_softmax_kernel(
m_d_X_y_stride (int): The stride of the m/d/X_y tensor. m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group. rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor. n_cols (int): The number of columns in the input tensor.
ignore_idx (int): The index to ignore for loss calculation.
n_non_ignore: The number of non-ignored elements in the batch.
BLOCK_SIZE (int): The block size for Triton operations. BLOCK_SIZE (int): The block size for Triton operations.
""" """
...@@ -44,6 +48,9 @@ def online_softmax_kernel( ...@@ -44,6 +48,9 @@ def online_softmax_kernel(
Y_ptr += program_id * Y_stride Y_ptr += program_id * Y_stride
y = tl.load(Y_ptr) y = tl.load(Y_ptr)
if y != ignore_idx:
tl.atomic_add(n_non_ignore, 1)
vocab_start_idx = rank * n_cols vocab_start_idx = rank * n_cols
vocab_end_idx = (rank + 1) * n_cols vocab_end_idx = (rank + 1) * n_cols
if y >= vocab_start_idx: if y >= vocab_start_idx:
...@@ -89,6 +96,7 @@ def cross_entropy_kernel( ...@@ -89,6 +96,7 @@ def cross_entropy_kernel(
world_size, world_size,
ignore_idx, ignore_idx,
n_cols, n_cols,
n_rows,
n_non_ignore, n_non_ignore,
reduce_loss: tl.constexpr, reduce_loss: tl.constexpr,
label_smoothing: tl.constexpr, label_smoothing: tl.constexpr,
...@@ -110,12 +118,14 @@ def cross_entropy_kernel( ...@@ -110,12 +118,14 @@ def cross_entropy_kernel(
world_size (int): The size of world involved in this distributed loss calculation. world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation. ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor. n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch. n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing.
n_non_ignore: The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations. BLOCK_SIZE (int): The block size for Triton operations.
""" """
program_id = tl.program_id(0).to(tl.int64) program_id = tl.program_id(0).to(tl.int64)
n_non_ignore = tl.load(n_non_ignore)
# locate the start index # locate the start index
X_ptr += program_id * X_stride X_ptr += program_id * X_stride
...@@ -140,7 +150,7 @@ def cross_entropy_kernel( ...@@ -140,7 +150,7 @@ def cross_entropy_kernel(
ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride))
for i in range(1, world_size): for i in range(1, world_size):
offset = i * 3 * n_non_ignore * m_d_X_y_stride offset = i * 3 * n_rows * m_d_X_y_stride
access_ptr = m_d_X_y_ptr + offset access_ptr = m_d_X_y_ptr + offset
m_new = tl.load(access_ptr) m_new = tl.load(access_ptr)
d_new = tl.load(access_ptr + m_d_X_y_stride) d_new = tl.load(access_ptr + m_d_X_y_stride)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -81,10 +81,8 @@ def _argsort(x, indices, n_dims: tl.constexpr): ...@@ -81,10 +81,8 @@ def _argsort(x, indices, n_dims: tl.constexpr):
@triton.jit @triton.jit
def _row_id_map_pass_1_kernel( def _row_id_map_pass_1_kernel(
# pointers # input pointers
routing_map_ptr, routing_map_ptr,
row_id_map_ptr,
workspace_ptr,
# sizes # sizes
num_tokens, num_tokens,
# strides # strides
...@@ -92,6 +90,9 @@ def _row_id_map_pass_1_kernel( ...@@ -92,6 +90,9 @@ def _row_id_map_pass_1_kernel(
stride_routing_map_expert, stride_routing_map_expert,
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
# output pointers
row_id_map_ptr,
workspace_ptr,
# metas # metas
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
...@@ -155,12 +156,11 @@ def _row_id_map_pass_2_kernel( ...@@ -155,12 +156,11 @@ def _row_id_map_pass_2_kernel(
def _row_id_map_pass_3_kernel( def _row_id_map_pass_3_kernel(
# pointers # pointers
row_id_map_ptr, row_id_map_ptr,
# sizes
num_experts: tl.constexpr,
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
# metas # metas
num_experts: tl.constexpr,
LOAD_SIZE: tl.constexpr, LOAD_SIZE: tl.constexpr,
): ):
pid = tl.program_id(0) pid = tl.program_id(0)
...@@ -194,18 +194,22 @@ def _row_id_map_pass_3_kernel( ...@@ -194,18 +194,22 @@ def _row_id_map_pass_3_kernel(
@triton.jit @triton.jit
def _permute_kernel( def _permute_kernel(
# pointers # input pointers
input_ptr, input_ptr,
output_ptr,
row_id_map_ptr, row_id_map_ptr,
probs_ptr, probs_ptr,
scale_ptr, scale_ptr,
permuted_probs_ptr,
permuted_scale_ptr, permuted_scale_ptr,
pad_offsets_ptr,
# Pre-allocated output buffers for JAX input_output_aliases.
# These are aliased to output_ptr/permuted_probs_ptr in JAX, so they point to the same memory.
# In PyTorch, pass the same tensors as output_ptr/permuted_probs_ptr.
output_buf_ptr, # pylint: disable=unused-argument
permuted_probs_buf_ptr, # pylint: disable=unused-argument
# sizes # sizes
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
scale_hidden_dim, scale_hidden_dim,
num_tokens, # pylint: disable=unused-argument
num_out_tokens, # pylint: disable=unused-argument
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
...@@ -220,15 +224,28 @@ def _permute_kernel( ...@@ -220,15 +224,28 @@ def _permute_kernel(
stride_permuted_probs_token, stride_permuted_probs_token,
stride_permuted_scale_token, stride_permuted_scale_token,
stride_permuted_scale_hidden, stride_permuted_scale_hidden,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas # metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
PERMUTE_SCALE: tl.constexpr, PERMUTE_SCALE: tl.constexpr,
FUSION_PAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
# Note: When FUSION_PAD=True, output buffers should be pre-zeroed by the caller
# to ensure padding positions contain zeros.
# PyTorch: Use torch.zeros() for output buffer allocation
# JAX: Pre-zeroed buffers should be passed (when input_output_aliases works)
expert_idx = 0
pid_t = tl.program_id(0) pid_t = tl.program_id(0)
pid_h = tl.program_id(1) pid_h = tl.program_id(1)
cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = cur_off < hidden_size mask = cur_off < hidden_size
src_row = pid_t.to(tl.int64) src_row = pid_t.to(tl.int64)
input_off = src_row * stride_input_token + cur_off * stride_input_hidden input_off = src_row * stride_input_token + cur_off * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask) inp = tl.load(input_ptr + input_off, mask=mask)
...@@ -245,6 +262,15 @@ def _permute_kernel( ...@@ -245,6 +262,15 @@ def _permute_kernel(
dst_row = tl.load( dst_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64) ).to(tl.int64)
if FUSION_PAD or PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
if FUSION_PAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
output_off = dst_row * stride_output_token + cur_off * stride_output_hidden output_off = dst_row * stride_output_token + cur_off * stride_output_hidden
if PERMUTE_SCALE: if PERMUTE_SCALE:
permuted_scale_off = ( permuted_scale_off = (
...@@ -252,11 +278,6 @@ def _permute_kernel( ...@@ -252,11 +278,6 @@ def _permute_kernel(
) )
tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale)
if PERMUTE_PROBS: if PERMUTE_PROBS:
expert_idx = tl.load(
row_id_map_ptr
+ pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert
)
prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert
prob = tl.load(probs_ptr + prob_off) prob = tl.load(probs_ptr + prob_off)
if pid_h == 0: if pid_h == 0:
...@@ -291,16 +312,16 @@ except RuntimeError: ...@@ -291,16 +312,16 @@ except RuntimeError:
@triton.jit @triton.jit
def _unpermute_kernel( def _unpermute_kernel(
# pointers # input pointers
input_ptr, input_ptr,
output_ptr,
row_id_map_ptr, row_id_map_ptr,
merging_probs_ptr, merging_probs_ptr,
permuted_probs_ptr, permuted_probs_ptr,
unpermuted_probs_ptr, pad_offsets_ptr,
# sizes # Dummy parameters for JAX input_output_aliases compatibility (matches _permute_kernel signature pattern)
num_experts: tl.constexpr, # These are unused in the unpermute kernel but maintain consistency with the permute kernel.
hidden_size: tl.constexpr, output_buf_ptr, # pylint: disable=unused-argument
unpermuted_probs_buf_ptr, # pylint: disable=unused-argument
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
...@@ -313,14 +334,21 @@ def _unpermute_kernel( ...@@ -313,14 +334,21 @@ def _unpermute_kernel(
stride_permuted_probs_token, stride_permuted_probs_token,
stride_unpermuted_probs_token, stride_unpermuted_probs_token,
stride_unpermuted_probs_expert, stride_unpermuted_probs_expert,
# output pointers
output_ptr,
unpermuted_probs_ptr,
# metas # metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr,
WITH_MERGING_PROBS: tl.constexpr, WITH_MERGING_PROBS: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
data_type = input_ptr.dtype.element_ty data_type = input_ptr.dtype.element_ty
compute_type = tl.float32 compute_type = tl.float32
expert_idx = 0
pid_t = tl.program_id(0) pid_t = tl.program_id(0)
pid_h = tl.program_id(1) pid_h = tl.program_id(1)
...@@ -347,15 +375,19 @@ def _unpermute_kernel( ...@@ -347,15 +375,19 @@ def _unpermute_kernel(
src_row = tl.load( src_row = tl.load(
row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert
).to(tl.int64) ).to(tl.int64)
input_off = src_row * stride_input_token + current_offset * stride_input_hidden if FUSION_UNPAD or WITH_MERGING_PROBS:
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
expert_idx = tl.load( expert_idx = tl.load(
row_id_map_ptr row_id_map_ptr
+ pid_t * stride_row_id_map_token + pid_t * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert + (num_experts + idx) * stride_row_id_map_expert
) )
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
src_row = src_row + pad_off
input_off = src_row * stride_input_token + current_offset * stride_input_hidden
inp = tl.load(input_ptr + input_off, mask=mask)
inp = inp.to(compute_type)
if WITH_MERGING_PROBS:
merging_prob_off = ( merging_prob_off = (
pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert
) )
...@@ -401,16 +433,12 @@ except RuntimeError: ...@@ -401,16 +433,12 @@ except RuntimeError:
@triton.jit @triton.jit
def _unpermute_bwd_with_merging_probs_kernel( def _unpermute_bwd_with_merging_probs_kernel(
# pointers # input pointers
fwd_output_grad_ptr, fwd_output_grad_ptr,
fwd_input_grad_ptr,
fwd_input_ptr, fwd_input_ptr,
merging_probs_ptr, merging_probs_ptr,
merging_probs_grad_ptr,
row_id_map_ptr, row_id_map_ptr,
# sizes pad_offsets_ptr,
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
# strides # strides
stride_row_id_map_token, stride_row_id_map_token,
stride_row_id_map_expert, stride_row_id_map_expert,
...@@ -424,8 +452,14 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -424,8 +452,14 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_expert, stride_merging_probs_expert,
stride_merging_probs_grad_token, stride_merging_probs_grad_token,
stride_merging_probs_grad_expert, stride_merging_probs_grad_expert,
# output pointers
fwd_input_grad_ptr,
merging_probs_grad_ptr,
# metas # metas
num_experts: tl.constexpr,
hidden_size: tl.constexpr,
PROBS_LOAD_WIDTH: tl.constexpr, PROBS_LOAD_WIDTH: tl.constexpr,
FUSION_UNPAD: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
): ):
data_type = fwd_output_grad_ptr.dtype.element_ty data_type = fwd_output_grad_ptr.dtype.element_ty
...@@ -449,6 +483,9 @@ def _unpermute_bwd_with_merging_probs_kernel( ...@@ -449,6 +483,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
+ pid * stride_row_id_map_token + pid * stride_row_id_map_token
+ (num_experts + idx) * stride_row_id_map_expert + (num_experts + idx) * stride_row_id_map_expert
) )
if FUSION_UNPAD:
pad_off = tl.load(pad_offsets_ptr + expert_idx)
dst_row = dst_row + pad_off
prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type)
current_start = 0 current_start = 0
while current_start < hidden_size: while current_start < hidden_size:
...@@ -546,14 +583,10 @@ def _make_chunk_sort_map_kernel( ...@@ -546,14 +583,10 @@ def _make_chunk_sort_map_kernel(
@triton.jit @triton.jit
def _sort_chunks_by_map_kernel( def _sort_chunks_by_map_kernel(
# pointers # input pointers
input_ptr, input_ptr,
output_ptr,
row_id_map_ptr, row_id_map_ptr,
probs_ptr, probs_ptr,
permuted_probs_ptr,
# sizes
hidden_size: tl.constexpr,
# strides # strides
stride_input_token, stride_input_token,
stride_input_hidden, stride_input_hidden,
...@@ -561,7 +594,11 @@ def _sort_chunks_by_map_kernel( ...@@ -561,7 +594,11 @@ def _sort_chunks_by_map_kernel(
stride_output_hidden, stride_output_hidden,
stride_probs_token, stride_probs_token,
stride_permuted_probs_token, stride_permuted_probs_token,
# output pointers
output_ptr,
permuted_probs_ptr,
# metas # metas
hidden_size: tl.constexpr,
PERMUTE_PROBS: tl.constexpr, PERMUTE_PROBS: tl.constexpr,
BLOCK_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr,
FORWARD: tl.constexpr, FORWARD: tl.constexpr,
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment