Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
468 additions
and
304 deletions
+468
-304
tests/pytorch/test_sanity.py
tests/pytorch/test_sanity.py
+5
-4
tests/pytorch/test_sanity_import.py
tests/pytorch/test_sanity_import.py
+1
-1
tests/pytorch/utils.py
tests/pytorch/utils.py
+11
-1
transformer_engine/__init__.py
transformer_engine/__init__.py
+1
-1
transformer_engine/common/CMakeLists.txt
transformer_engine/common/CMakeLists.txt
+47
-32
transformer_engine/common/__init__.py
transformer_engine/common/__init__.py
+97
-111
transformer_engine/common/activation/activation_template.h
transformer_engine/common/activation/activation_template.h
+1
-1
transformer_engine/common/activation/gelu.cu
transformer_engine/common/activation/gelu.cu
+1
-1
transformer_engine/common/activation/relu.cu
transformer_engine/common/activation/relu.cu
+1
-1
transformer_engine/common/activation/swiglu.cu
transformer_engine/common/activation/swiglu.cu
+1
-1
transformer_engine/common/cast/cast.cu
transformer_engine/common/cast/cast.cu
+16
-1
transformer_engine/common/cast/core/common.cuh
transformer_engine/common/cast/core/common.cuh
+1
-1
transformer_engine/common/cast/dispatch/dequantize.cuh
transformer_engine/common/cast/dispatch/dequantize.cuh
+4
-4
transformer_engine/common/cast/dispatch/gated.cuh
transformer_engine/common/cast/dispatch/gated.cuh
+35
-6
transformer_engine/common/cast/dispatch/quantize.cuh
transformer_engine/common/cast/dispatch/quantize.cuh
+70
-19
transformer_engine/common/cast/fp8/dequantize_fp8.cuh
transformer_engine/common/cast/fp8/dequantize_fp8.cuh
+1
-1
transformer_engine/common/cast/fp8/gated_fp8.cuh
transformer_engine/common/cast/fp8/gated_fp8.cuh
+1
-1
transformer_engine/common/cast/fp8/quantize_fp8.cuh
transformer_engine/common/cast/fp8/quantize_fp8.cuh
+1
-1
transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh
transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh
+4
-4
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
+169
-112
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
tests/pytorch/test_sanity.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -37,7 +37,6 @@ from transformer_engine.pytorch import (
from
transformer_engine.common
import
recipe
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.cpp_extensions
import
general_gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
from
transformer_engine.pytorch.tensor.utils
import
replace_raw_data
from
utils
import
ModelConfig
...
...
@@ -539,6 +538,7 @@ def test_sanity_grouped_linear(
@
pytest
.
mark
.
parametrize
(
"activation"
,
all_activations
)
@
pytest
.
mark
.
parametrize
(
"normalization"
,
all_normalizations
)
@
pytest
.
mark
.
parametrize
(
"microbatching"
,
all_boolean
)
@
pytest
.
mark
.
parametrize
(
"checkpoint"
,
all_boolean
)
def
test_sanity_layernorm_mlp
(
dtype
,
fp8_recipe
,
...
...
@@ -549,6 +549,7 @@ def test_sanity_layernorm_mlp(
activation
,
normalization
,
microbatching
,
checkpoint
,
):
config
=
model_configs
[
model
]
...
...
@@ -579,6 +580,7 @@ def test_sanity_layernorm_mlp(
normalization
=
normalization
,
params_dtype
=
dtype
,
device
=
"cuda"
,
checkpoint
=
checkpoint
,
)
_test_sanity_common
(
block
,
dtype
,
config
,
fp8_recipe
,
skip_wgrad
,
skip_dgrad
,
microbatching
)
...
...
@@ -961,7 +963,7 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
inp
=
torch
.
reshape
(
scratchpad
[
offset
:
-
offset
],
(
N
,
N
))
weight
=
torch
.
reshape
(
scratchpad
[
offset
*
2
:],
(
N
,
N
))
_
=
general_gemm
(
A
=
weight
,
B
=
inp
,
workspace
=
get_workspace
()
)
_
=
general_gemm
(
A
=
weight
,
B
=
inp
)
torch
.
cuda
.
synchronize
()
...
...
@@ -985,7 +987,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype):
general_gemm
(
weight_fp8
,
inp_fp8
,
get_workspace
(),
outp_type
,
bias
=
None
,
use_split_accumulator
=
False
,
...
...
tests/pytorch/test_sanity_import.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
tests/pytorch/utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -8,6 +8,7 @@ import logging
import
os
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Dict
,
Any
,
List
from
packaging.version
import
Version
as
PkgVersion
import
torch
...
...
@@ -210,6 +211,7 @@ class ModelConfig:
max_ctx_len
:
int
=
None
,
num_layers
:
int
=
1
,
eps
:
float
=
1e-5
,
num_splits
=
1
,
):
self
.
batch_size
=
batch_size
self
.
max_seqlen_q
=
max_seqlen_q
...
...
@@ -239,6 +241,7 @@ class ModelConfig:
self
.
max_ctx_len
=
max_ctx_len
self
.
num_layers
=
num_layers
self
.
eps
=
eps
self
.
num_splits
=
num_splits
@
contextmanager
...
...
@@ -321,6 +324,9 @@ def get_available_attention_backends(
inference_params
=
inference_params
,
softmax_type
=
config
.
softmax_type
,
return_max_logit
=
config
.
return_max_logit
,
# allow all backends to pass so they can be used for testing;
# check for FA3 availability later
num_splits
=
1
,
)
(
use_flash_attention
,
...
...
@@ -330,6 +336,10 @@ def get_available_attention_backends(
use_unfused_attention
,
available_backends
,
)
=
get_attention_backend
(
attention_params
)
# Check if FA3 is an available backend when num_splits != 1
if
available_backends
[
0
]:
if
config
.
num_splits
!=
1
and
not
flash_attention_backend
>
PkgVersion
(
"3.0.0b"
):
available_backends
[
0
]
=
False
# Set attention.py _attention_backends var using return value
# from get_attention_backend()
_attention_backends
[
"use_flash_attention"
]
=
use_flash_attention
...
...
transformer_engine/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/common/CMakeLists.txt
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -183,7 +183,6 @@ if(USE_CUDA)
list
(
APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
...
...
@@ -225,15 +224,20 @@ if(USE_CUDA)
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
gemm/cutlass_grouped_gemm.cu
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
)
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
)
# Compiling the files with the worst compilation time first to hopefully overlap
# better with the faster-compiling cpp files
...
...
@@ -281,13 +285,42 @@ if(USE_CUDA)
endif
()
add_library
(
transformer_engine SHARED
${
transformer_engine_SOURCES
}
)
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
#
CUTLASS
kernels require SM90a
and cause hang in debug build
#
Grouped GEMM
kernels require SM90a
set_property
(
SOURCE gemm/cutlass_grouped_gemm.cu
APPEND
PROPERTY
COMPILE_OPTIONS
"--generate-code=arch=compute_90a,code=sm_90a;-g0"
)
COMPILE_OPTIONS
"--generate-code=arch=compute_90a,code=sm_90a"
)
# CUTLASS kernels could cause hang in debug build
set
(
CUTLASS_KERNEL_SOURCES
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
)
set_property
(
SOURCE
${
CUTLASS_KERNEL_SOURCES
}
APPEND
PROPERTY
COMPILE_OPTIONS
"-g0;-dopt=on"
)
# Configure dependencies
target_link_libraries
(
transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine SYSTEM PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
target_include_directories
(
transformer_engine PRIVATE
${
CUTLASS_INCLUDE_DIR
}
${
CUTLASS_TOOLS_INCLUDE_DIR
}
)
else
()
list
(
APPEND transformer_engine_cpp_sources
cudnn_utils.cpp
...
...
@@ -308,7 +341,6 @@ else()
list
(
APPEND transformer_engine_cuda_sources
common.cu
multi_tensor/adam.cu
multi_tensor/compute_scale.cu
multi_tensor/l2norm.cu
multi_tensor/scale.cu
multi_tensor/sgd.cu
...
...
@@ -348,10 +380,12 @@ else()
comm_gemm_overlap/userbuffers/userbuffers.cu
)
list
(
APPEND transformer_engine_cuda_arch_specific_sources
cast/cast.cu
activation/gelu.cu
activation/relu.cu
activation/swiglu.cu
cast/cast.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
transpose/quantize_transpose_vector_blockwise_fp4.cu
)
...
...
@@ -398,27 +432,9 @@ else()
message
(
STATUS
"nvte hipified sources:
${
te_hip_sources
}
"
)
add_library
(
transformer_engine SHARED
${
te_hip_sources
}
)
endif
()
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
/include"
)
if
(
USE_CUDA
)
# Configure dependencies
target_link_libraries
(
transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDNN::cudnn_all
)
target_include_directories
(
transformer_engine PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
)
target_include_directories
(
transformer_engine SYSTEM PRIVATE
${
CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES
}
/cccl
)
target_include_directories
(
transformer_engine PRIVATE
"
${
CUDNN_FRONTEND_INCLUDE_DIR
}
"
)
target_include_directories
(
transformer_engine PRIVATE
${
CUTLASS_INCLUDE_DIR
}
${
CUTLASS_TOOLS_INCLUDE_DIR
}
)
else
()
target_include_directories
(
transformer_engine PUBLIC
"
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
# Aotriton is currently unsupported
set
(
AotritonAndCk_fused_attn
"unsupported"
)
...
...
@@ -441,7 +457,6 @@ else()
target_link_libraries
(
transformer_engine PUBLIC
${
transformer_engine_LINKER_LIBS
}
)
endif
()
# Compiling Userbuffers with native MPI bootstrapping requires linking against MPI
option
(
NVTE_UB_WITH_MPI
"Bootstrap Userbuffers with MPI"
OFF
)
if
(
NVTE_UB_WITH_MPI
)
...
...
transformer_engine/common/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -236,31 +236,6 @@ def _get_sys_extension() -> str:
raise
RuntimeError
(
f
"Unsupported operating system (
{
system
}
)"
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_nvidia_cuda_library
(
lib_name
:
str
):
"""
Attempts to load shared object file installed via pip.
`lib_name`: Name of package as found in the `nvidia` dir in python environment.
"""
so_paths
=
glob
.
glob
(
os
.
path
.
join
(
sysconfig
.
get_path
(
"purelib"
),
f
"nvidia/
{
lib_name
}
/lib/lib*
{
_get_sys_extension
()
}
.*[0-9]"
,
)
)
path_found
=
len
(
so_paths
)
>
0
ctypes_handles
=
[]
if
path_found
:
for
so_path
in
so_paths
:
ctypes_handles
.
append
(
ctypes
.
CDLL
(
so_path
,
mode
=
ctypes
.
RTLD_GLOBAL
))
return
path_found
,
ctypes_handles
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_nvidia_cudart_include_dir
()
->
str
:
"""Returns the include directory for cuda_runtime.h if exists in python environment."""
...
...
@@ -280,102 +255,102 @@ def _nvidia_cudart_include_dir() -> str:
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_cudnn
():
"""Load CUDNN shared library."""
def
_load_cuda_library_from_python
(
lib_name
:
str
,
strict
:
bool
=
False
):
"""
Attempts to load shared object file installed via python packages.
# Attempt to locate cuDNN in CUDNN_HOME or CUDNN_PATH, if either is set
cudnn_home
=
os
.
environ
.
get
(
"CUDNN_HOME"
)
or
os
.
environ
.
get
(
"CUDNN_PATH"
)
if
cudnn_home
:
libs
=
glob
.
glob
(
f
"
{
cudnn_home
}
/**/libcudnn
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
`lib_name` : Name of package as found in the `nvidia` dir in python environment.
`strict` : If set to `True`, throw an error if lib is not found.
"""
# Attempt to locate cuDNN in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home
=
os
.
environ
.
get
(
"CUDA_HOME"
)
or
os
.
environ
.
get
(
"CUDA_PATH"
)
or
"/usr/local/cuda"
libs
=
glob
.
glob
(
f
"
{
cuda_home
}
/**/libcudnn
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
ext
=
_get_sys_extension
()
nvidia_dir
=
os
.
path
.
join
(
sysconfig
.
get_path
(
"purelib"
),
"nvidia"
)
# Attempt to locate cuDNN in Python dist-packages
found
,
handle
=
_load_nvidia_cuda_library
(
"cudnn"
)
if
found
:
return
handle
# PyPI packages provided by nvidia libs exist
# in 4 possible locations inside `nvidia`.
# Check by order of priority.
path_found
=
False
if
os
.
path
.
isdir
(
os
.
path
.
join
(
nvidia_dir
,
"cu13"
,
lib_name
)):
so_paths
=
glob
.
glob
(
os
.
path
.
join
(
nvidia_dir
,
"cu13"
,
lib_name
,
f
"lib/lib*
{
ext
}
.*[0-9]"
))
path_found
=
len
(
so_paths
)
>
0
if
not
path_found
and
os
.
path
.
isdir
(
os
.
path
.
join
(
nvidia_dir
,
"cu13"
)):
so_paths
=
glob
.
glob
(
os
.
path
.
join
(
nvidia_dir
,
"cu13"
,
f
"lib/lib
{
lib_name
}
*
{
ext
}
.*[0-9]"
))
path_found
=
len
(
so_paths
)
>
0
if
not
path_found
and
os
.
path
.
isdir
(
os
.
path
.
join
(
nvidia_dir
,
lib_name
)):
so_paths
=
glob
.
glob
(
os
.
path
.
join
(
nvidia_dir
,
lib_name
,
f
"lib/lib*
{
ext
}
.*[0-9]"
))
path_found
=
len
(
so_paths
)
>
0
if
not
IS_HIP_EXTENSION
:
# Attempt to locate libcudnn via ldconfig
libs
=
subprocess
.
check_output
([
"ldconfig"
,
"-p"
])
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
if
"libcudnn"
in
lib
and
"=>"
in
lib
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
sos
:
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
if
not
path_found
:
so_paths
=
glob
.
glob
(
os
.
path
.
join
(
nvidia_dir
,
f
"cuda_
{
lib_name
}
"
,
f
"lib/lib*
{
ext
}
.*[0-9]"
))
path_found
=
len
(
so_paths
)
>
0
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcudnn
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
ctypes_handles
=
[]
if
path_found
:
for
so_path
in
so_paths
:
ctypes_handles
.
append
(
ctypes
.
CDLL
(
so_path
,
mode
=
ctypes
.
RTLD_GLOBAL
))
if
strict
and
not
path_found
:
raise
RuntimeError
(
f
"
{
lib_name
}
shared object not found."
)
return
path_found
,
ctypes_handles
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_nvrtc
():
"""Load NVRTC shared library."""
# Attempt to locate NVRTC in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home
=
os
.
environ
.
get
(
"CUDA_HOME"
)
or
os
.
environ
.
get
(
"CUDA_PATH"
)
or
"/usr/local/cuda"
libs
=
glob
.
glob
(
f
"
{
cuda_home
}
/**/libnvrtc
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
=
list
(
filter
(
lambda
x
:
not
(
"stub"
in
x
or
"libnvrtc-builtins"
in
x
),
libs
))
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate NVRTC in Python dist-packages
found
,
handle
=
_load_nvidia_cuda_library
(
"cuda_nvrtc"
)
if
found
:
return
handle
def
_load_cuda_library_from_system
(
lib_name
:
str
):
"""
Attempts to load shared object file installed via system/cuda-toolkit.
`lib_name`: Name of library to load without extension or `lib` prefix.
"""
# Attempt to locate NVRTC via ldconfig
libs
=
subprocess
.
check_output
([
"ldconfig"
,
"-p"
])
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
if
"libnvrtc"
in
lib
and
"=>"
in
lib
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
sos
:
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Where to look for the shared lib in decreasing order of preference.
paths
=
(
os
.
environ
.
get
(
f
"
{
lib_name
.
upper
()
}
_HOME"
),
os
.
environ
.
get
(
f
"
{
lib_name
.
upper
()
}
_PATH"
),
os
.
environ
.
get
(
"CUDA_HOME"
),
os
.
environ
.
get
(
"CUDA_PATH"
),
"/usr/local/cuda"
,
)
for
path
in
paths
:
if
path
is
None
:
continue
libs
=
glob
.
glob
(
f
"
{
path
}
/**/lib
{
lib_name
}{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
=
[
lib
for
lib
in
libs
if
"stub"
not
in
lib
]
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
return
True
,
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libnvrtc
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
# Search in LD_LIBRARY_PATH.
try
:
_lib_handle
=
ctypes
.
CDLL
(
f
"lib
{
lib_name
}{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
return
True
,
_lib_handle
except
OSError
:
return
False
,
None
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_load_curand
():
"""Load cuRAND shared library."""
# Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda
cuda_home
=
os
.
environ
.
get
(
"CUDA_HOME"
)
or
os
.
environ
.
get
(
"CUDA_PATH"
)
or
"/usr/local/cuda"
libs
=
glob
.
glob
(
f
"
{
cuda_home
}
/**/libcurand
{
_get_sys_extension
()
}
*"
,
recursive
=
True
)
libs
=
list
(
filter
(
lambda
x
:
not
(
"stub"
in
x
),
libs
))
libs
.
sort
(
reverse
=
True
,
key
=
os
.
path
.
basename
)
if
libs
:
return
ctypes
.
CDLL
(
libs
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate cuRAND in Python dist-packages
found
,
handle
=
_load_nvidia_cuda_library
(
"curand"
)
def
_load_cuda_library
(
lib_name
:
str
):
"""
Load given shared library.
Prioritize loading from system/toolkit
before checking python packages.
"""
# Attempt to locate library in system.
found
,
handle
=
_load_cuda_library_from_system
(
lib_name
)
if
found
:
return
handle
return
True
,
handle
# Attempt to locate cuRAND via ldconfig
libs
=
subprocess
.
check_output
([
"ldconfig"
,
"-p"
])
libs
=
libs
.
decode
(
"utf-8"
).
split
(
"
\n
"
)
sos
=
[]
for
lib
in
libs
:
if
"libcurand"
in
lib
and
"=>"
in
lib
:
sos
.
append
(
lib
.
split
(
">"
)[
1
].
strip
())
if
sos
:
return
ctypes
.
CDLL
(
sos
[
0
],
mode
=
ctypes
.
RTLD_GLOBAL
)
# Attempt to locate library in Python dist-packages.
found
,
handle
=
_load_cuda_library_from_python
(
lib_name
)
if
found
:
return
False
,
handle
# If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise
return
ctypes
.
CDLL
(
f
"libcurand
{
_get_sys_extension
()
}
"
,
mode
=
ctypes
.
RTLD_GLOBAL
)
raise
RuntimeError
(
f
"
{
lib_name
}
shared object not found."
)
@
functools
.
lru_cache
(
maxsize
=
None
)
...
...
@@ -387,11 +362,22 @@ def _load_core_library():
if
"NVTE_PROJECT_BUILDING"
not
in
os
.
environ
or
bool
(
int
(
os
.
getenv
(
"NVTE_RELEASE_BUILD"
,
"0"
))):
try
:
sanity_checks_for_pypi_installation
()
_CUDNN_LIB_CTYPES
=
_load_cudnn
()
_NVRTC_LIB_CTYPES
=
_load_nvrtc
()
_CURAND_LIB_CTYPES
=
_load_curand
()
_CUBLAS_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cublas"
)
_CUDART_LIB_CTYPES
=
_load_nvidia_cuda_library
(
"cuda_runtime"
)
# `_load_cuda_library` is used for packages that must be loaded
# during runtime. Both system and pypi packages are searched
# and an error is thrown if not found.
_
,
_CUDNN_LIB_CTYPES
=
_load_cuda_library
(
"cudnn"
)
system_nvrtc
,
_NVRTC_LIB_CTYPES
=
_load_cuda_library
(
"nvrtc"
)
system_curand
,
_CURAND_LIB_CTYPES
=
_load_cuda_library
(
"curand"
)
# This additional step is necessary to be able to install TE wheels
# and import TE (without any guards) in an environment where the cuda
# toolkit might be absent without being guarded
load_libs_for_no_ctk
=
not
system_nvrtc
and
not
system_curand
if
load_libs_for_no_ctk
:
_CUBLAS_LIB_CTYPES
=
_load_cuda_library_from_python
(
"cublas"
,
strict
=
True
)
_CUDART_LIB_CTYPES
=
_load_cuda_library_from_python
(
"cudart"
,
strict
=
True
)
_CUDNN_ALL_LIB_CTYPES
=
_load_cuda_library_from_python
(
"cudnn"
,
strict
=
True
)
# Needed to find the correct headers for NVRTC kernels.
if
not
os
.
getenv
(
"NVTE_CUDA_INCLUDE_DIR"
)
and
_nvidia_cudart_include_dir
():
...
...
transformer_engine/common/activation/activation_template.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/activation/gelu.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/activation/relu.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/activation/swiglu.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/cast/cast.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -102,3 +102,18 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs,
NVTE_CHECK_CUDA
(
cudaStreamWaitEvent
(
stream
,
detail
::
get_compute_stream_event
(
s
)));
}
}
// Group quantize assumes contiguous inputs and outputs in memory allocation
// TODO (zhongbo): find a better way to make it a more generalized API
void
nvte_group_nvfp4_quantize_with_amax
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
size_t
*
split_sections
,
const
size_t
num_tensors
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_group_nvfp4_quantize_with_amax
);
using
namespace
transformer_engine
;
constexpr
bool
IS_ACT
=
false
;
dispatch
::
group_quantize_fwd_helper
<
IS_ACT
,
Empty
,
nullptr
>
(
input
,
outputs
,
split_sections
,
num_tensors
,
quant_config
,
stream
);
}
transformer_engine/common/cast/core/common.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/cast/dispatch/dequantize.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -27,9 +27,9 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t
switch
(
input
.
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
NVTE_CHECK
(
is_fp8_dtype
(
input
.
data
.
dtype
)
||
is_int8_dtype
(
input
.
data
.
dtype
),
"Input must have FP8
or INT8
type."
);
NVTE_CHECK
(
!
is_fp8_dtype
(
output
->
data
.
dtype
)
&&
!
is_int8_dtype
(
output
->
data
.
dtype
),
"Output must be in higher precision."
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
.
data
.
shape
,
"Input and output shapes need to match."
);
NVTE_CHECK
(
is_fp8_dtype
(
input
.
dtype
()
)
||
is_int8_dtype
(
input
.
dtype
()
),
"Input must have FP8 type."
);
NVTE_CHECK
(
!
is_fp8_dtype
(
output
->
dtype
()
)
&&
!
is_int8_dtype
(
output
->
dtype
()
),
"Output must be in higher precision."
);
NVTE_CHECK
(
output
->
shape
()
==
input
.
shape
()
,
"Input and output shapes need to match."
);
fp8
::
dequantize
(
input
,
output
,
stream
);
break
;
}
...
...
transformer_engine/common/cast/dispatch/gated.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -14,6 +14,7 @@
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../transpose/transpose.h"
#include "../../utils.cuh"
#include "../fp8/gated_fp8.cuh"
#include "../mxfp8/gated_mxfp8.cuh"
...
...
@@ -53,6 +54,20 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp
}
else
{
fp8
::
cast_gated_fwd
<
ParamOP
,
ActOP
>
(
input
,
output
,
p
,
stream
);
}
if
(
is_fp8_dtype
(
output
->
dtype
())
&&
output
->
has_columnwise_data
())
{
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor
transpose_in
,
transpose_out
,
dummy
;
transpose_in
.
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
;
transpose_in
.
data
.
dptr
=
output
->
data
.
dptr
;
transpose_in
.
data
.
shape
=
{
output
->
flat_first_dim
(),
output
->
flat_last_dim
()};
transpose_in
.
data
.
dtype
=
output
->
data
.
dtype
;
transpose_out
.
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
;
transpose_out
.
data
.
dptr
=
output
->
columnwise_data
.
dptr
;
transpose_out
.
data
.
shape
=
{
output
->
flat_last_dim
(),
output
->
flat_first_dim
()};
transpose_out
.
data
.
dtype
=
output
->
data
.
dtype
;
detail
::
transpose
(
transpose_in
,
/*noop=*/
dummy
,
&
transpose_out
,
stream
);
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
...
...
@@ -98,8 +113,8 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
const
size_t
rows
=
gated_input
.
flat_first_dim
();
const
size_t
cols
=
gated_input
.
flat_last_dim
()
/
2
;
NVTE_CHECK
(
!
is_fp8_dtype
(
grad
.
data
.
dtype
),
"Grad input must be in higher precision."
);
NVTE_CHECK
(
grad
.
data
.
dtype
==
gated_input
.
data
.
dtype
,
"Types of both inputs must match."
);
NVTE_CHECK
(
!
is_fp8_dtype
(
grad
.
dtype
()
),
"Grad input must be in higher precision."
);
NVTE_CHECK
(
grad
.
dtype
()
==
gated_input
.
dtype
()
,
"Types of both inputs must match."
);
NVTE_CHECK
(
grad
.
flat_first_dim
()
==
rows
,
"Wrong Grad shape. Expected first dimension (after flattening) ["
,
rows
,
", *], got ["
,
...
...
@@ -116,9 +131,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
NVTE_CHECK
(
output
->
flat_last_dim
()
==
cols
*
2
,
"Wrong output shape. Expected (after flattening) [*, "
,
cols
*
2
,
"], got ["
,
output
->
flat_first_dim
(),
", "
,
output
->
flat_last_dim
(),
"]."
);
NVTE_CHECK
(
gated_input
.
data
.
shape
==
output
->
data
.
shape
,
"Gated input and output shapes must match. Input shape: "
,
gated_input
.
data
.
shape
,
", output shape: "
,
output
->
data
.
shape
,
"."
);
NVTE_CHECK
(
gated_input
.
shape
()
==
output
->
shape
()
,
"Gated input and output shapes must match. Input shape: "
,
gated_input
.
shape
()
,
", output shape: "
,
output
->
shape
()
,
"."
);
switch
(
output
->
scaling_mode
)
{
case
NVTE_DELAYED_TENSOR_SCALING
:
{
...
...
@@ -129,6 +144,20 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte
}
else
{
fp8
::
cast_gated_bwd
<
ParamOP
,
ActOP
,
DActOP
>
(
gated_input
,
grad
,
output
,
p
,
stream
);
}
if
(
is_fp8_dtype
(
output
->
dtype
())
&&
output
->
has_columnwise_data
())
{
// FP8 kernel only populates row-wise data, so perform
// transpose separately if needed
Tensor
transpose_in
,
transpose_out
,
dummy
;
transpose_in
.
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
;
transpose_in
.
data
.
dptr
=
output
->
data
.
dptr
;
transpose_in
.
data
.
shape
=
{
output
->
flat_first_dim
(),
output
->
flat_last_dim
()};
transpose_in
.
data
.
dtype
=
output
->
data
.
dtype
;
transpose_out
.
scaling_mode
=
NVTE_DELAYED_TENSOR_SCALING
;
transpose_out
.
data
.
dptr
=
output
->
columnwise_data
.
dptr
;
transpose_out
.
data
.
shape
=
{
output
->
flat_last_dim
(),
output
->
flat_first_dim
()};
transpose_out
.
data
.
dtype
=
output
->
data
.
dtype
;
detail
::
transpose
(
transpose_in
,
/*noop=*/
dummy
,
&
transpose_out
,
stream
);
}
break
;
}
case
NVTE_MXFP8_1D_SCALING
:
{
...
...
transformer_engine/common/cast/dispatch/quantize.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -19,6 +19,7 @@
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
#include "../nvfp4/quantize_transpose_nvfp4.cuh"
...
...
@@ -154,17 +155,10 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output,
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
if
(
output_tensor
->
has_data
())
{
bool
rowwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
rowwise_option
=
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
}
if
(
output_tensor
->
has_columnwise_data
())
{
bool
columnwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
columnwise_option
=
columnwise_compact
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
}
quantize_transpose_vector_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
...
...
@@ -307,17 +301,10 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
FP8BlockwiseRowwiseOption
rowwise_option
=
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
NONE
;
if
(
output_tensor
->
has_data
())
{
bool
rowwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
rowwise_option
=
rowwise_compact
?
FP8BlockwiseRowwiseOption
::
ROWWISE_COMPACT
:
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
rowwise_option
=
FP8BlockwiseRowwiseOption
::
ROWWISE_GEMM_READY
;
}
if
(
output_tensor
->
has_columnwise_data
())
{
bool
columnwise_compact
=
(
quant_config_cpp
.
float8_block_scale_tensor_format
==
Float8BlockScaleTensorFormat
::
COMPACT
);
columnwise_option
=
columnwise_compact
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_COMPACT
:
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
columnwise_option
=
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_GEMM_READY
;
}
quantize_transpose_vector_blockwise
(
grad_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
...
...
@@ -330,6 +317,70 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens
}
}
template
<
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
group_quantize_fwd_helper
(
const
NVTETensor
input
,
NVTETensor
*
outputs
,
const
size_t
*
split_sections
,
const
size_t
num_tensors
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
using
namespace
detail
;
const
Tensor
*
input_tensor
=
convertNVTETensorCheck
(
input
);
std
::
vector
<
Tensor
*>
output_tensors
;
for
(
size_t
i
=
0
;
i
<
num_tensors
;
++
i
)
{
output_tensors
.
push_back
(
convertNVTETensorCheck
(
outputs
[
i
]));
}
// Quantization config
QuantizationConfig
quant_config_cpp
;
if
(
quant_config
!=
nullptr
)
{
quant_config_cpp
=
*
reinterpret_cast
<
QuantizationConfig
*>
(
quant_config
);
}
// Noop flag
Tensor
dummy_tensor
;
Tensor
*
noop_tensor
=
&
dummy_tensor
;
if
(
quant_config_cpp
.
noop_tensor
!=
nullptr
)
{
noop_tensor
=
convertNVTETensorCheck
(
quant_config_cpp
.
noop_tensor
);
}
// Check for unsupported options
if
(
quant_config_cpp
.
stochastic_rounding
)
{
NVTE_CHECK
(
output_tensors
[
0
]
->
scaling_mode
==
NVTE_NVFP4_1D_SCALING
,
"Stochastic rounding is only supported for NVFP4 quantization."
);
}
// Take the scaling mode of the first output tensor
auto
scaling_mode
=
output_tensors
[
0
]
->
scaling_mode
;
// Dispatch to quantization kernel depending on data format
switch
(
scaling_mode
)
{
case
NVTE_NVFP4_1D_SCALING
:
{
NVTE_CHECK
(
!
IS_ACT
,
"IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"
);
// Check tensors
CheckNoopTensor
(
*
noop_tensor
,
"cast_noop"
);
CheckInputTensor
(
*
input_tensor
,
"input"
);
// Skip checking output tensor list
// output list here is allowed to have empty tensor
// Choose kernel
int32_t
rows
=
input_tensor
->
flat_first_dim
();
int32_t
cols
=
input_tensor
->
flat_last_dim
();
auto
dtype
=
input_tensor
->
dtype
();
NVTE_CHECK
(
!
quant_config_cpp
.
nvfp4_2d_quantization
,
"2D quantization is not supported for group quantize."
);
// Launch NVFP4 group quantize kernel
nvfp4
::
group_quantize_transpose
<
/*use_2d_quantization*/
false
>
(
*
input_tensor
,
noop_tensor
,
output_tensors
,
split_sections
,
num_tensors
,
&
quant_config_cpp
,
stream
);
break
;
}
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
scaling_mode
)
+
"."
);
}
}
}
// namespace dispatch
}
// namespace transformer_engine
...
...
transformer_engine/common/cast/fp8/dequantize_fp8.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/cast/fp8/gated_fp8.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/cast/fp8/quantize_fp8.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -234,8 +234,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
bool
use_colwise_scaling
=
input
.
has_columnwise_data
();
checkCuDriverContext
(
stream
);
const
auto
&
input_shape
=
input
.
data
.
shape
;
NVTE_CHECK
(
input_shape
.
size
()
>=
2
,
"Input must have at least 2 dimensions."
);
NVTE_CHECK
(
input
.
dim
()
>=
2
,
"Input must have at least 2 dimensions."
);
if
(
use_rowwise_scaling
)
{
NVTE_CHECK
(
input
.
has_data
(),
"Cannot dequantize tensor without rowwise data."
);
...
...
@@ -247,8 +246,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_CHECK
(
is_fp8_dtype
(
input
.
columnwise_data
.
dtype
),
"Input must have FP8 type."
);
}
NVTE_CHECK
(
!
input
.
with_gemm_swizzled_scales
,
"Input must have scales in compact format."
);
NVTE_CHECK
(
!
is_fp8_dtype
(
output
->
data
.
dtype
),
"Output must be in higher precision."
);
NVTE_CHECK
(
output
->
data
.
shape
==
input
.
data
.
shape
,
"Input and output shapes need to match."
);
NVTE_CHECK
(
output
->
shape
()
==
input
.
shape
()
,
"Input and output shapes need to match."
);
// TODO: Make more general
const
size_t
scale_dim_X_rowwise
=
use_rowwise_scaling
?
32
:
1
;
...
...
transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -22,6 +22,7 @@
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "swizzle.cuh"
namespace
transformer_engine
{
namespace
dispatch
{
...
...
@@ -54,7 +55,8 @@ constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128
#ifndef __HIP_PLATFORM_AMD__
template
<
bool
IS_BWD
,
typename
ParamOP
,
float
(
*
ActOP
)(
float
,
const
ParamOP
&
),
float
(
*
DActOP
)(
float
,
const
ParamOP
&
),
typename
IType
,
typename
OType
,
bool
ROWWISE_SCALING
,
bool
COLWISE_SCALING
,
size_t
THREADS_PER_CHUNK
>
bool
ROWWISE_SCALING
,
bool
COLWISE_SCALING
,
bool
WITH_GEMM_SWIZZLED_SCALES
,
size_t
THREADS_PER_CHUNK
>
__global__
void
__launch_bounds__
(
THREADS_PER_CHUNK
)
quantize_gated_mxfp8_kernel
(
const
__grid_constant__
CUtensorMap
tensor_map_grad
,
const
__grid_constant__
CUtensorMap
tensor_map_input_act
,
...
...
@@ -71,6 +73,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
using
IType2
=
typename
ptx
::
FPx2
<
IType
>
;
using
OType2
=
typename
ptx
::
FPx2
<
OType
>
;
using
transformer_engine
::
dispatch
::
mxfp8
::
swizzle
::
gemm_swizzled_scale_idx
;
constexpr
size_t
STAGES
=
CHUNK_DIM_Y
/
BUFF_DIM_Y
;
static_assert
(
STAGES
>=
1
);
...
...
@@ -358,14 +362,17 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// 2. Compute E8M0 scaling factor
const
e8m0_t
biased_exponent_act
=
ptx
::
float_to_e8m0
(
thread_amax_act
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
size_t
global_scales_offset_Y
=
scales_offset_Y_colwise
+
stage
;
const
size_t
global_scales_offset_X
=
scales_offset_X_colwise
;
const
size_t
scale_idx
=
global_scales_offset_Y
*
scale_stride_colwise
+
global_scales_offset_X
;
size_t
scale_idx
;
if
constexpr
(
WITH_GEMM_SWIZZLED_SCALES
)
{
scale_idx
=
gemm_swizzled_scale_idx
(
global_scales_offset_X
,
global_scales_offset_Y
,
DIVUP
(
rows
,
static_cast
<
size_t
>
(
128
)));
}
else
{
scale_idx
=
global_scales_offset_Y
*
scale_stride_colwise
+
global_scales_offset_X
;
}
const
bool
row_out_of_bounds_colwise
=
(
row_base_colwise
+
stage_offset_Y
)
>=
rows
;
const
bool
out_of_bounds_colwise
=
row_out_of_bounds_colwise
||
col_out_of_bounds_colwise
;
if
(
tid_Y_colwise
==
0
&&
(
!
out_of_bounds_colwise
))
{
scales_colwise
[
scale_idx
]
=
biased_exponent_act
;
}
...
...
@@ -377,8 +384,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
e8m0_t
biased_exponent_gate
=
ptx
::
float_to_e8m0
(
thread_amax_gate
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
const
size_t
scale_idx_gate
=
scale_idx
+
gate_scale_idx_offset_colwise
;
size_t
scale_idx_gate
;
if
constexpr
(
WITH_GEMM_SWIZZLED_SCALES
)
{
scale_idx_gate
=
gemm_swizzled_scale_idx
(
global_scales_offset_X
+
gate_scale_idx_offset_colwise
,
global_scales_offset_Y
,
DIVUP
(
rows
,
static_cast
<
size_t
>
(
128
)));
}
else
{
scale_idx_gate
=
scale_idx
+
gate_scale_idx_offset_colwise
;
}
if
(
tid_Y_colwise
==
0
&&
(
!
out_of_bounds_colwise
))
{
scales_colwise
[
scale_idx_gate
]
=
biased_exponent_gate
;
}
...
...
@@ -560,7 +573,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
ptx
::
float_to_e8m0
(
thread_amax_act
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
size_t
stage_scales_offset_Y
=
scales_offset_Y_rowwise
+
stage_offset_Y
;
const
size_t
stage_scales_offset_X
=
scales_offset_X_rowwise
;
const
size_t
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
size_t
scale_idx
;
if
constexpr
(
WITH_GEMM_SWIZZLED_SCALES
)
{
const
size_t
output_cols
=
(
IS_BWD
?
2
:
1
)
*
cols
;
scale_idx
=
gemm_swizzled_scale_idx
(
stage_scales_offset_Y
,
stage_scales_offset_X
,
DIVUP
(
output_cols
,
static_cast
<
size_t
>
(
128
)));
}
else
{
scale_idx
=
stage_scales_offset_Y
*
scale_stride_rowwise
+
stage_scales_offset_X
;
}
const
bool
row_out_of_bounds_rowwise
=
(
row_base_rowwise
+
stage_offset_Y
)
>=
rows
;
const
bool
out_of_bounds_rowwise
=
row_out_of_bounds_rowwise
||
col_out_of_bounds_rowwise
;
if
(
!
out_of_bounds_rowwise
)
{
...
...
@@ -576,7 +596,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
if
constexpr
(
IS_BWD
)
{
const
e8m0_t
biased_exponent_gate
=
ptx
::
float_to_e8m0
(
thread_amax_gate
*
Quantized_Limits
<
OType
>::
max_norm_rcp
);
const
size_t
scale_idx_gate
=
scale_idx
+
gate_scale_idx_offset_rowwise
;
size_t
scale_idx_gate
;
if
constexpr
(
WITH_GEMM_SWIZZLED_SCALES
)
{
const
size_t
output_cols
=
(
IS_BWD
?
2
:
1
)
*
cols
;
scale_idx_gate
=
gemm_swizzled_scale_idx
(
stage_scales_offset_Y
,
stage_scales_offset_X
+
gate_scale_idx_offset_rowwise
,
DIVUP
(
output_cols
,
static_cast
<
size_t
>
(
128
)));
}
else
{
scale_idx_gate
=
scale_idx
+
gate_scale_idx_offset_rowwise
;
}
if
(
!
out_of_bounds_rowwise
)
{
scales_rowwise
[
scale_idx_gate
]
=
biased_exponent_gate
;
}
...
...
@@ -670,7 +699,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
parity
^=
1
;
destroy_barriers
<
STAGES
>
(
mbar
,
is_master_thread
);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
}
// NOLINT(readability/fn_size)
#endif // __HIP_PLATFORM_AMD__
}
// namespace gated_kernel
...
...
@@ -686,6 +715,7 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
const
bool
USE_ROWWISE_SCALING
=
output
->
has_data
();
const
bool
USE_COLWISE_SCALING
=
output
->
has_columnwise_data
();
const
bool
with_gemm_swizzled_scales
=
output
->
with_gemm_swizzled_scales
;
if
(
USE_ROWWISE_SCALING
)
{
NVTE_CHECK
(
output
->
scale_inv
.
dptr
!=
nullptr
,
"Scaling tensor must be allocated."
);
...
...
@@ -729,113 +759,140 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu
gated_input
.
dtype
(),
IType
,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY
(
output
->
dtype
(),
OType
,
TRANSFORMER_ENGINE_SWITCH_CONDITION
(
with_gemm_swizzled_scales
,
WITH_GEMM_SWIZZLED_SCALES
,
alignas
(
64
)
CUtensorMap
tensor_map_grad
{};
alignas
(
64
)
CUtensorMap
tensor_map_input_act
{};
alignas
(
64
)
CUtensorMap
tensor_map_input_gate
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_act_rowwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_gate_rowwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_act_colwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_gate_colwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_grad
{};
alignas
(
64
)
CUtensorMap
tensor_map_input_act
{};
alignas
(
64
)
CUtensorMap
tensor_map_input_gate
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_act_rowwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_gate_rowwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_act_colwise
{};
alignas
(
64
)
CUtensorMap
tensor_map_output_gate_colwise
{};
constexpr
size_t
input_type_bit_size
=
TypeInfo
<
IType
>::
size
;
constexpr
size_t
output_type_bit_size
=
TypeInfo
<
OType
>::
size
;
constexpr
size_t
input_type_bit_size
=
TypeInfo
<
IType
>::
size
;
constexpr
size_t
output_type_bit_size
=
TypeInfo
<
OType
>::
size
;
if
constexpr
(
IS_BWD
)
{
create_2D_tensor_map
(
tensor_map_grad
,
grad
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
input_type_bit_size
);
}
if
constexpr
(
IS_BWD
)
{
create_2D_tensor_map
(
tensor_map_grad
,
grad
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
,
0
,
input_type_bit_size
);
}
const
uint32_t
tensor_stride_elems
=
output_cols
;
create_2D_tensor_map
(
tensor_map_input_act
,
gated_input
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
*
2
,
0
,
input_type_bit_size
);
create_2D_tensor_map
(
tensor_map_input_gate
,
gated_input
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
*
2
,
cols
,
input_type_bit_size
);
if
(
USE_ROWWISE_SCALING
)
{
create_2D_tensor_map
(
tensor_map_output_act_rowwise
,
output
->
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
tensor_stride_elems
,
0
,
output_type_bit_size
);
create_2D_tensor_map
(
tensor_map_output_gate_rowwise
,
output
->
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
tensor_stride_elems
,
cols
,
output_type_bit_size
);
}
const
uint32_t
tensor_stride_elems
=
output_cols
;
create_2D_tensor_map
(
tensor_map_input_act
,
gated_input
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
*
2
,
0
,
input_type_bit_size
);
create_2D_tensor_map
(
tensor_map_input_gate
,
gated_input
.
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
cols
*
2
,
cols
,
input_type_bit_size
);
if
(
USE_ROWWISE_SCALING
)
{
create_2D_tensor_map
(
tensor_map_output_act_rowwise
,
output
->
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
tensor_stride_elems
,
0
,
output_type_bit_size
);
create_2D_tensor_map
(
tensor_map_output_gate_rowwise
,
output
->
data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
tensor_stride_elems
,
cols
,
output_type_bit_size
);
}
if
(
USE_COLWISE_SCALING
)
{
create_2D_tensor_map
(
tensor_map_output_act_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
tensor_stride_elems
,
0
,
output_type_bit_size
);
create_2D_tensor_map
(
tensor_map_output_gate_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
tensor_stride_elems
,
cols
,
output_type_bit_size
);
}
if
(
USE_COLWISE_SCALING
)
{
create_2D_tensor_map
(
tensor_map_output_act_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
tensor_stride_elems
,
0
,
output_type_bit_size
);
create_2D_tensor_map
(
tensor_map_output_gate_colwise
,
output
->
columnwise_data
,
rows
,
cols
,
BUFF_DIM_Y
,
BUFF_DIM_X
,
tensor_stride_elems
,
cols
,
output_type_bit_size
);
}
const
size_t
buff_elems_total
=
BUFFS_NUM
*
BUFF_DIM_Y
*
BUFF_DIM_X
;
const
size_t
input_buff_size
=
(
buff_elems_total
*
input_type_bit_size
)
/
8
;
const
size_t
output_buff_size
=
(
buff_elems_total
*
output_type_bit_size
)
/
8
;
const
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
input_buff_size
,
TMA_SHMEM_ALIGNMENT
);
const
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
(
output_buff_size
,
TMA_SHMEM_ALIGNMENT
);
const
size_t
grad_mem
=
(
IS_BWD
?
buff_size_aligned_in
:
0
);
const
size_t
in_act_mem
=
buff_size_aligned_in
;
const
size_t
in_gate_mem
=
buff_size_aligned_in
;
const
size_t
in_mem
=
grad_mem
+
in_act_mem
+
in_gate_mem
;
const
size_t
out_act_mem
=
buff_size_aligned_out
;
const
size_t
out_gate_mem
=
(
IS_BWD
?
buff_size_aligned_out
:
0
);
size_t
out_mem
=
out_act_mem
+
out_gate_mem
;
if
(
USE_ROWWISE_SCALING
&&
USE_COLWISE_SCALING
)
{
out_mem
*=
2
;
}
const
size_t
shmem_size
=
in_mem
+
out_mem
+
TMA_SHMEM_ALIGNMENT
;
switch
(
scaling_type
)
{
case
ScalingType
::
ROWWISE
:
{
auto
kernel
=
quantize_gated_mxfp8_kernel
<
IS_BWD
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
true
,
false
,
THREADS_PER_CHUNK_NON_COLWISE
>
;
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
kernel
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
,
p
);
break
;
}
case
ScalingType
::
COLWISE
:
{
auto
kernel
=
quantize_gated_mxfp8_kernel
<
IS_BWD
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
false
,
true
,
THREADS_PER_CHUNK_COLWISE
>
;
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
kernel
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
,
p
);
break
;
}
case
ScalingType
::
BIDIMENSIONAL
:
{
auto
kernel
=
quantize_gated_mxfp8_kernel
<
IS_BWD
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
true
,
true
,
THREADS_PER_CHUNK_NON_COLWISE
>
;
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
kernel
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
,
p
);
break
;
}
}
NVTE_CHECK_CUDA
(
cudaGetLastError
()););
// NOLINT(*)
);
// NOLINT(*)
const
size_t
buff_elems_total
=
BUFFS_NUM
*
BUFF_DIM_Y
*
BUFF_DIM_X
;
const
size_t
input_buff_size
=
(
buff_elems_total
*
input_type_bit_size
)
/
8
;
const
size_t
output_buff_size
=
(
buff_elems_total
*
output_type_bit_size
)
/
8
;
const
size_t
buff_size_aligned_in
=
DIVUP_TO_MULTIPLE
(
input_buff_size
,
TMA_SHMEM_ALIGNMENT
);
const
size_t
buff_size_aligned_out
=
DIVUP_TO_MULTIPLE
(
output_buff_size
,
TMA_SHMEM_ALIGNMENT
);
const
size_t
grad_mem
=
(
IS_BWD
?
buff_size_aligned_in
:
0
);
const
size_t
in_act_mem
=
buff_size_aligned_in
;
const
size_t
in_gate_mem
=
buff_size_aligned_in
;
const
size_t
in_mem
=
grad_mem
+
in_act_mem
+
in_gate_mem
;
const
size_t
out_act_mem
=
buff_size_aligned_out
;
const
size_t
out_gate_mem
=
(
IS_BWD
?
buff_size_aligned_out
:
0
);
size_t
out_mem
=
out_act_mem
+
out_gate_mem
;
if
(
USE_ROWWISE_SCALING
&&
USE_COLWISE_SCALING
)
{
out_mem
*=
2
;
}
const
size_t
shmem_size
=
in_mem
+
out_mem
+
TMA_SHMEM_ALIGNMENT
;
// Zero out swizzled scales if padding is needed
/// TODO (tmoon) Handle this within the cast kernel
if
(
with_gemm_swizzled_scales
)
{
constexpr
size_t
TILE_DIM_X
=
128
;
// Tile dim in data buffer
constexpr
size_t
TILE_DIM_Y
=
128
;
if
(
cols
%
TILE_DIM_X
!=
0
||
rows
%
TILE_DIM_Y
!=
0
)
{
if
(
USE_ROWWISE_SCALING
)
{
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
output
->
scale_inv
.
dptr
,
0
,
output
->
scale_inv
.
buffer_size_bytes
(),
stream
));
}
if
(
USE_COLWISE_SCALING
)
{
NVTE_CHECK_CUDA
(
cudaMemsetAsync
(
output
->
columnwise_scale_inv
.
dptr
,
0
,
output
->
columnwise_scale_inv
.
buffer_size_bytes
(),
stream
));
}
}
}
switch
(
scaling_type
)
{
case
ScalingType
::
ROWWISE
:
{
auto
kernel
=
quantize_gated_mxfp8_kernel
<
IS_BWD
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
true
,
false
,
WITH_GEMM_SWIZZLED_SCALES
,
THREADS_PER_CHUNK_NON_COLWISE
>
;
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
kernel
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
,
p
);
break
;
}
case
ScalingType
::
COLWISE
:
{
auto
kernel
=
quantize_gated_mxfp8_kernel
<
IS_BWD
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
false
,
true
,
WITH_GEMM_SWIZZLED_SCALES
,
THREADS_PER_CHUNK_COLWISE
>
;
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
kernel
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
,
p
);
break
;
}
case
ScalingType
::
BIDIMENSIONAL
:
{
auto
kernel
=
quantize_gated_mxfp8_kernel
<
IS_BWD
,
ParamOP
,
ActOP
,
DActOP
,
IType
,
OType
,
true
,
true
,
WITH_GEMM_SWIZZLED_SCALES
,
THREADS_PER_CHUNK_NON_COLWISE
>
;
NVTE_CHECK_CUDA
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shmem_size
));
kernel
<<<
grid
,
block_size
,
shmem_size
,
stream
>>>
(
tensor_map_grad
,
tensor_map_input_act
,
tensor_map_input_gate
,
tensor_map_output_act_rowwise
,
tensor_map_output_gate_rowwise
,
tensor_map_output_act_colwise
,
tensor_map_output_gate_colwise
,
scales_rowwise_ptr
,
scales_colwise_ptr
,
rows
,
cols
,
scale_stride_rowwise
,
scale_stride_colwise
,
p
);
break
;
}
}
NVTE_CHECK_CUDA
(
cudaGetLastError
()););
// NOLINT(*)
);
// NOLINT(*)
);
// NOLINT(*)
#endif
}
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment