Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -129,7 +129,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() { ...@@ -129,7 +129,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() {
template <typename KernelParamsType> template <typename KernelParamsType>
std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const { std::vector<size_t> TeNormalizationPlan<KernelParamsType>::getWorkspaceShape() const {
return {_launch_params.getTotalWorkspaceBytes(_is_layernorm)}; size_t workspace_size = _launch_params.getTotalWorkspaceBytes(_is_layernorm);
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
} }
template <typename KernelParamsType> template <typename KernelParamsType>
...@@ -418,9 +424,15 @@ void CudnnNormalizationPlan::_build() { ...@@ -418,9 +424,15 @@ void CudnnNormalizationPlan::_build() {
std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const { std::vector<size_t> CudnnNormalizationPlan::getWorkspaceShape() const {
#ifdef USE_ROCM #ifdef USE_ROCM
assert(false); assert(false);
return {0}; return {1};
#else #else
return {static_cast<size_t>(_graph.get_workspace_size())}; size_t workspace_size = _graph.get_workspace_size();
if (workspace_size == 0) {
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size = 1;
}
return {workspace_size};
#endif #endif
} }
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -27,10 +27,15 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -27,10 +27,15 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace, const float epsilon, Tensor* z, Tensor* mu, Tensor* rsigma, Tensor* workspace,
const int multiprocessorCount, const bool zero_centered_gamma, const int multiprocessorCount, const bool zero_centered_gamma,
cudaStream_t stream) { cudaStream_t stream) {
// Check for unsupported configurations
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp8_scaling(z->scaling_mode)) { !is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
} }
if (is_mxfp8_scaling(z->scaling_mode)) {
NVTE_CHECK(!z->with_gemm_swizzled_scales,
"MXFP8 output must have scales in compact format, not swizzled for GEMM.");
}
NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape."); NVTE_CHECK(gamma.data.shape == beta.data.shape, "Gamma and Beta must have the same shape.");
...@@ -51,7 +56,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -51,7 +56,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
"RSigma must be 1D tensor with shape (x.shape[0],)."); "RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma"); CheckInputTensor(gamma, "gamma");
CheckInputTensor(beta, "beta"); CheckInputTensor(beta, "beta");
...@@ -101,7 +106,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size ...@@ -101,7 +106,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
...@@ -153,7 +158,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te ...@@ -153,7 +158,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_CHECK(dbeta->data.shape == gamma.data.shape); NVTE_CHECK(dbeta->data.shape == gamma.data.shape);
NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype); NVTE_CHECK(dbeta->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz"); CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(mu, "mu"); CheckInputTensor(mu, "mu");
...@@ -186,7 +191,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te ...@@ -186,7 +191,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -23,10 +23,15 @@ using namespace normalization; ...@@ -23,10 +23,15 @@ using namespace normalization;
void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z, void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tensor *z,
Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, Tensor *rsigma, Tensor *workspace, const int multiprocessorCount,
const bool zero_centered_gamma, cudaStream_t stream) { const bool zero_centered_gamma, cudaStream_t stream) {
// Check for unsupported configurations
if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) &&
!is_mxfp8_scaling(z->scaling_mode)) { !is_mxfp8_scaling(z->scaling_mode)) {
NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + ".");
} }
if (is_mxfp8_scaling(z->scaling_mode)) {
NVTE_CHECK(!z->with_gemm_swizzled_scales,
"MXFP8 output must have scales in compact format, not swizzled for GEMM.");
}
NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor."); NVTE_CHECK(x.data.shape.size() == 2, "x must be 2D tensor.");
...@@ -39,7 +44,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -39,7 +44,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
"RSigma must be 1D tensor with shape (x.shape[0],)."); "RSigma must be 1D tensor with shape (x.shape[0],).");
NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor."); NVTE_CHECK(rsigma->data.dtype == DType::kFloat32, "RSigma must be a float32 tensor.");
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(gamma, "gamma"); CheckInputTensor(gamma, "gamma");
...@@ -86,7 +91,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens ...@@ -86,7 +91,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training, multiprocessorCount, zero_centered_gamma, is_aligned, z->scaling_mode, training,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
...@@ -132,7 +137,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -132,7 +137,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape); NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz"); CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(rsigma, "rsigma"); CheckInputTensor(rsigma, "rsigma");
...@@ -163,7 +168,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const ...@@ -163,7 +168,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
...@@ -198,7 +203,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const ...@@ -198,7 +203,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
NVTE_CHECK(dgamma->data.shape == gamma.data.shape); NVTE_CHECK(dgamma->data.shape == gamma.data.shape);
NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype);
if (!workspace->data.shape.empty()) { if (workspace->data.numel() != 0) {
CheckInputTensor(dz, "dz"); CheckInputTensor(dz, "dz");
CheckInputTensor(x, "x"); CheckInputTensor(x, "x");
CheckInputTensor(add, "add"); CheckInputTensor(add, "add");
...@@ -229,7 +234,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const ...@@ -229,7 +234,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true,
gamma_in_weight_dtype); gamma_in_weight_dtype);
if (workspace->data.shape.empty()) { if (workspace->data.numel() == 0) {
workspace->data.shape = plan->getWorkspaceShape(); workspace->data.shape = plan->getWorkspaceShape();
workspace->data.dtype = DType::kByte; workspace->data.dtype = DType::kByte;
return; return;
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
########################################################################## ##########################################################################
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
########################################################################## ##########################################################################
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment