Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
c520cba3
Commit
c520cba3
authored
Mar 20, 2025
by
yuguo
Browse files
[DCU] Preliminary adaptation
parent
5b6ef054
Changes
79
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1215 additions
and
119 deletions
+1215
-119
build_tools/build_ext.py
build_tools/build_ext.py
+18
-8
build_tools/pytorch.py
build_tools/pytorch.py
+60
-34
build_tools/te_version.py
build_tools/te_version.py
+25
-1
build_tools/utils.py
build_tools/utils.py
+98
-0
hipify_custom_map.json
hipify_custom_map.json
+15
-0
setup.py
setup.py
+10
-1
tests/cpp/CMakeLists.txt
tests/cpp/CMakeLists.txt
+54
-11
tests/cpp/operator/CMakeLists.txt
tests/cpp/operator/CMakeLists.txt
+68
-28
tests/cpp/operator/test_cublaslt_gemm.cu
tests/cpp/operator/test_cublaslt_gemm.cu
+582
-0
tests/cpp/operator/test_normalization.cu
tests/cpp/operator/test_normalization.cu
+4
-0
tests/cpp/test_common.cu
tests/cpp/test_common.cu
+31
-0
tests/cpp/test_common.h
tests/cpp/test_common.h
+14
-2
tests/pytorch/distributed/run_numerics.py
tests/pytorch/distributed/run_numerics.py
+1
-0
tests/pytorch/distributed/test_fusible_ops.py
tests/pytorch/distributed/test_fusible_ops.py
+2
-2
tests/pytorch/fused_attn/test_fused_attn.py
tests/pytorch/fused_attn/test_fused_attn.py
+51
-29
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+3
-2
tests/pytorch/test_cuda_graphs.py
tests/pytorch/test_cuda_graphs.py
+13
-1
tests/pytorch/test_fusible_ops.py
tests/pytorch/test_fusible_ops.py
+12
-0
tests/pytorch/test_gemm_autotune.py
tests/pytorch/test_gemm_autotune.py
+153
-0
tests/pytorch/test_numerics.py
tests/pytorch/test_numerics.py
+1
-0
No files found.
build_tools/build_ext.py
View file @
c520cba3
...
@@ -19,6 +19,8 @@ from typing import List, Optional, Type
...
@@ -19,6 +19,8 @@ from typing import List, Optional, Type
import
setuptools
import
setuptools
from
.utils
import
(
from
.utils
import
(
rocm_build
,
rocm_path
,
cmake_bin
,
cmake_bin
,
debug_build_enabled
,
debug_build_enabled
,
found_ninja
,
found_ninja
,
...
@@ -155,26 +157,34 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
...
@@ -155,26 +157,34 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
ext
.
extra_compile_args
[
target
]
=
[]
ext
.
extra_compile_args
[
target
]
=
[]
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
# Also redirect .hip files to HIPCC
original_compile_fn
=
self
.
compiler
.
_compile
original_compile_fn
=
self
.
compiler
.
_compile
self
.
compiler
.
src_extensions
+=
[
".cu"
,
".cuh"
]
self
.
compiler
.
src_extensions
+=
[
".cu"
,
".cuh"
,
".hip"
]
def
_compile_fn
(
obj
,
src
,
ext
,
cc_args
,
extra_postargs
,
pp_opts
)
->
None
:
def
_compile_fn
(
obj
,
src
,
ext
,
cc_args
,
extra_postargs
,
pp_opts
)
->
None
:
# Copy before we make any modifications.
# Copy before we make any modifications.
cflags
=
copy
.
deepcopy
(
extra_postargs
)
cflags
=
copy
.
deepcopy
(
extra_postargs
)
original_compiler
=
self
.
compiler
.
compiler_so
original_compiler
=
self
.
compiler
.
compiler_so
try
:
try
:
if
rocm_build
():
_
,
nvcc_bin
=
rocm_path
()
else
:
_
,
nvcc_bin
=
cuda_path
()
_
,
nvcc_bin
=
cuda_path
()
original_compiler
=
self
.
compiler
.
compiler_so
original_compiler
=
self
.
compiler
.
compiler_so
if
os
.
path
.
splitext
(
src
)[
1
]
in
[
".cu"
,
".cuh"
]:
if
os
.
path
.
splitext
(
src
)[
1
]
in
[
".cu"
,
".cuh"
,
".hip"
]:
self
.
compiler
.
set_executable
(
"compiler_so"
,
str
(
nvcc_bin
))
self
.
compiler
.
set_executable
(
"compiler_so"
,
str
(
nvcc_bin
))
if
isinstance
(
cflags
,
dict
):
if
isinstance
(
cflags
,
dict
):
cflags
=
cflags
[
"nvcc"
]
cflags
=
cflags
[
"nvcc"
]
# Add -fPIC if not already specified
# Add -fPIC if not already specified
if
not
any
(
"-fPIC"
in
flag
for
flag
in
cflags
):
if
not
any
(
"-fPIC"
in
flag
for
flag
in
cflags
):
if
rocm_build
():
cflags
.
append
(
"-fPIC"
)
else
:
cflags
.
extend
([
"--compiler-options"
,
"'-fPIC'"
])
cflags
.
extend
([
"--compiler-options"
,
"'-fPIC'"
])
if
not
rocm_build
():
# Forward unknown options
# Forward unknown options
if
not
any
(
"--forward-unknown-opts"
in
flag
for
flag
in
cflags
):
if
not
any
(
"--forward-unknown-opts"
in
flag
for
flag
in
cflags
):
cflags
.
append
(
"--forward-unknown-opts"
)
cflags
.
append
(
"--forward-unknown-opts"
)
...
...
build_tools/pytorch.py
View file @
c520cba3
...
@@ -9,6 +9,8 @@ from pathlib import Path
...
@@ -9,6 +9,8 @@ from pathlib import Path
import
setuptools
import
setuptools
from
.utils
import
(
from
.utils
import
(
rocm_build
,
hipify
,
all_files_in_dir
,
all_files_in_dir
,
cuda_archs
,
cuda_archs
,
cuda_version
,
cuda_version
,
...
@@ -37,11 +39,27 @@ def setup_pytorch_extension(
...
@@ -37,11 +39,27 @@ def setup_pytorch_extension(
csrc_header_files
,
csrc_header_files
,
]
]
if
rocm_build
():
current_file_path
=
Path
(
__file__
).
parent
.
resolve
()
base_dir
=
current_file_path
.
parent
sources
=
hipify
(
base_dir
,
csrc_source_files
,
sources
,
include_dirs
)
# Compiler flags
# Compiler flags
cxx_flags
=
[
cxx_flags
=
[
"-O3"
,
"-O3"
,
"-fvisibility=hidden"
,
"-fvisibility=hidden"
,
]
]
if
rocm_build
():
nvcc_flags
=
[
"-O3"
,
"-U__HIP_NO_HALF_OPERATORS__"
,
"-U__HIP_NO_HALF_CONVERSIONS__"
,
"-U__HIP_NO_BFLOAT16_OPERATORS__"
,
"-U__HIP_NO_BFLOAT16_CONVERSIONS__"
,
"-U__HIP_NO_BFLOAT162_OPERATORS__"
,
"-U__HIP_NO_BFLOAT162_CONVERSIONS__"
,
]
else
:
nvcc_flags
=
[
nvcc_flags
=
[
"-O3"
,
"-O3"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
...
@@ -54,13 +72,14 @@ def setup_pytorch_extension(
...
@@ -54,13 +72,14 @@ def setup_pytorch_extension(
"--expt-extended-lambda"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
"--use_fast_math"
,
]
]
# Version-dependent CUDA options
if
rocm_build
():
##TODO: Figure out which hipcc version starts to support this parallel compilation
nvcc_flags
.
extend
([
"-parallel-jobs=4"
])
else
:
cuda_architectures
=
cuda_archs
()
cuda_architectures
=
cuda_archs
()
if
"70"
in
cuda_architectures
:
if
"70"
in
cuda_architectures
:
nvcc_flags
.
extend
([
"-gencode"
,
"arch=compute_70,code=sm_70"
])
nvcc_flags
.
extend
([
"-gencode"
,
"arch=compute_70,code=sm_70"
])
# Version-dependent CUDA options
try
:
try
:
version
=
cuda_version
()
version
=
cuda_version
()
except
FileNotFoundError
:
except
FileNotFoundError
:
...
@@ -80,6 +99,9 @@ def setup_pytorch_extension(
...
@@ -80,6 +99,9 @@ def setup_pytorch_extension(
continue
# Already handled
continue
# Already handled
nvcc_flags
.
extend
([
"-gencode"
,
f
"arch=compute_
{
arch
}
,code=sm_
{
arch
}
"
])
nvcc_flags
.
extend
([
"-gencode"
,
f
"arch=compute_
{
arch
}
,code=sm_
{
arch
}
"
])
# Libraries
library_dirs
=
[]
libraries
=
[]
if
bool
(
int
(
os
.
getenv
(
"NVTE_UB_WITH_MPI"
,
"0"
))):
if
bool
(
int
(
os
.
getenv
(
"NVTE_UB_WITH_MPI"
,
"0"
))):
assert
(
assert
(
os
.
getenv
(
"MPI_HOME"
)
is
not
None
os
.
getenv
(
"MPI_HOME"
)
is
not
None
...
@@ -88,6 +110,8 @@ def setup_pytorch_extension(
...
@@ -88,6 +110,8 @@ def setup_pytorch_extension(
include_dirs
.
append
(
mpi_path
/
"include"
)
include_dirs
.
append
(
mpi_path
/
"include"
)
cxx_flags
.
append
(
"-DNVTE_UB_WITH_MPI"
)
cxx_flags
.
append
(
"-DNVTE_UB_WITH_MPI"
)
nvcc_flags
.
append
(
"-DNVTE_UB_WITH_MPI"
)
nvcc_flags
.
append
(
"-DNVTE_UB_WITH_MPI"
)
library_dirs
.
append
(
mpi_path
/
"lib"
)
libraries
.
append
(
"mpi"
)
# Construct PyTorch CUDA extension
# Construct PyTorch CUDA extension
sources
=
[
str
(
path
)
for
path
in
sources
]
sources
=
[
str
(
path
)
for
path
in
sources
]
...
@@ -102,4 +126,6 @@ def setup_pytorch_extension(
...
@@ -102,4 +126,6 @@ def setup_pytorch_extension(
"cxx"
:
cxx_flags
,
"cxx"
:
cxx_flags
,
"nvcc"
:
nvcc_flags
,
"nvcc"
:
nvcc_flags
,
},
},
libraries
=
[
str
(
lib
)
for
lib
in
libraries
],
library_dirs
=
[
str
(
lib_dir
)
for
lib_dir
in
library_dirs
],
)
)
build_tools/te_version.py
View file @
c520cba3
...
@@ -7,6 +7,28 @@ import os
...
@@ -7,6 +7,28 @@ import os
from
pathlib
import
Path
from
pathlib
import
Path
import
subprocess
import
subprocess
DAS_VERSION
=
"1.6"
def
abi_value
():
try
:
return
(
subprocess
.
check_output
(
"echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI"
,
shell
=
True
)
.
decode
(
'ascii'
)
.
strip
()[
-
1
]
)
except
Exception
:
return
abiUNKNOWN
def
dtk_version_value
():
try
:
dtk_path
=
os
.
getenv
(
'ROCM_PATH'
)
dtk_version_path
=
os
.
path
.
join
(
dtk_path
,
'.info'
,
"version-dev"
)
with
open
(
dtk_version_path
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
lines
=
file
.
readlines
()
dtk_version
=
"dtk"
+
lines
[
0
][:].
replace
(
"."
,
""
)
return
dtk_version
except
Exception
:
return
UNKNOWN
def
te_version
()
->
str
:
def
te_version
()
->
str
:
"""Transformer Engine version string
"""Transformer Engine version string
...
@@ -33,5 +55,7 @@ def te_version() -> str:
...
@@ -33,5 +55,7 @@ def te_version() -> str:
pass
pass
else
:
else
:
commit
=
output
.
stdout
.
strip
()
commit
=
output
.
stdout
.
strip
()
version
+=
f
"+
{
commit
}
"
version
+=
"+das"
+
DAS_VERSION
+
f
".git
{
commit
}
"
+
".abi"
+
str
(
abi_value
())
+
"."
+
str
(
dtk_version_value
())
else
:
version
+=
"+das"
+
DAS_VERSION
+
f
".opt1"
+
"."
+
str
(
dtk_version_value
())
return
version
return
version
build_tools/utils.py
View file @
c520cba3
...
@@ -161,6 +161,44 @@ def found_pybind11() -> bool:
...
@@ -161,6 +161,44 @@ def found_pybind11() -> bool:
return
False
return
False
@
functools
.
lru_cache
(
maxsize
=
None
)
def
rocm_build
()
->
bool
:
""" ROCm build should be performed if:
- It is configured with NVTE_USE_ROCM=1 env
OR:
- HIP compiler is found and CUDA one is not
"""
if
bool
(
int
(
os
.
getenv
(
"NVTE_USE_ROCM"
,
"0"
))):
return
True
try
:
cuda_path
()
return
False
except
FileNotFoundError
:
pass
_
,
hipcc_bin
=
rocm_path
()
return
hipcc_bin
.
is_file
()
@
functools
.
lru_cache
(
maxsize
=
None
)
def
rocm_path
()
->
Tuple
[
str
,
str
]:
"""ROCm root path and HIPCC binary path as a tuple"""
"""If ROCm installation is not specified, use default /opt/dtk path"""
if
os
.
getenv
(
"ROCM_PATH"
):
rocm_home
=
Path
(
os
.
getenv
(
"ROCM_PATH"
))
hipcc_bin
=
rocm_home
/
"bin"
/
"hipcc"
if
hipcc_bin
is
None
:
hipcc_bin
=
shutil
.
which
(
"hipcc"
)
if
hipcc_bin
is
not
None
:
hipcc_bin
=
Path
(
hipcc_bin
)
rocm_home
=
hipcc_bin
.
parent
.
parent
if
hipcc_bin
is
None
:
rocm_home
=
Path
(
"/opt/dtk/"
)
hipcc_bin
=
rocm_home
/
"bin"
/
"hipcc"
return
rocm_home
,
hipcc_bin
@
functools
.
lru_cache
(
maxsize
=
None
)
@
functools
.
lru_cache
(
maxsize
=
None
)
def
cuda_path
()
->
Tuple
[
str
,
str
]:
def
cuda_path
()
->
Tuple
[
str
,
str
]:
"""CUDA root path and NVCC binary path as a tuple.
"""CUDA root path and NVCC binary path as a tuple.
...
@@ -228,6 +266,9 @@ def get_frameworks() -> List[str]:
...
@@ -228,6 +266,9 @@ def get_frameworks() -> List[str]:
_frameworks
.
extend
(
arg
.
replace
(
"--framework="
,
""
).
split
(
","
))
_frameworks
.
extend
(
arg
.
replace
(
"--framework="
,
""
).
split
(
","
))
sys
.
argv
.
remove
(
arg
)
sys
.
argv
.
remove
(
arg
)
if
rocm_build
():
_requested_frameworks
=
[
framework
.
lower
()
for
framework
in
_frameworks
]
# Detect installed frameworks if not explicitly specified
# Detect installed frameworks if not explicitly specified
if
not
_frameworks
:
if
not
_frameworks
:
try
:
try
:
...
@@ -255,6 +296,28 @@ def get_frameworks() -> List[str]:
...
@@ -255,6 +296,28 @@ def get_frameworks() -> List[str]:
if
framework
not
in
supported_frameworks
:
if
framework
not
in
supported_frameworks
:
raise
ValueError
(
f
"Transformer Engine does not support framework=
{
framework
}
"
)
raise
ValueError
(
f
"Transformer Engine does not support framework=
{
framework
}
"
)
if
rocm_build
():
_unsupported_frameworks
=
[]
if
"pytorch"
in
_frameworks
:
try
:
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
except
ImportError
:
IS_HIP_EXTENSION
=
False
if
not
IS_HIP_EXTENSION
:
if
"pytorch"
in
_requested_frameworks
:
_unsupported_frameworks
.
append
(
"pytorch"
)
_frameworks
.
remove
(
"pytorch"
)
if
"jax"
in
_frameworks
:
if
not
any
(
re
.
match
(
r
'jax-rocm\d+-plugin'
,
d
.
metadata
[
'Name'
])
for
d
in
importlib
.
metadata
.
distributions
()):
try
:
import
jaxlib.rocm
#pre JAX 0.4.30 way
except
ImportError
:
if
"jax"
in
_requested_frameworks
:
_unsupported_frameworks
.
append
(
"jax"
)
_frameworks
.
remove
(
"jax"
)
if
_unsupported_frameworks
:
raise
ValueError
(
f
"ROCm is not supported by requested frameworks:
{
_unsupported_frameworks
}
"
)
return
_frameworks
return
_frameworks
...
@@ -293,6 +356,41 @@ def copy_common_headers(
...
@@ -293,6 +356,41 @@ def copy_common_headers(
shutil
.
copy
(
path
,
new_path
)
shutil
.
copy
(
path
,
new_path
)
def
hipify
(
base_dir
,
src_dir
,
sources
,
include_dirs
):
hipify_path
=
base_dir
/
"3rdparty"
/
"hipify_torch"
cwd
=
os
.
getcwd
()
os
.
chdir
(
hipify_path
)
from
hipify_torch.hipify_python
import
hipify
as
do_hipify
os
.
chdir
(
cwd
)
hipify_result
=
do_hipify
(
project_directory
=
src_dir
,
output_directory
=
src_dir
,
includes
=
[
"*"
],
ignores
=
[
"*/amd_detail/*"
,
"*/aotriton/*"
,
"*/ck_fused_attn/*"
],
header_include_dirs
=
include_dirs
,
custom_map_list
=
base_dir
/
"hipify_custom_map.json"
,
extra_files
=
[],
is_pytorch_extension
=
True
,
hipify_extra_files_only
=
False
,
show_detailed
=
False
)
# Because hipify output_directory == project_directory
# Original sources list may contain previous hipifying results that ends up with duplicated entries
# Keep unique entries only
hipified_sources
=
set
()
for
fname
in
sources
:
fname
=
os
.
path
.
abspath
(
str
(
fname
))
if
fname
in
hipify_result
:
file_result
=
hipify_result
[
fname
]
if
file_result
.
hipified_path
is
not
None
:
fname
=
hipify_result
[
fname
].
hipified_path
# setup() arguments must *always* be /-separated paths relative to the setup.py directory,
# *never* absolute paths
hipified_sources
.
add
(
os
.
path
.
relpath
(
fname
,
cwd
))
return
list
(
hipified_sources
)
def
install_and_import
(
package
):
def
install_and_import
(
package
):
"""Install a package via pip (if not already installed) and import into globals."""
"""Install a package via pip (if not already installed) and import into globals."""
main_package
=
package
.
split
(
"["
)[
0
]
main_package
=
package
.
split
(
"["
)[
0
]
...
...
hipify_custom_map.json
0 → 100644
View file @
c520cba3
{
"custom_map"
:
{
"<cuda_bf16.h>"
:
"<hip/hip_bf16.h>"
,
"<cuda_fp8.h>"
:
"
\"
amd_detail/hip_float8.h
\"
"
,
"CUfunc_cache"
:
"hipFuncCache_t"
,
"<nvtx3/nvToolsExt.h>"
:
"<roctracer/roctx.h>"
,
"cudaLaunchKernelExC"
:
"hipLaunchKernelExC"
,
"cudaLaunchConfig_t"
:
"hipLaunchConfig_t"
,
"cudaLaunchAttributeClusterDimension"
:
"hipLaunchAttributeClusterDimension"
,
"cudaLaunchAttributeCooperative"
:
"hipLaunchAttributeCooperative"
,
"cudaLaunchAttribute"
:
"hipLaunchAttribute"
}
}
\ No newline at end of file
setup.py
View file @
c520cba3
...
@@ -16,6 +16,7 @@ from wheel.bdist_wheel import bdist_wheel
...
@@ -16,6 +16,7 @@ from wheel.bdist_wheel import bdist_wheel
from
build_tools.build_ext
import
CMakeExtension
,
get_build_ext
from
build_tools.build_ext
import
CMakeExtension
,
get_build_ext
from
build_tools.te_version
import
te_version
from
build_tools.te_version
import
te_version
from
build_tools.utils
import
(
from
build_tools.utils
import
(
rocm_build
,
cuda_archs
,
cuda_archs
,
found_cmake
,
found_cmake
,
found_ninja
,
found_ninja
,
...
@@ -57,6 +58,9 @@ class TimedBdist(bdist_wheel):
...
@@ -57,6 +58,9 @@ class TimedBdist(bdist_wheel):
def
setup_common_extension
()
->
CMakeExtension
:
def
setup_common_extension
()
->
CMakeExtension
:
"""Setup CMake extension for common library"""
"""Setup CMake extension for common library"""
if
rocm_build
():
cmake_flags
=
[]
else
:
cmake_flags
=
[
"-DCMAKE_CUDA_ARCHITECTURES={}"
.
format
(
archs
)]
cmake_flags
=
[
"-DCMAKE_CUDA_ARCHITECTURES={}"
.
format
(
archs
)]
if
bool
(
int
(
os
.
getenv
(
"NVTE_UB_WITH_MPI"
,
"0"
))):
if
bool
(
int
(
os
.
getenv
(
"NVTE_UB_WITH_MPI"
,
"0"
))):
assert
(
assert
(
...
@@ -69,6 +73,11 @@ def setup_common_extension() -> CMakeExtension:
...
@@ -69,6 +73,11 @@ def setup_common_extension() -> CMakeExtension:
# Project directory root
# Project directory root
root_path
=
Path
(
__file__
).
resolve
().
parent
root_path
=
Path
(
__file__
).
resolve
().
parent
if
rocm_build
():
if
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
:
cmake_flags
.
append
(
"-DUSE_HIPBLASLT=ON"
)
if
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
not
None
:
cmake_flags
.
append
(
"-DUSE_ROCBLAS=ON"
)
return
CMakeExtension
(
return
CMakeExtension
(
name
=
"transformer_engine"
,
name
=
"transformer_engine"
,
...
...
tests/cpp/CMakeLists.txt
View file @
c520cba3
...
@@ -4,20 +4,59 @@
...
@@ -4,20 +4,59 @@
cmake_minimum_required
(
VERSION 3.18
)
cmake_minimum_required
(
VERSION 3.18
)
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
option
(
USE_CUDA
"Use CUDA"
ON
)
option
(
USE_ROCM
"Use ROCm"
OFF
)
if
(((
EXISTS
"/opt/dtk/"
)
OR
(
EXISTS $ENV{ROCM_PATH}
))
AND
NOT
(
EXISTS
"/bin/nvcc"
))
message
(
"hcu detected."
)
set
(
USE_ROCM ON
)
set
(
USE_CUDA OFF
)
# Add HIP to the CMAKE Module Path
# set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})
# Disable Asserts In Code (Can't use asserts on HIP stack.)
add_definitions
(
-DNDEBUG
)
add_definitions
(
-DUSE_ROCM
)
if
(
NOT DEFINED ENV{NVTE_ROCM_ARCH}
)
SET
(
CMAKE_HIP_ARCHITECTURES gfx906;gfx926;gfx928;gfx936
)
else
()
SET
(
CMAKE_HIP_ARCHITECTURES $ENV{NVTE_ROCM_ARCH}
)
endif
()
else
()
if
(
NOT DEFINED CMAKE_CUDA_ARCHITECTURES
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
if
(
CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120
)
else
()
else
()
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90
)
set
(
CMAKE_CUDA_ARCHITECTURES 75 80 89 90
)
endif
()
endif
()
endif
()
endif
()
endif
()
set
(
message_line
"-------------------------------------------------------------"
)
message
(
"
${
message_line
}
"
)
message
(
STATUS
"USE_CUDA
${
USE_CUDA
}
"
)
message
(
STATUS
"USE_ROCM
${
USE_ROCM
}
"
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
project
(
transformer_engine_tests LANGUAGES CUDA CXX
)
if
(
USE_CUDA
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD 17
)
set
(
CMAKE_CUDA_STANDARD_REQUIRED ON
)
project
(
transformer_engine_tests LANGUAGES CUDA CXX
)
else
()
set
(
CMAKE_CXX_STANDARD 17
)
project
(
transformer_engine_tests LANGUAGES HIP CXX
)
# Ask hcc to generate device code during compilation so we can use
# host linker to link.
set
(
HIP_HCC_FLAGS
"
${
HIP_HCC_FLAGS
}
-fno-gpu-rdc -Wno-defaulted-function-deleted"
)
foreach
(
rocm_arch
${
CMAKE_HIP_ARCHITECTURES
}
)
# if CMAKE_CXX_FLAGS has --offload-arch set already, better to rm first
set
(
HIP_HCC_FLAGS
"
${
HIP_HCC_FLAGS
}
--offload-arch=
${
rocm_arch
}
"
)
endforeach
()
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
HIP_HCC_FLAGS
}
"
)
endif
()
add_subdirectory
(
../../3rdparty/googletest
${
PROJECT_BINARY_DIR
}
/googletest
)
add_subdirectory
(
../../3rdparty/googletest
${
PROJECT_BINARY_DIR
}
/googletest
)
...
@@ -37,8 +76,12 @@ include_directories(../../transformer_engine/common/include)
...
@@ -37,8 +76,12 @@ include_directories(../../transformer_engine/common/include)
include_directories
(
../../transformer_engine/common
)
include_directories
(
../../transformer_engine/common
)
include_directories
(
${
CMAKE_SOURCE_DIR
}
)
include_directories
(
${
CMAKE_SOURCE_DIR
}
)
find_package
(
CUDAToolkit REQUIRED
)
if
(
USE_CUDA
)
include
(
${
CMAKE_SOURCE_DIR
}
/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake
)
find_package
(
CUDAToolkit REQUIRED
)
include
(
${
CMAKE_SOURCE_DIR
}
/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake
)
else
()
find_package
(
hip REQUIRED
)
endif
()
add_subdirectory
(
operator
)
add_subdirectory
(
operator
)
add_subdirectory
(
util
)
add_subdirectory
(
util
)
tests/cpp/operator/CMakeLists.txt
View file @
c520cba3
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
add_executable
(
test_operator
list
(
APPEND test_cuda_sources
test_cast.cu
test_cast.cu
test_cast_current_scaling.cu
test_cast_current_scaling.cu
test_cast_dbias.cu
test_cast_dbias.cu
...
@@ -26,12 +26,52 @@ add_executable(test_operator
...
@@ -26,12 +26,52 @@ add_executable(test_operator
test_causal_softmax.cu
test_causal_softmax.cu
test_swizzle.cu
test_swizzle.cu
../test_common.cu
)
../test_common.cu
)
if
(
USE_ROCM
)
list
(
APPEND test_cuda_sources
test_cublaslt_gemm.cu
)
endif
()
if
(
USE_CUDA
)
add_executable
(
test_operator
${
test_cuda_sources
}
)
else
()
message
(
"
${
message_line
}
"
)
message
(
STATUS
"CMAKE_CURRENT_SOURCE_DIR:
${
CMAKE_CURRENT_SOURCE_DIR
}
"
)
message
(
STATUS
"PROJECT_SOURCE_DIR:
${
PROJECT_SOURCE_DIR
}
"
)
set
(
TE
${
CMAKE_CURRENT_SOURCE_DIR
}
/../../..
)
set
(
THIRDPARTY
${
TE
}
/3rdparty
)
list
(
APPEND CMAKE_MODULE_PATH
"
${
THIRDPARTY
}
/hipify_torch/cmake"
)
include
(
Hipify
)
message
(
STATUS
"CMAKE_MODULE_PATH:
${
CMAKE_MODULE_PATH
}
"
)
file
(
REAL_PATH ../../../transformer_engine/common/include header_include_dir1
)
file
(
REAL_PATH ../../../transformer_engine/common header_include_dir2
)
set
(
header_include_dir
${
header_include_dir1
}
${
header_include_dir2
}
)
message
(
STATUS
"CUDA_SOURCE_DIR:
${
PROJECT_SOURCE_DIR
}
"
)
message
(
STATUS
"HEADER_INCLUDE_DIR:
${
header_include_dir
}
"
)
set
(
cuda_source_dir
${
PROJECT_SOURCE_DIR
}
)
hipify
(
CUDA_SOURCE_DIR
${
cuda_source_dir
}
HEADER_INCLUDE_DIR
${
header_include_dir
}
CUSTOM_MAP_FILE
"
${
TE
}
/hipify_custom_map.json"
)
get_hipified_list
(
"
${
test_cuda_sources
}
"
test_hip_sources
)
message
(
"
${
message_line
}
"
)
message
(
STATUS
"nvte tests hipified sources:
${
test_hip_sources
}
"
)
add_executable
(
test_operator
${
test_hip_sources
}
)
endif
()
find_package
(
OpenMP REQUIRED
)
find_package
(
OpenMP REQUIRED
)
list
(
APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main
${
TE_LIB
}
CUDA::nvrtc CUDNN::cudnn
)
if
(
USE_CUDA
)
list
(
APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main
${
TE_LIB
}
CUDA::nvrtc CUDNN::cudnn
)
target_link_libraries
(
test_operator PUBLIC
${
test_operator_LINKER_LIBS
}
OpenMP::OpenMP_CXX
)
target_link_libraries
(
test_operator PUBLIC
${
test_operator_LINKER_LIBS
}
OpenMP::OpenMP_CXX
)
target_compile_options
(
test_operator PRIVATE -O2 -fopenmp
)
target_compile_options
(
test_operator PRIVATE -O2 -fopenmp
)
else
()
target_link_libraries
(
test_operator PUBLIC hip::host hip::device GTest::gtest_main
${
TE_LIB
}
OpenMP::OpenMP_CXX
)
target_compile_options
(
test_operator PRIVATE -O2 -fopenmp
)
endif
()
include
(
GoogleTest
)
include
(
GoogleTest
)
gtest_discover_tests
(
test_operator DISCOVERY_TIMEOUT 600
)
gtest_discover_tests
(
test_operator DISCOVERY_TIMEOUT 600
)
tests/cpp/operator/test_cublaslt_gemm.cu
0 → 100644
View file @
c520cba3
/*************************************************************************
* Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
*
* License for AMD contributions = MIT. See LICENSE for more information
************************************************************************/
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include <cmath>
#include "../test_common.h"
using
namespace
transformer_engine
;
using
namespace
test
;
namespace
{
//m, k, n
std
::
vector
<
std
::
tuple
<
size_t
,
size_t
,
size_t
>>
test_case_sizes
=
{
{
2304
,
768
,
4096
},
{
768
,
768
,
4096
},
{
768
,
3072
,
4096
},
{
229
,
541
,
541
},
//primes
{
71
,
71
,
3571
},
//primes
{
29
,
29
,
17389
},
//primes
};
// A, B, Bias, Gelu, D
// Bias type choose as bf16 in use_fp8, D_type otherwise
// Gelu type the same as Bias_Type
// {DType::kFloat32, DType::kFloat32, DType::kFloat32, DType::kFloat32, DType::kFloat32},
// {DType::kFloat16, DType::kFloat16, DType::kFloat16, DType::kFloat16, DType::kFloat16},
// {DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat32},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3},
// {DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat32},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3},
// {DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat32},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat16},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kBFloat16},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E4M3},
// {DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kBFloat16, DType::kBFloat16, DType::kFloat8E5M2},
}
// namespace
// <A_type, B_type, Bias_Type, Gelu_Type D_type>, <m, k, n>
class
GEMMTestSuite
:
public
::
testing
::
TestWithParam
<
std
::
tuple
<
std
::
tuple
<
size_t
,
size_t
,
size_t
>
,
bool
,
bool
>>
{};
float
ref_gelu
(
float
x
){
float
cdf
=
0.5
f
*
(
1.0
f
+
tanhf
((
0.7978845608028654
f
*
(
x
+
0.044715
f
*
x
*
x
*
x
))));
return
x
*
cdf
;
}
template
<
typename
A_Type
,
typename
B_Type
,
typename
Bias_Type
,
typename
Gelu_Type
,
typename
D_Type
>
void
compute_ref
(
const
A_Type
*
a_data
,
const
B_Type
*
b_data
,
const
float
a_scale_inv
,
const
float
b_scale_inv
,
const
Bias_Type
*
bias_data
,
//bias is of dim m
const
float
d_scale
,
size_t
m
,
size_t
k
,
size_t
n
,
D_Type
*
ref_d_data
,
float
*
ref_d_amax
,
Gelu_Type
*
ref_gelu_data
){
*
ref_d_amax
=
0
;
for
(
size_t
ii
=
0
;
ii
<
m
;
ii
++
){
for
(
size_t
jj
=
0
;
jj
<
n
;
jj
++
){
float
val
=
0
;
for
(
size_t
kk
=
0
;
kk
<
k
;
kk
++
){
val
+=
a_scale_inv
*
b_scale_inv
*
((
float
)
a_data
[
ii
+
kk
*
m
])
*
((
float
)
b_data
[
kk
+
jj
*
k
]);
}
if
(
bias_data
){
val
+=
(
float
)
bias_data
[
ii
];
}
if
(
ref_gelu_data
){
ref_gelu_data
[
ii
+
jj
*
m
]
=
(
Gelu_Type
)(
val
);
val
=
ref_gelu
(
val
);
}
ref_d_data
[
ii
+
jj
*
m
]
=
(
D_Type
)(
val
*
d_scale
);
// update ref_d_amax if in fp8
DType
dtype
=
TypeInfo
<
D_Type
>::
dtype
;
if
(
isFp8Type
(
dtype
)){
*
ref_d_amax
=
std
::
max
<
float
>
(
*
ref_d_amax
,
std
::
fabs
(
val
));
}
}
}
}
template
<
typename
A_Type
,
typename
B_Type
,
typename
Bias_Type
,
typename
Gelu_Type
,
typename
D_Type
>
void
performTest
(
bool
use_bias
,
bool
use_gelu
,
const
size_t
m
,
const
size_t
k
,
const
size_t
n
)
{
DType
atype
=
TypeInfo
<
A_Type
>::
dtype
;
DType
btype
=
TypeInfo
<
B_Type
>::
dtype
;
DType
bias_type
=
TypeInfo
<
Bias_Type
>::
dtype
;
DType
gelu_type
=
TypeInfo
<
Gelu_Type
>::
dtype
;
DType
dtype
=
TypeInfo
<
D_Type
>::
dtype
;
// pytorch tensor storage is row-major while cublas/rocblas is column-major
Tensor
A
({
k
,
m
},
atype
);
Tensor
B
({
n
,
k
},
btype
);
Tensor
D
({
n
,
m
},
dtype
);
Tensor
bias
;
if
(
use_bias
){
bias
=
Tensor
({
m
},
bias_type
);
}
Tensor
pre_gelu_out
;
if
(
use_gelu
){
pre_gelu_out
=
Tensor
({
n
,
m
},
gelu_type
);
}
//initialize the data and scale inv of A, B
fillUniform
(
&
A
);
fillUniform
(
&
B
);
if
(
use_bias
){
fillUniform
(
&
bias
);
}
//initialize the scale of D
if
(
isFp8Type
(
dtype
)){
setRandomScale
(
&
D
);
}
bool
transa
=
false
;
bool
transb
=
false
;
bool
grad
=
false
;
bool
accumulate
=
false
;
cudaDeviceProp
prop
;
cudaGetDeviceProperties
(
&
prop
,
0
);
#ifdef __HIP_PLATFORM_AMD__
if
((
isFp8Type
(
atype
)
||
isFp8Type
(
btype
))
&&
!
(
prop
.
major
==
9
&&
prop
.
minor
>=
4
))
{
GTEST_SKIP
()
<<
"FP8 is not supported on this HW"
;
}
#endif
Tensor
Workspace
({
33554432
},
DType
::
kByte
);
//perform the gemm in GPU
nvte_cublas_gemm
(
A
.
data
(),
B
.
data
(),
D
.
data
(),
bias
.
data
(),
pre_gelu_out
.
data
(),
transa
,
transb
,
grad
,
Workspace
.
data
(),
accumulate
,
false
,
prop
.
multiProcessorCount
,
//default stream
0
);
//copy the output results from GPU memory to CPU memory
D
.
to_cpu
();
if
(
use_gelu
){
pre_gelu_out
.
to_cpu
();
}
//perform the gemm in CPU
std
::
unique_ptr
<
D_Type
[]
>
ref_D
=
std
::
make_unique
<
D_Type
[]
>
(
m
*
n
);
std
::
unique_ptr
<
Gelu_Type
[]
>
ref_pre_gelu_out
;
if
(
use_gelu
){
ref_pre_gelu_out
=
std
::
make_unique
<
Gelu_Type
[]
>
(
m
*
n
);
}
float
ref_amax_d
;
compute_ref
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
A
.
cpu_dptr
<
A_Type
>
(),
B
.
cpu_dptr
<
B_Type
>
(),
A
.
scale_inv
(),
B
.
scale_inv
(),
use_bias
?
bias
.
cpu_dptr
<
Bias_Type
>
()
:
nullptr
,
D
.
scale
(),
m
,
k
,
n
,
ref_D
.
get
(),
&
ref_amax_d
,
use_gelu
?
ref_pre_gelu_out
.
get
()
:
nullptr
);
// check if error message happens in running
cudaDeviceSynchronize
();
auto
err
=
cudaGetLastError
();
ASSERT_EQ
(
err
,
cudaSuccess
)
<<
cudaGetErrorString
(
err
);
//compare results
auto
[
atol_amax
,
rtol_amax
]
=
getTolerances
(
DType
::
kFloat32
);
if
(
isFp8Type
(
dtype
))
{
compareResults
(
"D_amax"
,
D
.
amax
(),
ref_amax_d
,
atol_amax
,
rtol_amax
);
}
auto
[
atol
,
rtol
]
=
getTolerances
(
dtype
);
//relax for certain prime number gemm
if
(
dtype
==
DType
::
kFloat32
)
{
atol
=
1e-5
;
}
compareResults
(
"D"
,
D
,
ref_D
.
get
(),
atol
,
rtol
);
if
(
use_gelu
){
auto
[
atol
,
rtol
]
=
getTolerances
(
gelu_type
);
//relax for certain prime number gemm
if
(
dtype
==
DType
::
kFloat32
)
{
atol
=
5e-6
;
}
compareResults
(
"gelu"
,
pre_gelu_out
,
ref_pre_gelu_out
.
get
(),
atol
,
rtol
);
}
}
using
fp32
=
float
;
using
fp8
=
fp8e4m3
;
using
bf8
=
fp8e5m2
;
TEST_P
(
GEMMTestSuite
,
Testfp32xfp32xfp32xfp32xfp32
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp32
;
using
B_Type
=
fp32
;
using
Bias_Type
=
fp32
;
using
Gelu_Type
=
fp32
;
using
D_Type
=
fp32
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp16xfp16xfp16xfp16xfp16
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp16
;
using
B_Type
=
fp16
;
using
Bias_Type
=
fp16
;
using
Gelu_Type
=
fp16
;
using
D_Type
=
fp16
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testbf16xbf16xbf16xbf16xbf16
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
bf16
;
using
B_Type
=
bf16
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
bf16
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xfp8xbf16xbf16xfp32
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp32
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xfp8xbf16xbf16xfp16
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp16
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xfp8xbf16xbf16xbf16
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
bf16
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xfp8xbf16xbf16xfp8
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp8
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xfp8xbf16xbf16xbf8
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
bf8
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xbf8xbf16xbf16xfp32
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
bf8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp32
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xbf8xbf16xbf16xfp16
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
bf8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp16
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xbf8xbf16xbf16xbf16
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
bf8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
bf16
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xbf8xbf16xbf16xfp8
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
bf8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp8
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testfp8xbf8xbf16xbf16xbf8
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
fp8
;
using
B_Type
=
bf8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
bf8
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testbf8xfp8xbf16xbf16xfp32
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
bf8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp32
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testbf8xfp8xbf16xbf16xfp16
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
bf8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp16
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testbf8xfp8xbf16xbf16xbf16
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
bf8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
bf16
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testbf8xfp8xbf16xbf16xfp8
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
bf8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
fp8
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
TEST_P
(
GEMMTestSuite
,
Testbf8xfp8xbf16xbf16xbf8
)
{
using
namespace
transformer_engine
;
using
namespace
test
;
const
size_t
m
=
std
::
get
<
0
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
k
=
std
::
get
<
1
>
(
std
::
get
<
0
>
(
GetParam
()));
const
size_t
n
=
std
::
get
<
2
>
(
std
::
get
<
0
>
(
GetParam
()));
const
bool
use_bias
=
std
::
get
<
1
>
(
GetParam
());
const
bool
use_gelu
=
std
::
get
<
2
>
(
GetParam
());
using
A_Type
=
bf8
;
using
B_Type
=
fp8
;
using
Bias_Type
=
bf16
;
using
Gelu_Type
=
bf16
;
using
D_Type
=
bf8
;
performTest
<
A_Type
,
B_Type
,
Bias_Type
,
Gelu_Type
,
D_Type
>
(
use_bias
,
use_gelu
,
m
,
k
,
n
);
}
INSTANTIATE_TEST_SUITE_P
(
OperatorTest
,
GEMMTestSuite
,
::
testing
::
Combine
(
::
testing
::
ValuesIn
(
test_case_sizes
),
::
testing
::
Values
(
false
,
true
),
//use bias
::
testing
::
Values
(
false
,
true
)),
//use_gelu
[](
const
testing
::
TestParamInfo
<
GEMMTestSuite
::
ParamType
>&
info
)
{
std
::
string
name
=
std
::
to_string
(
std
::
get
<
0
>
(
std
::
get
<
0
>
(
info
.
param
)))
+
"X"
+
std
::
to_string
(
std
::
get
<
1
>
(
std
::
get
<
0
>
(
info
.
param
)))
+
"X"
+
std
::
to_string
(
std
::
get
<
2
>
(
std
::
get
<
0
>
(
info
.
param
)))
+
"X"
+
std
::
to_string
(
std
::
get
<
1
>
(
info
.
param
))
+
"X"
+
std
::
to_string
(
std
::
get
<
2
>
(
info
.
param
));
return
name
;
});
\ No newline at end of file
tests/cpp/operator/test_normalization.cu
View file @
c520cba3
...
@@ -55,7 +55,11 @@ void compute_ref_stats(NormType norm_type,
...
@@ -55,7 +55,11 @@ void compute_ref_stats(NormType norm_type,
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
current
=
static_cast
<
compute_t
>
(
data
[
i
*
H
+
j
]);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
sum_sq
+=
(
current
-
m
)
*
(
current
-
m
);
}
}
#ifdef __HIP_PLATFORM_AMD__
rsigma
[
i
]
=
1.0
/
sqrtf
((
sum_sq
/
H
)
+
epsilon
);
#else
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
rsigma
[
i
]
=
rsqrtf
((
sum_sq
/
H
)
+
epsilon
);
#endif
}
}
}
}
...
...
tests/cpp/test_common.cu
View file @
c520cba3
...
@@ -481,8 +481,13 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
...
@@ -481,8 +481,13 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
const
T
*
test_data
=
rowwise
?
test
.
rowwise_cpu_dptr
<
T
>
()
:
test
.
columnwise_cpu_dptr
<
T
>
();
const
T
*
test_data
=
rowwise
?
test
.
rowwise_cpu_dptr
<
T
>
()
:
test
.
columnwise_cpu_dptr
<
T
>
();
const
T
*
ref_data
=
reinterpret_cast
<
const
T
*>
(
ref
);
const
T
*
ref_data
=
reinterpret_cast
<
const
T
*>
(
ref
);
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
N
;
++
i
)
{
#ifndef __HIP_PLATFORM_AMD__
double
t
=
static_cast
<
double
>
(
test_data
[
i
]);
double
t
=
static_cast
<
double
>
(
test_data
[
i
]);
double
r
=
static_cast
<
double
>
(
ref_data
[
i
]);
double
r
=
static_cast
<
double
>
(
ref_data
[
i
]);
#else
double
t
=
static_cast
<
double
>
(
static_cast
<
float
>
(
test_data
[
i
]));
double
r
=
static_cast
<
double
>
(
static_cast
<
float
>
(
ref_data
[
i
]));
#endif
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
/* For Float32 the floating point comparison is enough to error out */
/* For Float32 the floating point comparison is enough to error out */
bool
assertion
=
mismatch
&&
test
.
dtype
()
==
DType
::
kFloat32
;
bool
assertion
=
mismatch
&&
test
.
dtype
()
==
DType
::
kFloat32
;
...
@@ -492,9 +497,19 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
...
@@ -492,9 +497,19 @@ void compareResults_sequential(const std::string &name, const Tensor &test,
const
double
mean
=
(
t
+
r
)
/
2
;
const
double
mean
=
(
t
+
r
)
/
2
;
const
double
mean_p
=
mean
>=
0
?
mean
*
(
1
+
1e-6
)
:
mean
*
(
1
-
1e-6
);
const
double
mean_p
=
mean
>=
0
?
mean
*
(
1
+
1e-6
)
:
mean
*
(
1
-
1e-6
);
const
double
mean_m
=
mean
>=
0
?
mean
*
(
1
-
1e-6
)
:
mean
*
(
1
+
1e-6
);
const
double
mean_m
=
mean
>=
0
?
mean
*
(
1
-
1e-6
)
:
mean
*
(
1
+
1e-6
);
#ifndef __HIP_PLATFORM_AMD__
const
double
cast_mean_p
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_p
));
const
double
cast_mean_p
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_p
));
const
double
cast_mean_m
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_m
));
const
double
cast_mean_m
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_m
));
#else
const
double
cast_mean_p
=
static_cast
<
double
>
(
static_cast
<
float
>
(
static_cast
<
T
>
(
static_cast
<
float
>
(
mean_p
))));
const
double
cast_mean_m
=
static_cast
<
double
>
(
static_cast
<
float
>
(
static_cast
<
T
>
(
static_cast
<
float
>
(
mean_m
))));
#endif
#ifdef __HIP_PLATFORM_AMD__
assertion
=
!
(
cast_mean_m
==
std
::
min
<
double
>
(
t
,
r
)
&&
cast_mean_p
==
std
::
max
<
double
>
(
t
,
r
));
#else
assertion
=
!
(
cast_mean_m
==
std
::
min
(
t
,
r
)
&&
cast_mean_p
==
std
::
max
(
t
,
r
));
assertion
=
!
(
cast_mean_m
==
std
::
min
(
t
,
r
)
&&
cast_mean_p
==
std
::
max
(
t
,
r
));
#endif
}
}
std
::
string
direction
=
rowwise
?
"rowwise"
:
"columnwise"
;
std
::
string
direction
=
rowwise
?
"rowwise"
:
"columnwise"
;
ASSERT_FALSE
(
assertion
)
<<
"Error in tensor "
<<
name
<<
" in "
ASSERT_FALSE
(
assertion
)
<<
"Error in tensor "
<<
name
<<
" in "
...
@@ -518,8 +533,14 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
...
@@ -518,8 +533,14 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
continue
;
continue
;
}
}
#ifndef __HIP_PLATFORM_AMD__
double
t
=
static_cast
<
double
>
(
test_data
[
i
]);
double
t
=
static_cast
<
double
>
(
test_data
[
i
]);
double
r
=
static_cast
<
double
>
(
ref_data
[
i
]);
double
r
=
static_cast
<
double
>
(
ref_data
[
i
]);
#else
double
t
=
static_cast
<
double
>
(
static_cast
<
float
>
(
test_data
[
i
]));
double
r
=
static_cast
<
double
>
(
static_cast
<
float
>
(
ref_data
[
i
]));
#endif
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
bool
mismatch
=
fabs
(
t
-
r
)
>
atol
&&
(
r
==
0
||
fabs
((
t
-
r
)
/
r
)
>
rtol
);
/* For Float32 the floating point comparison is enough to error out */
/* For Float32 the floating point comparison is enough to error out */
...
@@ -530,9 +551,19 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
...
@@ -530,9 +551,19 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con
const
double
mean
=
(
t
+
r
)
/
2
;
const
double
mean
=
(
t
+
r
)
/
2
;
const
double
mean_p
=
mean
>=
0
?
mean
*
(
1
+
1e-6
)
:
mean
*
(
1
-
1e-6
);
const
double
mean_p
=
mean
>=
0
?
mean
*
(
1
+
1e-6
)
:
mean
*
(
1
-
1e-6
);
const
double
mean_m
=
mean
>=
0
?
mean
*
(
1
-
1e-6
)
:
mean
*
(
1
+
1e-6
);
const
double
mean_m
=
mean
>=
0
?
mean
*
(
1
-
1e-6
)
:
mean
*
(
1
+
1e-6
);
#ifndef __HIP_PLATFORM_AMD__
const
double
cast_mean_p
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_p
));
const
double
cast_mean_p
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_p
));
const
double
cast_mean_m
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_m
));
const
double
cast_mean_m
=
static_cast
<
double
>
(
static_cast
<
T
>
(
mean_m
));
#else
const
double
cast_mean_p
=
static_cast
<
double
>
(
static_cast
<
float
>
(
static_cast
<
T
>
(
static_cast
<
float
>
(
mean_p
))));
const
double
cast_mean_m
=
static_cast
<
double
>
(
static_cast
<
float
>
(
static_cast
<
T
>
(
static_cast
<
float
>
(
mean_m
))));
#endif
#ifdef __HIP_PLATFORM_AMD__
assertion
=
!
(
cast_mean_m
==
std
::
min
<
double
>
(
t
,
r
)
&&
cast_mean_p
==
std
::
max
<
double
>
(
t
,
r
));
#else
assertion
=
!
(
cast_mean_m
==
std
::
min
(
t
,
r
)
&&
cast_mean_p
==
std
::
max
(
t
,
r
));
assertion
=
!
(
cast_mean_m
==
std
::
min
(
t
,
r
)
&&
cast_mean_p
==
std
::
max
(
t
,
r
));
#endif
}
}
if
(
assertion
&&
i
<
first_mismatch_idx
)
{
if
(
assertion
&&
i
<
first_mismatch_idx
)
{
first_mismatch_idx
=
i
;
first_mismatch_idx
=
i
;
...
...
tests/cpp/test_common.h
View file @
c520cba3
...
@@ -11,10 +11,16 @@
...
@@ -11,10 +11,16 @@
#include <array>
#include <array>
#include <random>
#include <random>
#include <cuda_runtime_api.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
#else
#include <hip/hip_bf16.h>
#include "amd_detail/hip_float8.h"
#endif
#include <cuda_fp16.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
#include "util/logging.h"
...
@@ -50,9 +56,15 @@ using int32 = int32_t;
...
@@ -50,9 +56,15 @@ using int32 = int32_t;
using
int64
=
int64_t
;
using
int64
=
int64_t
;
using
fp32
=
float
;
using
fp32
=
float
;
using
fp16
=
half
;
using
fp16
=
half
;
#ifndef USE_ROCM
using
bf16
=
nv_bfloat16
;
using
bf16
=
nv_bfloat16
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e4m3
=
__nv_fp8_e4m3
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
using
fp8e5m2
=
__nv_fp8_e5m2
;
#else
using
bf16
=
__hip_bfloat16
;
using
fp8e4m3
=
hip_f8
<
hip_f8_type
::
fp8
>
;
using
fp8e5m2
=
hip_f8
<
hip_f8_type
::
bf8
>
;
#endif
using
fp8e8m0
=
uint8_t
;
using
fp8e8m0
=
uint8_t
;
template
<
typename
T
>
template
<
typename
T
>
...
...
tests/pytorch/distributed/run_numerics.py
View file @
c520cba3
...
@@ -24,6 +24,7 @@ from transformer_engine.common.recipe import (
...
@@ -24,6 +24,7 @@ from transformer_engine.common.recipe import (
)
)
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
transformer_engine.pytorch.tensor.float8_tensor
import
Float8CurrentScalingQuantizer
from
run_layer_with_overlap
import
_compare_tensors
from
run_layer_with_overlap
import
_compare_tensors
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
SEQ_LEN
,
BATCH_SIZE
=
16
,
16
SEQ_LEN
,
BATCH_SIZE
=
16
,
16
HIDDEN_SIZE
=
64
HIDDEN_SIZE
=
64
...
...
tests/pytorch/distributed/test_fusible_ops.py
View file @
c520cba3
...
@@ -27,7 +27,7 @@ import transformer_engine.pytorch.ops as te_ops
...
@@ -27,7 +27,7 @@ import transformer_engine.pytorch.ops as te_ops
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.ops._common
import
is_float8_tensor
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
# Check what quantization schemes are supported
# Check what quantization schemes are supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
...
@@ -687,7 +687,7 @@ def _test_fp8_scale_update(
...
@@ -687,7 +687,7 @@ def _test_fp8_scale_update(
"""Expected absmax and FP8 scale"""
"""Expected absmax and FP8 scale"""
amax
=
ref
.
abs
().
amax
()
amax
=
ref
.
abs
().
amax
()
max_val
=
{
max_val
=
{
"forward"
:
448.0
,
"forward"
:
448.0
if
not
IS_HIP_EXTENSION
else
240.0
,
"backward"
:
57344.0
,
"backward"
:
57344.0
,
}[
stage
]
}[
stage
]
scale
=
(
max_val
/
amax
)
/
(
2
**
margin
)
scale
=
(
max_val
/
amax
)
/
(
2
**
margin
)
...
...
tests/pytorch/fused_attn/test_fused_attn.py
View file @
c520cba3
...
@@ -12,6 +12,7 @@ from contextlib import contextmanager
...
@@ -12,6 +12,7 @@ from contextlib import contextmanager
import
pytest
import
pytest
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
transformer_engine.pytorch
import
TransformerLayer
,
fp8_autocast
,
fp8_model_init
from
transformer_engine.pytorch
import
TransformerLayer
,
fp8_autocast
,
fp8_model_init
...
@@ -387,8 +388,24 @@ def test_dpa_checkpoint(dtype, model_configs, model):
...
@@ -387,8 +388,24 @@ def test_dpa_checkpoint(dtype, model_configs, model):
"""Test DotProductAttention module with checkpointing"""
"""Test DotProductAttention module with checkpointing"""
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
test_dot_product_attention
(
dtype
,
model_configs
,
model
,
True
,
True
,
None
,
False
,
False
)
if
IS_HIP_EXTENSION
:
model_configs_mla
=
{
model_configs_mla
=
{
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
8
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# self , 0
"mla_1_1"
:
ModelConfig
(
4
,
16
,
16
,
64
,
128
,
256
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# cross, 0
"mla_2_0"
:
ModelConfig
(
2
,
24
,
24
,
128
,
2048
,
2048
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# self , 1
"mla_2_1"
:
ModelConfig
(
1
,
24
,
24
,
128
,
2048
,
4096
,
0.0
,
"causal"
,
"no_bias"
,
head_dim_v
=
64
),
# cross, 1
}
else
:
model_configs_mla
=
{
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
# test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend
"mla_1_0"
:
ModelConfig
(
"mla_1_0"
:
ModelConfig
(
8
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
8
,
16
,
16
,
64
,
128
,
128
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
...
@@ -408,10 +425,10 @@ model_configs_mla = {
...
@@ -408,10 +425,10 @@ model_configs_mla = {
"mla_3_1"
:
ModelConfig
(
"mla_3_1"
:
ModelConfig
(
8
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
8
,
16
,
16
,
256
,
1
,
2048
,
0.0
,
"no_mask"
,
"no_bias"
,
head_dim_v
=
128
),
# inference
),
# inference
}
}
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
IS_HIP_EXTENSION
and
get_cudnn_version
()
<
(
8
,
9
,
1
),
reason
=
"cuDNN 8.9.1+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_mla
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_mla
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_mla
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_mla
.
keys
())
...
@@ -592,7 +609,7 @@ model_configs_swa = {
...
@@ -592,7 +609,7 @@ model_configs_swa = {
}
}
@
pytest
.
mark
.
skipif
(
not
FlashAttentionUtils
.
v2_3_plus
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
skipif
((
not
IS_HIP_EXTENSION
)
and
(
not
FlashAttentionUtils
.
v2_3_plus
)
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_swa
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_swa
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_swa
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_swa
.
keys
())
...
@@ -614,7 +631,7 @@ model_configs_alibi_slopes = {
...
@@ -614,7 +631,7 @@ model_configs_alibi_slopes = {
}
}
@
pytest
.
mark
.
skipif
(
not
FlashAttentionUtils
.
v2_3_plus
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
skipif
((
not
IS_HIP_EXTENSION
)
and
(
not
FlashAttentionUtils
.
v2_3_plus
)
,
reason
=
"Flash-attn 2.3+ is required."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
param_types_lean
)
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_alibi_slopes
])
@
pytest
.
mark
.
parametrize
(
"model_configs"
,
[
model_configs_alibi_slopes
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_alibi_slopes
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_alibi_slopes
.
keys
())
...
@@ -1130,11 +1147,16 @@ def test_transformer_layer(
...
@@ -1130,11 +1147,16 @@ def test_transformer_layer(
tols
=
dict
(
atol
=
5e-2
,
rtol
=
5e-2
)
tols
=
dict
(
atol
=
5e-2
,
rtol
=
5e-2
)
workspace_opt
=
True
workspace_opt
=
True
qkv_layout
=
"sbh3d"
if
fused_qkv_params
else
"sb3hd"
# override the qkv_layout in mqa gqa mode in ROCm TE
if
IS_HIP_EXTENSION
and
model_configs
[
model
].
num_gqa_groups
!=
model_configs
[
model
].
num_heads
:
qkv_layout
=
"sbhd_sbhd_sbhd"
# Test backend availability
# Test backend availability
available_backends
,
fused_attn_backends
=
_get_attention_backends
(
available_backends
,
fused_attn_backends
=
_get_attention_backends
(
config
,
config
,
qkv_dtype
=
dtype
,
qkv_dtype
=
dtype
,
qkv_layout
=
"sbh3d"
if
fused_qkv_params
else
"sb3hd"
,
qkv_layout
=
qkv_layout
,
)
)
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
flash_attn_supported
,
fused_attn_supported
,
unfused_attn_supported
=
available_backends
...
@@ -1434,7 +1456,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
...
@@ -1434,7 +1456,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol):
)
)
)
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
...
@@ -1641,7 +1663,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
...
@@ -1641,7 +1663,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoP
return
out
,
param_names
,
tuple
(
x
.
grad
for
x
in
params
)
return
out
,
param_names
,
tuple
(
x
.
grad
for
x
in
params
)
return
out
,
param_names
,
tuple
(
None
for
x
in
params
)
return
out
,
param_names
,
tuple
(
None
for
x
in
params
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
9
,
2
,
1
),
reason
=
"cuDNN 9.2.1+ is required."
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
not
fp8_available
,
reason
=
reason_for_no_fp8
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
9
,
0
),
reason
=
"FP8 tests require Hopper+."
)
...
@@ -1900,7 +1922,7 @@ cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
...
@@ -1900,7 +1922,7 @@ cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
models_v0
=
[
"fp8_1"
,
"fp8_2"
,
"fp8_5"
,
"fp8_6"
]
models_v0
=
[
"fp8_1"
,
"fp8_2"
,
"fp8_5"
,
"fp8_6"
]
models_v1
=
[
"fp8_3"
,
"fp8_4"
,
"fp8_7"
,
"fp8_8"
]
models_v1
=
[
"fp8_3"
,
"fp8_4"
,
"fp8_7"
,
"fp8_8"
]
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
,
reason
=
"FP8 Fused attention is not supported on ROCm"
)
@
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
(
(
get_cudnn_version
()
<
(
8
,
9
,
3
)
get_cudnn_version
()
<
(
8
,
9
,
3
)
...
...
tests/pytorch/fused_attn/test_fused_attn_with_cp.py
View file @
c520cba3
...
@@ -13,6 +13,7 @@ from transformer_engine.pytorch.utils import (
...
@@ -13,6 +13,7 @@ from transformer_engine.pytorch.utils import (
)
)
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
from
transformer_engine.pytorch.dot_product_attention.utils
import
FlashAttentionUtils
from
test_fused_attn
import
ModelConfig
from
test_fused_attn
import
ModelConfig
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
model_configs_flash_attn
=
{
model_configs_flash_attn
=
{
# test: b, h, hg, d, sq, skv, p, mask, bias
# test: b, h, hg, d, sq, skv, p, mask, bias
...
@@ -51,7 +52,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
...
@@ -51,7 +52,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs):
@
pytest
.
mark
.
skipif
(
not
FlashAttentionUtils
.
v2_plus
,
reason
=
"Flash-attn 2.0+ is required."
)
@
pytest
.
mark
.
skipif
(
not
FlashAttentionUtils
.
v2_plus
,
reason
=
"Flash-attn 2.0+ is required."
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
8
,
0
),
reason
=
"CP tests require sm80+."
)
@
pytest
.
mark
.
skipif
(
not
IS_HIP_EXTENSION
and
get_device_compute_capability
()
<
(
8
,
0
),
reason
=
"CP tests require sm80+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bf16"
,
"fp16"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bf16"
,
"fp16"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_flash_attn
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_flash_attn
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
[
"bshd"
,
"sbhd"
,
"thd"
])
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
[
"bshd"
,
"sbhd"
,
"thd"
])
...
@@ -111,7 +112,7 @@ model_configs_fused_attn = {
...
@@ -111,7 +112,7 @@ model_configs_fused_attn = {
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
7
),
reason
=
"cuDNN 8.9.7+ is required."
)
@
pytest
.
mark
.
skipif
(
get_cudnn_version
()
<
(
8
,
9
,
7
),
reason
=
"cuDNN 8.9.7+ is required."
)
@
pytest
.
mark
.
skipif
(
get_device_compute_capability
()
<
(
8
,
0
),
reason
=
"CP tests require sm80+."
)
@
pytest
.
mark
.
skipif
(
IS_HIP_EXTENSION
or
get_device_compute_capability
()
<
(
8
,
0
),
reason
=
"
DTK not surpport fused attn for now,
CP tests require sm80+."
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bf16"
,
"fp16"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"bf16"
,
"fp16"
,
"fp8"
])
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fused_attn
.
keys
())
@
pytest
.
mark
.
parametrize
(
"model"
,
model_configs_fused_attn
.
keys
())
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
[
"bshd"
,
"sbhd"
,
"thd"
])
@
pytest
.
mark
.
parametrize
(
"qkv_format"
,
[
"bshd"
,
"sbhd"
,
"thd"
])
...
...
tests/pytorch/test_cuda_graphs.py
View file @
c520cba3
...
@@ -23,7 +23,10 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
...
@@ -23,7 +23,10 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine.pytorch.ops
as
te_ops
import
transformer_engine.pytorch.ops
as
te_ops
from
transformer_engine.common
import
recipe
from
transformer_engine.common
import
recipe
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
if
IS_HIP_EXTENSION
:
import
os
from
functools
import
cache
# Check if FP8 is supported.
# Check if FP8 is supported.
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
...
@@ -73,6 +76,15 @@ def reset_global_fp8_state():
...
@@ -73,6 +76,15 @@ def reset_global_fp8_state():
yield
yield
FP8GlobalStateManager
.
reset
()
FP8GlobalStateManager
.
reset
()
if
IS_HIP_EXTENSION
:
@
cache
def
use_hipblaslt
()
->
bool
:
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
@
pytest
.
fixture
(
autouse
=
True
)
def
skip_rocblas
():
if
not
use_hipblaslt
():
pytest
.
skip
(
"CUDA graph capture not supported with rocBLAS path"
)
def
assert_all_equal
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
names
=
None
)
->
bool
:
def
assert_all_equal
(
l1
:
List
[
torch
.
Tensor
],
l2
:
List
[
torch
.
Tensor
],
names
=
None
)
->
bool
:
"""Check that two lists of tensors match exactly."""
"""Check that two lists of tensors match exactly."""
...
...
tests/pytorch/test_fusible_ops.py
View file @
c520cba3
...
@@ -10,6 +10,7 @@ from typing import Optional
...
@@ -10,6 +10,7 @@ from typing import Optional
import
pytest
import
pytest
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine
import
transformer_engine
import
transformer_engine.common.recipe
import
transformer_engine.common.recipe
...
@@ -27,6 +28,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8
...
@@ -27,6 +28,14 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
from
transformer_engine.pytorch.utils
import
is_bf16_compatible
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
if
IS_HIP_EXTENSION
:
import
os
from
functools
import
cache
@
cache
def
use_hipblaslt
()
->
bool
:
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
# Check if FP8 is supported
# Check if FP8 is supported
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
fp8_available
,
reason_for_no_fp8
=
FP8GlobalStateManager
.
is_fp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
mxfp8_available
,
reason_for_no_mxfp8
=
FP8GlobalStateManager
.
is_mxfp8_available
()
...
@@ -770,6 +779,9 @@ class TestBasicOps:
...
@@ -770,6 +779,9 @@ class TestBasicOps:
pytest
.
skip
(
"MXFP8 output is not supported with MXFP8 GEMMs"
)
pytest
.
skip
(
"MXFP8 output is not supported with MXFP8 GEMMs"
)
if
quantization
==
"mxfp8"
and
quantized_grad_input
:
if
quantization
==
"mxfp8"
and
quantized_grad_input
:
pytest
.
skip
(
"MXFP8 grad input is not supported with MXFP8 GEMMs"
)
pytest
.
skip
(
"MXFP8 grad input is not supported with MXFP8 GEMMs"
)
if
(
IS_HIP_EXTENSION
and
not
use_hipblaslt
()
and
accumulate_into_main_grad
and
dtype
!=
torch
.
float32
and
not
quantized_compute
):
pytest
.
skip
(
"Parameters combination is not supported by ROCBLAS"
)
# Random data
# Random data
x_ref
,
x_test
=
make_reference_and_test_tensors
(
x_ref
,
x_test
=
make_reference_and_test_tensors
(
...
...
tests/pytorch/test_gemm_autotune.py
0 → 100644
View file @
c520cba3
# Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# License for AMD contributions = MIT. See LICENSE for more information
import
os
,
sys
import
copy
import
pytest
import
tempfile
import
shutil
import
subprocess
import
csv
import
warnings
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.cpp_extensions
import
gemm
from
transformer_engine.pytorch.module.base
import
get_workspace
def
use_hipblaslt
():
return
(
os
.
getenv
(
"NVTE_USE_HIPBLASLT"
)
is
not
None
or
os
.
getenv
(
"NVTE_USE_ROCBLAS"
)
is
None
)
storage_fname
=
"te_algo"
def
dump_storage
(
fname
):
print
(
"========"
)
with
open
(
fname
,
"r"
)
as
ifile
:
for
row
in
ifile
:
print
(
row
)
print
(
"========"
)
def
analyse_storage
(
fname
):
with
open
(
fname
,
"r"
)
as
ifile
:
reader
=
csv
.
DictReader
(
ifile
)
next
(
reader
)
head
=
reader
.
fieldnames
assert
(
"m"
in
head
and
"algo_id"
in
head
and
"ws_min"
in
head
and
"ws_max"
in
head
and
"aidx"
in
head
),
"Invalid CSV format"
return
head
def
read_storage
(
fname
):
data
=
[]
with
open
(
fname
,
"r"
)
as
ifile
:
reader
=
csv
.
DictReader
(
ifile
)
for
row
in
reader
:
data
.
append
(
row
)
return
data
def
write_storage
(
fname
,
head
,
data
):
with
open
(
fname
,
"w"
)
as
ofile
:
writer
=
csv
.
DictWriter
(
ofile
,
fieldnames
=
head
,
lineterminator
=
"
\n
"
)
writer
.
writeheader
()
writer
.
writerows
(
data
)
@
pytest
.
mark
.
skipif
(
not
use_hipblaslt
(),
reason
=
"Autotune requires hipBLASLt"
)
@
pytest
.
mark
.
skipif
(
not
IS_HIP_EXTENSION
,
reason
=
"Autotune requires ROCm TE"
)
def
test_gemm_autotune
():
storage_dir
=
tempfile
.
mkdtemp
();
fname
=
storage_dir
+
"/"
+
storage_fname
script
=
os
.
path
.
abspath
(
__file__
)
try
:
os
.
environ
[
"TE_HIPBLASLT_ALGO_LOAD"
]
=
fname
os
.
environ
[
"TE_HIPBLASLT_ALGO_SAVE"
]
=
fname
run_args
=
[
"python"
,
script
,
"--run"
]
#Initial algo creation
subprocess
.
run
(
run_args
)
head
=
analyse_storage
(
fname
)
algos
=
read_storage
(
fname
)
assert
len
(
algos
)
==
1
,
"Expected 1 cached record"
algo0
=
copy
.
copy
(
algos
[
0
])
ofile
=
fname
+
".1"
os
.
environ
[
"TE_HIPBLASLT_ALGO_SAVE"
]
=
ofile
#Unused cache entries
algos
[
0
][
"m"
]
=
"999"
+
algos
[
0
][
"m"
]
# fake record for different shape
write_storage
(
fname
,
head
,
algos
)
subprocess
.
run
(
run_args
)
algos
=
read_storage
(
ofile
)
assert
len
(
algos
)
==
2
,
"Expected 2 cached records"
assert
algo0
==
algos
[
1
],
"Invalid algo"
#Adjust workspace size
ws_max
=
int
(
algo0
[
"ws_max"
])
if
(
ws_max
>
0
):
algos
=
[
copy
.
copy
(
algo0
)]
algos
[
0
][
"ws_max"
]
=
str
(
ws_max
-
1
)
# decrease WS range should restore size
ws_min
=
int
(
algos
[
0
][
"ws_min"
])
if
(
ws_max
-
ws_min
>
1
):
ws_min
=
ws_min
+
1
algos
[
0
][
"ws_min"
]
=
str
(
ws_min
)
write_storage
(
fname
,
head
,
algos
)
subprocess
.
run
(
run_args
)
algos
=
read_storage
(
ofile
)
assert
len
(
algos
)
==
1
,
"Expected 1 cached record"
assert
(
str
(
ws_min
),
str
(
ws_max
))
==
(
algos
[
0
][
"ws_min"
],
algos
[
0
][
"ws_max"
]),
"Invalid WS size"
else
:
warnings
.
warn
(
"Cached algo Workspace size is 0"
)
#Modify algo index
algo_index
=
int
(
algo0
[
"aidx"
])
algos
=
[
copy
.
copy
(
algo0
)]
algos
[
0
][
"aidx"
]
=
str
(
algo_index
+
1
);
write_storage
(
fname
,
head
,
algos
)
subprocess
.
run
(
run_args
)
algos
=
read_storage
(
ofile
)
assert
len
(
algos
)
==
1
,
"Expected 1 cached record"
assert
(
algo0
[
"aidx"
],
algo0
[
"algo_id"
])
==
(
algos
[
0
][
"aidx"
],
algos
[
0
][
"algo_id"
]),
"Invalid algo IDX"
# Configure autotune range so current cached algo is out of it
# and cache new value
os
.
environ
[
"TE_HIPBLASLT_ALGO_LOAD"
]
=
""
os
.
environ
[
"TE_HIPBLASLT_ALGO_SAVE"
]
=
fname
os
.
environ
[
"TE_HIPBLASLT_ALGO_SELECTION"
]
=
str
(
algo_index
+
1
)
subprocess
.
run
(
run_args
)
algos
=
read_storage
(
fname
)
assert
len
(
algos
)
==
1
,
"Expected 1 cached record"
algo1
=
copy
.
copy
(
algos
[
0
])
assert
algo0
[
"algo_id"
]
!=
algo1
[
"algo_id"
],
"Unexpected algo ID"
#Restore autotune range begining, the new algo should still be used
os
.
environ
[
"TE_HIPBLASLT_ALGO_LOAD"
]
=
fname
del
os
.
environ
[
"TE_HIPBLASLT_ALGO_SELECTION"
]
subprocess
.
run
(
run_args
)
algos
=
read_storage
(
fname
)
assert
len
(
algos
)
==
1
,
"Expected 1 cached record"
assert
algo1
==
algos
[
0
],
"Invalid algo ID"
finally
:
shutil
.
rmtree
(
storage_dir
)
pass
def
run_gemm
():
N
=
32
datatype
=
torch
.
float16
inp
=
torch
.
randn
((
N
,
N
),
device
=
"cuda"
,
dtype
=
datatype
)
_
,
_
,
_
=
gemm
(
A
=
inp
,
B
=
inp
,
dtype
=
datatype
,
workspace
=
get_workspace
())
if
__name__
==
"__main__"
:
if
sys
.
argv
[
1
]
==
"--run"
:
run_gemm
()
tests/pytorch/test_numerics.py
View file @
c520cba3
...
@@ -12,6 +12,7 @@ import random
...
@@ -12,6 +12,7 @@ import random
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
from
transformer_engine.pytorch.fp8
import
(
from
transformer_engine.pytorch.fp8
import
(
FP8GlobalStateManager
,
FP8GlobalStateManager
,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment