Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor.utils import replace_raw_data
from utils import ModelConfig
......@@ -539,6 +538,7 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_sanity_layernorm_mlp(
dtype,
fp8_recipe,
......@@ -549,6 +549,7 @@ def test_sanity_layernorm_mlp(
activation,
normalization,
microbatching,
checkpoint,
):
config = model_configs[model]
......@@ -579,6 +580,7 @@ def test_sanity_layernorm_mlp(
normalization=normalization,
params_dtype=dtype,
device="cuda",
checkpoint=checkpoint,
)
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
......@@ -961,7 +963,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_ = general_gemm(A=weight, B=inp, workspace=get_workspace())
_ = general_gemm(A=weight, B=inp)
torch.cuda.synchronize()
......@@ -985,7 +987,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
bias=None,
use_split_accumulator=False,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -8,6 +8,7 @@ import logging
import os
from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List
from packaging.version import Version as PkgVersion
import torch
......@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len: int = None,
num_layers: int = 1,
eps: float = 1e-5,
num_splits=1,
):
self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q
......@@ -239,6 +241,7 @@ class ModelConfig:
self.max_ctx_len = max_ctx_len
self.num_layers = num_layers
self.eps = eps
self.num_splits = num_splits
@contextmanager
......@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params=inference_params,
softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits=1,
)
(
use_flash_attention,
......@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention,
available_backends,
) = get_attention_backend(attention_params)
# Check if FA3 is an available backend when num_splits != 1
if available_backends[0]:
if config.num_splits != 1 and not flash_attention_backend > PkgVersion("3.0.0b"):
available_backends[0] = False
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -183,7 +183,6 @@ if(USE_CUDA)
list(APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
......@@ -225,15 +224,20 @@ if(USE_CUDA)
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
......@@ -281,13 +285,42 @@ if(USE_CUDA)
endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
# CUTLASS kernels require SM90a and cause hang in debug build
# Grouped GEMM kernels require SM90a
set_property(
SOURCE gemm/cutlass_grouped_gemm.cu
APPEND
PROPERTY
COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0")
COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a")
# CUTLASS kernels could cause hang in debug build
set(CUTLASS_KERNEL_SOURCES
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
set_property(
SOURCE ${CUTLASS_KERNEL_SOURCES}
APPEND
PROPERTY
COMPILE_OPTIONS "-g0;-dopt=on")
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})
else()
list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp
......@@ -308,7 +341,6 @@ else()
list(APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
......@@ -348,10 +380,12 @@ else()
comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
cast/cast.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)
......@@ -398,27 +432,9 @@ else()
message(STATUS "nvte hipified sources: ${te_hip_sources}")
add_library(transformer_engine SHARED ${te_hip_sources})
endif()
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
if (USE_CUDA)
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})
else()
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
# Aotriton is currently unsupported
set(AotritonAndCk_fused_attn "unsupported")
......@@ -441,7 +457,6 @@ else()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
endif()
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
......@@ -236,31 +236,6 @@ def _get_sys_extension() -> str:
raise RuntimeError(f"Unsupported operating system ({system})")
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]",
)
)
path_found = len(so_paths) > 0
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir() -> str:
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
......@@ -280,102 +255,102 @@ def _nvidia_cudart_include_dir() -> str:
@functools.lru_cache(maxsize=None)
def _load_cudnn():
"""Load CUDNN shared library."""
def _load_cuda_library_from_python(lib_name: str, strict: bool = False):
"""
Attempts to load shared object file installed via python packages.
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH")
if cudnn_home:
libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
`lib_name` : Name of package as found in the `nvidia` dir in python environment.
`strict` : If set to `True`, throw an error if lib is not found.
"""
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
ext = _get_sys_extension()
nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia")
# Attempt to locate cuDNN in Python dist-packages
found, handle = _load_nvidia_cuda_library("cudnn")
if found:
return handle
# PyPI packages provided by nvidia libs exist
# in 4 possible locations inside `nvidia`.
# Check by order of priority.
path_found = False
if os.path.isdir(os.path.join(nvidia_dir, "cu13", lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", lib_name, f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not path_found and os.path.isdir(os.path.join(nvidia_dir, "cu13")):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib{lib_name}*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not path_found and os.path.isdir(os.path.join(nvidia_dir, lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not IS_HIP_EXTENSION:
# Attempt to locate libcudnn via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcudnn" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
if not path_found:
so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
if strict and not path_found:
raise RuntimeError(f"{lib_name} shared object not found.")
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None)
def _load_nvrtc():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
def _load_cuda_library_from_system(lib_name: str):
"""
Attempts to load shared object file installed via system/cuda-toolkit.
`lib_name`: Name of library to load without extension or `lib` prefix.
"""
# Attempt to locate NVRTC via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libnvrtc" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# Where to look for the shared lib in decreasing order of preference.
paths = (
os.environ.get(f"{lib_name.upper()}_HOME"),
os.environ.get(f"{lib_name.upper()}_PATH"),
os.environ.get("CUDA_HOME"),
os.environ.get("CUDA_PATH"),
"/usr/local/cuda",
)
for path in paths:
if path is None:
continue
libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True)
libs = [lib for lib in libs if "stub" not in lib]
libs.sort(reverse=True, key=os.path.basename)
if libs:
return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
# Search in LD_LIBRARY_PATH.
try:
_lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return True, _lib_handle
except OSError:
return False, None
@functools.lru_cache(maxsize=None)
def _load_curand():
"""Load cuRAND shared library."""
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True)
libs = list(filter(lambda x: not ("stub" in x), libs))
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuRAND in Python dist-packages
found, handle = _load_nvidia_cuda_library("curand")
def _load_cuda_library(lib_name: str):
"""
Load given shared library.
Prioritize loading from system/toolkit
before checking python packages.
"""
# Attempt to locate library in system.
found, handle = _load_cuda_library_from_system(lib_name)
if found:
return handle
return True, handle
# Attempt to locate cuRAND via ldconfig
libs = subprocess.check_output(["ldconfig", "-p"])
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcurand" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate library in Python dist-packages.
found, handle = _load_cuda_library_from_python(lib_name)
if found:
return False, handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
raise RuntimeError(f"{lib_name} shared object not found.")
@functools.lru_cache(maxsize=None)
......@@ -387,11 +362,22 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try:
sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc()
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
# `_load_cuda_library` is used for packages that must be loaded
# during runtime. Both system and pypi packages are searched
# and an error is thrown if not found.
_, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn")
system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc")
system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand")
# This additional step is necessary to be able to install TE wheels
# and import TE (without any guards) in an environment where the cuda
# toolkit might be absent without being guarded
load_libs_for_no_ctk = not system_nvrtc and not system_curand
if load_libs_for_no_ctk:
_CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True)
_CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True)
_CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True)
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -102,3 +102,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
}
}
// Group quantize assumes contiguous inputs and outputs in memory allocation
// TODO (zhongbo): find a better way to make it a more generalized API
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, outputs, split_sections,
num_tensors, quant_config, stream);
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -27,9 +27,9 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t
switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
NVTE_CHECK(is_fp8_dtype(input.data.dtype) || is_int8_dtype(input.data.dtype), "Input must have FP8 or INT8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype) && !is_int8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
NVTE_CHECK(is_fp8_dtype(input.dtype()) || is_int8_dtype(input.dtype()), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->dtype()) && !is_int8_dtype(output->dtype()), "Output must be in higher precision.");
NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match.");
fp8::dequantize(input, output, stream);
break;
}
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -14,6 +14,7 @@
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../transpose/transpose.h"
#include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh"
......@@ -53,6 +54,20 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp
} else {
fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream);
}
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
......@@ -98,8 +113,8 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2;
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision.");
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match.");
NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision.");
NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows,
"Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [",
......@@ -116,9 +131,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
NVTE_CHECK(output->flat_last_dim() == cols * 2,
"Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(gated_input.data.shape == output->data.shape,
"Gated input and output shapes must match. Input shape: ", gated_input.data.shape,
", output shape: ", output->data.shape, ".");
NVTE_CHECK(gated_input.shape() == output->shape(),
"Gated input and output shapes must match. Input shape: ", gated_input.shape(),
", output shape: ", output->shape(), ".");
switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: {
......@@ -129,6 +144,20 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
} else {
fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream);
}
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
break;
}
case NVTE_MXFP8_1D_SCALING: {
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -19,6 +19,7 @@
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
......@@ -154,17 +155,10 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
......@@ -307,17 +301,10 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
}
if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format ==
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
......@@ -330,6 +317,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
}
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor *> output_tensors;
for (size_t i = 0; i < num_tensors; ++i) {
output_tensors.push_back(convertNVTETensorCheck(outputs[i]));
}
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Take the scaling mode of the first output tensor
auto scaling_mode = output_tensors[0]->scaling_mode;
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
// Skip checking output tensor list
// output list here is allowed to have empty tensor
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
"2D quantization is not supported for group quantize.");
// Launch NVFP4 group quantize kernel
nvfp4::group_quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensors, split_sections, num_tensors,
&quant_config_cpp, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -234,8 +234,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream);
const auto &input_shape = input.data.shape;
NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions.");
NVTE_CHECK(input.dim() >= 2, "Input must have at least 2 dimensions.");
if (use_rowwise_scaling) {
NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data.");
......@@ -247,8 +246,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type.");
}
NVTE_CHECK(!input.with_gemm_swizzled_scales, "Input must have scales in compact format.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match.");
NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match.");
// TODO: Make more general
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
......@@ -22,6 +22,7 @@
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "swizzle.cuh"
namespace transformer_engine {
namespace dispatch {
......@@ -54,7 +55,8 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128
#ifndef __HIP_PLATFORM_AMD__
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK>
bool ROWWISE_SCALING, bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES,
size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act,
......@@ -71,6 +73,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
......@@ -358,14 +362,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
size_t scale_idx;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx] = biased_exponent_act;
}
......@@ -377,8 +384,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
size_t scale_idx_gate;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx_gate = gemm_swizzled_scale_idx(
global_scales_offset_X + gate_scale_idx_offset_colwise, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
}
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx_gate] = biased_exponent_gate;
}
......@@ -560,7 +573,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
size_t scale_idx;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(output_cols, static_cast<size_t>(128)));
} else {
scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
}
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) {
......@@ -576,7 +596,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_BWD) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
size_t scale_idx_gate;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
scale_idx_gate = gemm_swizzled_scale_idx(
stage_scales_offset_Y, stage_scales_offset_X + gate_scale_idx_offset_rowwise,
DIVUP(output_cols, static_cast<size_t>(128)));
} else {
scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
}
if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
......@@ -670,7 +699,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
parity ^= 1;
destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // NOLINT(readability/fn_size)
#endif // __HIP_PLATFORM_AMD__
} // namespace gated_kernel
......@@ -686,6 +715,7 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data();
const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;
if (USE_ROWWISE_SCALING) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
......@@ -729,113 +759,140 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act_rowwise{};
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{};
alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act_rowwise{};
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{};
alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
if constexpr (IS_BWD) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size);
}
if constexpr (IS_BWD) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size);
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, 0, input_type_bit_size);
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, cols, input_type_bit_size);
if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size);
}
const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, 0, input_type_bit_size);
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, cols, input_type_bit_size);
if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size);
}
if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows,
cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size);
}
if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows,
cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows,
cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size);
}
const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X;
const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0);
size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
switch (scaling_type) {
case ScalingType::ROWWISE: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, true,
false, THREADS_PER_CHUNK_NON_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr,
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p);
break;
}
case ScalingType::COLWISE: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, false,
true, THREADS_PER_CHUNK_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr,
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p);
break;
}
case ScalingType::BIDIMENSIONAL: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, true,
true, THREADS_PER_CHUNK_NON_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr,
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p);
break;
}
} NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X;
const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0);
size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
// Zero out swizzled scales if padding is needed
/// TODO (tmoon) Handle this within the cast kernel
if (with_gemm_swizzled_scales) {
constexpr size_t TILE_DIM_X = 128; // Tile dim in data buffer
constexpr size_t TILE_DIM_Y = 128;
if (cols % TILE_DIM_X != 0 || rows % TILE_DIM_Y != 0) {
if (USE_ROWWISE_SCALING) {
NVTE_CHECK_CUDA(cudaMemsetAsync(output->scale_inv.dptr, 0,
output->scale_inv.buffer_size_bytes(), stream));
}
if (USE_COLWISE_SCALING) {
NVTE_CHECK_CUDA(
cudaMemsetAsync(output->columnwise_scale_inv.dptr, 0,
output->columnwise_scale_inv.buffer_size_bytes(), stream));
}
}
}
switch (scaling_type) {
case ScalingType::ROWWISE: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
true, false, WITH_GEMM_SWIZZLED_SCALES,
THREADS_PER_CHUNK_NON_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
break;
}
case ScalingType::COLWISE: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
false, true, WITH_GEMM_SWIZZLED_SCALES,
THREADS_PER_CHUNK_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
break;
}
case ScalingType::BIDIMENSIONAL: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
true, true, WITH_GEMM_SWIZZLED_SCALES,
THREADS_PER_CHUNK_NON_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
break;
}
} NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment