Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -129,7 +129,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() {
template <typename KernelParamsType>
std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const {
return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)};
size_t workspace_size = _launch_params.getTotalWorkspaceBytes(_is_layernorm);
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
}
template <typename KernelParamsType>
......@@ -418,9 +424,15 @@ void CudnnNormalizationPlan::_build() {
std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
#ifdef USE_ROCM
assert(false);
return {0};
return {1};
#else
return {static_cast<size_t>(_graph.get_workspace_size())};
size_t workspace_size = _graph.get_workspace_size();
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
#endif
}
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -27,10 +27,15 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace,
const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) {
// Check for unsupported configurations
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
if (is_mxfp8_scaling(z->scaling_mode)) {
NVTE_CHECK(!z->with_gemm_swizzled_scales,
"MXFP8 output must have scales in compact format, not swizzled for GEMM.");
}
NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape.");
......@@ -51,7 +56,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
"RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta");
......@@ -101,7 +106,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......@@ -153,7 +158,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu");
......@@ -186,7 +191,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -23,10 +23,15 @@ using namespace normalization;
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) {
// Check for unsupported configurations
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
}
if (is_mxfp8_scaling(z->scaling_mode)) {
NVTE_CHECK(!z->with_gemm_swizzled_scales,
"MXFP8 output must have scales in compact format, not swizzled for GEMM.");
}
NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
......@@ -39,7 +44,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
"RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma");
......@@ -86,7 +91,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......@@ -132,7 +137,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma");
......@@ -163,7 +168,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......@@ -198,7 +203,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) {
if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x");
CheckInputTensor(add, "add");
......@@ -229,7 +234,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype);
if (workspace->data.shape.empty()) {
if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte;
return;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
##########################################################################
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
##########################################################################
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment