Commit 0d874a4e authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main' of v2.12

parents a68e5f87 dfdd3820
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -37,7 +37,6 @@ from transformer_engine.pytorch import ( ...@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from transformer_engine.common import recipe from transformer_engine.common import recipe
import transformer_engine_torch as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.cpp_extensions import general_gemm
from transformer_engine.pytorch.module.base import get_workspace
from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.tensor.utils import replace_raw_data
from utils import ModelConfig from utils import ModelConfig
...@@ -539,6 +538,7 @@ def test_sanity_grouped_linear( ...@@ -539,6 +538,7 @@ def test_sanity_grouped_linear(
@pytest.mark.parametrize("activation", all_activations) @pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("microbatching", all_boolean) @pytest.mark.parametrize("microbatching", all_boolean)
@pytest.mark.parametrize("checkpoint", all_boolean)
def test_sanity_layernorm_mlp( def test_sanity_layernorm_mlp(
dtype, dtype,
fp8_recipe, fp8_recipe,
...@@ -549,6 +549,7 @@ def test_sanity_layernorm_mlp( ...@@ -549,6 +549,7 @@ def test_sanity_layernorm_mlp(
activation, activation,
normalization, normalization,
microbatching, microbatching,
checkpoint,
): ):
config = model_configs[model] config = model_configs[model]
...@@ -579,6 +580,7 @@ def test_sanity_layernorm_mlp( ...@@ -579,6 +580,7 @@ def test_sanity_layernorm_mlp(
normalization=normalization, normalization=normalization,
params_dtype=dtype, params_dtype=dtype,
device="cuda", device="cuda",
checkpoint=checkpoint,
) )
_test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching) _test_sanity_common(block, dtype, config, fp8_recipe, skip_wgrad, skip_dgrad, microbatching)
...@@ -961,7 +963,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype): ...@@ -961,7 +963,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp = torch.reshape(scratchpad[offset:-offset], (N, N)) inp = torch.reshape(scratchpad[offset:-offset], (N, N))
weight = torch.reshape(scratchpad[offset * 2 :], (N, N)) weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
_ = general_gemm(A=weight, B=inp, workspace=get_workspace()) _ = general_gemm(A=weight, B=inp)
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -985,7 +987,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): ...@@ -985,7 +987,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm( general_gemm(
weight_fp8, weight_fp8,
inp_fp8, inp_fp8,
get_workspace(),
outp_type, outp_type,
bias=None, bias=None,
use_split_accumulator=False, use_split_accumulator=False,
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -8,6 +8,7 @@ import logging ...@@ -8,6 +8,7 @@ import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple, Dict, Any, List from typing import Optional, Tuple, Dict, Any, List
from packaging.version import Version as PkgVersion
import torch import torch
...@@ -210,6 +211,7 @@ class ModelConfig: ...@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len: int = None, max_ctx_len: int = None,
num_layers: int = 1, num_layers: int = 1,
eps: float = 1e-5, eps: float = 1e-5,
num_splits=1,
): ):
self.batch_size = batch_size self.batch_size = batch_size
self.max_seqlen_q = max_seqlen_q self.max_seqlen_q = max_seqlen_q
...@@ -239,6 +241,7 @@ class ModelConfig: ...@@ -239,6 +241,7 @@ class ModelConfig:
self.max_ctx_len = max_ctx_len self.max_ctx_len = max_ctx_len
self.num_layers = num_layers self.num_layers = num_layers
self.eps = eps self.eps = eps
self.num_splits = num_splits
@contextmanager @contextmanager
...@@ -321,6 +324,9 @@ def get_available_attention_backends( ...@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params=inference_params, inference_params=inference_params,
softmax_type=config.softmax_type, softmax_type=config.softmax_type,
return_max_logit=config.return_max_logit, return_max_logit=config.return_max_logit,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits=1,
) )
( (
use_flash_attention, use_flash_attention,
...@@ -330,6 +336,10 @@ def get_available_attention_backends( ...@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention, use_unfused_attention,
available_backends, available_backends,
) = get_attention_backend(attention_params) ) = get_attention_backend(attention_params)
# Check if FA3 is an available backend when num_splits != 1
if available_backends[0]:
if config.num_splits != 1 and not flash_attention_backend > PkgVersion("3.0.0b"):
available_backends[0] = False
# Set attention.py _attention_backends var using return value # Set attention.py _attention_backends var using return value
# from get_attention_backend() # from get_attention_backend()
_attention_backends["use_flash_attention"] = use_flash_attention _attention_backends["use_flash_attention"] = use_flash_attention
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -183,7 +183,6 @@ if(USE_CUDA) ...@@ -183,7 +183,6 @@ if(USE_CUDA)
list(APPEND transformer_engine_cuda_sources list(APPEND transformer_engine_cuda_sources
common.cu common.cu
multi_tensor/adam.cu multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu multi_tensor/l2norm.cu
multi_tensor/scale.cu multi_tensor/scale.cu
multi_tensor/sgd.cu multi_tensor/sgd.cu
...@@ -225,15 +224,20 @@ if(USE_CUDA) ...@@ -225,15 +224,20 @@ if(USE_CUDA)
comm_gemm_overlap/userbuffers/userbuffers.cu) comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources list(APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
cast/cast.cu
activation/gelu.cu activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu cast/cast.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu) hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu)
# Compiling the files with the worst compilation time first to hopefully overlap # Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files # better with the faster-compiling cpp files
...@@ -281,13 +285,42 @@ if(USE_CUDA) ...@@ -281,13 +285,42 @@ if(USE_CUDA)
endif() endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
# CUTLASS kernels require SM90a and cause hang in debug build # Grouped GEMM kernels require SM90a
set_property( set_property(
SOURCE gemm/cutlass_grouped_gemm.cu SOURCE gemm/cutlass_grouped_gemm.cu
APPEND APPEND
PROPERTY PROPERTY
COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0") COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a")
# CUTLASS kernels could cause hang in debug build
set(CUTLASS_KERNEL_SOURCES
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/hadamard_transform_cast_fusion.cu)
set_property(
SOURCE ${CUTLASS_KERNEL_SOURCES}
APPEND
PROPERTY
COMPILE_OPTIONS "-g0;-dopt=on")
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})
else() else()
list(APPEND transformer_engine_cpp_sources list(APPEND transformer_engine_cpp_sources
cudnn_utils.cpp cudnn_utils.cpp
...@@ -308,7 +341,6 @@ else() ...@@ -308,7 +341,6 @@ else()
list(APPEND transformer_engine_cuda_sources list(APPEND transformer_engine_cuda_sources
common.cu common.cu
multi_tensor/adam.cu multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu multi_tensor/l2norm.cu
multi_tensor/scale.cu multi_tensor/scale.cu
multi_tensor/sgd.cu multi_tensor/sgd.cu
...@@ -348,10 +380,12 @@ else() ...@@ -348,10 +380,12 @@ else()
comm_gemm_overlap/userbuffers/userbuffers.cu) comm_gemm_overlap/userbuffers/userbuffers.cu)
list(APPEND transformer_engine_cuda_arch_specific_sources list(APPEND transformer_engine_cuda_arch_specific_sources
cast/cast.cu
activation/gelu.cu activation/gelu.cu
activation/relu.cu activation/relu.cu
activation/swiglu.cu activation/swiglu.cu
cast/cast.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu) transpose/quantize_transpose_vector_blockwise_fp4.cu)
...@@ -398,27 +432,9 @@ else() ...@@ -398,27 +432,9 @@ else()
message(STATUS "nvte hipified sources: ${te_hip_sources}") message(STATUS "nvte hipified sources: ${te_hip_sources}")
add_library(transformer_engine SHARED ${te_hip_sources}) add_library(transformer_engine SHARED ${te_hip_sources})
endif() target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
if (USE_CUDA)
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine SYSTEM PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl)
target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}")
target_include_directories(transformer_engine PRIVATE
${CUTLASS_INCLUDE_DIR}
${CUTLASS_TOOLS_INCLUDE_DIR})
else()
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}") target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
# Aotriton is currently unsupported # Aotriton is currently unsupported
set(AotritonAndCk_fused_attn "unsupported") set(AotritonAndCk_fused_attn "unsupported")
...@@ -441,7 +457,6 @@ else() ...@@ -441,7 +457,6 @@ else()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
endif() endif()
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF)
if (NVTE_UB_WITH_MPI) if (NVTE_UB_WITH_MPI)
......
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
...@@ -236,31 +236,6 @@ def _get_sys_extension() -> str: ...@@ -236,31 +236,6 @@ def _get_sys_extension() -> str:
raise RuntimeError(f"Unsupported operating system ({system})") raise RuntimeError(f"Unsupported operating system ({system})")
@functools.lru_cache(maxsize=None)
def _load_nvidia_cuda_library(lib_name: str):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths = glob.glob(
os.path.join(
sysconfig.get_path("purelib"),
f"nvidia/{lib_name}/lib/lib*{_get_sys_extension()}.*[0-9]",
)
)
path_found = len(so_paths) > 0
ctypes_handles = []
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _nvidia_cudart_include_dir() -> str: def _nvidia_cudart_include_dir() -> str:
"""Returns the include directory for cuda_runtime.h if exists in python environment.""" """Returns the include directory for cuda_runtime.h if exists in python environment."""
...@@ -280,102 +255,102 @@ def _nvidia_cudart_include_dir() -> str: ...@@ -280,102 +255,102 @@ def _nvidia_cudart_include_dir() -> str:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _load_cudnn(): def _load_cuda_library_from_python(lib_name: str, strict: bool = False):
"""Load CUDNN shared library.""" """
Attempts to load shared object file installed via python packages.
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set `lib_name` : Name of package as found in the `nvidia` dir in python environment.
cudnn_home = os.environ.get("CUDNN_HOME") or os.environ.get("CUDNN_PATH") `strict` : If set to `True`, throw an error if lib is not found.
if cudnn_home: """
libs = glob.glob(f"{cudnn_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda ext = _get_sys_extension()
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" nvidia_dir = os.path.join(sysconfig.get_path("purelib"), "nvidia")
libs = glob.glob(f"{cuda_home}/**/libcudnn{_get_sys_extension()}*", recursive=True)
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate cuDNN in Python dist-packages # PyPI packages provided by nvidia libs exist
found, handle = _load_nvidia_cuda_library("cudnn") # in 4 possible locations inside `nvidia`.
if found: # Check by order of priority.
return handle path_found = False
if os.path.isdir(os.path.join(nvidia_dir, "cu13", lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", lib_name, f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not path_found and os.path.isdir(os.path.join(nvidia_dir, "cu13")):
so_paths = glob.glob(os.path.join(nvidia_dir, "cu13", f"lib/lib{lib_name}*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not path_found and os.path.isdir(os.path.join(nvidia_dir, lib_name)):
so_paths = glob.glob(os.path.join(nvidia_dir, lib_name, f"lib/lib*{ext}.*[0-9]"))
path_found = len(so_paths) > 0
if not IS_HIP_EXTENSION: if not path_found:
# Attempt to locate libcudnn via ldconfig so_paths = glob.glob(os.path.join(nvidia_dir, f"cuda_{lib_name}", f"lib/lib*{ext}.*[0-9]"))
libs = subprocess.check_output(["ldconfig", "-p"]) path_found = len(so_paths) > 0
libs = libs.decode("utf-8").split("\n")
sos = []
for lib in libs:
if "libcudnn" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise ctypes_handles = []
return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
if path_found:
for so_path in so_paths:
ctypes_handles.append(ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL))
if strict and not path_found:
raise RuntimeError(f"{lib_name} shared object not found.")
return path_found, ctypes_handles
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _load_nvrtc(): def _load_cuda_library_from_system(lib_name: str):
"""Load NVRTC shared library.""" """
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda Attempts to load shared object file installed via system/cuda-toolkit.
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda"
libs = glob.glob(f"{cuda_home}/**/libnvrtc{_get_sys_extension()}*", recursive=True) `lib_name`: Name of library to load without extension or `lib` prefix.
libs = list(filter(lambda x: not ("stub" in x or "libnvrtc-builtins" in x), libs)) """
libs.sort(reverse=True, key=os.path.basename)
if libs:
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# Attempt to locate NVRTC in Python dist-packages
found, handle = _load_nvidia_cuda_library("cuda_nvrtc")
if found:
return handle
# Attempt to locate NVRTC via ldconfig # Where to look for the shared lib in decreasing order of preference.
libs = subprocess.check_output(["ldconfig", "-p"]) paths = (
libs = libs.decode("utf-8").split("\n") os.environ.get(f"{lib_name.upper()}_HOME"),
sos = [] os.environ.get(f"{lib_name.upper()}_PATH"),
for lib in libs: os.environ.get("CUDA_HOME"),
if "libnvrtc" in lib and "=>" in lib: os.environ.get("CUDA_PATH"),
sos.append(lib.split(">")[1].strip()) "/usr/local/cuda",
if sos: )
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
for path in paths:
if path is None:
continue
libs = glob.glob(f"{path}/**/lib{lib_name}{_get_sys_extension()}*", recursive=True)
libs = [lib for lib in libs if "stub" not in lib]
libs.sort(reverse=True, key=os.path.basename)
if libs:
return True, ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise # Search in LD_LIBRARY_PATH.
return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) try:
_lib_handle = ctypes.CDLL(f"lib{lib_name}{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
return True, _lib_handle
except OSError:
return False, None
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _load_curand(): def _load_cuda_library(lib_name: str):
"""Load cuRAND shared library.""" """
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda Load given shared library.
cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" Prioritize loading from system/toolkit
libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True) before checking python packages.
libs = list(filter(lambda x: not ("stub" in x), libs)) """
libs.sort(reverse=True, key=os.path.basename)
if libs: # Attempt to locate library in system.
return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) found, handle = _load_cuda_library_from_system(lib_name)
# Attempt to locate cuRAND in Python dist-packages
found, handle = _load_nvidia_cuda_library("curand")
if found: if found:
return handle return True, handle
# Attempt to locate cuRAND via ldconfig # Attempt to locate library in Python dist-packages.
libs = subprocess.check_output(["ldconfig", "-p"]) found, handle = _load_cuda_library_from_python(lib_name)
libs = libs.decode("utf-8").split("\n") if found:
sos = [] return False, handle
for lib in libs:
if "libcurand" in lib and "=>" in lib:
sos.append(lib.split(">")[1].strip())
if sos:
return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise raise RuntimeError(f"{lib_name} shared object not found.")
return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL)
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -387,11 +362,22 @@ def _load_core_library(): ...@@ -387,11 +362,22 @@ def _load_core_library():
if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
try: try:
sanity_checks_for_pypi_installation() sanity_checks_for_pypi_installation()
_CUDNN_LIB_CTYPES = _load_cudnn()
_NVRTC_LIB_CTYPES = _load_nvrtc() # `_load_cuda_library` is used for packages that must be loaded
_CURAND_LIB_CTYPES = _load_curand() # during runtime. Both system and pypi packages are searched
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") # and an error is thrown if not found.
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") _, _CUDNN_LIB_CTYPES = _load_cuda_library("cudnn")
system_nvrtc, _NVRTC_LIB_CTYPES = _load_cuda_library("nvrtc")
system_curand, _CURAND_LIB_CTYPES = _load_cuda_library("curand")
# This additional step is necessary to be able to install TE wheels
# and import TE (without any guards) in an environment where the cuda
# toolkit might be absent without being guarded
load_libs_for_no_ctk = not system_nvrtc and not system_curand
if load_libs_for_no_ctk:
_CUBLAS_LIB_CTYPES = _load_cuda_library_from_python("cublas", strict=True)
_CUDART_LIB_CTYPES = _load_cuda_library_from_python("cudart", strict=True)
_CUDNN_ALL_LIB_CTYPES = _load_cuda_library_from_python("cudnn", strict=True)
# Needed to find the correct headers for NVRTC kernels. # Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir(): if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -102,3 +102,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, ...@@ -102,3 +102,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s)));
} }
} }
// Group quantize assumes contiguous inputs and outputs in memory allocation
// TODO (zhongbo): find a better way to make it a more generalized API
void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_nvfp4_quantize_with_amax);
using namespace transformer_engine;
constexpr bool IS_ACT = false;
dispatch::group_quantize_fwd_helper<IS_ACT, Empty, nullptr>(input, outputs, split_sections,
num_tensors, quant_config, stream);
}
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -27,9 +27,9 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t ...@@ -27,9 +27,9 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t
switch (input.scaling_mode) { switch (input.scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
NVTE_CHECK(is_fp8_dtype(input.data.dtype) || is_int8_dtype(input.data.dtype), "Input must have FP8 or INT8 type."); NVTE_CHECK(is_fp8_dtype(input.dtype()) || is_int8_dtype(input.dtype()), "Input must have FP8 type.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype) && !is_int8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(output->dtype()) && !is_int8_dtype(output->dtype()), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match.");
fp8::dequantize(input, output, stream); fp8::dequantize(input, output, stream);
break; break;
} }
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <transformer_engine/transformer_engine.h> #include <transformer_engine/transformer_engine.h>
#include "../../common.h" #include "../../common.h"
#include "../../transpose/transpose.h"
#include "../../utils.cuh" #include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh" #include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh" #include "../mxfp8/gated_mxfp8.cuh"
...@@ -53,6 +54,20 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp ...@@ -53,6 +54,20 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp
} else { } else {
fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream); fp8::cast_gated_fwd<ParamOP, ActOP>(input, output, p, stream);
} }
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
break; break;
} }
case NVTE_MXFP8_1D_SCALING: { case NVTE_MXFP8_1D_SCALING: {
...@@ -98,8 +113,8 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte ...@@ -98,8 +113,8 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
const size_t rows = gated_input.flat_first_dim(); const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2; const size_t cols = gated_input.flat_last_dim() / 2;
NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(grad.dtype()), "Grad input must be in higher precision.");
NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); NVTE_CHECK(grad.dtype() == gated_input.dtype(), "Types of both inputs must match.");
NVTE_CHECK(grad.flat_first_dim() == rows, NVTE_CHECK(grad.flat_first_dim() == rows,
"Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [",
...@@ -116,9 +131,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte ...@@ -116,9 +131,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
NVTE_CHECK(output->flat_last_dim() == cols * 2, NVTE_CHECK(output->flat_last_dim() == cols * 2,
"Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [",
output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(gated_input.data.shape == output->data.shape, NVTE_CHECK(gated_input.shape() == output->shape(),
"Gated input and output shapes must match. Input shape: ", gated_input.data.shape, "Gated input and output shapes must match. Input shape: ", gated_input.shape(),
", output shape: ", output->data.shape, "."); ", output shape: ", output->shape(), ".");
switch (output->scaling_mode) { switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
...@@ -129,6 +144,20 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte ...@@ -129,6 +144,20 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
} else { } else {
fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream); fp8::cast_gated_bwd<ParamOP, ActOP, DActOP>(gated_input, grad, output, p, stream);
} }
if (is_fp8_dtype(output->dtype()) && output->has_columnwise_data()) {
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor transpose_in, transpose_out, dummy;
transpose_in.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_in.data.dptr = output->data.dptr;
transpose_in.data.shape = {output->flat_first_dim(), output->flat_last_dim()};
transpose_in.data.dtype = output->data.dtype;
transpose_out.scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
transpose_out.data.dptr = output->columnwise_data.dptr;
transpose_out.data.shape = {output->flat_last_dim(), output->flat_first_dim()};
transpose_out.data.dtype = output->data.dtype;
detail::transpose(transpose_in, /*noop=*/dummy, &transpose_out, stream);
}
break; break;
} }
case NVTE_MXFP8_1D_SCALING: { case NVTE_MXFP8_1D_SCALING: {
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "../core/common.cuh" #include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh" #include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh"
...@@ -154,17 +155,10 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, ...@@ -154,17 +155,10 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) { if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
} }
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
} }
quantize_transpose_vector_blockwise( quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
...@@ -307,17 +301,10 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens ...@@ -307,17 +301,10 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE;
FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE;
if (output_tensor->has_data()) { if (output_tensor->has_data()) {
bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == rowwise_option = FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
Float8BlockScaleTensorFormat::COMPACT);
rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT
: FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY;
} }
if (output_tensor->has_columnwise_data()) { if (output_tensor->has_columnwise_data()) {
bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == columnwise_option = FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
Float8BlockScaleTensorFormat::COMPACT);
columnwise_option = columnwise_compact
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
} }
quantize_transpose_vector_blockwise( quantize_transpose_vector_blockwise(
grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
...@@ -330,6 +317,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens ...@@ -330,6 +317,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
} }
} }
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
const size_t *split_sections, const size_t num_tensors,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
const Tensor *input_tensor = convertNVTETensorCheck(input);
std::vector<Tensor *> output_tensors;
for (size_t i = 0; i < num_tensors; ++i) {
output_tensors.push_back(convertNVTETensorCheck(outputs[i]));
}
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Check for unsupported options
if (quant_config_cpp.stochastic_rounding) {
NVTE_CHECK(output_tensors[0]->scaling_mode == NVTE_NVFP4_1D_SCALING,
"Stochastic rounding is only supported for NVFP4 quantization.");
}
// Take the scaling mode of the first output tensor
auto scaling_mode = output_tensors[0]->scaling_mode;
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_NVFP4_1D_SCALING: {
NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING");
// Check tensors
CheckNoopTensor(*noop_tensor, "cast_noop");
CheckInputTensor(*input_tensor, "input");
// Skip checking output tensor list
// output list here is allowed to have empty tensor
// Choose kernel
int32_t rows = input_tensor->flat_first_dim();
int32_t cols = input_tensor->flat_last_dim();
auto dtype = input_tensor->dtype();
NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization,
"2D quantization is not supported for group quantize.");
// Launch NVFP4 group quantize kernel
nvfp4::group_quantize_transpose</*use_2d_quantization*/ false>(
*input_tensor, noop_tensor, output_tensors, split_sections, num_tensors,
&quant_config_cpp, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
} // namespace dispatch } // namespace dispatch
} // namespace transformer_engine } // namespace transformer_engine
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -234,8 +234,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -234,8 +234,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
bool use_colwise_scaling = input.has_columnwise_data(); bool use_colwise_scaling = input.has_columnwise_data();
checkCuDriverContext(stream); checkCuDriverContext(stream);
const auto &input_shape = input.data.shape; NVTE_CHECK(input.dim() >= 2, "Input must have at least 2 dimensions.");
NVTE_CHECK(input_shape.size() >= 2, "Input must have at least 2 dimensions.");
if (use_rowwise_scaling) { if (use_rowwise_scaling) {
NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data."); NVTE_CHECK(input.has_data(), "Cannot dequantize tensor without rowwise data.");
...@@ -247,8 +246,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ...@@ -247,8 +246,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type."); NVTE_CHECK(is_fp8_dtype(input.columnwise_data.dtype), "Input must have FP8 type.");
} }
NVTE_CHECK(!input.with_gemm_swizzled_scales, "Input must have scales in compact format.");
NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision.");
NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); NVTE_CHECK(output->shape() == input.shape(), "Input and output shapes need to match.");
// TODO: Make more general // TODO: Make more general
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
......
/************************************************************************* /*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "../../util/math.h" #include "../../util/math.h"
#include "../../util/ptx.cuh" #include "../../util/ptx.cuh"
#include "../../utils.cuh" #include "../../utils.cuh"
#include "swizzle.cuh"
namespace transformer_engine { namespace transformer_engine {
namespace dispatch { namespace dispatch {
...@@ -54,7 +55,8 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 ...@@ -54,7 +55,8 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_BWD, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType, float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK> bool ROWWISE_SCALING, bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES,
size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) __global__ void __launch_bounds__(THREADS_PER_CHUNK)
quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act, const __grid_constant__ CUtensorMap tensor_map_input_act,
...@@ -71,6 +73,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -71,6 +73,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
using IType2 = typename ptx::FPx2<IType>; using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>; using OType2 = typename ptx::FPx2<OType>;
using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1); static_assert(STAGES >= 1);
...@@ -358,14 +362,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -358,14 +362,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor // 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act = const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise; const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx = size_t scale_idx;
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx] = biased_exponent_act; scales_colwise[scale_idx] = biased_exponent_act;
} }
...@@ -377,8 +384,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -377,8 +384,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; size_t scale_idx_gate;
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx_gate = gemm_swizzled_scale_idx(
global_scales_offset_X + gate_scale_idx_offset_colwise, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
}
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx_gate] = biased_exponent_gate; scales_colwise[scale_idx_gate] = biased_exponent_gate;
} }
...@@ -560,7 +573,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -560,7 +573,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise; const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; size_t scale_idx;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(output_cols, static_cast<size_t>(128)));
} else {
scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
}
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
if (!out_of_bounds_rowwise) { if (!out_of_bounds_rowwise) {
...@@ -576,7 +596,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -576,7 +596,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if constexpr (IS_BWD) { if constexpr (IS_BWD) {
const e8m0_t biased_exponent_gate = const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
size_t scale_idx_gate;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
const size_t output_cols = (IS_BWD ? 2 : 1) * cols;
scale_idx_gate = gemm_swizzled_scale_idx(
stage_scales_offset_Y, stage_scales_offset_X + gate_scale_idx_offset_rowwise,
DIVUP(output_cols, static_cast<size_t>(128)));
} else {
scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
}
if (!out_of_bounds_rowwise) { if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate; scales_rowwise[scale_idx_gate] = biased_exponent_gate;
} }
...@@ -670,7 +699,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -670,7 +699,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
parity ^= 1; parity ^= 1;
destroy_barriers<STAGES>(mbar, is_master_thread); destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} } // NOLINT(readability/fn_size)
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
} // namespace gated_kernel } // namespace gated_kernel
...@@ -686,6 +715,7 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu ...@@ -686,6 +715,7 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data();
const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;
if (USE_ROWWISE_SCALING) { if (USE_ROWWISE_SCALING) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
...@@ -729,113 +759,140 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu ...@@ -729,113 +759,140 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
gated_input.dtype(), IType, gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType, output->dtype(), OType,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,
alignas(64) CUtensorMap tensor_map_grad{}; alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{}; alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{}; alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; alignas(64) CUtensorMap tensor_map_output_act_rowwise{};
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; alignas(64) CUtensorMap tensor_map_output_gate_rowwise{};
alignas(64) CUtensorMap tensor_map_output_act_colwise{}; alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size; constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size; constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
if constexpr (IS_BWD) { if constexpr (IS_BWD) {
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
cols, 0, input_type_bit_size); cols, 0, input_type_bit_size);
} }
const uint32_t tensor_stride_elems = output_cols; const uint32_t tensor_stride_elems = output_cols;
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y, create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, 0, input_type_bit_size); BUFF_DIM_X, cols * 2, 0, input_type_bit_size);
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y, create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y,
BUFF_DIM_X, cols * 2, cols, input_type_bit_size); BUFF_DIM_X, cols * 2, cols, input_type_bit_size);
if (USE_ROWWISE_SCALING) { if (USE_ROWWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size); output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size); output_type_bit_size);
} }
if (USE_COLWISE_SCALING) { if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
output_type_bit_size); output_type_bit_size);
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows,
cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
output_type_bit_size); output_type_bit_size);
} }
const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X;
const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
const size_t buff_size_aligned_in = const size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out = const size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in; const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0);
size_t out_mem = out_act_mem + out_gate_mem; size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
switch (scaling_type) { // Zero out swizzled scales if padding is needed
case ScalingType::ROWWISE: { /// TODO (tmoon) Handle this within the cast kernel
auto kernel = if (with_gemm_swizzled_scales) {
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, true, constexpr size_t TILE_DIM_X = 128; // Tile dim in data buffer
false, THREADS_PER_CHUNK_NON_COLWISE>; constexpr size_t TILE_DIM_Y = 128;
NVTE_CHECK_CUDA(cudaFuncSetAttribute( if (cols % TILE_DIM_X != 0 || rows % TILE_DIM_Y != 0) {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); if (USE_ROWWISE_SCALING) {
NVTE_CHECK_CUDA(cudaMemsetAsync(output->scale_inv.dptr, 0,
kernel<<<grid, block_size, shmem_size, stream>>>( output->scale_inv.buffer_size_bytes(), stream));
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, }
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, if (USE_COLWISE_SCALING) {
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, NVTE_CHECK_CUDA(
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); cudaMemsetAsync(output->columnwise_scale_inv.dptr, 0,
break; output->columnwise_scale_inv.buffer_size_bytes(), stream));
} }
case ScalingType::COLWISE: { }
auto kernel = }
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, false,
true, THREADS_PER_CHUNK_COLWISE>; switch (scaling_type) {
NVTE_CHECK_CUDA(cudaFuncSetAttribute( case ScalingType::ROWWISE: {
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
kernel<<<grid, block_size, shmem_size, stream>>>( true, false, WITH_GEMM_SWIZZLED_SCALES,
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, THREADS_PER_CHUNK_NON_COLWISE>;
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, NVTE_CHECK_CUDA(cudaFuncSetAttribute(
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p);
break; kernel<<<grid, block_size, shmem_size, stream>>>(
} tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
case ScalingType::BIDIMENSIONAL: { tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
auto kernel = tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType, true, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
true, THREADS_PER_CHUNK_NON_COLWISE>; scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaFuncSetAttribute( break;
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); }
case ScalingType::COLWISE: {
kernel<<<grid, block_size, shmem_size, stream>>>( auto kernel =
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, false, true, WITH_GEMM_SWIZZLED_SCALES,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, THREADS_PER_CHUNK_COLWISE>;
scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaFuncSetAttribute(
break; kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
}
} NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) kernel<<<grid, block_size, shmem_size, stream>>>(
); // NOLINT(*) tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
break;
}
case ScalingType::BIDIMENSIONAL: {
auto kernel =
quantize_gated_mxfp8_kernel<IS_BWD, ParamOP, ActOP, DActOP, IType, OType,
true, true, WITH_GEMM_SWIZZLED_SCALES,
THREADS_PER_CHUNK_NON_COLWISE>;
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
kernel<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise, p);
break;
}
} NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment