diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py
index 7b64ea44cfa53459e475e622fa5afe8071f54d8b..e5df485eda1d8492a06c7860db21773e43951819 100644
--- a/benchmarks/attention/benchmark_attention.py
+++ b/benchmarks/attention/benchmark_attention.py
@@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import (
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
- _run_dot_product_attention
+ _run_dot_product_attention,
)
pd.set_option("display.precision", 4)
@@ -28,7 +28,7 @@ ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
-qkv_layout = 'bshd_bshd_bshd'
+qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
@@ -38,16 +38,17 @@ is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
- "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
- "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
- "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
+ "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
+ "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
+ "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
+ "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
+
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
- tols = dict(atol=2.5e-2, rtol=2.5e-2)
+ tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
@@ -57,17 +58,31 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
for i in range(warmup_iters):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
- dtype, config, "FusedAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FusedAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
- dtype, config, "FlashAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FlashAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
- for i,_ in enumerate(flash_attn_bwd):
+ for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
torch.cuda.cudart().cudaProfilerStart()
@@ -76,8 +91,15 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if fused_attn_supported:
for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
- dtype, config, "FusedAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FusedAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
torch.cuda.synchronize()
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
@@ -87,81 +109,113 @@ def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supp
if flash_attn_supported:
for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
- dtype, config, "FlashAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FlashAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0
- df = pd.read_csv('times.csv')
- df = pd.concat([
- df,
- pd.DataFrame(
- [[fused_attn_time*1e3/num_iters, 0, 0, 0,
- flash_attn_time*1e3/num_iters, 0, 0, 0, 0]], columns=df.columns)],
- ignore_index=True
- )
- df.to_csv('times.csv',index=False)
+ df = pd.read_csv("times.csv")
+ df = pd.concat(
+ [
+ df,
+ pd.DataFrame(
+ [
+ [
+ fused_attn_time * 1e3 / num_iters,
+ 0,
+ 0,
+ 0,
+ flash_attn_time * 1e3 / num_iters,
+ 0,
+ 0,
+ 0,
+ 0,
+ ]
+ ],
+ columns=df.columns,
+ ),
+ ],
+ ignore_index=True,
+ )
+ df.to_csv("times.csv", index=False)
torch.cuda.cudart().cudaProfilerStop()
+
def parse_results(per_cudnn, per_flash, model):
- filename = f'prof_{model}_cuda_gpu_trace.csv'
- df = pd.read_csv(os.path.join('./',filename))
- df_times = pd.read_csv('times.csv')
- row = len(df_times.index)-1
-
+ filename = f"prof_{model}_cuda_gpu_trace.csv"
+ df = pd.read_csv(os.path.join("./", filename))
+ df_times = pd.read_csv("times.csv")
+ row = len(df_times.index) - 1
+
if per_cudnn > 0:
- t_cudnn_all = df[df['Name'].str.contains('cudnn')]['Duration (ns)'].to_numpy()
+ t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
- df_times.loc[row, 'FusedAttention Kernels (fwd)'] = t_cudnn_avg[0]/1e6
- df_times.loc[row, 'FusedAttention Kernels (bwd)'] = t_cudnn_avg[1:4].sum()/1e6
- df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)'] = t_cudnn_avg.sum()/1e6
+ df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
+ df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
+ df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0:
- t_flash_all = df[df['Name'].str.contains('void flash')]['Duration (ns)'].to_numpy()
+ t_flash_all = df[df["Name"].str.contains("void flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
- df_times.loc[row, 'FlashAttention Kernels (fwd)'] = t_flash_avg[0]/1e6
- df_times.loc[row, 'FlashAttention Kernels (bwd)'] = t_flash_avg[1:4].sum()/1e6
- df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] = t_flash_avg.sum()/1e6
+ df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
+ df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
+ df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6
if per_cudnn > 0 and per_flash > 0:
- df_times.loc[row, 'Fused vs Flash Kernels Speedup (fwd+bwd)'] = \
- df_times.loc[row, 'FlashAttention Kernels (fwd+bwd)'] / \
- df_times.loc[row, 'FusedAttention Kernels (fwd+bwd)']
- df_times.to_csv('times.csv',index=False)
+ df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
+ df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"]
+ / df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
+ )
+ df_times.to_csv("times.csv", index=False)
+
def main():
times = pd.DataFrame(
- columns=[
- 'FusedAttention Module',
- 'FusedAttention Kernels (fwd)',
- 'FusedAttention Kernels (bwd)',
- 'FusedAttention Kernels (fwd+bwd)',
- 'FlashAttention Module',
- 'FlashAttention Kernels (fwd)',
- 'FlashAttention Kernels (bwd)',
- 'FlashAttention Kernels (fwd+bwd)',
- 'Fused vs Flash Kernels Speedup (fwd+bwd)',
- ])
- times.to_csv('times.csv',index=False)
+ columns=[
+ "FusedAttention Module",
+ "FusedAttention Kernels (fwd)",
+ "FusedAttention Kernels (bwd)",
+ "FusedAttention Kernels (fwd+bwd)",
+ "FlashAttention Module",
+ "FlashAttention Kernels (fwd)",
+ "FlashAttention Kernels (bwd)",
+ "FlashAttention Kernels (fwd+bwd)",
+ "Fused vs Flash Kernels Speedup (fwd+bwd)",
+ ]
+ )
+ times.to_csv("times.csv", index=False)
device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id)
- print(f"Device {device_id}: "
+ print(
+ f"Device {device_id}: "
f"{device_properties.name} GPU, "
f"sm{device_properties.major}{device_properties.minor} compute capability, "
- f"{device_properties.total_memory/1024**3:.1f}GB memory")
+ f"{device_properties.total_memory/1024**3:.1f}GB memory"
+ )
for model in model_configs.keys():
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
- config, dtype, qkv_layout=qkv_layout,
+ config,
+ dtype,
+ qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
- print(f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
- f'{" and flash-attention" if flash_attn_supported else ""}...')
+ print(
+ f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
+ f'{" and flash-attention" if flash_attn_supported else ""}...'
+ )
prof_cmd = [
"nsys",
@@ -175,8 +229,8 @@ def main():
f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
- ]
- prof_cmd = ' '.join(prof_cmd)
+ ]
+ prof_cmd = " ".join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [
"nsys",
@@ -190,17 +244,17 @@ def main():
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
- ]
+ ]
if fused_attn_supported:
num_kernels_cudnn = 4
- if config.attn_bias_type == 'post_scale_bias':
- num_kernels_cudnn = num_kernels_cudnn+1
+ if config.attn_bias_type == "post_scale_bias":
+ num_kernels_cudnn = num_kernels_cudnn + 1
if config.num_heads != config.num_gqa_groups:
- num_kernels_cudnn = num_kernels_cudnn+2
+ num_kernels_cudnn = num_kernels_cudnn + 2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
- stats_cmd = ' '.join(stats_cmd)
+ stats_cmd = " ".join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [
"python",
@@ -208,18 +262,23 @@ def main():
f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
- ]
- parse_cmd = ' '.join(parse_cmd)
+ ]
+ parse_cmd = " ".join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
- df_times = pd.read_csv('times.csv')
+ df_times = pd.read_csv("times.csv")
df_times.index = list(model_configs.keys())
- a=df_times[['FusedAttention Kernels (fwd+bwd)',
- 'FlashAttention Kernels (fwd+bwd)',
- 'Fused vs Flash Kernels Speedup (fwd+bwd)']]
- a.columns = ['cuDNN fwd+bwd (ms)', 'flash-attn fwd+bwd (ms)', 'cuDNN vs flash speedup']
+ a = df_times[
+ [
+ "FusedAttention Kernels (fwd+bwd)",
+ "FlashAttention Kernels (fwd+bwd)",
+ "Fused vs Flash Kernels Speedup (fwd+bwd)",
+ ]
+ ]
+ a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
print()
print(a)
+
if __name__ == "__main__":
main()
diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py
index 170a8da843cc833f9848d2e8c5f3d1ef51d65582..73414864cb900cc60b77acadab5971b0dee54ddd 100644
--- a/build_tools/build_ext.py
+++ b/build_tools/build_ext.py
@@ -64,6 +64,7 @@ class CMakeExtension(setuptools.Extension):
configure_command.append("-GNinja")
import pybind11
+
pybind11_dir = Path(pybind11.__file__).resolve().parent
pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
@@ -130,6 +131,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
else:
# Only during release sdist build.
import transformer_engine
+
search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
del transformer_engine
@@ -142,8 +144,9 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# Figure out stub file path
module_name = paddle_ext.name
- assert module_name.endswith("_pd_"), \
- "Expected Paddle extension module to end with '_pd_'"
+ assert module_name.endswith(
+ "_pd_"
+ ), "Expected Paddle extension module to end with '_pd_'"
stub_name = module_name[:-4] # remove '_pd_'
stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py")
Path(stub_path).parent.mkdir(exist_ok=True, parents=True)
@@ -158,6 +161,7 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# Write stub file
print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
from paddle.utils.cpp_extension.extension_utils import custom_write_stub
+
custom_write_stub(lib_name, stub_path)
# Ensure that binaries are not in global package space.
@@ -182,13 +186,14 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
# extra_compile_args is a dict.
for ext in self.extensions:
if isinstance(ext.extra_compile_args, dict):
- for target in ['cxx', 'nvcc']:
+ for target in ["cxx", "nvcc"]:
if target not in ext.extra_compile_args.keys():
ext.extra_compile_args[target] = []
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
original_compile_fn = self.compiler._compile
- self.compiler.src_extensions += ['.cu', '.cuh']
+ self.compiler.src_extensions += [".cu", ".cuh"]
+
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs)
@@ -197,31 +202,31 @@ def get_build_ext(extension_cls: Type[setuptools.Extension]):
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so
- if os.path.splitext(src)[1] in ['.cu', '.cuh']:
- self.compiler.set_executable('compiler_so', str(nvcc_bin))
+ if os.path.splitext(src)[1] in [".cu", ".cuh"]:
+ self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict):
- cflags = cflags['nvcc']
+ cflags = cflags["nvcc"]
# Add -fPIC if not already specified
- if not any('-fPIC' in flag for flag in cflags):
- cflags.extend(['--compiler-options', "'-fPIC'"])
+ if not any("-fPIC" in flag for flag in cflags):
+ cflags.extend(["--compiler-options", "'-fPIC'"])
# Forward unknown options
- if not any('--forward-unknown-opts' in flag for flag in cflags):
- cflags.append('--forward-unknown-opts')
+ if not any("--forward-unknown-opts" in flag for flag in cflags):
+ cflags.append("--forward-unknown-opts")
elif isinstance(cflags, dict):
- cflags = cflags['cxx']
+ cflags = cflags["cxx"]
# Append -std=c++17 if not already in flags
- if not any(flag.startswith('-std=') for flag in cflags):
- cflags.append('-std=c++17')
+ if not any(flag.startswith("-std=") for flag in cflags):
+ cflags.append("-std=c++17")
return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)
finally:
# Put the original compiler back in place.
- self.compiler.set_executable('compiler_so', original_compiler)
+ self.compiler.set_executable("compiler_so", original_compiler)
self.compiler._compile = _compile_fn
diff --git a/build_tools/jax.py b/build_tools/jax.py
index 21248640cc2288c6e6d13d0e74a8c9948e3d8a43..496bf056e8270676d69f4ef706d5e9a37bcfaa4b 100644
--- a/build_tools/jax.py
+++ b/build_tools/jax.py
@@ -36,8 +36,8 @@ def setup_jax_extension(
]
# Compile flags
- cxx_flags = [ "-O3" ]
- nvcc_flags = [ "-O3" ]
+ cxx_flags = ["-O3"]
+ nvcc_flags = ["-O3"]
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
@@ -47,9 +47,9 @@ def setup_jax_extension(
def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
- cxx_flags = self.extra_compile_args.pop('cxx', [])
+ cxx_flags = self.extra_compile_args.pop("cxx", [])
cxx_flags += flags
- self.extra_compile_args['cxx'] = cxx_flags
+ self.extra_compile_args["cxx"] = cxx_flags
else:
self.extra_compile_args[:0] = flags
@@ -57,8 +57,5 @@ def setup_jax_extension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
- extra_compile_args={
- "cxx": cxx_flags,
- "nvcc": nvcc_flags
- },
+ extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
)
diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py
index 655b7dbf26896df608ec900c4165d3fcb7762a18..0b44cfb372f17058bbddbf85adbd0d32b1423471 100644
--- a/build_tools/pytorch.py
+++ b/build_tools/pytorch.py
@@ -76,11 +76,12 @@ def setup_pytorch_extension(
# Libraries -- PyTorch CUDAExtension links to libcudart.so but not to libcuda.so
cuda_home, _ = cuda_path()
- library_dirs = [ cuda_home / "compat" / "lib" ]
- libraries = [ "cuda" ]
+ library_dirs = [cuda_home / "compat" / "lib"]
+ libraries = ["cuda"]
if os.getenv("UB_MPI_BOOTSTRAP"):
- assert os.getenv("MPI_HOME") is not None, \
- "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1"
+ assert (
+ os.getenv("MPI_HOME") is not None
+ ), "MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1"
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DUB_MPI_BOOTSTRAP")
@@ -95,12 +96,12 @@ def setup_pytorch_extension(
return CUDAExtension(
name="transformer_engine_torch",
- sources=[ str(src) for src in sources ],
- include_dirs=[ str(inc) for inc in include_dirs ],
+ sources=[str(src) for src in sources],
+ include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
- libraries=[ str(lib) for lib in libraries ],
- library_dirs=[ str(lib_dir) for lib_dir in library_dirs ],
+ libraries=[str(lib) for lib in libraries],
+ library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
diff --git a/build_tools/te_version.py b/build_tools/te_version.py
index 9ebc0ccdc470c5ac4d999af8dd42668bddfb48f9..b40fb260146d35bf0c1728ac38ac826b51fa8969 100644
--- a/build_tools/te_version.py
+++ b/build_tools/te_version.py
@@ -18,11 +18,12 @@ def te_version() -> str:
root_path = Path(__file__).resolve().parent
with open(root_path / "VERSION.txt", "r") as f:
version = f.readline().strip()
- if (not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0"))
- and not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))):
+ if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool(
+ int(os.getenv("NVTE_RELEASE_BUILD", "0"))
+ ):
try:
output = subprocess.run(
- ["git", "rev-parse" , "--short", "HEAD"],
+ ["git", "rev-parse", "--short", "HEAD"],
capture_output=True,
cwd=root_path,
check=True,
diff --git a/build_tools/utils.py b/build_tools/utils.py
index b601ec137e8cf29e01636f8cda95d57f4a8dac46..036cb1eac5e28e089add81bcac531f4288828e26 100644
--- a/build_tools/utils.py
+++ b/build_tools/utils.py
@@ -174,7 +174,7 @@ def cuda_version() -> Tuple[int, ...]:
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
- version = match.group(1).split('.')
+ version = match.group(1).split(".")
return tuple(int(v) for v in version)
@@ -224,9 +224,7 @@ def get_frameworks() -> List[str]:
_frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks:
if framework not in supported_frameworks:
- raise ValueError(
- f"Transformer Engine does not support framework={framework}"
- )
+ raise ValueError(f"Transformer Engine does not support framework={framework}")
return _frameworks
@@ -242,8 +240,8 @@ def package_files(directory):
def copy_common_headers(te_src, dst):
headers = te_src / "common"
- for file_path in glob.glob(os.path.join(str(headers), "**", '*.h'), recursive=True):
- new_path = os.path.join(dst, file_path[len(str(te_src)) + 1:])
+ for file_path in glob.glob(os.path.join(str(headers), "**", "*.h"), recursive=True):
+ new_path = os.path.join(dst, file_path[len(str(te_src)) + 1 :])
Path(new_path).parent.mkdir(exist_ok=True, parents=True)
shutil.copy(file_path, new_path)
@@ -251,9 +249,10 @@ def copy_common_headers(te_src, dst):
def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
import importlib
+
try:
importlib.import_module(package)
except ImportError:
- subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
finally:
globals()[package] = importlib.import_module(package)
diff --git a/docs/conf.py b/docs/conf.py
index 77b59eb8e0851135c8387c0193d37bcacbd0af43..695546a9bad0b2cb97457f18beb02e6f55c10d03 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -28,21 +28,26 @@ if current_year == release_year:
else:
copyright_year = str(release_year) + "-" + str(current_year)
-project = u'Transformer Engine'
-copyright = u'{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.'.format(copyright_year)
-author = u'NVIDIA CORPORATION & AFFILIATES'
+project = "Transformer Engine"
+copyright = "{}, NVIDIA CORPORATION & AFFILIATES. All rights reserved.".format(copyright_year)
+author = "NVIDIA CORPORATION & AFFILIATES"
git_sha = os.getenv("GIT_SHA")
if not git_sha:
try:
- git_sha = subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"]).decode('ascii').replace("'","").strip()
+ git_sha = (
+ subprocess.check_output(["git", "log", "--pretty=format:'%h'", "-n1"])
+ .decode("ascii")
+ .replace("'", "")
+ .strip()
+ )
except:
- git_sha = u'0000000'
+ git_sha = "0000000"
git_sha = git_sha[:7] if len(git_sha) > 7 else git_sha
-version = str(te_version + u"-" + git_sha)
+version = str(te_version + "-" + git_sha)
release = te_version
# hack: version is used for html creation, so put the version picker
@@ -51,58 +56,60 @@ option_on = " selected"
option_off = ""
release_opt = option_on
option_nr = 0
-version = version + """
+version = (
+ version
+ + """
Version select: """.format(option_nr, release_opt)
+""".format(
+ option_nr, release_opt
+ )
+)
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
extensions = [
- 'sphinx.ext.autodoc',
- 'sphinx.ext.mathjax',
- 'sphinx.ext.napoleon',
- 'sphinx.ext.ifconfig',
- 'nbsphinx',
- 'breathe',
- 'autoapi.extension',
+ "sphinx.ext.autodoc",
+ "sphinx.ext.mathjax",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.ifconfig",
+ "nbsphinx",
+ "breathe",
+ "autoapi.extension",
]
-templates_path = ['_templates']
-exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
+templates_path = ["_templates"]
+exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
-source_suffix = '.rst'
+source_suffix = ".rst"
-master_doc = 'index'
-
-pygments_style = 'sphinx'
+master_doc = "index"
+pygments_style = "sphinx"
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
-html_theme = 'sphinx_rtd_theme'
+html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
-html_static_path = ['_static']
+html_static_path = ["_static"]
html_show_sphinx = False
html_css_files = [
- 'css/nvidia_font.css',
- 'css/nvidia_footer.css',
+ "css/nvidia_font.css",
+ "css/nvidia_footer.css",
]
-html_theme_options = {
- 'display_version': True,
- 'collapse_navigation': False,
- 'logo_only': False
-}
+html_theme_options = {"display_version": True, "collapse_navigation": False, "logo_only": False}
-napoleon_custom_sections = [('Parallelism parameters', 'params_style'),
- ('Optimization parameters', 'params_style'),
- ('Values', 'params_style')]
+napoleon_custom_sections = [
+ ("Parallelism parameters", "params_style"),
+ ("Optimization parameters", "params_style"),
+ ("Values", "params_style"),
+]
breathe_projects = {"TransformerEngine": os.path.abspath("doxygen/xml/")}
breathe_default_project = "TransformerEngine"
diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
index 33852286d84efc91b007f7ded161d19971b32ab8..cd8ab85ba245ef5173c84dd4ac5235b152d4a867 100644
--- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
+++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py
@@ -4,8 +4,8 @@
import os
import torch
-from typing import Tuple
-from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
+from typing import Tuple
+from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from transformer_engine.pytorch.distributed import _set_cuda_rng_state
from transformer_engine.pytorch.attention import DotProductAttention
@@ -18,87 +18,105 @@ _cuda_rng_state = torch.cuda.get_rng_state()
_NVTE_DEBUG = int(os.getenv("NVTE_DEBUG", "0"))
+
def reset_rng_states() -> None:
"""Revert back to initial RNG state"""
torch.set_rng_state(_cpu_rng_state)
_set_cuda_rng_state(_cuda_rng_state)
+
def _run_dot_product_attention(
- dtype: torch.dtype,
- config: ModelConfig,
- qkv_layout: str,
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
+ dtype: torch.dtype,
+ config: ModelConfig,
+ qkv_layout: str,
+) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
reset_rng_states()
- seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
- dtype=torch.int32, device="cuda")
- seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
- dtype=torch.int32, device="cuda")
- inp = torch.randn([config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
- dtype=dtype, device="cuda")
- q = inp[:,:,0,:,:]
- k = inp[:,:,1,:,:]
- v = inp[:,:,2,:,:]
+ seqlens_q = torch.full(
+ [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
+ )
+ seqlens_kv = torch.full(
+ [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
+ )
+ inp = torch.randn(
+ [config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim],
+ dtype=dtype,
+ device="cuda",
+ )
+ q = inp[:, :, 0, :, :]
+ k = inp[:, :, 1, :, :]
+ v = inp[:, :, 2, :, :]
q.requires_grad = True
k.requires_grad = True
v.requires_grad = True
- out_grad = torch.randn([config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
- dtype=dtype, device="cuda")
+ out_grad = torch.randn(
+ [config.batch_size, config.max_seqlen_q, config.num_heads * config.head_dim],
+ dtype=dtype,
+ device="cuda",
+ )
# Create attention mask / bias
attention_mask = None
bias = None
if config.attn_mask_type == "arbitrary":
- attention_mask = torch.randint(-10,10,
- [config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
- dtype=torch.bool, device="cuda")
+ attention_mask = torch.randint(
+ -10,
+ 10,
+ [config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv],
+ ).to(dtype=torch.bool, device="cuda")
if config.attn_bias_type == "post_scale_bias":
# convert mask to bias
- attention_mask = torch.randint(-10,10,
- [config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv]).to(
- dtype=torch.bool, device="cuda")
+ attention_mask = torch.randint(
+ -10,
+ 10,
+ [config.batch_size, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv],
+ ).to(dtype=torch.bool, device="cuda")
bias = attention_mask.clone()
- neginf = -2**50 if dtype == torch.bfloat16 else -2**15
- bias = torch.where(bias==0, 0, neginf).to(dtype=dtype, device='cuda')
+ neginf = -(2**50) if dtype == torch.bfloat16 else -(2**15)
+ bias = torch.where(bias == 0, 0, neginf).to(dtype=dtype, device="cuda")
bias.requires_grad = False
attention_mask = None
- block = (
- DotProductAttention(
- config.num_heads,
- config.head_dim,
- num_gqa_groups=config.num_gqa_groups,
- qkv_format='bshd',
- attention_dropout=config.dropout_p,
- sequence_parallel=False,
- tp_size=1,
- get_rng_state_tracker=None,
- tp_group=None,
- layer_number=1,
- ).to(dtype=dtype, device="cuda")
- )
+ block = DotProductAttention(
+ config.num_heads,
+ config.head_dim,
+ num_gqa_groups=config.num_gqa_groups,
+ qkv_format="bshd",
+ attention_dropout=config.dropout_p,
+ sequence_parallel=False,
+ tp_size=1,
+ get_rng_state_tracker=None,
+ tp_group=None,
+ layer_number=1,
+ ).to(dtype=dtype, device="cuda")
# Run a forward and backward pass
out = None
if config.attn_mask_type == "arbitrary":
- out = block(q, k, v,
- attention_mask=attention_mask, # attention_mask
- qkv_format='bshd',
- attn_mask_type=config.attn_mask_type, # 'arbitrary'
- core_attention_bias_type=config.attn_bias_type, # 'no_bias'
- core_attention_bias=bias, # None
- )
+ out = block(
+ q,
+ k,
+ v,
+ attention_mask=attention_mask, # attention_mask
+ qkv_format="bshd",
+ attn_mask_type=config.attn_mask_type, # 'arbitrary'
+ core_attention_bias_type=config.attn_bias_type, # 'no_bias'
+ core_attention_bias=bias, # None
+ )
out.backward(out_grad)
if config.attn_bias_type == "post_scale_bias":
- out = block(q, k, v,
- attention_mask=attention_mask, # None
- qkv_format='bshd',
- attn_mask_type=config.attn_mask_type, # no_mask
- core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
- core_attention_bias=bias, # bias
- )
+ out = block(
+ q,
+ k,
+ v,
+ attention_mask=attention_mask, # None
+ qkv_format="bshd",
+ attn_mask_type=config.attn_mask_type, # no_mask
+ core_attention_bias_type=config.attn_bias_type, # 'post_scale_bias'
+ core_attention_bias=bias, # bias
+ )
out.backward(out_grad)
return out, (q.grad, k.grad, v.grad)
@@ -107,19 +125,19 @@ def _run_dot_product_attention(
dtype = torch.bfloat16
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
- "test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
+ "test_mask": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "arbitrary", "no_bias"),
+ "test_bias": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
}
-print('Run with post_scale_bias:')
+print("Run with post_scale_bias:")
config = model_configs["test_bias"]
-fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
+fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
-print('Run with arbitrary mask:')
+print("Run with arbitrary mask:")
config = model_configs["test_mask"]
-unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, 'bs3hd')
+unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(dtype, config, "bs3hd")
torch.testing.assert_close(unfused_attn_fwd, fused_attn_fwd, atol=2.5e-2, rtol=2.5e-2)
for i in range(3):
torch.testing.assert_close(unfused_attn_bwd[i], fused_attn_bwd[i], atol=2.5e-2, rtol=2.5e-2)
-print('Test passed!')
+print("Test passed!")
diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py
index 5eac38f99bd1a92b4b1f82896d1689f8cb31a4a5..2ed73034179e3603a88100beaeb94cfbade42908 100644
--- a/docs/examples/attention/example_attention.py
+++ b/docs/examples/attention/example_attention.py
@@ -14,7 +14,7 @@ from tests.pytorch.fused_attn.test_fused_attn import (
_is_flash_attention_supported,
_is_fused_attention_supported,
_is_unfused_attention_supported,
- _run_dot_product_attention
+ _run_dot_product_attention,
)
# data type
@@ -26,7 +26,7 @@ ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
-qkv_layout = 'bshd_bshd_bshd'
+qkv_layout = "bshd_bshd_bshd"
# sliding window attention
swa = False
# padding between sequences for qkv_format=thd
@@ -36,12 +36,13 @@ is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
- "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
- "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
- "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
+ "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
+ "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
+ "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
+ "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
+
def example_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
@@ -51,40 +52,58 @@ def example_attention(model, fused_attn_supported, flash_attn_supported):
if fused_attn_supported:
print()
- print('Run cuDNN attention...')
+ print("Run cuDNN attention...")
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
- dtype, config, "FusedAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FusedAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
if flash_attn_supported:
print()
- print('Run flash-attention...')
+ print("Run flash-attention...")
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
- dtype, config, "FlashAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FlashAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
- for i,_ in enumerate(flash_attn_bwd):
+ for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
print()
- print('Test passed.')
+ print("Test passed.")
+
def main():
- models = ['test_0']
+ models = ["test_0"]
for model in models:
config = model_configs[model]
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
- config, dtype, qkv_layout=qkv_layout,
+ config,
+ dtype,
+ qkv_layout=qkv_layout,
)
fused_attn_supported = fused_attn_supported and not swa
flash_attn_supported = _is_flash_attention_supported(config)
example_attention(model, fused_attn_supported, flash_attn_supported)
+
if __name__ == "__main__":
main()
diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py
index 9532dd27fabf16b4b0d85de5c4f8ca1454b3ffec..0582efd52e629613acdd89778ccbbb0a1ac6cdd4 100644
--- a/docs/examples/quickstart_utils.py
+++ b/docs/examples/quickstart_utils.py
@@ -8,14 +8,15 @@ import torch
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type
+
def speedometer(
- module: torch.nn.Module,
- input: torch.Tensor,
- output_grad: torch.Tensor,
- forward_kwargs: dict = {},
- fp8_autocast_kwargs: Optional[dict] = None,
- timing_iters: int = 50,
- warmup_iters: int = 50,
+ module: torch.nn.Module,
+ input: torch.Tensor,
+ output_grad: torch.Tensor,
+ forward_kwargs: dict = {},
+ fp8_autocast_kwargs: Optional[dict] = None,
+ timing_iters: int = 50,
+ warmup_iters: int = 50,
) -> None:
"""Measure average run time for a PyTorch module
@@ -24,7 +25,7 @@ def speedometer(
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if fp8_autocast_kwargs is None:
- fp8_autocast_kwargs = { "enabled": False }
+ fp8_autocast_kwargs = {"enabled": False}
# Warmup runs
torch.cuda.synchronize()
@@ -51,11 +52,12 @@ class DotProductAttention(torch.nn.Module):
Built with plain PyTorch modules.
"""
+
def __init__(
- self,
- num_attention_heads: int,
- kv_channels: int,
- attention_dropout: float,
+ self,
+ num_attention_heads: int,
+ kv_channels: int,
+ attention_dropout: float,
) -> None:
super().__init__()
self.projection_size = kv_channels * num_attention_heads
@@ -63,21 +65,17 @@ class DotProductAttention(torch.nn.Module):
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
self.dropout = torch.nn.Dropout(attention_dropout)
- def masked_softmax(
- self,
- inp: torch.Tensor,
- mask: Optional[torch.Tensor]
- ) -> torch.Tensor:
+ def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
if mask is not None:
inp.masked_fill_(mask, -10000.0)
return torch.nn.Softmax(dim=-1)(inp)
def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
b = query.size(1)
np = query.size(2)
@@ -90,7 +88,9 @@ class DotProductAttention(torch.nn.Module):
# [sk, b, np, hn] -> [sk, b * np, hn]
key = key.view(sk, b * np, -1)
- bmm1 = torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor
+ bmm1 = (
+ torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor
+ )
# change view to [b, np, sq, sk]
attention_scores = bmm1.view(b, np, sq, sk)
@@ -126,10 +126,11 @@ class BasicMLP(torch.nn.Module):
Built with plain PyTorch modules.
"""
+
def __init__(
- self,
- hidden_size: int,
- ffn_hidden_size: int,
+ self,
+ hidden_size: int,
+ ffn_hidden_size: int,
) -> None:
super().__init__()
self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
@@ -137,7 +138,7 @@ class BasicMLP(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
- x = torch.nn.functional.gelu(x, approximate='tanh')
+ x = torch.nn.functional.gelu(x, approximate="tanh")
x = self.linear2(x)
return x
@@ -148,7 +149,7 @@ def share_parameters_with_basic_te_model(te_model, basic_model):
Parameter values are copied from pure PyTorch implementation.
"""
- te_model.ln1.weight= basic_model.ln1.weight
+ te_model.ln1.weight = basic_model.ln1.weight
te_model.ln1.bias = basic_model.ln1.bias
te_model.qkv_projection.weight = basic_model.qkv_projection.weight
te_model.qkv_projection.bias = basic_model.qkv_projection.bias
@@ -202,14 +203,15 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
te_model.layernorm_mlp.fc2_bias = basic_model.mlp.linear2.bias
-def cast_to_representable(inp, scale = 1., fp8_format='e4m3'):
+def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"):
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
- fp8_type = tex.DType.kFloat8E4M3 if fp8_format == 'e4m3' else tex.DType.kFloat8E5M2
+
+ fp8_type = tex.DType.kFloat8E4M3 if fp8_format == "e4m3" else tex.DType.kFloat8E5M2
input_type = TE_DType[inp.dtype]
meta = tex.FP8TensorMeta()
- meta.scale = torch.ones(1,dtype=torch.float32, device="cuda") * scale
+ meta.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale
meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale
meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type)
diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py
index 307507ad1d341b46bcec2a078d882d338e748a5c..cb384aa10c41cd747e298e5b16f1d51b788516bb 100644
--- a/docs/examples/te_llama/te_llama.py
+++ b/docs/examples/te_llama/te_llama.py
@@ -15,11 +15,17 @@ from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init
import transformers
-from transformers.models.llama.modeling_llama import LlamaModel, LlamaForCausalLM, LlamaRMSNorm, LlamaConfig
+from transformers.models.llama.modeling_llama import (
+ LlamaModel,
+ LlamaForCausalLM,
+ LlamaRMSNorm,
+ LlamaConfig,
+)
from transformers.modeling_utils import _add_variant, load_state_dict, _load_state_dict_into_model
from transformers.utils import WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files
+
@contextmanager
def replace_decoder(te_decoder_cls):
"""
@@ -43,6 +49,7 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
args: positional args (for compatibility with `LlamaDecoderLayer`)
kwargs: keyword args (for compatibility with `LlamaDecoderLayer`)
"""
+
def __init__(self, config, *args, **kwargs):
super().__init__(
hidden_size=config.hidden_size,
@@ -56,22 +63,22 @@ class TELlamaDecoderLayer(te.pytorch.TransformerLayer):
normalization="RMSNorm",
activation="swiglu",
attn_input_format="bshd",
- num_gqa_groups=config.num_key_value_heads
+ num_gqa_groups=config.num_key_value_heads,
)
- te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)
+ te_rope = RotaryPositionEmbedding(config.hidden_size // config.num_attention_heads)
self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()
- def forward(self,
- hidden_states,
- *args,
- attention_mask,
- **kwargs):
+ def forward(self, hidden_states, *args, attention_mask, **kwargs):
"""
Custom forward to make sure we only pass relevant arguments to the
forward pass of the `TransformerLayer`. Also, make sure the output
format matches the output of the HF's `LlamaDecoderLayer`.
"""
- return (super().forward(hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb),)
+ return (
+ super().forward(
+ hidden_states, attention_mask=attention_mask, rotary_pos_emb=self.te_rope_emb
+ ),
+ )
class TELlamaForCausalLM:
@@ -95,21 +102,29 @@ class TELlamaForCausalLM:
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo: https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
- vanilla_model = cls(config).to(kwargs['torch_dtype'])
+ vanilla_model = cls(config).to(kwargs["torch_dtype"])
is_local = os.path.isdir(pretrained_model_name_or_path)
subfolder = ""
variant = None
if os.path.isfile(
- os.path.join(pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant))
- ):
+ os.path.join(
+ pretrained_model_name_or_path,
+ subfolder,
+ _add_variant("model.safetensors.index.json", variant),
+ )
+ ):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
- pretrained_model_name_or_path, subfolder, _add_variant("model.safetensors.index.json", variant)
+ pretrained_model_name_or_path,
+ subfolder,
+ _add_variant("model.safetensors.index.json", variant),
)
is_sharded = True
elif os.path.isfile(
- os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
- ):
+ os.path.join(
+ pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
+ )
+ ):
# Load from a sharded PyTorch checkpoint
archive_file = os.path.join(
pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
@@ -118,10 +133,9 @@ class TELlamaForCausalLM:
else:
raise AssertionError("Only sharded PyTorch ckpt format supported at the moment")
-
resolved_archive_file, sharded_metadata = get_checkpoint_shard_files(
- pretrained_model_name_or_path,
- archive_file,
+ pretrained_model_name_or_path,
+ archive_file,
)
# If the checkpoint is not sharded, it's a trivial sharding case
@@ -142,48 +156,63 @@ class TELlamaForCausalLM:
return vanilla_model
+
def replace_params(hf_state_dict, te_state_dict, config):
# collect all layer prefixes to update
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
- layer_prefix_pat = 'model.layers.\d+.'
+ layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
-
-
for layer_prefix in all_layer_prefixes:
# When loading weights into models with less number of layers, skip the
# copy if the corresponding layer doesn't exist in HF model
- if layer_prefix + 'input_layernorm.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:]
+ if layer_prefix + "input_layernorm.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.layernorm_qkv.layer_norm_weight"].data[
+ :
+ ] = hf_state_dict[layer_prefix + "input_layernorm.weight"].data[:]
+
+ if layer_prefix + "self_attn.q_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.layernorm_qkv.query_weight"].data[:] = (
+ hf_state_dict[layer_prefix + "self_attn.q_proj.weight"].data[:]
+ )
- if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:]
+ if layer_prefix + "self_attn.k_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.layernorm_qkv.key_weight"].data[:] = (
+ hf_state_dict[layer_prefix + "self_attn.k_proj.weight"].data[:]
+ )
- if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:]
+ if layer_prefix + "self_attn.v_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.layernorm_qkv.value_weight"].data[:] = (
+ hf_state_dict[layer_prefix + "self_attn.v_proj.weight"].data[:]
+ )
- if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:]
+ if layer_prefix + "self_attn.o_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "self_attention.proj.weight"].data[:] = hf_state_dict[
+ layer_prefix + "self_attn.o_proj.weight"
+ ].data[:]
- if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:]
+ if layer_prefix + "post_attention_layernorm.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "layernorm_mlp.layer_norm_weight"].data[:] = hf_state_dict[
+ layer_prefix + "post_attention_layernorm.weight"
+ ].data[:]
- if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:]
-
# It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to
# load them separately.
- if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \
- hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data
-
- if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \
- hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data
-
- if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict:
- te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:]
- return all_layer_prefixes
\ No newline at end of file
+ if layer_prefix + "mlp.gate_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
+ : config.intermediate_size
+ ] = hf_state_dict[layer_prefix + "mlp.gate_proj.weight"].data
+
+ if layer_prefix + "mlp.up_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "layernorm_mlp.fc1_weight"].data[
+ config.intermediate_size :
+ ] = hf_state_dict[layer_prefix + "mlp.up_proj.weight"].data
+
+ if layer_prefix + "mlp.down_proj.weight" in hf_state_dict:
+ te_state_dict[layer_prefix + "layernorm_mlp.fc2_weight"].data[:] = hf_state_dict[
+ layer_prefix + "mlp.down_proj.weight"
+ ].data[:]
+ return all_layer_prefixes
diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py
index 71d2aa2e2ef4deedfcb7d2bcc797283ffd80887b..b6b3683d4c397008ad7273818fe62af7c4df153a 100644
--- a/docs/examples/te_llama/utils.py
+++ b/docs/examples/te_llama/utils.py
@@ -10,29 +10,36 @@ import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
-from transformers import AutoModelForCausalLM, AutoTokenizer, get_linear_schedule_with_warmup, AutoConfig
+from transformers import (
+ AutoModelForCausalLM,
+ AutoTokenizer,
+ get_linear_schedule_with_warmup,
+ AutoConfig,
+)
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
from accelerate import Accelerator
from accelerate.utils.dataclasses import FP8RecipeKwargs
+
class HyperParameters:
def __init__(self):
self.mixed_precision = "bf16"
- #self.model_name = "" # <== Add model weight location here
+ # self.model_name = "" # <== Add model weight location here
self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
self.batch_size = 8
self.max_seq_length = 256
self.gradient_accumulation_steps = 1
- self.num_warmup_steps=5
- self.num_training_steps=10
-
+ self.num_warmup_steps = 5
+ self.num_training_steps = 10
+
hyperparams = HyperParameters()
-def get_dataloaders(accelerator:Accelerator, hyperparams):
+
+def get_dataloaders(accelerator: Accelerator, hyperparams):
dataset = load_dataset(hyperparams.dataset_name, split="train")
tokenizer = AutoTokenizer.from_pretrained(hyperparams.model_name)
if getattr(tokenizer, "pad_token", None) is None:
@@ -45,16 +52,12 @@ def get_dataloaders(accelerator:Accelerator, hyperparams):
padding=False,
max_length=hyperparams.max_seq_length,
return_overflowing_tokens=False,
- return_length=False
+ return_length=False,
)
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
with accelerator.main_process_first():
- dataset = dataset.map(
- tokenize,
- batched=True,
- remove_columns=dataset.column_names
- )
+ dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
# Simply pad to the multiple of 16 for both FP8 and BF16 precision
pad_to_multiple_of = 16
@@ -72,6 +75,7 @@ def get_dataloaders(accelerator:Accelerator, hyperparams):
train_dataloader = DataLoader(dataset, **dataloader_params)
return train_dataloader
+
def init_baseline_model(hyperparams):
# Init the model
config = AutoConfig.from_pretrained(hyperparams.model_name)
@@ -84,42 +88,47 @@ def init_baseline_model(hyperparams):
)
model = model.cuda()
# Needed for the cases when using TELlamaForCausalLM. So adding here for 1:1 comparison
- model.config.use_cache=False
+ model.config.use_cache = False
return model
+
def init_te_llama_model(hyperparams):
# Init the model
from te_llama import TELlamaForCausalLM
+
config = AutoConfig.from_pretrained(hyperparams.model_name)
config._attn_implementation = "flash_attention_2"
model = TELlamaForCausalLM.from_pretrained_local(
- hyperparams.model_name,
- config=config,
- torch_dtype=torch.bfloat16,
+ hyperparams.model_name,
+ config=config,
+ torch_dtype=torch.bfloat16,
)
model = model.cuda()
# Needed for the cases when using TELlamaForCausalLM
- model.config.use_cache=False
+ model.config.use_cache = False
return model
+
def wrap_with_accelerator(model, hyperparams):
# Create FP8 kwarg handler if required
- fp8_kwarg_handler = [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None
+ fp8_kwarg_handler = (
+ [FP8RecipeKwargs(backend="te")] if hyperparams.mixed_precision == "fp8" else None
+ )
# Init HF accelerator that's used for training
accelerator = Accelerator(
log_with="wandb",
gradient_accumulation_steps=hyperparams.gradient_accumulation_steps,
mixed_precision=hyperparams.mixed_precision,
- kwargs_handlers=fp8_kwarg_handler
+ kwargs_handlers=fp8_kwarg_handler,
)
- #accelerator.print(f'State: {accelerator.state}')
+ # accelerator.print(f'State: {accelerator.state}')
train_dataloader = get_dataloaders(accelerator, hyperparams)
# Wrap model, optimizer/scheduler, dataloaders in accelerate
- optimizer = AdamW(params = model.parameters(), lr=hyperparams.learning_rate, fused=True)
+ optimizer = AdamW(params=model.parameters(), lr=hyperparams.learning_rate, fused=True)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=100,
@@ -131,6 +140,7 @@ def wrap_with_accelerator(model, hyperparams):
return accelerator, model, optimizer, train_dataloader, lr_scheduler
+
def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler):
model.train()
total_loss = 0
@@ -170,7 +180,11 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer,
end.record()
accelerator.end_training()
- print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds")
+ print(
+ f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step:"
+ f" {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds"
+ )
+
def restart_jupyter_notebook():
# Try restarting the Jupyter kernel
@@ -179,18 +193,23 @@ def restart_jupyter_notebook():
# Check whether the device memory has been flushed
if torch.cuda.memory_allocated() != 0:
import warnings
+
warnings.warn("The device memory hasn't been flushed, trying with a second method!")
# Try restarting the Jupyter kernel another way
# Restart the kernel
from IPython.core.display import HTML
+
HTML("")
if torch.cuda.memory_allocated() != 0:
- print("The device memory hasn't been flushed, try manually restarting the Jupyter kernel!")
+ print(
+ "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!"
+ )
# Suppress the warnings
if not sys.warnoptions:
import warnings
+
warnings.simplefilter("ignore")
torch.set_warn_always(False)
diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py
index b02b3eddd90b2eb2de4ea40d8d46e7ec68a4172a..716d543d5b9e7f68d3cc3f0745376d24545d05b1 100644
--- a/examples/jax/encoder/test_model_parallel_encoder.py
+++ b/examples/jax/encoder/test_model_parallel_encoder.py
@@ -22,18 +22,19 @@ from jax.experimental.pjit import pjit
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
-DEVICE_DP_AXIS = 'data'
-DEVICE_TP_AXIS = 'model'
-NAMED_BROADCAST_AXIS = 'my_broadcast_axis'
-NAMED_TP_AXIS = 'my_tp_axis'
-PARAMS_KEY = 'params'
-PARAMS_AXES_KEY = PARAMS_KEY + '_axes'
-DROPOUT_KEY = 'dropout'
-INPUT_KEY = 'input_rng'
+DEVICE_DP_AXIS = "data"
+DEVICE_TP_AXIS = "model"
+NAMED_BROADCAST_AXIS = "my_broadcast_axis"
+NAMED_TP_AXIS = "my_tp_axis"
+PARAMS_KEY = "params"
+PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
+DROPOUT_KEY = "dropout"
+INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
+
num_embed: int
enable_seq_paral: bool
@@ -41,36 +42,43 @@ class Net(nn.Module):
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
- te_Encoder = partial(te_flax.TransformerLayer,
- hidden_size=256,
- mlp_hidden_size=1024,
- num_attention_heads=8,
- hidden_dropout=0.1,
- attention_dropout=0.1,
- dropout_rng_name=DROPOUT_KEY,
- layer_type=te_flax.TransformerLayerType.ENCODER,
- self_attn_mask_type='padding',
- enable_relative_embedding=False,
- enable_sequence_parallel=self.enable_seq_paral,
- dtype=jnp.bfloat16)
+ te_Encoder = partial(
+ te_flax.TransformerLayer,
+ hidden_size=256,
+ mlp_hidden_size=1024,
+ num_attention_heads=8,
+ hidden_dropout=0.1,
+ attention_dropout=0.1,
+ dropout_rng_name=DROPOUT_KEY,
+ layer_type=te_flax.TransformerLayerType.ENCODER,
+ self_attn_mask_type="padding",
+ enable_relative_embedding=False,
+ enable_sequence_parallel=self.enable_seq_paral,
+ dtype=jnp.bfloat16,
+ )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
- x = jax.lax.with_sharding_constraint(x,
- jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None))
-
- x = te_flax.DenseGeneral(features=256,
- kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
- bias_axes=(NAMED_TP_AXIS,),
- dtype=jnp.bfloat16)(x)
-
- x = te_flax.DenseGeneral(features=256,
- kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
- bias_axes=(NAMED_BROADCAST_AXIS,),
- dtype=jnp.bfloat16)(x)
+ x = jax.lax.with_sharding_constraint(
+ x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
+ )
+
+ x = te_flax.DenseGeneral(
+ features=256,
+ kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
+ bias_axes=(NAMED_TP_AXIS,),
+ dtype=jnp.bfloat16,
+ )(x)
+
+ x = te_flax.DenseGeneral(
+ features=256,
+ kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
+ bias_axes=(NAMED_BROADCAST_AXIS,),
+ dtype=jnp.bfloat16,
+ )(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
@@ -98,20 +106,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch."""
- train_ds_size = len(train_ds['sentence'])
+ train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
- perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
+ perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
- batch_inputs = train_ds['sentence'][perm, ...]
- batch_masks = train_ds['mask'][perm, ...]
- batch_labels = train_ds['label'][perm, ...]
- state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
- batch_labels, var_collect, rngs)
+ batch_inputs = train_ds["sentence"][perm, ...]
+ batch_masks = train_ds["mask"][perm, ...]
+ batch_labels = train_ds["label"][perm, ...]
+ state, loss, accuracy, var_collect = train_fn(
+ state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
+ )
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
@@ -137,7 +146,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop."""
- test_ds_size = len(test_ds['sentence'])
+ test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
@@ -145,9 +154,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
- batch_inputs = test_ds['sentence'][batch_start:batch_end]
- batch_masks = test_ds['mask'][batch_start:batch_end]
- batch_labels = test_ds['label'][batch_start:batch_end]
+ batch_inputs = test_ds["sentence"][batch_start:batch_end]
+ batch_masks = test_ds["mask"][batch_start:batch_end]
+ batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
@@ -159,12 +168,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
- nltk.download('punkt')
- dataset_size = len(dataset['sentence'])
+ nltk.download("punkt")
+ dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
- for j, sentence in enumerate(dataset['sentence']):
+ for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
@@ -184,9 +193,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
- 'sentence': output,
- 'label': dataset['label'].astype(np.float32),
- 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
+ "sentence": output,
+ "label": dataset["label"].astype(np.float32),
+ "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
@@ -196,12 +205,12 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
- train_ds = load_dataset('glue', 'cola', split='train')
- train_ds.set_format(type='np')
+ train_ds = load_dataset("glue", "cola", split="train")
+ train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
- test_ds = load_dataset('glue', 'cola', split='validation')
- test_ds.set_format(type='np')
+ test_ds = load_dataset("glue", "cola", split="validation")
+ test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
@@ -210,7 +219,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
- jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
+ jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
+ )
def get_params_pspec(sharding_rules, abs_var_collect):
@@ -255,8 +265,9 @@ def train_and_evaluate(args):
num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
- assert args.test_batch_size % num_gpu_dp == 0, \
- f"Test batch size needs to be multiple of {num_gpu_dp}"
+ assert (
+ args.test_batch_size % num_gpu_dp == 0
+ ), f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)):
@@ -270,9 +281,9 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
- with te.fp8_autocast(args.use_fp8,
- mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
- None)):
+ with te.fp8_autocast(
+ args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
+ ):
encoder = Net(num_embed, args.enable_sp)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
@@ -285,18 +296,21 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
- out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
- for key in abs_var_collect}
+ out_shardings = {
+ key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
+ }
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
- state = train_state.TrainState.create(apply_fn=encoder.apply,
- params=params,
- tx=optimizer)
+ state = train_state.TrainState.create(
+ apply_fn=encoder.apply, params=params, tx=optimizer
+ )
state_pspec = get_state_pspec(state, params_pspec)
- labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,)
+ labels_pspec = jax.sharding.PartitionSpec(
+ DEVICE_DP_AXIS,
+ )
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
@@ -323,16 +337,20 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
- state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step)
+ state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
+ )
- test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
- var_collect, pjit_eval_step)
+ test_loss, test_accuracy = eval_model(
+ state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
+ )
- print(f"Epoch: {epoch:>2} "
- f"Train Loss: {train_loss:.6f} "
- f"Train Accuracy: {train_accuracy:.6f} "
- f"Test Loss: {test_loss:.6f} "
- f"Test Accuracy: {test_accuracy:.6f} ")
+ print(
+ f"Epoch: {epoch:>2} "
+ f"Train Loss: {train_loss:.6f} "
+ f"Train Accuracy: {train_accuracy:.6f} "
+ f"Test Loss: {test_loss:.6f} "
+ f"Test Accuracy: {test_accuracy:.6f} "
+ )
return [train_loss, train_accuracy, test_loss, test_accuracy]
@@ -382,14 +400,15 @@ def encoder_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
- parser.add_argument("--use-fp8",
- action="store_true",
- default=False,
- help="Use FP8 for inference and training without recalibration")
- parser.add_argument("--enable-sp",
- action="store_true",
- default=False,
- help="Enable sequence parallelism.")
+ parser.add_argument(
+ "--use-fp8",
+ action="store_true",
+ default=False,
+ help="Use FP8 for inference and training without recalibration",
+ )
+ parser.add_argument(
+ "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
+ )
return parser.parse_args(args)
diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py
index f7c1779ca15f85885f12e8f201f0c9a76daa3490..c6223ed5bb782d4c7327d840b8b30d290d13ed0d 100644
--- a/examples/jax/encoder/test_multigpu_encoder.py
+++ b/examples/jax/encoder/test_multigpu_encoder.py
@@ -22,32 +22,35 @@ from jax.experimental.pjit import pjit
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
-DEVICE_DP_AXIS = 'data'
-PARAMS_KEY = 'params'
-PARAMS_AXES_KEY = PARAMS_KEY + '_axes'
-DROPOUT_KEY = 'dropout'
-INPUT_KEY = 'input_rng'
+DEVICE_DP_AXIS = "data"
+PARAMS_KEY = "params"
+PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
+DROPOUT_KEY = "dropout"
+INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
+
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
- te_Encoder = partial(te_flax.TransformerLayer,
- hidden_size=256,
- mlp_hidden_size=1024,
- num_attention_heads=8,
- hidden_dropout=0.1,
- attention_dropout=0.1,
- dropout_rng_name=DROPOUT_KEY,
- layer_type=te_flax.TransformerLayerType.ENCODER,
- self_attn_mask_type='padding',
- enable_relative_embedding=False,
- dtype=jnp.bfloat16)
+ te_Encoder = partial(
+ te_flax.TransformerLayer,
+ hidden_size=256,
+ mlp_hidden_size=1024,
+ num_attention_heads=8,
+ hidden_dropout=0.1,
+ attention_dropout=0.1,
+ dropout_rng_name=DROPOUT_KEY,
+ layer_type=te_flax.TransformerLayerType.ENCODER,
+ self_attn_mask_type="padding",
+ enable_relative_embedding=False,
+ dtype=jnp.bfloat16,
+ )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
@@ -82,20 +85,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn):
"""Train for a single epoch."""
- train_ds_size = len(train_ds['sentence'])
+ train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
- perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
+ perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
- batch_inputs = train_ds['sentence'][perm, ...]
- batch_masks = train_ds['mask'][perm, ...]
- batch_labels = train_ds['label'][perm, ...]
- state, loss, accuracy, var_collect = train_fn(state, batch_inputs, batch_masks,
- batch_labels, var_collect, rngs)
+ batch_inputs = train_ds["sentence"][perm, ...]
+ batch_masks = train_ds["mask"][perm, ...]
+ batch_labels = train_ds["label"][perm, ...]
+ state, loss, accuracy, var_collect = train_fn(
+ state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
+ )
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
@@ -121,7 +125,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
"""Evaluation loop."""
- test_ds_size = len(test_ds['sentence'])
+ test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
@@ -129,9 +133,9 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
- batch_inputs = test_ds['sentence'][batch_start:batch_end]
- batch_masks = test_ds['mask'][batch_start:batch_end]
- batch_labels = test_ds['label'][batch_start:batch_end]
+ batch_inputs = test_ds["sentence"][batch_start:batch_end]
+ batch_masks = test_ds["mask"][batch_start:batch_end]
+ batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
@@ -143,12 +147,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
- nltk.download('punkt')
- dataset_size = len(dataset['sentence'])
+ nltk.download("punkt")
+ dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
- for j, sentence in enumerate(dataset['sentence']):
+ for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
@@ -168,9 +172,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
- 'sentence': output,
- 'label': dataset['label'].astype(np.float32),
- 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
+ "sentence": output,
+ "label": dataset["label"].astype(np.float32),
+ "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
@@ -180,12 +184,12 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
- train_ds = load_dataset('glue', 'cola', split='train')
- train_ds.set_format(type='np')
+ train_ds = load_dataset("glue", "cola", split="train")
+ train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
- test_ds = load_dataset('glue', 'cola', split='validation')
- test_ds.set_format(type='np')
+ test_ds = load_dataset("glue", "cola", split="validation")
+ test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
@@ -194,7 +198,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
- jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
+ jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
+ )
def get_params_pspec(sharding_rules, abs_var_collect):
@@ -232,8 +237,7 @@ def train_and_evaluate(args):
num_gpu = jax.local_device_count()
assert args.batch_size % num_gpu == 0, f"Batch size needs to be multiple of {num_gpu}"
- assert args.test_batch_size % num_gpu == 0, \
- f"Test batch size needs to be multiple of {num_gpu}"
+ assert args.test_batch_size % num_gpu == 0, f"Test batch size needs to be multiple of {num_gpu}"
device_mesh = mesh_utils.create_device_mesh((num_gpu,))
with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)):
@@ -247,8 +251,9 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
- with te.fp8_autocast(args.use_fp8,
- mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)):
+ with te.fp8_autocast(
+ args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None)
+ ):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
@@ -260,18 +265,21 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
- out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
- for key in abs_var_collect}
+ out_shardings = {
+ key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
+ }
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
- state = train_state.TrainState.create(apply_fn=encoder.apply,
- params=params,
- tx=optimizer)
+ state = train_state.TrainState.create(
+ apply_fn=encoder.apply, params=params, tx=optimizer
+ )
state_pspec = get_state_pspec(state, params_pspec)
- labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,)
+ labels_pspec = jax.sharding.PartitionSpec(
+ DEVICE_DP_AXIS,
+ )
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
@@ -298,16 +306,20 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
- state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step)
+ state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step
+ )
- test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
- var_collect, pjit_eval_step)
+ test_loss, test_accuracy = eval_model(
+ state, test_ds, args.test_batch_size, var_collect, pjit_eval_step
+ )
- print(f"Epoch: {epoch:>2} "
- f"Train Loss: {train_loss:.6f} "
- f"Train Accuracy: {train_accuracy:.6f} "
- f"Test Loss: {test_loss:.6f} "
- f"Test Accuracy: {test_accuracy:.6f} ")
+ print(
+ f"Epoch: {epoch:>2} "
+ f"Train Loss: {train_loss:.6f} "
+ f"Train Accuracy: {train_accuracy:.6f} "
+ f"Test Loss: {test_loss:.6f} "
+ f"Test Accuracy: {test_accuracy:.6f} "
+ )
return [train_loss, train_accuracy, test_loss, test_accuracy]
@@ -357,10 +369,12 @@ def encoder_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
- parser.add_argument("--use-fp8",
- action="store_true",
- default=False,
- help="Use FP8 for inference and training without recalibration")
+ parser.add_argument(
+ "--use-fp8",
+ action="store_true",
+ default=False,
+ help="Use FP8 for inference and training without recalibration",
+ )
return parser.parse_args(args)
diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py
index 1330d5c84500ad903802fe7fb3d99f0d69edbfe0..c9620aa2be65ccf77b697164e70831d1d5bf3b9e 100644
--- a/examples/jax/encoder/test_multiprocessing_encoder.py
+++ b/examples/jax/encoder/test_multiprocessing_encoder.py
@@ -24,49 +24,56 @@ from jax.experimental.pjit import pjit
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
-os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
-DEVICE_DP_AXIS = 'data'
-DEVICE_TP_AXIS = 'model'
-NAMED_BROADCAST_AXIS = 'my_broadcast_axis'
-NAMED_TP_AXIS = 'my_tp_axis'
-PARAMS_KEY = 'params'
-PARAMS_AXES_KEY = PARAMS_KEY + '_axes'
-DROPOUT_KEY = 'dropout'
-INPUT_KEY = 'input_rng'
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+DEVICE_DP_AXIS = "data"
+DEVICE_TP_AXIS = "model"
+NAMED_BROADCAST_AXIS = "my_broadcast_axis"
+NAMED_TP_AXIS = "my_tp_axis"
+PARAMS_KEY = "params"
+PARAMS_AXES_KEY = PARAMS_KEY + "_axes"
+DROPOUT_KEY = "dropout"
+INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
+
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
- te_Encoder = partial(te_flax.TransformerLayer,
- hidden_size=256,
- mlp_hidden_size=1024,
- num_attention_heads=8,
- hidden_dropout=0.1,
- attention_dropout=0.1,
- dropout_rng_name=DROPOUT_KEY,
- layer_type=te_flax.TransformerLayerType.ENCODER,
- self_attn_mask_type='padding',
- enable_relative_embedding=False,
- dtype=jnp.bfloat16)
+ te_Encoder = partial(
+ te_flax.TransformerLayer,
+ hidden_size=256,
+ mlp_hidden_size=1024,
+ num_attention_heads=8,
+ hidden_dropout=0.1,
+ attention_dropout=0.1,
+ dropout_rng_name=DROPOUT_KEY,
+ layer_type=te_flax.TransformerLayerType.ENCODER,
+ self_attn_mask_type="padding",
+ enable_relative_embedding=False,
+ dtype=jnp.bfloat16,
+ )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
- x = te_flax.DenseGeneral(features=256,
- kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
- bias_axes=(NAMED_TP_AXIS,),
- dtype=jnp.bfloat16)(x)
+ x = te_flax.DenseGeneral(
+ features=256,
+ kernel_axes=(NAMED_BROADCAST_AXIS, NAMED_TP_AXIS),
+ bias_axes=(NAMED_TP_AXIS,),
+ dtype=jnp.bfloat16,
+ )(x)
- x = te_flax.DenseGeneral(features=256,
- kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
- bias_axes=(NAMED_BROADCAST_AXIS,),
- dtype=jnp.bfloat16)(x)
+ x = te_flax.DenseGeneral(
+ features=256,
+ kernel_axes=(NAMED_TP_AXIS, NAMED_BROADCAST_AXIS),
+ bias_axes=(NAMED_BROADCAST_AXIS,),
+ dtype=jnp.bfloat16,
+ )(x)
x = nn.Dense(features=2, dtype=jnp.bfloat16)(x)
return x
@@ -90,8 +97,9 @@ def shard_array_wrapper(dataset, batch_size, mesh, pspec, enable_partition=False
(dp_size, tp_size) = mesh.device_ids.shape
valid_input_size, global_batch_size, num_steps, tp_group_id = valid_shard_size(
- total_input_size, batch_size, dp_size, tp_size)
- inputs = inputs[:valid_input_size] # skip incomplete batch
+ total_input_size, batch_size, dp_size, tp_size
+ )
+ inputs = inputs[:valid_input_size] # skip incomplete batch
single_input_shape = inputs.shape[1:]
global_input_shape = (global_batch_size, *single_input_shape)
@@ -124,25 +132,39 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
return state, loss, accuracy, var_collect
-def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn, mesh, inputs_pspec,
- masks_pspec, labels_pspec):
+def train_epoch(
+ state,
+ train_ds,
+ batch_size,
+ rngs,
+ var_collect,
+ train_fn,
+ mesh,
+ inputs_pspec,
+ masks_pspec,
+ labels_pspec,
+):
"""Train for a single epoch."""
- total_batch_size = len(train_ds['sentence'])
+ total_batch_size = len(train_ds["sentence"])
(dp_size, tp_size) = mesh.device_ids.shape
- valid_size, _, num_steps, tp_group_id = valid_shard_size(total_batch_size, batch_size, dp_size,
- tp_size)
+ valid_size, _, num_steps, tp_group_id = valid_shard_size(
+ total_batch_size, batch_size, dp_size, tp_size
+ )
perms = jax.random.permutation(rngs[INPUT_KEY], valid_size)
perms = perms.reshape(dp_size, num_steps, batch_size)
perms = perms[tp_group_id]
global_input_shape, input_named_sharding, sentence = shard_array_wrapper(
- train_ds['sentence'], batch_size, mesh, inputs_pspec)
+ train_ds["sentence"], batch_size, mesh, inputs_pspec
+ )
global_mask_shape, mask_named_sharding, mask = shard_array_wrapper(
- train_ds['mask'], batch_size, mesh, masks_pspec)
+ train_ds["mask"], batch_size, mesh, masks_pspec
+ )
global_label_shape, label_named_sharding, label = shard_array_wrapper(
- train_ds['label'], batch_size, mesh, labels_pspec)
+ train_ds["label"], batch_size, mesh, labels_pspec
+ )
epoch_loss = []
epoch_accuracy = []
@@ -152,15 +174,19 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn, mesh,
batch_mask = mask[perm, ...]
batch_label = label[perm, ...]
- shard_input = jax.make_array_from_single_device_arrays(global_input_shape,
- input_named_sharding, [batch_input])
- shard_mask = jax.make_array_from_single_device_arrays(global_mask_shape,
- mask_named_sharding, [batch_mask])
- shard_label = jax.make_array_from_single_device_arrays(global_label_shape,
- label_named_sharding, [batch_label])
-
- state, loss, accuracy, var_collect = train_fn(state, shard_input, shard_mask, shard_label,
- var_collect, rngs)
+ shard_input = jax.make_array_from_single_device_arrays(
+ global_input_shape, input_named_sharding, [batch_input]
+ )
+ shard_mask = jax.make_array_from_single_device_arrays(
+ global_mask_shape, mask_named_sharding, [batch_mask]
+ )
+ shard_label = jax.make_array_from_single_device_arrays(
+ global_label_shape, label_named_sharding, [batch_label]
+ )
+
+ state, loss, accuracy, var_collect = train_fn(
+ state, shard_input, shard_mask, shard_label, var_collect, rngs
+ )
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
@@ -184,36 +210,34 @@ def eval_step(state, inputs, masks, labels, var_collect):
return loss, accuracy
-def eval_model(state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec,
- labels_pspec):
+def eval_model(
+ state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec
+):
"""Evaluation loop."""
- global_input_shape, input_named_sharding, sentence = shard_array_wrapper(test_ds['sentence'],
- batch_size,
- mesh,
- inputs_pspec,
- enable_partition=True)
- global_mask_shape, mask_named_sharding, mask = shard_array_wrapper(test_ds['mask'],
- batch_size,
- mesh,
- masks_pspec,
- enable_partition=True)
- global_label_shape, label_named_sharding, label = shard_array_wrapper(test_ds['label'],
- batch_size,
- mesh,
- labels_pspec,
- enable_partition=True)
+ global_input_shape, input_named_sharding, sentence = shard_array_wrapper(
+ test_ds["sentence"], batch_size, mesh, inputs_pspec, enable_partition=True
+ )
+ global_mask_shape, mask_named_sharding, mask = shard_array_wrapper(
+ test_ds["mask"], batch_size, mesh, masks_pspec, enable_partition=True
+ )
+ global_label_shape, label_named_sharding, label = shard_array_wrapper(
+ test_ds["label"], batch_size, mesh, labels_pspec, enable_partition=True
+ )
all_loss = []
all_accuracy = []
for batch_input, batch_mask, batch_label in zip(sentence, mask, label):
- shard_input = jax.make_array_from_single_device_arrays(global_input_shape,
- input_named_sharding, [batch_input])
- shard_mask = jax.make_array_from_single_device_arrays(global_mask_shape,
- mask_named_sharding, [batch_mask])
- shard_label = jax.make_array_from_single_device_arrays(global_label_shape,
- label_named_sharding, [batch_label])
+ shard_input = jax.make_array_from_single_device_arrays(
+ global_input_shape, input_named_sharding, [batch_input]
+ )
+ shard_mask = jax.make_array_from_single_device_arrays(
+ global_mask_shape, mask_named_sharding, [batch_mask]
+ )
+ shard_label = jax.make_array_from_single_device_arrays(
+ global_label_shape, label_named_sharding, [batch_label]
+ )
loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect)
all_loss.append(loss)
@@ -226,12 +250,12 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_ps
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
- nltk.download('punkt')
- dataset_size = len(dataset['sentence'])
+ nltk.download("punkt")
+ dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
- for j, sentence in enumerate(dataset['sentence']):
+ for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
@@ -251,9 +275,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
- 'sentence': output,
- 'label': dataset['label'].astype(np.float32),
- 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
+ "sentence": output,
+ "label": dataset["label"].astype(np.float32),
+ "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
@@ -263,12 +287,12 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
- train_ds = load_dataset('glue', 'cola', split='train')
- train_ds.set_format(type='np')
+ train_ds = load_dataset("glue", "cola", split="train")
+ train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
- test_ds = load_dataset('glue', 'cola', split='validation')
- test_ds.set_format(type='np')
+ test_ds = load_dataset("glue", "cola", split="validation")
+ test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
@@ -277,7 +301,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
- jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
+ jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
+ )
def get_params_pspec(sharding_rules, abs_var_collect):
@@ -313,10 +338,12 @@ def train_and_evaluate(args):
print(args)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
- jax.distributed.initialize(coordinator_address=args.coordinator_address,
- num_processes=args.num_process,
- process_id=args.process_id,
- local_device_ids=args.process_id)
+ jax.distributed.initialize(
+ coordinator_address=args.coordinator_address,
+ num_processes=args.num_process,
+ process_id=args.process_id,
+ local_device_ids=args.process_id,
+ )
assert jax.local_device_count() == 1, "1 GPU per process"
num_gpu_tp = 2
@@ -328,12 +355,14 @@ def train_and_evaluate(args):
num_gpu_tp = 1
assert args.batch_size % num_gpu_dp == 0, f"Batch size needs to be multiple of {num_gpu_dp}"
- assert args.test_batch_size % num_gpu_dp == 0, \
- f"Test batch size needs to be multiple of {num_gpu_dp}"
+ assert (
+ args.test_batch_size % num_gpu_dp == 0
+ ), f"Test batch size needs to be multiple of {num_gpu_dp}"
device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp))
- with jax.sharding.Mesh(devices=device_mesh,
- axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)) as shard_mesh:
+ with jax.sharding.Mesh(
+ devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS)
+ ) as shard_mesh:
rng = jax.random.PRNGKey(args.seed)
rng, params_rng = jax.random.split(rng)
@@ -344,9 +373,9 @@ def train_and_evaluate(args):
mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len]
label_shape = [args.batch_size]
- with te.fp8_autocast(args.use_fp8,
- mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None,
- None)):
+ with te.fp8_autocast(
+ args.use_fp8, mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None)
+ ):
encoder = Net(num_embed)
inputs = jnp.zeros(input_shape, dtype=jnp.int32)
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
@@ -359,18 +388,21 @@ def train_and_evaluate(args):
masks_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None, None, None)
in_shardings = (None, inputs_pspec, masks_pspec)
- out_shardings = {key: params_pspec if key is PARAMS_KEY else None \
- for key in abs_var_collect}
+ out_shardings = {
+ key: params_pspec if key is PARAMS_KEY else None for key in abs_var_collect
+ }
pjit_encoder_init = pjit(encoder.init, in_shardings, out_shardings)
var_collect = pjit_encoder_init(init_rngs, inputs, masks)
optimizer = optax.adamw(args.lr)
var_collect, params = flax.core.pop(var_collect, PARAMS_KEY)
- state = train_state.TrainState.create(apply_fn=encoder.apply,
- params=params,
- tx=optimizer)
+ state = train_state.TrainState.create(
+ apply_fn=encoder.apply, params=params, tx=optimizer
+ )
state_pspec = get_state_pspec(state, params_pspec)
- labels_pspec = jax.sharding.PartitionSpec(DEVICE_DP_AXIS,)
+ labels_pspec = jax.sharding.PartitionSpec(
+ DEVICE_DP_AXIS,
+ )
in_shardings = (state_pspec, inputs_pspec, masks_pspec, labels_pspec, None, None)
out_shardings = (state_pspec, None, None, None)
@@ -396,18 +428,37 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
- state, train_ds, args.batch_size, rngs, var_collect, pjit_train_step,
- shard_mesh, inputs_pspec, masks_pspec, labels_pspec)
-
- test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size,
- var_collect, pjit_eval_step, shard_mesh,
- inputs_pspec, masks_pspec, labels_pspec)
+ state,
+ train_ds,
+ args.batch_size,
+ rngs,
+ var_collect,
+ pjit_train_step,
+ shard_mesh,
+ inputs_pspec,
+ masks_pspec,
+ labels_pspec,
+ )
+
+ test_loss, test_accuracy = eval_model(
+ state,
+ test_ds,
+ args.test_batch_size,
+ var_collect,
+ pjit_eval_step,
+ shard_mesh,
+ inputs_pspec,
+ masks_pspec,
+ labels_pspec,
+ )
if args.process_id == 0:
- print(f"Epoch: {epoch:>2} "
- f"Train Loss: {train_loss:.6f} "
- f"Train Accuracy: {train_accuracy:.6f} "
- f"Test Loss: {test_loss:.6f} "
- f"Test Accuracy: {test_accuracy:.6f} ")
+ print(
+ f"Epoch: {epoch:>2} "
+ f"Train Loss: {train_loss:.6f} "
+ f"Train Accuracy: {train_accuracy:.6f} "
+ f"Test Loss: {test_loss:.6f} "
+ f"Test Accuracy: {test_accuracy:.6f} "
+ )
jax.distributed.shutdown()
return [train_loss, train_accuracy, test_loss, test_accuracy]
@@ -458,24 +509,31 @@ def encoder_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
- parser.add_argument("--use-fp8",
- action="store_true",
- default=False,
- help="Use FP8 for inference and training without recalibration")
- parser.add_argument("--coordinator-address",
- type=str,
- default="127.0.0.1:1234",
- help="the IP address of process 0 and a port on \
- which that process should launch a coordinator service \
- (default: 127.0.0.1:1234)")
- parser.add_argument("--num-process",
- type=int,
- default=1,
- help="number of processes (default: 1)")
- parser.add_argument("--process-id",
- type=int,
- default=0,
- help="the ID number of the current process (default: 0)")
+ parser.add_argument(
+ "--use-fp8",
+ action="store_true",
+ default=False,
+ help="Use FP8 for inference and training without recalibration",
+ )
+ parser.add_argument(
+ "--coordinator-address",
+ type=str,
+ default="127.0.0.1:1234",
+ help=(
+ "the IP address of process 0 and a port on which that"
+ " process should launch a coordinator service (default:"
+ " 127.0.0.1:1234)"
+ ),
+ )
+ parser.add_argument(
+ "--num-process", type=int, default=1, help="number of processes (default: 1)"
+ )
+ parser.add_argument(
+ "--process-id",
+ type=int,
+ default=0,
+ help="the ID number of the current process (default: 0)",
+ )
return parser.parse_args(args)
diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py
index b89243792550f5cb6e0724a0dca2e0363321ca03..674f7de81526157d4e2fa228421bb0f39b2af105 100644
--- a/examples/jax/encoder/test_single_gpu_encoder.py
+++ b/examples/jax/encoder/test_single_gpu_encoder.py
@@ -19,30 +19,33 @@ from flax.training import train_state
import transformer_engine.jax as te
import transformer_engine.jax.flax as te_flax
-PARAMS_KEY = 'params'
-DROPOUT_KEY = 'dropout'
-INPUT_KEY = 'input_rng'
+PARAMS_KEY = "params"
+DROPOUT_KEY = "dropout"
+INPUT_KEY = "input_rng"
class Net(nn.Module):
"""NLP Encoder"""
+
num_embed: int
@nn.compact
def __call__(self, x, mask, disable_dropout=False):
x = nn.Embed(num_embeddings=self.num_embed, features=256, dtype=jnp.bfloat16)(x)
- te_Encoder = partial(te_flax.TransformerLayer,
- hidden_size=256,
- mlp_hidden_size=1024,
- num_attention_heads=8,
- hidden_dropout=0.1,
- attention_dropout=0.1,
- dropout_rng_name=DROPOUT_KEY,
- layer_type=te_flax.TransformerLayerType.ENCODER,
- self_attn_mask_type='padding',
- enable_relative_embedding=False,
- dtype=jnp.bfloat16)
+ te_Encoder = partial(
+ te_flax.TransformerLayer,
+ hidden_size=256,
+ mlp_hidden_size=1024,
+ num_attention_heads=8,
+ hidden_dropout=0.1,
+ attention_dropout=0.1,
+ dropout_rng_name=DROPOUT_KEY,
+ layer_type=te_flax.TransformerLayerType.ENCODER,
+ self_attn_mask_type="padding",
+ enable_relative_embedding=False,
+ dtype=jnp.bfloat16,
+ )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
@@ -78,20 +81,21 @@ def train_step(state, inputs, masks, labels, var_collect, rngs):
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch."""
- train_ds_size = len(train_ds['sentence'])
+ train_ds_size = len(train_ds["sentence"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
- perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
+ perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
- batch_inputs = train_ds['sentence'][perm, ...]
- batch_masks = train_ds['mask'][perm, ...]
- batch_labels = train_ds['label'][perm, ...]
- state, loss, accuracy, var_collect = train_step(state, batch_inputs, batch_masks,
- batch_labels, var_collect, rngs)
+ batch_inputs = train_ds["sentence"][perm, ...]
+ batch_masks = train_ds["mask"][perm, ...]
+ batch_labels = train_ds["label"][perm, ...]
+ state, loss, accuracy, var_collect = train_step(
+ state, batch_inputs, batch_masks, batch_labels, var_collect, rngs
+ )
epoch_loss.append(loss)
epoch_accuracy.append(accuracy)
@@ -118,7 +122,7 @@ def eval_step(state, inputs, masks, labels, var_collect):
def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop."""
- test_ds_size = len(test_ds['sentence'])
+ test_ds_size = len(test_ds["sentence"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
@@ -126,9 +130,9 @@ def eval_model(state, test_ds, batch_size, var_collect):
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
- batch_inputs = test_ds['sentence'][batch_start:batch_end]
- batch_masks = test_ds['mask'][batch_start:batch_end]
- batch_labels = test_ds['label'][batch_start:batch_end]
+ batch_inputs = test_ds["sentence"][batch_start:batch_end]
+ batch_masks = test_ds["mask"][batch_start:batch_end]
+ batch_labels = test_ds["label"][batch_start:batch_end]
loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
@@ -140,12 +144,12 @@ def eval_model(state, test_ds, batch_size, var_collect):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
- nltk.download('punkt')
- dataset_size = len(dataset['sentence'])
+ nltk.download("punkt")
+ dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
- for j, sentence in enumerate(dataset['sentence']):
+ for j, sentence in enumerate(dataset["sentence"]):
tokens = nltk.word_tokenize(sentence)
tensor = output[j]
@@ -165,9 +169,9 @@ def data_preprocess(dataset, vocab, word_id, max_seq_len):
mask_2d[:seq_len, :seq_len] = 0
new_dataset = {
- 'sentence': output,
- 'label': dataset['label'].astype(np.float32),
- 'mask': mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len))
+ "sentence": output,
+ "label": dataset["label"].astype(np.float32),
+ "mask": mask_3d.reshape((dataset_size, 1, max_seq_len, max_seq_len)),
}
return new_dataset, vocab, word_id
@@ -177,12 +181,12 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
- train_ds = load_dataset('glue', 'cola', split='train')
- train_ds.set_format(type='np')
+ train_ds = load_dataset("glue", "cola", split="train")
+ train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
- test_ds = load_dataset('glue', 'cola', split='validation')
- test_ds.set_format(type='np')
+ test_ds = load_dataset("glue", "cola", split="validation")
+ test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
@@ -191,7 +195,8 @@ def check_fp8(state, var_collect, inputs, masks, labels):
"Check if model includes FP8."
rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)}
assert "fp8_" in str(
- jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs))
+ jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)
+ )
def train_and_evaluate(args):
@@ -214,9 +219,9 @@ def train_and_evaluate(args):
masks = jnp.zeros(mask_shape, dtype=jnp.uint8)
var_collect = encoder.init(init_rngs, inputs, masks)
tx = optax.adamw(args.lr)
- state = train_state.TrainState.create(apply_fn=encoder.apply,
- params=var_collect[PARAMS_KEY],
- tx=tx)
+ state = train_state.TrainState.create(
+ apply_fn=encoder.apply, params=var_collect[PARAMS_KEY], tx=tx
+ )
if args.use_fp8:
labels = jnp.zeros(label_shape, dtype=jnp.bfloat16)
@@ -235,15 +240,18 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
- state, train_ds, args.batch_size, rngs, var_collect)
+ state, train_ds, args.batch_size, rngs, var_collect
+ )
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
- print(f"Epoch: {epoch:>2} "
- f"Train Loss: {train_loss:.6f} "
- f"Train Accuracy: {train_accuracy:.6f} "
- f"Test Loss: {test_loss:.6f} "
- f"Test Accuracy: {test_accuracy:.6f} ")
+ print(
+ f"Epoch: {epoch:>2} "
+ f"Train Loss: {train_loss:.6f} "
+ f"Train Accuracy: {train_accuracy:.6f} "
+ f"Test Loss: {test_loss:.6f} "
+ f"Test Accuracy: {test_accuracy:.6f} "
+ )
return [train_loss, train_accuracy, test_loss, test_accuracy]
@@ -293,10 +301,12 @@ def encoder_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=0, metavar="S", help="random seed (default: 0)")
- parser.add_argument("--use-fp8",
- action="store_true",
- default=False,
- help="Use FP8 for inference and training without recalibration")
+ parser.add_argument(
+ "--use-fp8",
+ action="store_true",
+ default=False,
+ help="Use FP8 for inference and training without recalibration",
+ )
return parser.parse_args(args)
diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py
index ae74a66337205f400dcb372424766a331f5bfa3a..ff431a261b4495db2f99a0e3afd18c5eebd16db4 100644
--- a/examples/jax/mnist/test_single_gpu_mnist.py
+++ b/examples/jax/mnist/test_single_gpu_mnist.py
@@ -1,7 +1,7 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
-""" MNIST training on single GPU"""
+"""MNIST training on single GPU"""
import argparse
import unittest
from functools import partial
@@ -20,13 +20,14 @@ import transformer_engine.jax.flax as te_flax
IMAGE_H = 28
IMAGE_W = 28
IMAGE_C = 1
-PARAMS_KEY = 'params'
-DROPOUT_KEY = 'dropout'
-INPUT_KEY = 'input_rng'
+PARAMS_KEY = "params"
+DROPOUT_KEY = "dropout"
+INPUT_KEY = "input_rng"
class Net(nn.Module):
"""CNN model for MNIST."""
+
use_te: bool = False
@nn.compact
@@ -83,17 +84,17 @@ def update_model(state, grads):
def train_epoch(state, train_ds, batch_size, rngs, var_collect):
"""Train for a single epoch."""
- train_ds_size = len(train_ds['image'])
+ train_ds_size = len(train_ds["image"])
steps_per_epoch = train_ds_size // batch_size
perms = jax.random.permutation(rngs[INPUT_KEY], train_ds_size)
- perms = perms[:steps_per_epoch * batch_size] # skip incomplete batch
+ perms = perms[: steps_per_epoch * batch_size] # skip incomplete batch
perms = perms.reshape((steps_per_epoch, batch_size))
epoch_loss = []
epoch_accuracy = []
for perm in perms:
- batch_images = train_ds['image'][perm, ...]
- batch_labels = train_ds['label'][perm, ...]
+ batch_images = train_ds["image"][perm, ...]
+ batch_labels = train_ds["label"][perm, ...]
grads, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect, rngs)
state, var_collect = update_model(state, grads)
epoch_loss.append(loss)
@@ -106,7 +107,7 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect):
def eval_model(state, test_ds, batch_size, var_collect):
"""Evaluation loop."""
- test_ds_size = len(test_ds['image'])
+ test_ds_size = len(test_ds["image"])
num_steps = test_ds_size // batch_size
valid_size = num_steps * batch_size
all_loss = []
@@ -114,8 +115,8 @@ def eval_model(state, test_ds, batch_size, var_collect):
for batch_start in range(0, valid_size, batch_size):
batch_end = batch_start + batch_size
- batch_images = test_ds['image'][batch_start:batch_end]
- batch_labels = test_ds['label'][batch_start:batch_end]
+ batch_images = test_ds["image"][batch_start:batch_end]
+ batch_labels = test_ds["label"][batch_start:batch_end]
_, loss, accuracy = apply_model(state, batch_images, batch_labels, var_collect)
all_loss.append(loss)
all_accuracy.append(accuracy)
@@ -127,21 +128,21 @@ def eval_model(state, test_ds, batch_size, var_collect):
def get_datasets():
"""Load MNIST train and test datasets into memory."""
- train_ds = load_dataset('mnist', split='train')
- train_ds.set_format(type='np')
- batch_size = train_ds['image'].shape[0]
+ train_ds = load_dataset("mnist", split="train")
+ train_ds.set_format(type="np")
+ batch_size = train_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_train_ds = {
- 'image': train_ds['image'].astype(np.float32).reshape(shape) / 255.,
- 'label': train_ds['label']
+ "image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
+ "label": train_ds["label"],
}
- test_ds = load_dataset('mnist', split='test')
- test_ds.set_format(type='np')
- batch_size = test_ds['image'].shape[0]
+ test_ds = load_dataset("mnist", split="test")
+ test_ds.set_format(type="np")
+ batch_size = test_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
new_test_ds = {
- 'image': test_ds['image'].astype(np.float32).reshape(shape) / 255.,
- 'label': test_ds['label']
+ "image": test_ds["image"].astype(np.float32).reshape(shape) / 255.0,
+ "label": test_ds["label"],
}
return new_train_ds, new_test_ds
@@ -149,8 +150,13 @@ def get_datasets():
def check_fp8(state, var_collect, input_shape, label_shape):
"Check if model includes FP8."
assert "f8_" in str(
- jax.make_jaxpr(apply_model)(state, jnp.empty(input_shape, dtype=jnp.bfloat16),
- jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect))
+ jax.make_jaxpr(apply_model)(
+ state,
+ jnp.empty(input_shape, dtype=jnp.bfloat16),
+ jnp.empty(label_shape, dtype=jnp.bfloat16),
+ var_collect,
+ )
+ )
def train_and_evaluate(args):
@@ -173,17 +179,21 @@ def train_and_evaluate(args):
cnn = Net(args.use_te)
var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16))
tx = optax.sgd(args.lr, args.momentum)
- state = train_state.TrainState.create(apply_fn=cnn.apply,
- params=var_collect[PARAMS_KEY],
- tx=tx)
+ state = train_state.TrainState.create(
+ apply_fn=cnn.apply, params=var_collect[PARAMS_KEY], tx=tx
+ )
if args.use_fp8:
check_fp8(state, var_collect, input_shape, label_shape)
if args.dry_run:
- apply_model(state, jnp.empty(input_shape, dtype=jnp.bfloat16),
- jnp.empty(label_shape, dtype=jnp.bfloat16), var_collect,
- {DROPOUT_KEY: dropout_rng})
+ apply_model(
+ state,
+ jnp.empty(input_shape, dtype=jnp.bfloat16),
+ jnp.empty(label_shape, dtype=jnp.bfloat16),
+ var_collect,
+ {DROPOUT_KEY: dropout_rng},
+ )
print("PASSED")
return None
@@ -193,14 +203,17 @@ def train_and_evaluate(args):
rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng}
state, train_loss, train_accuracy, var_collect = train_epoch(
- state, train_ds, args.batch_size, rngs, var_collect)
+ state, train_ds, args.batch_size, rngs, var_collect
+ )
test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect)
- print(f"Epoch: {epoch:>2} "
- f"Train Loss: {train_loss:.6f} "
- f"Train Accuracy: {train_accuracy:.6f} "
- f"Test Loss: {test_loss:.6f} "
- f"Test Accuracy: {test_accuracy:.6f} ")
+ print(
+ f"Epoch: {epoch:>2} "
+ f"Train Loss: {train_loss:.6f} "
+ f"Train Accuracy: {train_accuracy:.6f} "
+ f"Test Loss: {test_loss:.6f} "
+ f"Test Accuracy: {test_accuracy:.6f} "
+ )
return [train_loss, train_accuracy, test_loss, test_accuracy]
@@ -250,15 +263,18 @@ def mnist_parser(args):
help="quickly check a single pass",
)
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
- parser.add_argument("--use-fp8",
- action="store_true",
- default=False,
- help="Use FP8 for inference and training without recalibration. " \
- "It also enables Transformer Engine implicitly.")
- parser.add_argument("--use-te",
- action="store_true",
- default=False,
- help="Use Transformer Engine")
+ parser.add_argument(
+ "--use-fp8",
+ action="store_true",
+ default=False,
+ help=(
+ "Use FP8 for inference and training without recalibration. "
+ "It also enables Transformer Engine implicitly."
+ ),
+ )
+ parser.add_argument(
+ "--use-te", action="store_true", default=False, help="Use Transformer Engine"
+ )
return parser.parse_args(args)
diff --git a/examples/paddle/mnist/test_single_gpu_mnist.py b/examples/paddle/mnist/test_single_gpu_mnist.py
index c1eacf39dac99c18bc38aebe53b878d5c87c378f..de5c9e9b6cca759dcdaaa07b7667bf3be0d54ef5 100644
--- a/examples/paddle/mnist/test_single_gpu_mnist.py
+++ b/examples/paddle/mnist/test_single_gpu_mnist.py
@@ -59,7 +59,9 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8):
model.train()
losses = []
for batch_id, (data, labels) in enumerate(train_loader):
- with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
+ with paddle.amp.auto_cast(
+ dtype="bfloat16", level="O2"
+ ): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
loss = F.cross_entropy(outputs, labels)
@@ -70,10 +72,12 @@ def train(args, model, train_loader, optimizer, epoch, use_fp8):
optimizer.clear_gradients()
if batch_id % args.log_interval == 0:
- print(f"Train Epoch: {epoch} "
- f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
- f"({100. * batch_id / len(train_loader):.0f}%)]\t"
- f"Loss: {loss.item():.6f}")
+ print(
+ f"Train Epoch: {epoch} "
+ f"[{batch_id * len(data)}/{len(train_loader.dataset)} "
+ f"({100. * batch_id / len(train_loader):.0f}%)]\t"
+ f"Loss: {loss.item():.6f}"
+ )
if args.dry_run:
return loss.item()
avg_loss = sum(losses) / len(losses)
@@ -89,7 +93,9 @@ def evaluate(model, test_loader, epoch, use_fp8):
with paddle.no_grad():
for data, labels in test_loader:
- with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
+ with paddle.amp.auto_cast(
+ dtype="bfloat16", level="O2"
+ ): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=use_fp8):
outputs = model(data)
acc = metric.compute(outputs, labels)
@@ -104,7 +110,9 @@ def calibrate(model, test_loader):
with paddle.no_grad():
for data, _ in test_loader:
- with paddle.amp.auto_cast(dtype='bfloat16', level='O2'): # pylint: disable=not-context-manager
+ with paddle.amp.auto_cast(
+ dtype="bfloat16", level="O2"
+ ): # pylint: disable=not-context-manager
with te.fp8_autocast(enabled=False, calibrating=True):
_ = model(data)
@@ -160,20 +168,27 @@ def mnist_parser(args):
metavar="N",
help="how many batches to wait before logging training status",
)
- parser.add_argument("--use-fp8",
- action="store_true",
- default=False,
- help="Use FP8 for inference and training without recalibration. " \
- "It also enables Transformer Engine implicitly.")
- parser.add_argument("--use-fp8-infer",
- action="store_true",
- default=False,
- help="Use FP8 for inference only. If not using FP8 for training, "
- "calibration is performed for FP8 infernece.")
- parser.add_argument("--use-te",
- action="store_true",
- default=False,
- help="Use Transformer Engine")
+ parser.add_argument(
+ "--use-fp8",
+ action="store_true",
+ default=False,
+ help=(
+ "Use FP8 for inference and training without recalibration. "
+ "It also enables Transformer Engine implicitly."
+ ),
+ )
+ parser.add_argument(
+ "--use-fp8-infer",
+ action="store_true",
+ default=False,
+ help=(
+ "Use FP8 for inference only. If not using FP8 for training, "
+ "calibration is performed for FP8 infernece."
+ ),
+ )
+ parser.add_argument(
+ "--use-te", action="store_true", default=False, help="Use Transformer Engine"
+ )
args = parser.parse_args(args)
return args
@@ -185,9 +200,9 @@ def train_and_evaluate(args):
paddle.seed(args.seed)
# Load MNIST dataset
- transform = Normalize(mean=[127.5], std=[127.5], data_format='CHW')
- train_dataset = MNIST(mode='train', transform=transform)
- val_dataset = MNIST(mode='test', transform=transform)
+ transform = Normalize(mean=[127.5], std=[127.5], data_format="CHW")
+ train_dataset = MNIST(mode="train", transform=transform)
+ val_dataset = MNIST(mode="test", transform=transform)
# Define data loaders
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
@@ -198,7 +213,7 @@ def train_and_evaluate(args):
optimizer = paddle.optimizer.Adam(learning_rate=args.lr, parameters=model.parameters())
# Cast model to BF16
- model = paddle.amp.decorate(models=model, level='O2', dtype='bfloat16')
+ model = paddle.amp.decorate(models=model, level="O2", dtype="bfloat16")
for epoch in range(1, args.epochs + 1):
loss = train(args, model, train_loader, optimizer, epoch, args.use_fp8)
@@ -209,7 +224,7 @@ def train_and_evaluate(args):
if args.save_model or args.use_fp8_infer:
paddle.save(model.state_dict(), "mnist_cnn.pdparams")
- print('Eval with reloaded checkpoint : fp8=' + str(args.use_fp8))
+ print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8))
weights = paddle.load("mnist_cnn.pdparams")
model.set_state_dict(weights)
acc = evaluate(model, val_loader, 0, args.use_fp8)
@@ -235,8 +250,10 @@ class TestMNIST(unittest.TestCase):
assert actual[0] < desired_traing_loss
assert actual[1] > desired_test_accuracy
- @unittest.skipIf(paddle.device.cuda.get_device_capability() < (8, 0),
- "BF16 MNIST example requires Ampere+ GPU")
+ @unittest.skipIf(
+ paddle.device.cuda.get_device_capability() < (8, 0),
+ "BF16 MNIST example requires Ampere+ GPU",
+ )
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
self.args.use_te = True
diff --git a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py b/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py
index f0905d76975e9d4de6ffdb1faf4cb91018634b37..619dbaf9d736bbe185bd7a978c869fe145611874 100644
--- a/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py
+++ b/examples/pytorch/comm_gemm_overlap/ln_mlp_with_overlap.py
@@ -15,62 +15,77 @@ import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
+
def parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
- description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers.")
- parser.add_argument('-i', "--num-iters", type=int, default=5,
- help="Number of dummy 'training' iterations.")
- parser.add_argument('-b', "--batch-size", type=int, default=2,
- help="Input batch size.")
- parser.add_argument('-s', "--seq-length", type=int, default=2048,
- help="Input sequence length.")
- parser.add_argument('-n', "--num-heads", type=int, default=64,
- help="Number of attention heads.")
- parser.add_argument('-d', "--head-dim", type=int, default=128,
- help="Dimension of each attention head.")
- parser.add_argument("--mlp-expansion-factor", type=int, default=4,
- help="MLP block intermediate size as a factor of hidden dimension.")
- parser.add_argument("--seed", type=int, default=1234,
- help="RNG seed.")
- parser.add_argument("--fp8", action="store_true", default=False,
- help="Enables the te.fp8_autocast() context.")
- parser.add_argument("--no-comm-overlap", action="store_true", default=False,
- help="Disable the comm+GEMM overlap.")
- parser.add_argument('-v', "--verbose", action="store_true", default=False)
+ description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers."
+ )
+ parser.add_argument(
+ "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
+ )
+ parser.add_argument("-b", "--batch-size", type=int, default=2, help="Input batch size.")
+ parser.add_argument("-s", "--seq-length", type=int, default=2048, help="Input sequence length.")
+ parser.add_argument(
+ "-n", "--num-heads", type=int, default=64, help="Number of attention heads."
+ )
+ parser.add_argument(
+ "-d", "--head-dim", type=int, default=128, help="Dimension of each attention head."
+ )
+ parser.add_argument(
+ "--mlp-expansion-factor",
+ type=int,
+ default=4,
+ help="MLP block intermediate size as a factor of hidden dimension.",
+ )
+ parser.add_argument("--seed", type=int, default=1234, help="RNG seed.")
+ parser.add_argument(
+ "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context."
+ )
+ parser.add_argument(
+ "--no-comm-overlap",
+ action="store_true",
+ default=False,
+ help="Disable the comm+GEMM overlap.",
+ )
+ parser.add_argument("-v", "--verbose", action="store_true", default=False)
return parser.parse_args(argv, namespace)
+
def train(opts):
WORLD_RANK = int(os.getenv("RANK"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE"))
- def dist_print(msg, end='\n', all_ranks=False):
+
+ def dist_print(msg, end="\n", all_ranks=False):
if WORLD_RANK == 0 or all_ranks:
print(f"[RANK-{WORLD_RANK}] {msg}", end=end)
# Seed RNG
torch.cuda.set_device(WORLD_RANK)
- torch.manual_seed(opts.seed+WORLD_RANK)
- torch.cuda.manual_seed(opts.seed+WORLD_RANK)
+ torch.manual_seed(opts.seed + WORLD_RANK)
+ torch.cuda.manual_seed(opts.seed + WORLD_RANK)
# Initialize torch.distributed global process group and get TP group
- dist.init_process_group(backend="nccl",
- rank=WORLD_RANK,
- world_size=WORLD_SIZE,
- device_id=torch.device(f'cuda:{WORLD_RANK}'))
+ dist.init_process_group(
+ backend="nccl",
+ rank=WORLD_RANK,
+ world_size=WORLD_SIZE,
+ device_id=torch.device(f"cuda:{WORLD_RANK}"),
+ )
tp_group = dist.new_group(backend="nccl")
tp_size = dist.get_world_size(tp_group)
# Intialize userbuffers
ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad
- 'method': 'ring_exchange',
- 'num_splits' : 8,
- 'num_sm' : 1,
- 'set_sm_margin' : False,
+ "method": "ring_exchange",
+ "num_splits": 8,
+ "num_sm": 1,
+ "set_sm_margin": False,
}
rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop
- 'method': 'ring_exchange',
- 'num_splits' : 4,
- 'num_sm' : 1,
- 'set_sm_margin' : True,
+ "method": "ring_exchange",
+ "num_splits": 4,
+ "num_sm": 1,
+ "set_sm_margin": True,
}
hidden_size = opts.num_heads * opts.head_dim
batched_size = opts.seq_length * opts.batch_size
@@ -78,30 +93,31 @@ def train(opts):
te.initialize_ub(
[batched_size, hidden_size],
tp_group,
- use_fp8 = opts.fp8,
- dtype = torch.bfloat16,
- ub_cfgs = {
- 'fc1_fprop': ag_cfg,
- 'fc1_dgrad': rs_cfg,
- 'fc2_fprop': rs_cfg,
- 'fc2_dgrad': ag_cfg,
+ use_fp8=opts.fp8,
+ dtype=torch.bfloat16,
+ ub_cfgs={
+ "fc1_fprop": ag_cfg,
+ "fc1_dgrad": rs_cfg,
+ "fc2_fprop": rs_cfg,
+ "fc2_dgrad": ag_cfg,
},
)
#
model = te.LayerNormMLP(
- hidden_size, opts.mlp_expansion_factor * hidden_size,
- params_dtype = torch.bfloat16,
- device = 'cuda',
- tp_group = tp_group,
- tp_size = tp_size,
- set_parallel_mode = True,
- sequence_parallel = True, # this is required for comm+GEMM overlap
- seq_length = opts.seq_length,
- micro_batch_size = opts.batch_size,
- ub_overlap_rs_dgrad = not opts.no_comm_overlap,
- ub_overlap_rs = not opts.no_comm_overlap,
- ub_overlap_ag = not opts.no_comm_overlap,
+ hidden_size,
+ opts.mlp_expansion_factor * hidden_size,
+ params_dtype=torch.bfloat16,
+ device="cuda",
+ tp_group=tp_group,
+ tp_size=tp_size,
+ set_parallel_mode=True,
+ sequence_parallel=True, # this is required for comm+GEMM overlap
+ seq_length=opts.seq_length,
+ micro_batch_size=opts.batch_size,
+ ub_overlap_rs_dgrad=not opts.no_comm_overlap,
+ ub_overlap_rs=not opts.no_comm_overlap,
+ ub_overlap_ag=not opts.no_comm_overlap,
)
# Initialize optimizer with model parameters
@@ -109,16 +125,19 @@ def train(opts):
# Fp8 recipe setup
fp8_format = Format.HYBRID
- fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32,
- amax_compute_algo="max")
+ fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max")
# Start dummy "training" iterations
for i in range(opts.num_iters):
dist_print(f"Iter {i+1}", all_ranks=opts.verbose)
dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
- x = torch.rand((opts.seq_length // tp_size, opts.batch_size, hidden_size),
- dtype=torch.bfloat16, device='cuda', requires_grad=True)
+ x = torch.rand(
+ (opts.seq_length // tp_size, opts.batch_size, hidden_size),
+ dtype=torch.bfloat16,
+ device="cuda",
+ requires_grad=True,
+ )
dist_print("|-- Forward pass", all_ranks=opts.verbose)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
@@ -135,17 +154,15 @@ def train(opts):
te.destroy_ub()
dist.destroy_process_group()
+
if __name__ == "__main__":
if "TORCHELASTIC_RUN_ID" in os.environ.keys():
args = parse_args()
train(args)
else:
subprocess.run(
- [
- 'torchrun', f'--nproc-per-node={torch.cuda.device_count()}',
- *sys.argv
- ],
+ ["torchrun", f"--nproc-per-node={torch.cuda.device_count()}", *sys.argv],
env=os.environ,
- check=True
+ check=True,
)
os._exit(0)
diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py
index e78c1b1582c95d59b866d469f7cf0c80dabff762..cf0a75c3365e7fb9aab693adda85bbb015de667a 100644
--- a/examples/pytorch/fsdp/fsdp.py
+++ b/examples/pytorch/fsdp/fsdp.py
@@ -14,7 +14,7 @@ from torch.distributed.fsdp import FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp.wrap import always_wrap_policy, transformer_auto_wrap_policy
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
- checkpoint_wrapper
+ checkpoint_wrapper,
)
import transformer_engine.pytorch as te
@@ -29,46 +29,56 @@ rng_seed = 1234
torch.manual_seed(rng_seed)
torch.cuda.manual_seed(rng_seed)
CUDA_RNG_STATES_TRACKER = te.distributed.CudaRNGStatesTracker()
-CUDA_RNG_STATES_TRACKER.add('model-parallel-rng', rng_seed)
+CUDA_RNG_STATES_TRACKER.add("model-parallel-rng", rng_seed)
+
+
def get_cuda_rng_tracker():
return CUDA_RNG_STATES_TRACKER
+
def apply_fsdp_checkpointing(model, blocks):
"""apply activation checkpointing to model
returns None as model is updated directly
"""
- wrapper = lambda m: checkpoint_wrapper(m,
- checkpoint_fn=te.distributed.checkpoint,
- use_reentrant=False,
- get_rng_state_tracker=get_cuda_rng_tracker)
+ wrapper = lambda m: checkpoint_wrapper(
+ m,
+ checkpoint_fn=te.distributed.checkpoint,
+ use_reentrant=False,
+ get_rng_state_tracker=get_cuda_rng_tracker,
+ )
check_fn = lambda submodule: isinstance(submodule, blocks)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=wrapper, check_fn=check_fn)
+
def lowercase(s):
return str(s).lower()
+
def torch_dtype(d):
typemap = {
- 'fp32' : torch.float32,
- 'float32' : torch.float32,
- 'fp16' : torch.float16,
- 'float16' : torch.float16,
- 'bf16' : torch.bfloat16,
- 'bfloat16' : torch.bfloat16
+ "fp32": torch.float32,
+ "float32": torch.float32,
+ "fp16": torch.float16,
+ "float16": torch.float16,
+ "bf16": torch.bfloat16,
+ "bfloat16": torch.bfloat16,
}
if lowercase(d) not in typemap.keys():
raise TypeError
return typemap[lowercase(d)]
+
te_layer_map = {
- 'linear': te.Linear,
- 'layernorm': te.LayerNorm,
- 'rmsnorm': te.RMSNorm,
- 'layernormlinear': te.LayerNormLinear,
- 'layernormmlp': te.LayerNormMLP,
- 'multiheadattention': te.MultiheadAttention,
- 'transformerlayer': te.TransformerLayer
+ "linear": te.Linear,
+ "layernorm": te.LayerNorm,
+ "rmsnorm": te.RMSNorm,
+ "layernormlinear": te.LayerNormLinear,
+ "layernormmlp": te.LayerNormMLP,
+ "multiheadattention": te.MultiheadAttention,
+ "transformerlayer": te.TransformerLayer,
}
+
+
def te_layer(l):
if l is not None:
if lowercase(l) not in te_layer_map.keys():
@@ -76,74 +86,120 @@ def te_layer(l):
return te_layer_map[lowercase(l)]
return None
+
def get_layer_args(opts):
hidden_size = opts.num_heads * opts.head_dim
- layer_args = (hidden_size, )
+ layer_args = (hidden_size,)
layer_kwargs = {
- 'params_dtype': opts.dtype,
- 'device': 'cuda' if opts.no_defer_init else 'meta',
- 'get_rng_state_tracker': get_cuda_rng_tracker,
+ "params_dtype": opts.dtype,
+ "device": "cuda" if opts.no_defer_init else "meta",
+ "get_rng_state_tracker": get_cuda_rng_tracker,
}
if opts.layer_type in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 3 * hidden_size if opts.num_layers == 1 else hidden_size
- layer_args += (ffn_hidden_size, )
- layer_kwargs['bias'] = True
+ layer_args += (ffn_hidden_size,)
+ layer_kwargs["bias"] = True
if opts.layer_type == te.LayerNormMLP:
- layer_kwargs['seq_length'] = opts.seq_length
+ layer_kwargs["seq_length"] = opts.seq_length
elif opts.layer_type == te.MultiheadAttention:
- layer_args += (opts.num_heads, )
- layer_kwargs['fuse_qkv_params'] = True
- layer_kwargs['input_layernorm'] = True
+ layer_args += (opts.num_heads,)
+ layer_kwargs["fuse_qkv_params"] = True
+ layer_kwargs["input_layernorm"] = True
elif opts.layer_type == te.TransformerLayer:
layer_args += (3 * hidden_size, opts.num_heads)
- layer_kwargs['fuse_qkv_params'] = True
- layer_kwargs['seq_length'] = opts.seq_length
+ layer_kwargs["fuse_qkv_params"] = True
+ layer_kwargs["seq_length"] = opts.seq_length
return layer_args, layer_kwargs
+
def parse_fsdp_args():
- parser = argparse.ArgumentParser(description="Run Transformer Engine modules with the " +
- "torch.distributed.fsdp.FullyShardedDataParallel strategy.")
- parser.add_argument('-v', "--verbose", action="store_true", default=False,
- help="Print out information from all GPUs instead of only the root GPU-0.")
- parser.add_argument('-b', "--batch-size", type=int, default=32,
- help="Input batch size.")
- parser.add_argument('-s', "--seq-length", type=int, default=1048,
- help="Input sequence length.")
- parser.add_argument('-n', "--num-heads", type=int, default=16,
- help="Number of attention heads.")
- parser.add_argument('-d', "--head-dim", type=int, default=128,
- help="Dimension of each attention head (number of KV channels).")
- parser.add_argument('-i', "--num-iters", type=int, default=5,
- help="Number of dummy 'training' iterations.")
- parser.add_argument('-k', "--num-layers", type=int, default=3,
- help="Number of modules chained together with nn.Sequential.")
- parser.add_argument("--layer-type", type=te_layer, default=te.TransformerLayer,
- choices=list(te_layer_map.values()),
- help="TE module type used to construct the test model.")
- parser.add_argument("--seed", type=int, default=1234,
- help="PyTorch RNG seed.")
- parser.add_argument("--profile-memory", action="store_true",
- help="Enable memory profiling via torch.profiler.profile().")
- parser.add_argument("--profile-name", type=str, default=None,
- help="File path for memory profiling.")
- parser.add_argument("--checkpoint-layer", type=te_layer, default=None,
- help="Recompute activations of the selected layer during the backward " + \
- "pass instead of saving.")
- parser.add_argument("--no-fp8", action="store_true", default=False,
- help="Disables the te.fp8_autocast() context.")
- parser.add_argument("--no-defer-init", action="store_true",
- help="Defer module parameter initialization until after FSDP sharding.")
- parser.add_argument("--no-te-fsdp", action="store_true",
- help="Disable sharding of intermediate/activation tensors in TE modules.")
- parser.add_argument("--dtype", type=torch_dtype, default=torch.bfloat16,
- help="Data type for input tensor and Transformer Engine module parameters.")
+ parser = argparse.ArgumentParser(
+ description="Run Transformer Engine modules with the "
+ + "torch.distributed.fsdp.FullyShardedDataParallel strategy."
+ )
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ action="store_true",
+ default=False,
+ help="Print out information from all GPUs instead of only the root GPU-0.",
+ )
+ parser.add_argument("-b", "--batch-size", type=int, default=32, help="Input batch size.")
+ parser.add_argument("-s", "--seq-length", type=int, default=1048, help="Input sequence length.")
+ parser.add_argument(
+ "-n", "--num-heads", type=int, default=16, help="Number of attention heads."
+ )
+ parser.add_argument(
+ "-d",
+ "--head-dim",
+ type=int,
+ default=128,
+ help="Dimension of each attention head (number of KV channels).",
+ )
+ parser.add_argument(
+ "-i", "--num-iters", type=int, default=5, help="Number of dummy 'training' iterations."
+ )
+ parser.add_argument(
+ "-k",
+ "--num-layers",
+ type=int,
+ default=3,
+ help="Number of modules chained together with nn.Sequential.",
+ )
+ parser.add_argument(
+ "--layer-type",
+ type=te_layer,
+ default=te.TransformerLayer,
+ choices=list(te_layer_map.values()),
+ help="TE module type used to construct the test model.",
+ )
+ parser.add_argument("--seed", type=int, default=1234, help="PyTorch RNG seed.")
+ parser.add_argument(
+ "--profile-memory",
+ action="store_true",
+ help="Enable memory profiling via torch.profiler.profile().",
+ )
+ parser.add_argument(
+ "--profile-name", type=str, default=None, help="File path for memory profiling."
+ )
+ parser.add_argument(
+ "--checkpoint-layer",
+ type=te_layer,
+ default=None,
+ help="Recompute activations of the selected layer during the backward "
+ + "pass instead of saving.",
+ )
+ parser.add_argument(
+ "--no-fp8",
+ action="store_true",
+ default=False,
+ help="Disables the te.fp8_autocast() context.",
+ )
+ parser.add_argument(
+ "--no-defer-init",
+ action="store_true",
+ help="Defer module parameter initialization until after FSDP sharding.",
+ )
+ parser.add_argument(
+ "--no-te-fsdp",
+ action="store_true",
+ help="Disable sharding of intermediate/activation tensors in TE modules.",
+ )
+ parser.add_argument(
+ "--dtype",
+ type=torch_dtype,
+ default=torch.bfloat16,
+ help="Data type for input tensor and Transformer Engine module parameters.",
+ )
return parser.parse_args()
+
def dist_print(text, all_ranks=False, no_new_line=False):
if LOCAL_RANK == 0 or all_ranks:
- end = '' if no_new_line else '\n'
+ end = "" if no_new_line else "\n"
print(f"[GPU-{LOCAL_RANK}] " + text, end=end)
+
def train(opts):
# Initialize torch.distributed global process group
dist.init_process_group(backend="nccl")
@@ -157,7 +213,7 @@ def train(opts):
te_layer_list = []
for i in range(opts.num_layers):
if opts.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
- layer_kwargs['layer_number'] = i+1
+ layer_kwargs["layer_number"] = i + 1
te_layer_list.append(opts.layer_type(*layer_args, **layer_kwargs))
te_model = nn.Sequential(*te_layer_list)
else:
@@ -171,20 +227,23 @@ def train(opts):
# Wrap the model with FSDP
# NOTE: The TE model itself has no inherent parallelism. FSDP shards model parameters and
# controls all communication.
- all_gpus = dist.new_group(backend='nccl')
+ all_gpus = dist.new_group(backend="nccl")
fsdp_wrap_policy = always_wrap_policy
if opts.layer_type == te.TransformerLayer:
# NOTE: FSDP causes illegal memory access without this special policy for Transformers
- fsdp_wrap_policy = partial(transformer_auto_wrap_policy,
- transformer_layer_cls={te.TransformerLayer})
- te_model = FullyShardedDataParallel(te_model,
- process_group=all_gpus,
- use_orig_params=True,
- mixed_precision=MixedPrecision(
- param_dtype=opts.dtype,
- reduce_dtype=torch.float32,
- ),
- auto_wrap_policy=fsdp_wrap_policy)
+ fsdp_wrap_policy = partial(
+ transformer_auto_wrap_policy, transformer_layer_cls={te.TransformerLayer}
+ )
+ te_model = FullyShardedDataParallel(
+ te_model,
+ process_group=all_gpus,
+ use_orig_params=True,
+ mixed_precision=MixedPrecision(
+ param_dtype=opts.dtype,
+ reduce_dtype=torch.float32,
+ ),
+ auto_wrap_policy=fsdp_wrap_policy,
+ )
if opts.checkpoint_layer is not None:
# Recompute the activations of the selected layer during the backward pass instead of
@@ -218,8 +277,13 @@ def train(opts):
for i in range(opts.num_iters):
# Generate a random input batch
- x = torch.rand(opts.seq_length, opts.batch_size, opts.num_heads*opts.head_dim,
- dtype=opts.dtype, device='cuda')
+ x = torch.rand(
+ opts.seq_length,
+ opts.batch_size,
+ opts.num_heads * opts.head_dim,
+ dtype=opts.dtype,
+ device="cuda",
+ )
# fp8_autocast needs to be given the FSDP process group for amax reductions
with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus):
y = te_model(x)
@@ -230,7 +294,6 @@ def train(opts):
optim.zero_grad(set_to_none=True)
del x
-
if opts.profile_memory:
torch.cuda.memory._dump_snapshot(f"gpu{LOCAL_RANK}_{opts.profile_name}.pickle")
torch.cuda.memory._record_memory_history(enabled=None)
@@ -238,7 +301,7 @@ def train(opts):
end.record()
torch.cuda.synchronize()
peak_mem = torch.cuda.max_memory_allocated()
- train_time = start.elapsed_time(end)/1000.
+ train_time = start.elapsed_time(end) / 1000.0
dist_print(f"Training Time: {train_time}s")
dist_print(f"Avg. Iter. Time: {train_time / opts.num_iters}s")
dist_print(f"Peak Memory Use: {peak_mem * 1e-6}MBs")
diff --git a/examples/pytorch/mnist/main.py b/examples/pytorch/mnist/main.py
index d8bfcfd4808791f9ff35af311b27162a6f0605d9..2a003f0a0d464b961c15f5a63ef0d07624e50126 100644
--- a/examples/pytorch/mnist/main.py
+++ b/examples/pytorch/mnist/main.py
@@ -79,6 +79,7 @@ def calibrate(model, device, test_loader, fp8):
with te.fp8_autocast(enabled=fp8, calibrating=True):
output = model(data)
+
def test(model, device, test_loader, use_fp8):
"""Testing function."""
model.eval()
@@ -89,12 +90,8 @@ def test(model, device, test_loader, use_fp8):
data, target = data.to(device), target.to(device)
with te.fp8_autocast(enabled=use_fp8):
output = model(data)
- test_loss += F.nll_loss(
- output, target, reduction="sum"
- ).item() # sum up batch loss
- pred = output.argmax(
- dim=1, keepdim=True
- ) # get the index of the max log-probability
+ test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss
+ pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
@@ -150,9 +147,7 @@ def main():
default=False,
help="quickly check a single pass",
)
- parser.add_argument(
- "--seed", type=int, default=1, metavar="S", help="random seed (default: 1)"
- )
+ parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--log-interval",
type=int,
@@ -167,7 +162,10 @@ def main():
help="For Saving the current Model",
)
parser.add_argument(
- "--use-fp8", action="store_true", default=False, help="Use FP8 for inference and training without recalibration"
+ "--use-fp8",
+ action="store_true",
+ default=False,
+ help="Use FP8 for inference and training without recalibration",
)
parser.add_argument(
"--use-fp8-infer", action="store_true", default=False, help="Use FP8 inference only"
@@ -215,7 +213,7 @@ def main():
if args.save_model or args.use_fp8_infer:
torch.save(model.state_dict(), "mnist_cnn.pt")
- print('Eval with reloaded checkpoint : fp8='+str(args.use_fp8_infer))
+ print("Eval with reloaded checkpoint : fp8=" + str(args.use_fp8_infer))
weights = torch.load("mnist_cnn.pt")
model.load_state_dict(weights)
test(model, device, test_loader, args.use_fp8_infer)
diff --git a/qa/L0_license/copyright_checker.py b/qa/L0_license/copyright_checker.py
index 310610fac6f6ae1197c32c6e821fa14428e3125e..46a3a6d4fe9005005df43d74c5502d470a885261 100644
--- a/qa/L0_license/copyright_checker.py
+++ b/qa/L0_license/copyright_checker.py
@@ -18,21 +18,26 @@ path = sys.argv[1]
config_path = os.path.dirname(os.path.realpath(__file__)) + "/config.json"
+
class bcolors:
- OKGREEN = '\033[92m'
- WARNING = '\033[93m'
- FAIL = '\033[91m'
- ENDC = '\033[0m'
+ OKGREEN = "\033[92m"
+ WARNING = "\033[93m"
+ FAIL = "\033[91m"
+ ENDC = "\033[0m"
+
def print_ok(msg):
print(f"{bcolors.OKGREEN}{msg}{bcolors.ENDC}")
+
def print_fail(msg):
print(f"{bcolors.FAIL}{msg}{bcolors.ENDC}")
+
def print_warn(msg):
print(f"{bcolors.WARNING}{msg}{bcolors.ENDC}")
+
with open(config_path, "r") as f:
c = json.load(f)
current_year = datetime.date.today().year
@@ -41,7 +46,7 @@ with open(config_path, "r") as f:
else:
year_string = str(c["initial_year"]) + "-" + str(current_year)
copyright_string = c["copyright"].replace("", year_string)
- license = c["license"].split('\n')
+ license = c["license"].split("\n")
excludes = c["exclude"]
root_path = os.path.abspath(path)
copyright_only = c["copyright_only"]
@@ -49,36 +54,42 @@ with open(config_path, "r") as f:
has_gitignore = os.path.exists(root_path + "/.gitignore")
+
def strip_star_slash(s):
ret = s
- if ret.startswith('*'):
+ if ret.startswith("*"):
ret = ret[1:]
- if ret.endswith('/'):
+ if ret.endswith("/"):
ret = ret[:-1]
return ret
+
if has_gitignore:
with open(root_path + "/.gitignore", "r") as f:
for line in f.readlines():
excludes.append(strip_star_slash(line.strip()))
+
def get_file_type(path):
- ext = {"c": ["c", "cpp", "cu", "h", "cuh"],
- "py": ["py"],
- "rst": ["rst"],
- "txt": ["txt"],
- "cfg": ["cfg"],
- "sh": ["sh"],
- "md": ["md"],
- }
+ ext = {
+ "c": ["c", "cpp", "cu", "h", "cuh"],
+ "py": ["py"],
+ "rst": ["rst"],
+ "txt": ["txt"],
+ "cfg": ["cfg"],
+ "sh": ["sh"],
+ "md": ["md"],
+ }
tmp = path.split(".")
for filetype, ext_list in ext.items():
if tmp[-1] in ext_list:
return filetype
return "unknown"
+
success = True
+
def check_file(path):
global success
N = 10
@@ -127,9 +138,10 @@ def check_file(path):
if copyright_found and license_found:
print_ok("OK")
+
for root, dirs, files in os.walk(root_path):
print(f"Entering {root}")
- hidden = [d for d in dirs if d.startswith('.')] + [f for f in files if f.startswith('.')]
+ hidden = [d for d in dirs if d.startswith(".")] + [f for f in files if f.startswith(".")]
all_excludes = excludes + hidden
to_remove = []
for d in dirs:
diff --git a/setup.py b/setup.py
index d6493447633ca243e1de4be0b4a6e351d734839d..d2cc91d65a16357dd0bd54eb6467d5cad2eeaa2b 100644
--- a/setup.py
+++ b/setup.py
@@ -27,12 +27,13 @@ current_file_path = Path(__file__).parent.resolve()
from setuptools.command.build_ext import build_ext as BuildExtension
+
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
- install_and_import('pybind11')
+ install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension
@@ -86,34 +87,45 @@ if __name__ == "__main__":
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
+
ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
- current_file_path / "transformer_engine"))
+ current_file_path / "transformer_engine",
+ )
+ )
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension
+
ext_modules.append(
setup_jax_extension(
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
- current_file_path / "transformer_engine"))
+ current_file_path / "transformer_engine",
+ )
+ )
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension
+
ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
- current_file_path / "transformer_engine"))
+ current_file_path / "transformer_engine",
+ )
+ )
# Configure package
setuptools.setup(
name="transformer_engine",
version=__version__,
packages=setuptools.find_packages(
- include=["transformer_engine",
- "transformer_engine.*",
- "transformer_engine/build_tools"],
+ include=[
+ "transformer_engine",
+ "transformer_engine.*",
+ "transformer_engine/build_tools",
+ ],
),
extras_require={
"test": test_requires,
@@ -125,5 +137,5 @@ if __name__ == "__main__":
install_requires=install_requires,
license_files=("LICENSE",),
include_package_data=True,
- package_data={"": ["VERSION.txt"]}
+ package_data={"": ["VERSION.txt"]},
)
diff --git a/tests/cpp/operator/test_causal_softmax.cu b/tests/cpp/operator/test_causal_softmax.cu
index 1d7d70e5722b527c46e938efe27a9aac23d0fb93..640434674b40e6aa48d27cc27a5ba59611372699 100644
--- a/tests/cpp/operator/test_causal_softmax.cu
+++ b/tests/cpp/operator/test_causal_softmax.cu
@@ -132,7 +132,7 @@ void compute_bwd_ref(
for (int b = 0; b < batches; ++b) {
for (int h = 0; h < heads; ++h) {
size_t offset = b * batch_size + h * head_size;
- compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
+ compute_single_head_bwd(grad_out + offset, grad_in + offset, softmax_in + offset,
buff + offset, scaling_factor, batches, heads, rows, cols);
}
}
diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py
index 5f1aaa4c397751c4b1909c90dd3c272d2472ce00..55494c42d631245c1aedc532bbc23024b2326319 100644
--- a/tests/jax/conftest.py
+++ b/tests/jax/conftest.py
@@ -6,7 +6,7 @@ import jax
import pytest
-@pytest.fixture(autouse=True, scope='function')
+@pytest.fixture(autouse=True, scope="function")
def clear_live_arrays():
"""
Clear all live arrays to keep the resource clean
diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py
index 1360af19fadc955664f05210f3b41eddd98ffd3b..3a7fe333785a77c3e6dcf016bac331e9f0424d32 100644
--- a/tests/jax/distributed_test_base.py
+++ b/tests/jax/distributed_test_base.py
@@ -16,15 +16,15 @@ from utils import assert_allclose, is_devices_enough
def generate_configs():
configs = []
if is_devices_enough(2):
- configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')])
- configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')])
-
+ configs.append([2, (2,), "dp", MeshResource(dp_resource="dp")])
+ configs.append([2, (2,), "tp", MeshResource(tp_resource="tp")])
+
if is_devices_enough(4):
TP_size = 2
DP_size = 2
configs.append(
- [4, (DP_size, TP_size), ('dp', 'tp'),
- MeshResource(dp_resource='dp', tp_resource='tp')])
+ [4, (DP_size, TP_size), ("dp", "tp"), MeshResource(dp_resource="dp", tp_resource="tp")]
+ )
return configs
@@ -46,7 +46,7 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
bytes_count = 0
def get_bytes_per_txt(t):
- '''
+ """
The pattern of t would be like:
'f32[]',
'(f32[1024]{0}',
@@ -54,24 +54,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
'f8E4M3FN[1024]{0}',
'i32[1024]{0}',
'bf16[1024,1024]{0}'
- '''
- match = re.search(r'(i|f)(\d+).*\[([0-9,]*)\]', t)
+ """
+ match = re.search(r"(i|f)(\d+).*\[([0-9,]*)\]", t)
_, bits_of_type, shape = match.groups()
bytes_of_type = int(bits_of_type) // 8
- if shape == '':
+ if shape == "":
num_of_elements = 1
else:
- num_of_elements = reduce(operator.mul, map(int, shape.split(',')))
+ num_of_elements = reduce(operator.mul, map(int, shape.split(",")))
return bytes_of_type * num_of_elements
# ['xxx-start', '=', '(bf16[xxx]', 'bf16[xxx])', 'xxx-start(', ...]
- if '(' in hlo_text[2]:
+ if "(" in hlo_text[2]:
for txt in hlo_text[2:]:
bytes_count += get_bytes_per_txt(txt)
- if ')' in txt:
+ if ")" in txt:
break
- else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
+ else: # ['xxx-start', '=', 'fp32[]', 'xxx-start(', ...]
bytes_count = get_bytes_per_txt(hlo_text[2])
return bytes_count
@@ -91,21 +91,24 @@ def assert_equal_collectives(target_hlo, coll_count_ref):
return result
target_result = count_collectives(target_splitted_hlo)
- assert target_result == coll_count_ref, \
- f"Expected collective count is {coll_count_ref}, but got {target_result}."
-
-
-def compare_ops(target_func,
- ref_func,
- inputs,
- coll_count_ref,
- *,
- grad_args=None,
- metric_fwd_dtype=None,
- metric_bwd_dtype=None,
- in_shardings=_UNSPECIFIED,
- out_shardings=_UNSPECIFIED,
- **kwargs):
+ assert (
+ target_result == coll_count_ref
+ ), f"Expected collective count is {coll_count_ref}, but got {target_result}."
+
+
+def compare_ops(
+ target_func,
+ ref_func,
+ inputs,
+ coll_count_ref,
+ *,
+ grad_args=None,
+ metric_fwd_dtype=None,
+ metric_bwd_dtype=None,
+ in_shardings=_UNSPECIFIED,
+ out_shardings=_UNSPECIFIED,
+ **kwargs,
+):
assert len(inputs) >= 1
if metric_fwd_dtype is None:
diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py
index 1e0b424e2fc92b5cfe9e6a454329ce4f27d285cb..8664a03f8d6af2b0c92f7e7ac9ff420e11a8526a 100644
--- a/tests/jax/test_custom_call_compute.py
+++ b/tests/jax/test_custom_call_compute.py
@@ -15,24 +15,10 @@ from jax import jit, value_and_grad
from flax import linen as nn
from utils import assert_allclose
-from transformer_engine.jax.dot import (
- type_safe_dot_general,
- dequantize,
- quantize
-)
-from transformer_engine.jax.fp8 import (
- FP8MetaPackage,
- FP8Helper,
- is_fp8_available
-)
-from transformer_engine.jax.layernorm import (
- layernorm,
- layernorm_fp8_dot
-)
-from transformer_engine.jax.layernorm_mlp import (
- activation_lu,
- fused_layernorm_fp8_mlp
-)
+from transformer_engine.jax.dot import type_safe_dot_general, dequantize, quantize
+from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper, is_fp8_available
+from transformer_engine.jax.layernorm import layernorm, layernorm_fp8_dot
+from transformer_engine.jax.layernorm_mlp import activation_lu, fused_layernorm_fp8_mlp
from transformer_engine.jax import cpp_extensions as tex
GEMM_CASES = [
@@ -50,11 +36,11 @@ is_fp8_supported, reason = is_fp8_available()
def _convert_to_activation_function(fn_or_string):
"""Convert a string to an activation function."""
- if fn_or_string == 'linear':
+ if fn_or_string == "linear":
return lambda x: x
- if fn_or_string == 'quick_gelu':
+ if fn_or_string == "quick_gelu":
return lambda x: nn.gelu(x, approximate=True)
- if fn_or_string == 'squared_relu':
+ if fn_or_string == "squared_relu":
return lambda x: functools.reduce(operator.mul, [nn.relu(x), nn.relu(x)])
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
@@ -93,7 +79,7 @@ class TestFP8Dot:
assert_allclose(z, x, dtype=jnp.float8_e4m3fn)
- @pytest.mark.parametrize('m,n,k', GEMM_CASES)
+ @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_bf16(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
@@ -106,7 +92,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('m,n,k', GEMM_CASES)
+ @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_forward_fp8_randint(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
@@ -135,7 +121,7 @@ class TestFP8Dot:
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
- @pytest.mark.parametrize('m,n,k', GEMM_CASES)
+ @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_bf16(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
@@ -161,7 +147,7 @@ class TestFP8Dot:
assert_allclose(primitive_b_grad, ref_b_grad, dtype=jnp.bfloat16)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('m,n,k', GEMM_CASES)
+ @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_grad_fp8_dot(self, m, n, k):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
@@ -192,33 +178,38 @@ class TestFP8Dot:
ref_out, (ref_a_grad, ref_b_grad) = value_n_grad_ref_func(a, b)
for _ in range(3):
- primitive_out, (primitive_a_grad, primitive_b_grad, amax_list,
- scale_list) = value_n_grad_primitive_func(a, b, amax_list, scale_list)
+ primitive_out, (primitive_a_grad, primitive_b_grad, amax_list, scale_list) = (
+ value_n_grad_primitive_func(a, b, amax_list, scale_list)
+ )
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('m,n,k', [(256, 128, 512),
- (16384, 1024, 2816),
- (16384, 2816, 1024),
- (16384, 1024, 1024)])
- @pytest.mark.parametrize('activation_type', [('gelu', ),
- ('gelu', 'linear'),
- ('silu', ),
- ('silu', 'linear'),
- ('relu',),
- ('relu', 'linear'),
- ('quick_gelu',),
- ('quick_gelu', 'linear'),
- ('squared_relu',),
- ('squared_relu', 'linear')])
- @pytest.mark.parametrize('use_bias', [True, False])
- def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, activation_type: Sequence[Union[str,
- Callable]],
- use_bias: bool):
- """ N/a """
+ @pytest.mark.parametrize(
+ "m,n,k", [(256, 128, 512), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]
+ )
+ @pytest.mark.parametrize(
+ "activation_type",
+ [
+ ("gelu",),
+ ("gelu", "linear"),
+ ("silu",),
+ ("silu", "linear"),
+ ("relu",),
+ ("relu", "linear"),
+ ("quick_gelu",),
+ ("quick_gelu", "linear"),
+ ("squared_relu",),
+ ("squared_relu", "linear"),
+ ],
+ )
+ @pytest.mark.parametrize("use_bias", [True, False])
+ def test_grad_fused_layernorm_fp8_mlp(
+ self, m, n, k, activation_type: Sequence[Union[str, Callable]], use_bias: bool
+ ):
+ """N/a"""
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 6)
@@ -233,8 +224,9 @@ class TestFP8Dot:
b1 = None
b2 = None
- def primitive_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1,
- scale_list_2):
+ def primitive_func(
+ x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
+ ):
# x is input tensor, matrix 2d
# y, z are weights, matrix 2d
# out = ((x * y) + w) * z + v
@@ -255,18 +247,31 @@ class TestFP8Dot:
scale_list_2[2],
)
return jnp.mean(
- fused_layernorm_fp8_mlp(x,
- ln_s,
- None, [y, z], [w, v], [fp8_meta_pkg_1, fp8_meta_pkg_2],
- "rmsnorm",
- activation_type=activation_type,
- use_bias=use_bias))
-
- def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray,
- kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray,
- amax_list_1: List[jnp.ndarray], amax_list_2: List[jnp.ndarray],
- scale_list_1: List[jnp.ndarray],
- scale_list_2: List[jnp.ndarray]) -> jnp.ndarray:
+ fused_layernorm_fp8_mlp(
+ x,
+ ln_s,
+ None,
+ [y, z],
+ [w, v],
+ [fp8_meta_pkg_1, fp8_meta_pkg_2],
+ "rmsnorm",
+ activation_type=activation_type,
+ use_bias=use_bias,
+ )
+ )
+
+ def layernorm_fp8_mlp_ref(
+ x: jnp.ndarray,
+ ln_scale: jnp.ndarray,
+ kernel_1: jnp.ndarray,
+ kernel_2: jnp.ndarray,
+ bias_1: jnp.ndarray,
+ bias_2: jnp.ndarray,
+ amax_list_1: List[jnp.ndarray],
+ amax_list_2: List[jnp.ndarray],
+ scale_list_1: List[jnp.ndarray],
+ scale_list_2: List[jnp.ndarray],
+ ) -> jnp.ndarray:
x = jnp.asarray(x, jnp.float32)
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True)
@@ -315,11 +320,14 @@ class TestFP8Dot:
def ref_func(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2):
return jnp.mean(
- layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1,
- scale_list_2))
+ layernorm_fp8_mlp_ref(
+ x, ln_s, y, z, w, v, amax_list_1, amax_list_2, scale_list_1, scale_list_2
+ )
+ )
value_n_grad_primitive_func = jit(
- value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
+ value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9))
+ )
value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7, 8, 9)))
_, amax_list_1, scale_list_1 = TestFP8Dot._generate_fp8_meta()
@@ -339,40 +347,87 @@ class TestFP8Dot:
# Convert str to index as str is not a valid type for JAX JIT
for _ in range(3):
- ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad,
- ref_amax_list_1, ref_amax_list_2, ref_scale_list_1,
- ref_scale_list_2) = value_n_grad_ref_func(a, s, k1, k2, b1, b2,
- ref_amax_list_1, ref_amax_list_2,
- ref_scale_list_1, ref_scale_list_2)
+ ref_out, (
+ ref_a_grad,
+ ref_s_grad,
+ ref_k1_grad,
+ ref_k2_grad,
+ ref_b1_grad,
+ ref_b2_grad,
+ ref_amax_list_1,
+ ref_amax_list_2,
+ ref_scale_list_1,
+ ref_scale_list_2,
+ ) = value_n_grad_ref_func(
+ a,
+ s,
+ k1,
+ k2,
+ b1,
+ b2,
+ ref_amax_list_1,
+ ref_amax_list_2,
+ ref_scale_list_1,
+ ref_scale_list_2,
+ )
for _ in range(3):
- primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad,
- primitive_k2_grad, primitive_b1_grad, primitive_b2_grad,
- primitive_amax_list_1, primitive_amax_list_2, primitive_scale_list_1,
- primitive_scale_list_2) = value_n_grad_primitive_func(
- a, s, k1, k2, b1, b2, primitive_amax_list_1, primitive_amax_list_2,
- primitive_scale_list_1, primitive_scale_list_2)
+ primitive_out, (
+ primitive_a_grad,
+ primitive_s_grad,
+ primitive_k1_grad,
+ primitive_k2_grad,
+ primitive_b1_grad,
+ primitive_b2_grad,
+ primitive_amax_list_1,
+ primitive_amax_list_2,
+ primitive_scale_list_1,
+ primitive_scale_list_2,
+ ) = value_n_grad_primitive_func(
+ a,
+ s,
+ k1,
+ k2,
+ b1,
+ b2,
+ primitive_amax_list_1,
+ primitive_amax_list_2,
+ primitive_scale_list_1,
+ primitive_scale_list_2,
+ )
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
- assert_allclose(jnp.asarray(primitive_a_grad, np.float32),
- jnp.asarray(ref_a_grad, np.float32),
- dtype=FP8Helper.BWD_DTYPE)
- assert_allclose(jnp.asarray(primitive_k1_grad, np.float32),
- jnp.asarray(ref_k1_grad, np.float32),
- dtype=FP8Helper.BWD_DTYPE)
- assert_allclose(jnp.asarray(primitive_s_grad, np.float32),
- jnp.asarray(ref_s_grad, np.float32),
- dtype=FP8Helper.BWD_DTYPE)
- assert_allclose(jnp.asarray(primitive_k2_grad, np.float32),
- jnp.asarray(ref_k2_grad, np.float32),
- dtype=FP8Helper.BWD_DTYPE)
+ assert_allclose(
+ jnp.asarray(primitive_a_grad, np.float32),
+ jnp.asarray(ref_a_grad, np.float32),
+ dtype=FP8Helper.BWD_DTYPE,
+ )
+ assert_allclose(
+ jnp.asarray(primitive_k1_grad, np.float32),
+ jnp.asarray(ref_k1_grad, np.float32),
+ dtype=FP8Helper.BWD_DTYPE,
+ )
+ assert_allclose(
+ jnp.asarray(primitive_s_grad, np.float32),
+ jnp.asarray(ref_s_grad, np.float32),
+ dtype=FP8Helper.BWD_DTYPE,
+ )
+ assert_allclose(
+ jnp.asarray(primitive_k2_grad, np.float32),
+ jnp.asarray(ref_k2_grad, np.float32),
+ dtype=FP8Helper.BWD_DTYPE,
+ )
if use_bias:
- assert_allclose(jnp.asarray(primitive_b2_grad, np.float32),
- jnp.asarray(ref_b2_grad, np.float32),
- dtype=FP8Helper.BWD_DTYPE)
- assert_allclose(jnp.asarray(primitive_b1_grad, np.float32),
- jnp.asarray(ref_b1_grad, np.float32),
- dtype=FP8Helper.BWD_DTYPE)
+ assert_allclose(
+ jnp.asarray(primitive_b2_grad, np.float32),
+ jnp.asarray(ref_b2_grad, np.float32),
+ dtype=FP8Helper.BWD_DTYPE,
+ )
+ assert_allclose(
+ jnp.asarray(primitive_b1_grad, np.float32),
+ jnp.asarray(ref_b1_grad, np.float32),
+ dtype=FP8Helper.BWD_DTYPE,
+ )
@pytest.fixture(name="random_inputs")
@@ -402,17 +457,22 @@ class TestActivationLu:
def primitive_func(self, inputs):
return jnp.mean(activation_lu(inputs, activation_type=self.activation_type))
- @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
- @pytest.mark.parametrize('activation_type', [('gelu',),
- ('gelu', 'linear'),
- ('silu',),
- ('silu', 'linear'),
- ('relu',),
- ('relu', 'linear'),
- ('quick_gelu',),
- ('quick_gelu', 'linear'),
- ('squared_relu',),
- ('squared_relu', 'linear') ])
+ @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
+ @pytest.mark.parametrize(
+ "activation_type",
+ [
+ ("gelu",),
+ ("gelu", "linear"),
+ ("silu",),
+ ("silu", "linear"),
+ ("relu",),
+ ("relu", "linear"),
+ ("quick_gelu",),
+ ("quick_gelu", "linear"),
+ ("squared_relu",),
+ ("squared_relu", "linear"),
+ ],
+ )
def test_activation_lu(self, random_inputs, activation_type):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=1)
@@ -441,23 +501,34 @@ class TestActivationLuFP8(TestActivationLu):
return output
def _prim_func_fwd(x, _x_t, _dbias, _amax):
- activation_lu_out, _ = tex.act_lu_fp8(x, amax, scale, scale_inv,
- FP8Helper.FWD_DTYPE, activation_type)
+ activation_lu_out, _ = tex.act_lu_fp8(
+ x, amax, scale, scale_inv, FP8Helper.FWD_DTYPE, activation_type
+ )
activation_lu_out = dequantize(activation_lu_out, x.dtype, scale_inv)
- ctx = (x)
+ ctx = x
return activation_lu_out, ctx
def _prim_func_bwd(ctx, g):
x = ctx
- if len(self.activation_type) > 1: #gated, no bias
- dactivation_lu, dactivation_lu_trans, amax_out = \
- tex.dgated_act_lu_cast_transpose(g, x, amax, scale, scale_inv,
- FP8Helper.BWD_DTYPE, -1, activation_type)
+ if len(self.activation_type) > 1: # gated, no bias
+ dactivation_lu, dactivation_lu_trans, amax_out = tex.dgated_act_lu_cast_transpose(
+ g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE, -1, activation_type
+ )
dbias = jnp.empty(x.shape[-1], x.dtype)
- else: #not gated, with bias
- dactivation_lu, dactivation_lu_trans, dbias, amax_out = \
- tex.dact_lu_dbias_cast_transpose(g, x, amax, scale, scale_inv, FP8Helper.BWD_DTYPE,
- -1, -2, self.activation_type)
+ else: # not gated, with bias
+ dactivation_lu, dactivation_lu_trans, dbias, amax_out = (
+ tex.dact_lu_dbias_cast_transpose(
+ g,
+ x,
+ amax,
+ scale,
+ scale_inv,
+ FP8Helper.BWD_DTYPE,
+ -1,
+ -2,
+ self.activation_type,
+ )
+ )
dactivation_lu = dequantize(dactivation_lu, x.dtype, scale_inv)
dactivation_lu_trans = dequantize(dactivation_lu_trans, x.dtype, scale_inv)
ctx = (dactivation_lu, dactivation_lu_trans, dbias, amax_out)
@@ -468,23 +539,28 @@ class TestActivationLuFP8(TestActivationLu):
dx_trans_no_use = jnp.empty([x.shape[i] for i in self.transpose_indices], dtype=x.dtype)
dbias_no_use = jnp.empty(x.shape[-1], dtype=x.dtype)
amax_no_use = jnp.zeros(1, jnp.float32)
- value_n_grad_primitive_func = value_and_grad(lambda a, b, c, d:
- jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3))
+ value_n_grad_primitive_func = value_and_grad(
+ lambda a, b, c, d: jnp.mean(_prim_func(a, b, c, d)), (0, 1, 2, 3)
+ )
return value_n_grad_primitive_func(x, dx_trans_no_use, dbias_no_use, amax_no_use)
-
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('shape', [(32, 1, 64), (64, 1, 256)])
- @pytest.mark.parametrize('activation_type', [('gelu',),
- ('gelu', 'linear'),
- ('silu',),
- ('silu', 'linear'),
- ('relu',),
- ('relu', 'linear'),
- ('quick_gelu',),
- ('quick_gelu', 'linear'),
- ('squared_relu',),
- ('squared_relu', 'linear') ])
+ @pytest.mark.parametrize("shape", [(32, 1, 64), (64, 1, 256)])
+ @pytest.mark.parametrize(
+ "activation_type",
+ [
+ ("gelu",),
+ ("gelu", "linear"),
+ ("silu",),
+ ("silu", "linear"),
+ ("relu",),
+ ("relu", "linear"),
+ ("quick_gelu",),
+ ("quick_gelu", "linear"),
+ ("squared_relu",),
+ ("squared_relu", "linear"),
+ ],
+ )
def test_activation_lu(self, random_inputs, activation_type):
self.amax = jnp.zeros(1, jnp.float32)
self.scale = jnp.ones(1, jnp.float32)
@@ -500,12 +576,14 @@ class TestActivationLuFP8(TestActivationLu):
assert_allclose(prim_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(amax, jnp.amax(jnp.abs(ref_grad)), rtol=1e-2)
- if 'linear' not in activation_type:
+ if "linear" not in activation_type:
assert_allclose(dbias, jnp.sum(ref_grad, axis=(i for i in range(x.ndim - 1))))
assert_allclose(prim_grad, ref_grad, dtype=FP8Helper.BWD_DTYPE)
- assert_allclose(prim_grad_trans,
- jnp.transpose(ref_grad, self.transpose_indices),
- dtype=FP8Helper.BWD_DTYPE)
+ assert_allclose(
+ prim_grad_trans,
+ jnp.transpose(ref_grad, self.transpose_indices),
+ dtype=FP8Helper.BWD_DTYPE,
+ )
class TestNorm:
@@ -536,34 +614,38 @@ class TestNorm:
"""
x_ = jnp.asarray(x, jnp.float32)
if bias is None:
- mean = 0.
+ mean = 0.0
else:
mean = jnp.mean(x_, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True)
normed_input = (x_ - mean) * jax.lax.rsqrt(var + eps)
if zero_centered_gamma:
- scale += 1.
+ scale += 1.0
if bias is None:
- bias = 0.
+ bias = 0.0
return jnp.asarray(normed_input * scale + bias).astype(x.dtype)
- @pytest.mark.parametrize('n, hidden', LN_CASES)
- @pytest.mark.parametrize('dtype', DTYPES)
- @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
- @pytest.mark.parametrize('zero_centered_gamma', [False, True])
- @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
- def test_layernorm_forward_backward(self, n, hidden, ln_type, zero_centered_gamma, epsilon,
- dtype):
+ @pytest.mark.parametrize("n, hidden", LN_CASES)
+ @pytest.mark.parametrize("dtype", DTYPES)
+ @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
+ @pytest.mark.parametrize("zero_centered_gamma", [False, True])
+ @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
+ def test_layernorm_forward_backward(
+ self, n, hidden, ln_type, zero_centered_gamma, epsilon, dtype
+ ):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
expect_assert = False
- if ln_type == 'rmsnorm' and zero_centered_gamma:
+ if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
- with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
- ) if expect_assert else nullcontext():
+ with (
+ pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
+ if expect_assert
+ else nullcontext()
+ ):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
@@ -571,7 +653,7 @@ class TestNorm:
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, dtype)
- if ln_type == 'layernorm':
+ if ln_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, dtype)
else:
@@ -585,19 +667,27 @@ class TestNorm:
jitted_primitive = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
- layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)),
- (0, 1, 2)))
+ layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon)
+ ),
+ (0, 1, 2),
+ )
+ )
jitted_reference = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
- self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)),
- (0, 1, 2)))
+ self.reference_layernorm(x, gamma, beta, zero_centered_gamma, epsilon)
+ ),
+ (0, 1, 2),
+ )
+ )
- primitive_out, (primitive_dx, primitive_dgamma,
- primitive_dbeta) = jitted_primitive(x, gamma, beta)
- reference_out, (reference_dx, reference_dgamma,
- reference_dbeta) = jitted_reference(x, gamma, beta)
+ primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
+ x, gamma, beta
+ )
+ reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
+ x, gamma, beta
+ )
assert_allclose(primitive_out, reference_out, dtype=dtype)
assert_allclose(primitive_dx, reference_dx, dtype=dtype)
@@ -606,21 +696,24 @@ class TestNorm:
assert_allclose(primitive_dbeta, reference_dbeta, dtype=dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('m,n,k', GEMM_CASES)
- @pytest.mark.parametrize('ln_type', ['layernorm', 'rmsnorm'])
- @pytest.mark.parametrize('zero_centered_gamma', [True, False])
- @pytest.mark.parametrize('epsilon', [1e-2, 1e-6])
+ @pytest.mark.parametrize("m,n,k", GEMM_CASES)
+ @pytest.mark.parametrize("ln_type", ["layernorm", "rmsnorm"])
+ @pytest.mark.parametrize("zero_centered_gamma", [True, False])
+ @pytest.mark.parametrize("epsilon", [1e-2, 1e-6])
def test_ln_fp8_dot_forward_backward(self, m, n, k, ln_type, zero_centered_gamma, epsilon):
"""
Test transformer_engine.jax.layernorm.layernorm_fp8_dot
"""
expect_assert = False
- if ln_type == 'rmsnorm' and zero_centered_gamma:
+ if ln_type == "rmsnorm" and zero_centered_gamma:
# zero_centered_gamma is not supported for rmsnorm, expect an assertion.
expect_assert = True
- with pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*"
- ) if expect_assert else nullcontext():
+ with (
+ pytest.raises(AssertionError, match=r".*zero_centered_gamma is not supported.*")
+ if expect_assert
+ else nullcontext()
+ ):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 4)
@@ -628,7 +721,7 @@ class TestNorm:
b = jax.random.normal(subkeys[1], (k, n)).astype(jnp.bfloat16)
gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16)
- if ln_type == 'layernorm':
+ if ln_type == "layernorm":
beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16)
else:
beta = None
@@ -644,8 +737,9 @@ class TestNorm:
amax_list_1[2],
scale_list_1[2],
)
- primitive_out = layernorm_fp8_dot(x, y, gamma, beta, fp8_meta_pkg, ln_type,
- zero_centered_gamma)
+ primitive_out = layernorm_fp8_dot(
+ x, y, gamma, beta, fp8_meta_pkg, ln_type, zero_centered_gamma
+ )
return jnp.mean(primitive_out)
def ref_func(x, y, gamma, beta, zero_centered_gamma):
@@ -655,14 +749,19 @@ class TestNorm:
value_n_grad_primitive_func = value_and_grad(primitive_func, range(6))
value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2, 3))
- ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad,
- ref_beta_grad) = value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
+ ref_out, (ref_a_grad, ref_b_grad, ref_gamma_grad, ref_beta_grad) = (
+ value_n_grad_ref_func(a, b, gamma, beta, zero_centered_gamma)
+ )
for _ in range(3):
- primitive_out, (primitive_a_grad, primitive_b_grad, primitive_gamma_grad,
- primitive_beta_grad, amax_list_1,
- scale_list_1) = value_n_grad_primitive_func(
- a, b, gamma, beta, amax_list_1, scale_list_1)
+ primitive_out, (
+ primitive_a_grad,
+ primitive_b_grad,
+ primitive_gamma_grad,
+ primitive_beta_grad,
+ amax_list_1,
+ scale_list_1,
+ ) = value_n_grad_primitive_func(a, b, gamma, beta, amax_list_1, scale_list_1)
assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE)
assert_allclose(primitive_a_grad, ref_a_grad, dtype=FP8Helper.BWD_DTYPE)
diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py
index 3efee867982c19117c09fb56ea306ddd8841a0c6..8a6c5792ff517f77800f8ef35e5daa760c7cb599 100644
--- a/tests/jax/test_distributed_fused_attn.py
+++ b/tests/jax/test_distributed_fused_attn.py
@@ -9,16 +9,8 @@ import jax.numpy as jnp
import numpy as np
from flax.linen import dot_product_attention
from jax import random
-from jax.sharding import (
- Mesh,
- NamedSharding,
- PartitionSpec
-)
-from distributed_test_base import (
- generate_configs,
- generate_collectives_count,
- compare_ops
-)
+from jax.sharding import Mesh, NamedSharding, PartitionSpec
+from distributed_test_base import generate_configs, generate_collectives_count, compare_ops
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.attention import (
@@ -27,7 +19,7 @@ from transformer_engine.jax.attention import (
fused_attn_kvpacked,
AttnBiasType,
AttnMaskType,
- QKVLayout
+ QKVLayout,
)
@@ -36,8 +28,9 @@ DTYPES = [jnp.float16, jnp.bfloat16]
class TestDistributedSelfAttn:
- def generate_collectives_count_ref(self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape,
- dtype):
+ def generate_collectives_count_ref(
+ self, mesh_shape, mesh_axes, mesh_resource, with_bias, shape, dtype
+ ):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
_, seqlen, _, heads, _ = shape
is_dp_enabled = mesh_resource.dp_resource is not None
@@ -46,7 +39,7 @@ class TestDistributedSelfAttn:
idx = mesh_axes.index(mesh_resource.tp_resource)
tp_size = mesh_shape[idx]
- all_reduce_loss_bytes = 4 # 1 * FP32
+ all_reduce_loss_bytes = 4 # 1 * FP32
bias_bytes = int(with_bias) * (heads // tp_size) * seqlen * seqlen * jax_dtype.itemsize
allreduce_total_bytes = all_reduce_loss_bytes + (bias_bytes * is_dp_enabled)
# for loss and dbias
@@ -57,8 +50,11 @@ class TestDistributedSelfAttn:
qkv = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
- bias = random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype) \
- if with_bias else None
+ bias = (
+ random.normal(random.PRNGKey(1125), (1, heads, seqlen, seqlen), dtype)
+ if with_bias
+ else None
+ )
mask = None
if attn_mask_type == AttnMaskType.PADDING_MASK:
@@ -66,47 +62,76 @@ class TestDistributedSelfAttn:
elif attn_mask_type == AttnMaskType.CAUSAL_MASK:
mask = make_self_mask(batch, seqlen)
- qkv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
- None)
- bias_pspec = PartitionSpec(None, mesh_resource.tp_resource, None, None) \
- if with_bias else None
- mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
- if attn_mask_type != AttnMaskType.NO_MASK else None
+ qkv_pspec = PartitionSpec(
+ mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
+ )
+ bias_pspec = (
+ PartitionSpec(None, mesh_resource.tp_resource, None, None) if with_bias else None
+ )
+ mask_pspec = (
+ PartitionSpec(mesh_resource.dp_resource, None, None, None)
+ if attn_mask_type != AttnMaskType.NO_MASK
+ else None
+ )
return (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec)
- @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
- @pytest.mark.parametrize('data_shape', [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
+ @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
+ @pytest.mark.parametrize("data_shape", [[32, 512, 3, 12, 64], [32, 1024, 3, 16, 128]])
@pytest.mark.parametrize(
- 'attn_bias_type',
- [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS])
- @pytest.mark.parametrize('attn_mask_type',
- [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
- @pytest.mark.parametrize('dtype', DTYPES)
- def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
- attn_bias_type, attn_mask_type, dtype):
+ "attn_bias_type",
+ [AttnBiasType.NO_BIAS, AttnBiasType.PRE_SCALE_BIAS, AttnBiasType.POST_SCALE_BIAS],
+ )
+ @pytest.mark.parametrize(
+ "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
+ )
+ @pytest.mark.parametrize("dtype", DTYPES)
+ def test_self_attn(
+ self,
+ device_count,
+ mesh_shape,
+ mesh_axes,
+ mesh_resource,
+ data_shape,
+ attn_bias_type,
+ attn_mask_type,
+ dtype,
+ ):
dropout_prob = 0.0
is_training = True
scaling_factor = 1.0
_, seqlen, _, num_head, hidden = data_shape
- if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BS3HD, attn_bias_type,
- attn_mask_type, dropout_prob, num_head, num_head,
- seqlen, seqlen, hidden):
+ if not is_fused_attn_kernel_available(
+ dtype,
+ dtype,
+ QKVLayout.BS3HD,
+ attn_bias_type,
+ attn_mask_type,
+ dropout_prob,
+ num_head,
+ num_head,
+ seqlen,
+ seqlen,
+ hidden,
+ ):
pytest.skip(f"No FusedAttn backwend found")
def target_func(qkv, bias, mask):
return jnp.mean(
- fused_attn_qkvpacked(qkv,
- bias,
- mask,
- None,
- attn_bias_type=attn_bias_type,
- attn_mask_type=attn_mask_type,
- scaling_factor=scaling_factor,
- dropout_probability=dropout_prob,
- is_training=is_training))
+ fused_attn_qkvpacked(
+ qkv,
+ bias,
+ mask,
+ None,
+ attn_bias_type=attn_bias_type,
+ attn_mask_type=attn_mask_type,
+ scaling_factor=scaling_factor,
+ dropout_probability=dropout_prob,
+ is_training=is_training,
+ )
+ )
def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
@@ -114,52 +139,59 @@ class TestDistributedSelfAttn:
key = jnp.squeeze(key)
value = jnp.squeeze(value)
- output = dot_product_attention(query,
- key,
- value,
- bias=bias,
- mask=mask,
- deterministic=is_training,
- dropout_rate=dropout_prob,
- dropout_rng=None,
- dtype=jnp.float32)
+ output = dot_product_attention(
+ query,
+ key,
+ value,
+ bias=bias,
+ mask=mask,
+ deterministic=is_training,
+ dropout_rate=dropout_prob,
+ dropout_rng=None,
+ dtype=jnp.float32,
+ )
return jnp.mean(output).astype(dtype)
with_bias = attn_bias_type != AttnBiasType.NO_BIAS
- (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = \
- self.generate_inputs(data_shape, mesh_resource, with_bias,
- attn_mask_type, dtype)
- collective_count_ref = self.generate_collectives_count_ref(mesh_shape, mesh_axes,
- mesh_resource, with_bias,
- data_shape, dtype)
+ (qkv, bias, mask), (qkv_pspec, bias_pspec, mask_pspec) = self.generate_inputs(
+ data_shape, mesh_resource, with_bias, attn_mask_type, dtype
+ )
+ collective_count_ref = self.generate_collectives_count_ref(
+ mesh_shape, mesh_axes, mesh_resource, with_bias, data_shape, dtype
+ )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
qkv_ = jax.device_put(qkv, NamedSharding(mesh, qkv_pspec))
- bias_ = jax.device_put(bias, NamedSharding(mesh, bias_pspec)) \
- if bias is not None else bias
- mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
- if mask is not None else mask
+ bias_ = (
+ jax.device_put(bias, NamedSharding(mesh, bias_pspec)) if bias is not None else bias
+ )
+ mask_ = (
+ jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
+ )
grad_args = (0, 1) if with_bias else (0,)
out_grad_shardings = (qkv_pspec, bias_pspec) if with_bias else (qkv_pspec,)
- compare_ops(target_func,
- ref_func, [qkv_, bias_, mask_],
- collective_count_ref,
- grad_args=grad_args,
- metric_fwd_dtype=dtype,
- metric_bwd_dtype=dtype,
- in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
- out_shardings=(None, out_grad_shardings))
+ compare_ops(
+ target_func,
+ ref_func,
+ [qkv_, bias_, mask_],
+ collective_count_ref,
+ grad_args=grad_args,
+ metric_fwd_dtype=dtype,
+ metric_bwd_dtype=dtype,
+ in_shardings=(qkv_pspec, bias_pspec, mask_pspec),
+ out_shardings=(None, out_grad_shardings),
+ )
class TestDistributedCrossAttn:
def generate_collectives_count_ref(self):
# for loss
- all_reduce_loss_bytes = 4 # 1 * FP32
+ all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, attn_mask_type, dtype):
@@ -176,20 +208,26 @@ class TestDistributedCrossAttn:
q_pspec = PartitionSpec(mesh_resource.dp_resource, None, mesh_resource.tp_resource, None)
- kv_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, mesh_resource.tp_resource,
- None)
- mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None) \
- if attn_mask_type != AttnMaskType.NO_MASK else None
+ kv_pspec = PartitionSpec(
+ mesh_resource.dp_resource, None, None, mesh_resource.tp_resource, None
+ )
+ mask_pspec = (
+ PartitionSpec(mesh_resource.dp_resource, None, None, None)
+ if attn_mask_type != AttnMaskType.NO_MASK
+ else None
+ )
return (q, kv, mask), (q_pspec, kv_pspec, mask_pspec)
- @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
- @pytest.mark.parametrize('data_shape', [[32, 128, 12, 64], [32, 512, 16, 64]])
- @pytest.mark.parametrize('attn_mask_type',
- [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK])
- @pytest.mark.parametrize('dtype', DTYPES)
- def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
- attn_mask_type, dtype):
+ @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
+ @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]])
+ @pytest.mark.parametrize(
+ "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK]
+ )
+ @pytest.mark.parametrize("dtype", DTYPES)
+ def test_cross_attn(
+ self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, attn_mask_type, dtype
+ ):
attn_bias_type = AttnBiasType.NO_BIAS
dropout_prob = 0.0
is_training = True
@@ -197,23 +235,36 @@ class TestDistributedCrossAttn:
_, seqlen, num_head, hidden = data_shape
- if not is_fused_attn_kernel_available(dtype, dtype, QKVLayout.BSHD_BS2HD, attn_bias_type,
- attn_mask_type, dropout_prob, num_head, num_head,
- seqlen, seqlen, hidden):
+ if not is_fused_attn_kernel_available(
+ dtype,
+ dtype,
+ QKVLayout.BSHD_BS2HD,
+ attn_bias_type,
+ attn_mask_type,
+ dropout_prob,
+ num_head,
+ num_head,
+ seqlen,
+ seqlen,
+ hidden,
+ ):
pytest.skip(f"No FusedAttn backwend found")
def target_func(q, kv, mask):
return jnp.mean(
- fused_attn_kvpacked(q,
- kv,
- None,
- mask,
- None,
- attn_bias_type=attn_bias_type,
- attn_mask_type=attn_mask_type,
- scaling_factor=scaling_factor,
- dropout_probability=dropout_prob,
- is_training=is_training))
+ fused_attn_kvpacked(
+ q,
+ kv,
+ None,
+ mask,
+ None,
+ attn_bias_type=attn_bias_type,
+ attn_mask_type=attn_mask_type,
+ scaling_factor=scaling_factor,
+ dropout_probability=dropout_prob,
+ is_training=is_training,
+ )
+ )
def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3)
@@ -221,34 +272,41 @@ class TestDistributedCrossAttn:
key = jnp.squeeze(key)
value = jnp.squeeze(value)
- output = dot_product_attention(query,
- key,
- value,
- bias=None,
- mask=mask,
- deterministic=is_training,
- dropout_rate=dropout_prob,
- dropout_rng=None,
- dtype=jnp.float32)
+ output = dot_product_attention(
+ query,
+ key,
+ value,
+ bias=None,
+ mask=mask,
+ deterministic=is_training,
+ dropout_rate=dropout_prob,
+ dropout_rng=None,
+ dtype=jnp.float32,
+ )
return jnp.mean(output).astype(dtype)
- (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = \
- self.generate_inputs(data_shape, mesh_resource, attn_mask_type, dtype)
+ (q, kv, mask), (q_pspec, kv_pspec, mask_pspec) = self.generate_inputs(
+ data_shape, mesh_resource, attn_mask_type, dtype
+ )
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
q_ = jax.device_put(q, NamedSharding(mesh, q_pspec))
kv_ = jax.device_put(kv, NamedSharding(mesh, kv_pspec))
- mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) \
- if mask is not None else mask
-
- compare_ops(target_func,
- ref_func, [q_, kv_, mask_],
- collective_count_ref,
- grad_args=(0, 1),
- metric_fwd_dtype=dtype,
- metric_bwd_dtype=dtype,
- in_shardings=(q_pspec, kv_pspec, mask_pspec),
- out_shardings=(None, (q_pspec, kv_pspec)))
+ mask_ = (
+ jax.device_put(mask, NamedSharding(mesh, mask_pspec)) if mask is not None else mask
+ )
+
+ compare_ops(
+ target_func,
+ ref_func,
+ [q_, kv_, mask_],
+ collective_count_ref,
+ grad_args=(0, 1),
+ metric_fwd_dtype=dtype,
+ metric_bwd_dtype=dtype,
+ in_shardings=(q_pspec, kv_pspec, mask_pspec),
+ out_shardings=(None, (q_pspec, kv_pspec)),
+ )
diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py
index 3aa5b9ae203a9d9e37d313706bbd92896aa86834..f0dd56feaaa737f7498a6af5ac9cec3332a05177 100644
--- a/tests/jax/test_distributed_layernorm.py
+++ b/tests/jax/test_distributed_layernorm.py
@@ -35,31 +35,44 @@ class TestDistributedLayernorm:
else:
raise NotImplementedError
- g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
+ g_pspec = b_pspec = (
+ PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None)
+ )
return (x, gamma, beta), (x_pspec, g_pspec, b_pspec)
def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype):
jax_dtype = jax.dtypes.canonicalize_dtype(dtype)
is_dp_enabled = mesh_resource.dp_resource is not None
- assert ln_type in ['layernorm', 'rmsnorm']
- all_reduce_loss_bytes = 4 # 1 * FP32
+ assert ln_type in ["layernorm", "rmsnorm"]
+ all_reduce_loss_bytes = 4 # 1 * FP32
# for loss, dgamma and dbeta
- weight_count = 2 if ln_type == 'layernorm' else 1
- allreduce_total_bytes = all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
- return generate_collectives_count(allreduce=allreduce_total_bytes * int(is_dp_enabled),
- allgather=0,
- other=0)
-
- @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
- @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
- @pytest.mark.parametrize('dtype', DTYPES)
- @pytest.mark.parametrize('zero_centered_gamma', [False, True])
- @pytest.mark.parametrize('shard_weights', [False, True])
- def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype,
- zero_centered_gamma, shard_weights):
+ weight_count = 2 if ln_type == "layernorm" else 1
+ allreduce_total_bytes = (
+ all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize
+ )
+ return generate_collectives_count(
+ allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=0
+ )
+
+ @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
+ @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
+ @pytest.mark.parametrize("dtype", DTYPES)
+ @pytest.mark.parametrize("zero_centered_gamma", [False, True])
+ @pytest.mark.parametrize("shard_weights", [False, True])
+ def test_layernorm(
+ self,
+ device_count,
+ mesh_shape,
+ mesh_axes,
+ mesh_resource,
+ data_shape,
+ dtype,
+ zero_centered_gamma,
+ shard_weights,
+ ):
epsilon = 1e-6
- ln_type = 'layernorm'
+ ln_type = "layernorm"
def target_func(x, gamma, beta):
return jnp.mean(layernorm(x, gamma, beta, ln_type, zero_centered_gamma, epsilon))
@@ -75,10 +88,12 @@ class TestDistributedLayernorm:
output = jnp.asarray(normed_input * gamma + beta).astype(x.dtype)
return jnp.mean(output)
- (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \
- self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
- collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
- data_shape, dtype)
+ (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = self.generate_inputs(
+ data_shape, mesh_resource, dtype, shard_weights
+ )
+ collective_count_ref = self.generate_collectives_count_ref(
+ mesh_resource, ln_type, data_shape, dtype
+ )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
@@ -88,20 +103,25 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns:
try:
- compare_ops(target_func,
- ref_func, [x_, gamma_, beta_],
- collective_count_ref,
- grad_args=(0, 1, 2),
- metric_fwd_dtype=dtype,
- metric_bwd_dtype=dtype,
- in_shardings=(x_pspec, g_pspec, b_pspec),
- out_shardings=(None, (x_pspec, g_pspec, b_pspec)))
+ compare_ops(
+ target_func,
+ ref_func,
+ [x_, gamma_, beta_],
+ collective_count_ref,
+ grad_args=(0, 1, 2),
+ metric_fwd_dtype=dtype,
+ metric_bwd_dtype=dtype,
+ in_shardings=(x_pspec, g_pspec, b_pspec),
+ out_shardings=(None, (x_pspec, g_pspec, b_pspec)),
+ )
except AssertionError as err:
# Layernorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
# when XLA is forced to unshard gamma and/or beta. We can catch
# and ignore that specific error here.
- if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err):
+ if (
+ g_pspec[-1] is None and b_pspec[-1] is None
+ ) or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
@@ -110,13 +130,15 @@ class TestDistributedLayernorm:
"unsupported sharding of gamma and/or beta"
)
- @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
- @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]])
- @pytest.mark.parametrize('dtype', DTYPES)
- @pytest.mark.parametrize('shard_weights', [False, True])
- def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights):
+ @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
+ @pytest.mark.parametrize("data_shape", [[32, 128, 1024], [32, 1024]])
+ @pytest.mark.parametrize("dtype", DTYPES)
+ @pytest.mark.parametrize("shard_weights", [False, True])
+ def test_rmsnorm(
+ self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights
+ ):
epsilon = 1e-6
- ln_type = 'rmsnorm'
+ ln_type = "rmsnorm"
def target_func(x, gamma):
return jnp.mean(layernorm(x, gamma, None, ln_type, False, epsilon))
@@ -128,10 +150,12 @@ class TestDistributedLayernorm:
output = y * gamma
return jnp.mean(output)
- (x, gamma, _), (x_pspec, g_pspec, _) = \
- self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights)
- collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type,
- data_shape, dtype)
+ (x, gamma, _), (x_pspec, g_pspec, _) = self.generate_inputs(
+ data_shape, mesh_resource, dtype, shard_weights
+ )
+ collective_count_ref = self.generate_collectives_count_ref(
+ mesh_resource, ln_type, data_shape, dtype
+ )
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
@@ -140,14 +164,17 @@ class TestDistributedLayernorm:
with warnings.catch_warnings(record=True) as warns:
try:
- compare_ops(target_func,
- ref_func, [x_, gamma_],
- collective_count_ref,
- grad_args=(0, 1),
- metric_fwd_dtype=dtype,
- metric_bwd_dtype=dtype,
- in_shardings=(x_pspec, g_pspec),
- out_shardings=(None, (x_pspec, g_pspec)))
+ compare_ops(
+ target_func,
+ ref_func,
+ [x_, gamma_],
+ collective_count_ref,
+ grad_args=(0, 1),
+ metric_fwd_dtype=dtype,
+ metric_bwd_dtype=dtype,
+ in_shardings=(x_pspec, g_pspec),
+ out_shardings=(None, (x_pspec, g_pspec)),
+ )
except AssertionError as err:
# RmsNorm should still produce the correct numerical result with
# gamma/beta sharded. However, the collective count may not be the same
diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py
index e61b013b8fd7aa3178953b83ad72d18ddb8ac906..38f7ec0d49f9408f08971cde0a57db460a329a7d 100644
--- a/tests/jax/test_distributed_layernorm_mlp.py
+++ b/tests/jax/test_distributed_layernorm_mlp.py
@@ -15,22 +15,23 @@ from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.flax import LayerNormMLP
from transformer_engine.jax.layernorm_mlp import fused_layernorm_fp8_mlp
from transformer_engine.jax.sharding import (
- HIDDEN_AXES, HIDDEN_TP_AXES,
+ HIDDEN_AXES,
+ HIDDEN_TP_AXES,
BATCH_AXES,
- SEQLEN_TP_AXES, SEQLEN_AXES,
- W_NO_SHARD_AXES, W_FSDP_AXES, W_TP_AXES, W_JOINED_AXES
+ SEQLEN_TP_AXES,
+ SEQLEN_AXES,
+ W_NO_SHARD_AXES,
+ W_FSDP_AXES,
+ W_TP_AXES,
+ W_JOINED_AXES,
)
from transformer_engine.jax.sharding import MeshResource
-from utils import (
- assert_allclose,
- assert_tree_like_allclose,
- is_devices_enough
-)
+from utils import assert_allclose, assert_tree_like_allclose, is_devices_enough
is_fp8_supported, reason = is_fp8_available()
DTYPES = [jnp.bfloat16, jnp.float16]
-INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in]
+INPUT_SHAPE = [[64, 128, 32]] # [batch, seqlen, hidden_in]
LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES)
DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES)
@@ -43,13 +44,13 @@ def generate_fsdp_and_tp_configs():
configs = []
if is_devices_enough(2):
configs.append(
- [2, (1, 2), ('fsdp', 'tp'),
- MeshResource(fsdp_resource='fsdp', tp_resource='tp')])
+ [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
+ )
if is_devices_enough(4):
configs.append(
- [4, (2, 2), ('fsdp', 'tp'),
- MeshResource(fsdp_resource='fsdp', tp_resource='tp')])
+ [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")]
+ )
return configs
@@ -64,10 +65,12 @@ class TestDistributedLayernormMLP:
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
gamma = jax.random.normal(subkeys[5], (hidden_in,), dtype=dtype)
- k1 = jax.random.normal(subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE),
- dtype) / jnp.sqrt(hidden_in)
- k2 = jax.random.normal(subkeys[2],
- (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(INTERMEDIATE)
+ k1 = jax.random.normal(
+ subkeys[1], (hidden_in, len(activation_type), INTERMEDIATE), dtype
+ ) / jnp.sqrt(hidden_in)
+ k2 = jax.random.normal(subkeys[2], (INTERMEDIATE, hidden_out), dtype) / jnp.sqrt(
+ INTERMEDIATE
+ )
if use_bias:
b1 = jax.random.normal(subkeys[3], (len(activation_type), INTERMEDIATE), dtype)
b2 = jax.random.normal(subkeys[4], (hidden_out,), dtype)
@@ -90,15 +93,27 @@ class TestDistributedLayernormMLP:
scale_list_1: List[jnp.ndarray],
scale_list_2: List[jnp.ndarray],
layernorm_type: str = "rmsnorm",
- activation_type: Sequence[Union[str, Callable]] = ('gelu',),
+ activation_type: Sequence[Union[str, Callable]] = ("gelu",),
use_bias: bool = True,
multi_gpus: bool = False,
) -> jnp.ndarray:
- fp8_meta_pkg1 = FP8MetaPackage(amax_list_1[0], scale_list_1[0], amax_list_1[1],
- scale_list_1[1], amax_list_1[2], scale_list_1[2])
- fp8_meta_pkg2 = FP8MetaPackage(amax_list_2[0], scale_list_2[0], amax_list_2[1],
- scale_list_2[1], amax_list_2[2], scale_list_2[2])
+ fp8_meta_pkg1 = FP8MetaPackage(
+ amax_list_1[0],
+ scale_list_1[0],
+ amax_list_1[1],
+ scale_list_1[1],
+ amax_list_1[2],
+ scale_list_1[2],
+ )
+ fp8_meta_pkg2 = FP8MetaPackage(
+ amax_list_2[0],
+ scale_list_2[0],
+ amax_list_2[1],
+ scale_list_2[1],
+ amax_list_2[2],
+ scale_list_2[2],
+ )
if multi_gpus:
layernorm_input_axes = LAYERNORM_INPUT_AXES
@@ -111,60 +126,68 @@ class TestDistributedLayernormMLP:
# out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2
return jnp.mean(
- fused_layernorm_fp8_mlp(x,
- ln_scale,
- None, [kernel_1, kernel_2], [bias_1, bias_2],
- [fp8_meta_pkg1, fp8_meta_pkg2],
- layernorm_type,
- layernorm_input_axes=layernorm_input_axes,
- dot_1_input_axes=dot_1_input_axes,
- dot_2_input_axes=dot_2_input_axes,
- activation_type=activation_type,
- use_bias=use_bias))
+ fused_layernorm_fp8_mlp(
+ x,
+ ln_scale,
+ None,
+ [kernel_1, kernel_2],
+ [bias_1, bias_2],
+ [fp8_meta_pkg1, fp8_meta_pkg2],
+ layernorm_type,
+ layernorm_input_axes=layernorm_input_axes,
+ dot_1_input_axes=dot_1_input_axes,
+ dot_2_input_axes=dot_2_input_axes,
+ activation_type=activation_type,
+ use_bias=use_bias,
+ )
+ )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
- @pytest.mark.parametrize('input_shape', INPUT_SHAPE)
- @pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear')])
- @pytest.mark.parametrize('dtype', DTYPES)
- @pytest.mark.parametrize('use_bias', [True, False])
- def test_layernorm_fp8_mlp_primitive(self, mesh_config, activation_type, use_bias, input_shape,
- dtype):
+ @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
+ @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
+ @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
+ @pytest.mark.parametrize("dtype", DTYPES)
+ @pytest.mark.parametrize("use_bias", [True, False])
+ def test_layernorm_fp8_mlp_primitive(
+ self, mesh_config, activation_type, use_bias, input_shape, dtype
+ ):
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
- layernorm_type = 'rmsnorm'
+ layernorm_type = "rmsnorm"
fp8_amax_list_1 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
- jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
+ jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_amax_list_2 = [
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
- jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32)
+ jnp.zeros((FP8Helper.AMAX_HISTORY_LEN,), jnp.float32),
]
fp8_scale_list_1 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
- jnp.ones((1,), jnp.float32)
+ jnp.ones((1,), jnp.float32),
]
fp8_scale_list_2 = [
jnp.ones((1,), jnp.float32),
jnp.ones((1,), jnp.float32),
- jnp.ones((1,), jnp.float32)
+ jnp.ones((1,), jnp.float32),
]
- inputs = [x, gamma, k1, k2, b1, b2] = \
- self.generate_inputs(input_shape, activation_type, use_bias, dtype)
+ inputs = [x, gamma, k1, k2, b1, b2] = self.generate_inputs(
+ input_shape, activation_type, use_bias, dtype
+ )
inputs = [*inputs, fp8_amax_list_1, fp8_amax_list_2, fp8_scale_list_1, fp8_scale_list_2]
static_inputs = [layernorm_type, activation_type, use_bias]
- value_and_grad_func = jax.value_and_grad(self.layernorm_fp8_mlp_prim_func,
- argnums=range(len(inputs)))
+ value_and_grad_func = jax.value_and_grad(
+ self.layernorm_fp8_mlp_prim_func, argnums=range(len(inputs))
+ )
# Single GPU
- single_jitter = jax.jit(value_and_grad_func,
- static_argnums=range(len(inputs),
- len(static_inputs) + len(inputs)))
+ single_jitter = jax.jit(
+ value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs))
+ )
with fp8_autocast(enabled=True):
single_fwd, single_grads = single_jitter(*inputs, *static_inputs)
@@ -172,12 +195,12 @@ class TestDistributedLayernormMLP:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=True, mesh_resource=mesh_resource):
- k1_sharding = NamedSharding(mesh, PartitionSpec('fsdp', None, 'tp'))
- k2_sharding = NamedSharding(mesh, PartitionSpec('tp', 'fsdp'))
+ k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp"))
+ k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp"))
k1_ = jax.device_put(k1, k1_sharding)
k2_ = jax.device_put(k2, k2_sharding)
if use_bias:
- b1_sharding = NamedSharding(mesh, PartitionSpec(None, 'tp'))
+ b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp"))
b1_ = jax.device_put(b1, b1_sharding)
else:
b1_sharding = b1_ = None
@@ -186,17 +209,29 @@ class TestDistributedLayernormMLP:
# Position ref for sharding pspec lists
# x, gamma, k1, k2, b1,
# b2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv
- in_shardings = (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None,
- None, None)
- out_shardings = (None, (None, None, k1_sharding, k2_sharding, b1_sharding, None, None,
- None, None, None))
-
- multi_jitter = jax.jit(value_and_grad_func,
- in_shardings=in_shardings,
- out_shardings=out_shardings,
- static_argnums=range(len(multi_inputs),
- len(static_inputs) + len(multi_inputs) +
- 1)) # +1 for multi_gpus
+ in_shardings = (
+ None,
+ None,
+ k1_sharding,
+ k2_sharding,
+ b1_sharding,
+ None,
+ None,
+ None,
+ None,
+ None,
+ )
+ out_shardings = (
+ None,
+ (None, None, k1_sharding, k2_sharding, b1_sharding, None, None, None, None, None),
+ )
+
+ multi_jitter = jax.jit(
+ value_and_grad_func,
+ in_shardings=in_shardings,
+ out_shardings=out_shardings,
+ static_argnums=range(len(multi_inputs), len(static_inputs) + len(multi_inputs) + 1),
+ ) # +1 for multi_gpus
multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True)
@@ -206,97 +241,96 @@ class TestDistributedLayernormMLP:
if isinstance(multi_grads[i], list):
assert isinstance(single_grads[i], list)
for m_grad, s_grad in zip(multi_grads[i], single_grads[i]):
- assert_allclose(m_grad,
- s_grad,
- dtype=dtype,
- err_msg=f'multi_grads[{i}] is not close')
+ assert_allclose(
+ m_grad, s_grad, dtype=dtype, err_msg=f"multi_grads[{i}] is not close"
+ )
else:
- assert_allclose(multi_grads[i],
- single_grads[i],
- dtype=dtype,
- err_msg=f'multi_grads[{i}] is not close')
-
- def _test_layernorm_mlp(self, mesh_config, activation_type, use_bias, input_shape, dtype,
- use_fp8):
+ assert_allclose(
+ multi_grads[i],
+ single_grads[i],
+ dtype=dtype,
+ err_msg=f"multi_grads[{i}] is not close",
+ )
+
+ def _test_layernorm_mlp(
+ self, mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8
+ ):
batch, seqlen, hidden_in = input_shape
- layernorm_type = 'rmsnorm'
+ layernorm_type = "rmsnorm"
rng = jax.random.PRNGKey(0)
subkeys = jax.random.split(rng, 2)
x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype)
- init_rngs = {'params': subkeys[1]}
+ init_rngs = {"params": subkeys[1]}
# Single GPUs
with fp8_autocast(enabled=use_fp8):
ln_mlp_single = LayerNormMLP(
layernorm_type=layernorm_type,
- transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
+ transpose_batch_sequence=False, # input: [batch, seqlen, hidden]
intermediate_dim=INTERMEDIATE,
activations=activation_type,
dtype=dtype,
use_bias=use_bias,
)
params_single = ln_mlp_single.init(init_rngs, x)
- mlp_out_single, ln_out_single = ln_mlp_single.apply(params_single,
- x,
- deterministic=True)
+ mlp_out_single, ln_out_single = ln_mlp_single.apply(
+ params_single, x, deterministic=True
+ )
# Multi GPUs
device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(enabled=use_fp8, mesh_resource=mesh_resource):
- ln_mlp_sharded = LayerNormMLP(layernorm_type=layernorm_type,
- transpose_batch_sequence=False,
- intermediate_dim=INTERMEDIATE,
- activations=activation_type,
- dtype=dtype,
- scale_axes=(W_NO_SHARD_AXES,),
- ln_bias_axes=(W_NO_SHARD_AXES,),
- kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
- kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
- use_bias=use_bias,
- bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
- bias_axes_2=(W_NO_SHARD_AXES,),
- layernorm_input_axes=LAYERNORM_INPUT_AXES,
- dot_1_input_axes=DOT_1_INPUT_AXES,
- dot_2_input_axes=DOT_2_INPUT_AXES,
- name='mlp')
+ ln_mlp_sharded = LayerNormMLP(
+ layernorm_type=layernorm_type,
+ transpose_batch_sequence=False,
+ intermediate_dim=INTERMEDIATE,
+ activations=activation_type,
+ dtype=dtype,
+ scale_axes=(W_NO_SHARD_AXES,),
+ ln_bias_axes=(W_NO_SHARD_AXES,),
+ kernel_axes_1=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES),
+ kernel_axes_2=(W_TP_AXES, W_FSDP_AXES),
+ use_bias=use_bias,
+ bias_axes_1=(W_JOINED_AXES, W_TP_AXES),
+ bias_axes_2=(W_NO_SHARD_AXES,),
+ layernorm_input_axes=LAYERNORM_INPUT_AXES,
+ dot_1_input_axes=DOT_1_INPUT_AXES,
+ dot_2_input_axes=DOT_2_INPUT_AXES,
+ name="mlp",
+ )
params_sharded = ln_mlp_sharded.init(init_rngs, x)
- mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(params_sharded,
- x,
- deterministic=True)
+ mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply(
+ params_sharded, x, deterministic=True
+ )
# Make sure params values are the same
- assert_tree_like_allclose(params_sharded['params'], params_single['params'])
+ assert_tree_like_allclose(params_sharded["params"], params_single["params"])
assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype)
assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype)
- @pytest.mark.parametrize('input_shape', INPUT_SHAPE)
- @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
- @pytest.mark.parametrize('activation_type', [("gelu",), ('silu', 'linear'), ('gelu', 'gelu')])
- @pytest.mark.parametrize('dtype', DTYPES)
- @pytest.mark.parametrize('use_bias', [True, False])
+ @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
+ @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
+ @pytest.mark.parametrize("activation_type", [("gelu",), ("silu", "linear"), ("gelu", "gelu")])
+ @pytest.mark.parametrize("dtype", DTYPES)
+ @pytest.mark.parametrize("use_bias", [True, False])
def test_layernorm_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape, dtype):
- self._test_layernorm_mlp(mesh_config,
- activation_type,
- use_bias,
- input_shape,
- dtype,
- use_fp8=False)
+ self._test_layernorm_mlp(
+ mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=False
+ )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('mesh_config', generate_fsdp_and_tp_configs())
- @pytest.mark.parametrize('activation_type', [("gelu",), ('gelu', 'linear'), ('gelu', 'gelu')])
- @pytest.mark.parametrize('use_bias', [True, False])
- @pytest.mark.parametrize('input_shape', INPUT_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPES)
- def test_layernorm_fp8_mlp_layer(self, mesh_config, activation_type, use_bias, input_shape,
- dtype):
- self._test_layernorm_mlp(mesh_config,
- activation_type,
- use_bias,
- input_shape,
- dtype,
- use_fp8=True)
+ @pytest.mark.parametrize("mesh_config", generate_fsdp_and_tp_configs())
+ @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear"), ("gelu", "gelu")])
+ @pytest.mark.parametrize("use_bias", [True, False])
+ @pytest.mark.parametrize("input_shape", INPUT_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPES)
+ def test_layernorm_fp8_mlp_layer(
+ self, mesh_config, activation_type, use_bias, input_shape, dtype
+ ):
+ self._test_layernorm_mlp(
+ mesh_config, activation_type, use_bias, input_shape, dtype, use_fp8=True
+ )
diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py
index daa3a9111dd9cd77f88971873a4f4bc341cb7d4c..0ed6b84fd525e8c9c1518410179ea3ed97645976 100644
--- a/tests/jax/test_distributed_softmax.py
+++ b/tests/jax/test_distributed_softmax.py
@@ -25,7 +25,7 @@ class TestDistributedSoftmax:
def generate_collectives_count_ref(self):
# for loss
- all_reduce_loss_bytes = 4 # 1 * FP32
+ all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
@@ -38,49 +38,65 @@ class TestDistributedSoftmax:
mask = make_self_mask(batch, sqelen)
if not bad_sharding:
- x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource,
- None, None)
+ x_pspec = PartitionSpec(
+ mesh_resource.dp_resource, mesh_resource.tp_resource, None, None
+ )
else:
- x_pspec = PartitionSpec(mesh_resource.dp_resource, None,
- None, mesh_resource.tp_resource)
+ x_pspec = PartitionSpec(
+ mesh_resource.dp_resource, None, None, mesh_resource.tp_resource
+ )
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)
return (x, mask), (x_pspec, mask_pspec)
@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
- return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
+ return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
bias = None
if mask is not None:
- bias = jax.lax.select(mask > 0,
- jnp.full(mask.shape, -1e10).astype(dtype),
- jnp.full(mask.shape, 0.).astype(dtype))
+ bias = jax.lax.select(
+ mask > 0,
+ jnp.full(mask.shape, -1e10).astype(dtype),
+ jnp.full(mask.shape, 0.0).astype(dtype),
+ )
if bias is not None:
x = x + bias.astype(dtype)
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)
- @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
- @pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]])
+ @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs())
+ @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
- 'softmax_type',
- [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
- @pytest.mark.parametrize('scale_factor', [1.0, 3.0])
- @pytest.mark.parametrize('dtype', DTYPES)
- @pytest.mark.parametrize('bad_sharding', [False, True])
- def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
- softmax_type, scale_factor, dtype, bad_sharding):
-
- target_func = partial(self.target_func,
- scale_factor=scale_factor,
- softmax_type=softmax_type)
+ "softmax_type",
+ [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED],
+ )
+ @pytest.mark.parametrize("scale_factor", [1.0, 3.0])
+ @pytest.mark.parametrize("dtype", DTYPES)
+ @pytest.mark.parametrize("bad_sharding", [False, True])
+ def test_softmax(
+ self,
+ device_count,
+ mesh_shape,
+ mesh_axes,
+ mesh_resource,
+ data_shape,
+ softmax_type,
+ scale_factor,
+ dtype,
+ bad_sharding,
+ ):
+
+ target_func = partial(
+ self.target_func, scale_factor=scale_factor, softmax_type=softmax_type
+ )
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)
- (x, mask), (x_pspec, mask_pspec) = \
- self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype, bad_sharding)
+ (x, mask), (x_pspec, mask_pspec) = self.generate_inputs(
+ data_shape, mesh_resource, softmax_type, dtype, bad_sharding
+ )
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
@@ -90,14 +106,17 @@ class TestDistributedSoftmax:
with warnings.catch_warnings(record=True) as warns:
try:
- compare_ops(target_func,
- ref_func, [x_, mask_],
- collective_count_ref,
- grad_args=(0,),
- metric_fwd_dtype=dtype,
- metric_bwd_dtype=dtype,
- in_shardings=(x_pspec, mask_pspec),
- out_shardings=(None, (x_pspec,)))
+ compare_ops(
+ target_func,
+ ref_func,
+ [x_, mask_],
+ collective_count_ref,
+ grad_args=(0,),
+ metric_fwd_dtype=dtype,
+ metric_bwd_dtype=dtype,
+ in_shardings=(x_pspec, mask_pspec),
+ out_shardings=(None, (x_pspec,)),
+ )
except AssertionError as err:
# Softmax should still produce the correct numerical result with
# bad sharding. However, the collective count may not be the same
diff --git a/tests/jax/test_functions.py b/tests/jax/test_functions.py
index aaa6be77acf589b035c77f0b8c7a77939c9f7c46..d6da307fd3d9400d05965a890c89018b72579828 100644
--- a/tests/jax/test_functions.py
+++ b/tests/jax/test_functions.py
@@ -20,12 +20,14 @@ class TestLoRA:
out = jnp.einsum(pattern, x, la, lb)
return out * scale
- @pytest.mark.parametrize('shape', [(32, 1024), (32, 128, 1024)])
- @pytest.mark.parametrize('dtype', [jnp.float32, jnp.bfloat16])
- @pytest.mark.parametrize('axis_features_pattern', [((-1,), (1024,), '...h,hr,rk->...k'),
- ((-1,), (3, 1024), '...h,hkr,krz->...kz')])
- @pytest.mark.parametrize('rank', [32, 16])
- @pytest.mark.parametrize('alpha', [None, 4, 8])
+ @pytest.mark.parametrize("shape", [(32, 1024), (32, 128, 1024)])
+ @pytest.mark.parametrize("dtype", [jnp.float32, jnp.bfloat16])
+ @pytest.mark.parametrize(
+ "axis_features_pattern",
+ [((-1,), (1024,), "...h,hr,rk->...k"), ((-1,), (3, 1024), "...h,hkr,krz->...kz")],
+ )
+ @pytest.mark.parametrize("rank", [32, 16])
+ @pytest.mark.parametrize("alpha", [None, 4, 8])
def test_lora(self, shape, dtype, axis_features_pattern, rank, alpha):
axis, features, pattern = axis_features_pattern
axis = _normalize_axes(axis, len(shape))
@@ -49,16 +51,20 @@ class TestLoRA:
assert_allclose(out_target, out_ref, dtype=dtype)
- @pytest.mark.parametrize('scope_ref_assert',
- [('none', LoRAScope(False, False, False), False),
- ('all', LoRAScope(True, True, True), False),
- ('qkv_proj', LoRAScope(True, False, False), False),
- ('output_proj', LoRAScope(False, True, False), False),
- ('mlp', LoRAScope(False, False, True), False),
- ('exclude_qkv_proj', LoRAScope(False, True, True), False),
- ('exclude_output_proj', LoRAScope(True, False, True), False),
- ('exclude_mlp', LoRAScope(True, True, False), False),
- ('messing_up', LoRAScope(), True)])
+ @pytest.mark.parametrize(
+ "scope_ref_assert",
+ [
+ ("none", LoRAScope(False, False, False), False),
+ ("all", LoRAScope(True, True, True), False),
+ ("qkv_proj", LoRAScope(True, False, False), False),
+ ("output_proj", LoRAScope(False, True, False), False),
+ ("mlp", LoRAScope(False, False, True), False),
+ ("exclude_qkv_proj", LoRAScope(False, True, True), False),
+ ("exclude_output_proj", LoRAScope(True, False, True), False),
+ ("exclude_mlp", LoRAScope(True, True, False), False),
+ ("messing_up", LoRAScope(), True),
+ ],
+ )
def test_lora_scope_generator(self, scope_ref_assert):
scope, reference, need_assert = scope_ref_assert
try:
diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py
index 3bb7520ca1a995b5c26f0985950e1e4ce89c4ede..3390c36426f13f843953bafc3f007ebe05a4f11f 100644
--- a/tests/jax/test_fused_attn.py
+++ b/tests/jax/test_fused_attn.py
@@ -24,7 +24,7 @@ from transformer_engine.jax.attention import (
QKVLayout,
fused_attn_qkvpacked,
fused_attn_kvpacked,
- fused_attn
+ fused_attn,
)
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
@@ -33,7 +33,7 @@ from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose
-@pytest.fixture(autouse=True, scope='module')
+@pytest.fixture(autouse=True, scope="module")
def init():
"""
WAR for CUDA uninitialize error
@@ -43,10 +43,18 @@ def init():
yield
-def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: ArrayLike,
- bias: ArrayLike, mask: ArrayLike, deterministic: bool,
- scale_factor: float, dropout_rate: float, dropout_rng: ArrayLike,
- dtype: DTypeLike) -> Array:
+def general_dot_product_attention(
+ query: ArrayLike,
+ key: ArrayLike,
+ value: ArrayLike,
+ bias: ArrayLike,
+ mask: ArrayLike,
+ deterministic: bool,
+ scale_factor: float,
+ dropout_rate: float,
+ dropout_rng: ArrayLike,
+ dtype: DTypeLike,
+) -> Array:
"""
Similar to flax.linen.dot_product_attention but with GQA support
"""
@@ -59,7 +67,7 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
num_groups = h_q // h_kv
grouped_query = jnp.reshape(query, (b, s_q, h_kv, num_groups, d))
# logits with shape (b, h_kv, num_groups, s_q, s_kv)
- logits = scale_factor * jnp.einsum('...qhgd,...khd->...hgqk', grouped_query, key)
+ logits = scale_factor * jnp.einsum("...qhgd,...khd->...hgqk", grouped_query, key)
if bias is not None:
# reshape logits without groups
@@ -76,13 +84,13 @@ def general_dot_product_attention(query: ArrayLike, key: ArrayLike, value: Array
softmax_out = jax.nn.softmax(logits).astype(dtype)
- if not deterministic and dropout_rate > 0.:
+ if not deterministic and dropout_rate > 0.0:
keep_prob = 1.0 - dropout_rate
keep = jax.random.bernoulli(dropout_rng, keep_prob, softmax_out.shape)
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
softmax_out = softmax_out * multiplier
- context = jnp.einsum('...hgqk,...khd->...qhgd', softmax_out, value)
+ context = jnp.einsum("...hgqk,...khd->...qhgd", softmax_out, value)
context = jnp.reshape(context, query.shape)
return context
@@ -105,6 +113,7 @@ def make_decoder_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array:
inv_padding_mask = make_attention_mask(q_tokens > 0, kv_tokens > 0)
return combine_masks(inv_causal_mask, inv_padding_mask)
+
def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskType) -> Array:
"""
Create attention mask based on mask type. A `True` value in the mask means
@@ -118,23 +127,26 @@ def make_mask(q_token: ArrayLike, kv_token: ArrayLike, attn_mask_type: AttnMaskT
mask = jnp.logical_not(inv_mask)
return mask
+
def jax_dpa(query, key, value, bias, q_token, kv_token, dropout_rng, **kwargs):
"""
JAX native dot product attention implementation
"""
- attn_mask_type = kwargs['attn_mask_type']
+ attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type)
- output = general_dot_product_attention(query,
- key,
- value,
- bias=bias,
- mask=mask,
- deterministic=not kwargs['is_training'],
- scale_factor=kwargs['scaling_factor'],
- dropout_rate=kwargs['dropout_probability'],
- dropout_rng=dropout_rng,
- dtype=jnp.float32)
+ output = general_dot_product_attention(
+ query,
+ key,
+ value,
+ bias=bias,
+ mask=mask,
+ deterministic=not kwargs["is_training"],
+ scale_factor=kwargs["scaling_factor"],
+ dropout_rate=kwargs["dropout_probability"],
+ dropout_rng=dropout_rng,
+ dtype=jnp.float32,
+ )
return output.astype(query.dtype)
@@ -142,10 +154,10 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
"""
TE customcall dot product attention implementation
"""
- attn_mask_type = kwargs['attn_mask_type']
+ attn_mask_type = kwargs["attn_mask_type"]
mask = make_mask(q_token, kv_token, attn_mask_type)
- qkv_layout = kwargs.pop('qkv_layout')
+ qkv_layout = kwargs.pop("qkv_layout")
match qkv_layout:
case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
@@ -154,11 +166,13 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
- return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
- **kwargs).astype(query.dtype)
+ return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, **kwargs).astype(
+ query.dtype
+ )
case QKVLayout.BSHD_BSHD_BSHD:
- return fused_attn(query, key, value, bias, mask, dropout_rng,
- **kwargs).astype(query.dtype)
+ return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(
+ query.dtype
+ )
class BiasShape(Enum):
@@ -166,10 +180,10 @@ class BiasShape(Enum):
Enum class to represent the different bias shapes used in the fused attention.
"""
- BIAS_1HSS = '1HSS'
- BIAS_B1SS = 'B1SS'
- BIAS_BHSS = 'BHSS'
- BIAS_11SS = '11SS'
+ BIAS_1HSS = "1HSS"
+ BIAS_B1SS = "B1SS"
+ BIAS_BHSS = "BHSS"
+ BIAS_11SS = "11SS"
@dataclass
@@ -177,6 +191,7 @@ class FusedAttnRunner:
"""
Fused attention runner
"""
+
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
@@ -198,21 +213,33 @@ class FusedAttnRunner:
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")
- self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
- self.attn_bias_type.value, self.attn_mask_type.value,
- self.dropout_prob, self.num_heads_q, self.num_heads_kv,
- self.max_seqlen_q, self.max_seqlen_kv,
- self.head_dim).get_fused_attn_backend()
+ self.backend = FusedAttnHelper(
+ self.dtype,
+ self.dtype,
+ self.qkv_layout.value,
+ self.attn_bias_type.value,
+ self.attn_mask_type.value,
+ self.dropout_prob,
+ self.num_heads_q,
+ self.num_heads_kv,
+ self.max_seqlen_q,
+ self.max_seqlen_kv,
+ self.head_dim,
+ ).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
- pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
- "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
+ pytest.skip(
+ "B1SS, BHSS and 11SS bias shapes are only supported for "
+ "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK."
+ )
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
- pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
- "the F16_arbitrary_seqlen backend.")
+ pytest.skip(
+ "B1SS, BHSS and 11SS bias shapes are only supported for "
+ "the F16_arbitrary_seqlen backend."
+ )
def _setup_inputs(self):
self._check_configs()
@@ -235,24 +262,25 @@ class FusedAttnRunner:
else:
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")
- self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
- self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
- self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.)
+ self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.0)
+ self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.0)
+ self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0)
if self.attn_bias_type != AttnBiasType.NO_BIAS:
if self.bias_shape == BiasShape.BIAS_1HSS:
- self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.)
+ self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0)
else:
# [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for
# an arbitrary mask where (True/False -> 0/-Inf)
- cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
+ cudnn_neg_inf = -(2.0**27.0) if self.dtype == jnp.bfloat16 else -(2.0**15.0)
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
- seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
+ seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)):
- self.bias = \
- self.bias.at[:, :, seq_id[i-1]:seq_id[i], seq_id[i-1]:seq_id[i]].set(0.)
+ self.bias = self.bias.at[
+ :, :, seq_id[i - 1] : seq_id[i], seq_id[i - 1] : seq_id[i]
+ ].set(0.0)
else:
self.bias = None
@@ -271,7 +299,7 @@ class FusedAttnRunner:
self.valid_len_kv, self.token_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio)
self.dropout_rng = dropout_key if self.dropout_prob > 0 else None
- self.scaling_factor = 1. / sqrt(self.head_dim)
+ self.scaling_factor = 1.0 / sqrt(self.head_dim)
def test_forward(self):
"""
@@ -281,19 +309,19 @@ class FusedAttnRunner:
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = {
- 'attn_bias_type': self.attn_bias_type,
- 'attn_mask_type': self.attn_mask_type,
- 'scaling_factor': self.scaling_factor,
- 'dropout_probability': self.dropout_prob,
- 'is_training': self.is_training,
- 'qkv_layout': self.qkv_layout,
+ "attn_bias_type": self.attn_bias_type,
+ "attn_mask_type": self.attn_mask_type,
+ "scaling_factor": self.scaling_factor,
+ "dropout_probability": self.dropout_prob,
+ "is_training": self.is_training,
+ "qkv_layout": self.qkv_layout,
}
# Convert the outputs to float32 for the elementwise comparison
primitive_out = customcall_fused_dpa(*args, **kwargs).astype(jnp.float32)
reference_out = jax_dpa(*args, **kwargs).astype(jnp.float32)
- if self.is_training and self.dropout_prob > 0.:
+ if self.is_training and self.dropout_prob > 0.0:
return
primitive_valid, primitive_invalid = jnp.split(primitive_out, (self.valid_len_q,), axis=1)
@@ -322,12 +350,12 @@ class FusedAttnRunner:
args = [self.q, self.k, self.v, self.bias, self.token_q, self.token_kv, self.dropout_rng]
kwargs = {
- 'attn_bias_type': self.attn_bias_type,
- 'attn_mask_type': self.attn_mask_type,
- 'scaling_factor': self.scaling_factor,
- 'dropout_probability': self.dropout_prob,
- 'is_training': self.is_training,
- 'qkv_layout': self.qkv_layout,
+ "attn_bias_type": self.attn_bias_type,
+ "attn_mask_type": self.attn_mask_type,
+ "scaling_factor": self.scaling_factor,
+ "dropout_probability": self.dropout_prob,
+ "is_training": self.is_training,
+ "qkv_layout": self.qkv_layout,
}
# We can compute dBias only for the [1, h, s, s] layout
@@ -336,12 +364,18 @@ class FusedAttnRunner:
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
value_and_grad(
- lambda q, k, v, bias, *args: grad_func(customcall_fused_dpa, q, k, v, bias, *args,
- **kwargs), arg_nums))
+ lambda q, k, v, bias, *args: grad_func(
+ customcall_fused_dpa, q, k, v, bias, *args, **kwargs
+ ),
+ arg_nums,
+ )
+ )
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
- arg_nums))
+ arg_nums,
+ )
+ )
primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args)
@@ -350,9 +384,9 @@ class FusedAttnRunner:
if self.dropout_prob > 0.0:
return
- assert_allclose(primitive_out.astype(jnp.float32),
- reference_out.astype(jnp.float32),
- dtype=self.dtype)
+ assert_allclose(
+ primitive_out.astype(jnp.float32), reference_out.astype(jnp.float32), dtype=self.dtype
+ )
def check_dqkv(primitive, reference, valid_len):
primitive_valid, primitive_invalid = jnp.split(primitive, (valid_len,), axis=1)
@@ -374,81 +408,158 @@ class FusedAttnRunner:
primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3])
- assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
- jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
- self.valid_len_kv:]),
- dtype=self.dtype)
+ assert_allclose(
+ primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
+ jnp.zeros_like(primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :]),
+ dtype=self.dtype,
+ )
# dbias padded part
- assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
- reference_dbias[..., self.valid_len_q:, self.valid_len_kv:],
- dtype=self.dtype)
+ assert_allclose(
+ primitive_dbias[..., self.valid_len_q :, self.valid_len_kv :],
+ reference_dbias[..., self.valid_len_q :, self.valid_len_kv :],
+ dtype=self.dtype,
+ )
# dbias valid part
- assert_allclose(primitive_dbias[..., :self.valid_len_q, :self.valid_len_kv],
- reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
- dtype=self.dtype)
-
-
-@pytest.mark.parametrize('attn_bias_type, bias_shape', [
- pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
- pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
- pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
- pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
- pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
-])
-@pytest.mark.parametrize('attn_mask_type', [
- pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
- pytest.param(AttnMaskType.PADDING_MASK, id='PADDING'),
- pytest.param(AttnMaskType.CAUSAL_MASK, id='CAUSAL'),
- pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id='PADDING_CAUSAL'),
-])
-@pytest.mark.parametrize('qkv_layout', [
- pytest.param(QKVLayout.BS3HD, id='QKV_PACKED'),
- pytest.param(QKVLayout.BSHD_BS2HD, id='KV_PACKED'),
- pytest.param(QKVLayout.BSHD_BSHD_BSHD, id='SEPARATE'),
-])
-@pytest.mark.parametrize('dtype', [
- pytest.param(jnp.bfloat16, id="BF16"),
- pytest.param(jnp.float16, id="FP16"),
-])
-@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
- pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
- pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
- pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
- pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
- pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
- pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
-])
-@pytest.mark.parametrize('dropout_prob', [
- pytest.param(0.0, id="DROP_0.0"),
- pytest.param(0.1, id="DROP_0.1"),
-])
+ assert_allclose(
+ primitive_dbias[..., : self.valid_len_q, : self.valid_len_kv],
+ reference_dbias[..., : self.valid_len_q, : self.valid_len_kv],
+ dtype=self.dtype,
+ )
+
+
+@pytest.mark.parametrize(
+ "attn_bias_type, bias_shape",
+ [
+ pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"),
+ pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"),
+ pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"),
+ pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"),
+ pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"),
+ ],
+)
+@pytest.mark.parametrize(
+ "attn_mask_type",
+ [
+ pytest.param(AttnMaskType.NO_MASK, id="NO_MASK"),
+ pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"),
+ pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"),
+ pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"),
+ ],
+)
+@pytest.mark.parametrize(
+ "qkv_layout",
+ [
+ pytest.param(QKVLayout.BS3HD, id="QKV_PACKED"),
+ pytest.param(QKVLayout.BSHD_BS2HD, id="KV_PACKED"),
+ pytest.param(QKVLayout.BSHD_BSHD_BSHD, id="SEPARATE"),
+ ],
+)
+@pytest.mark.parametrize(
+ "dtype",
+ [
+ pytest.param(jnp.bfloat16, id="BF16"),
+ pytest.param(jnp.float16, id="FP16"),
+ ],
+)
+@pytest.mark.parametrize(
+ "b, s_q, s_kv, h_q, h_kv, d",
+ [
+ pytest.param(32, 128, 128, 16, 16, 64, id="32-128-128-16-16-64-SELF"),
+ pytest.param(4, 2048, 2048, 12, 12, 64, id="4-2048-2048-12-12-64-SELF"),
+ pytest.param(32, 512, 128, 16, 16, 64, id="32-512-128-16-16-64-CROSS"),
+ pytest.param(4, 2048, 1024, 12, 12, 64, id="4-2048-1048-12-12-64-CROSS"),
+ pytest.param(32, 128, 128, 16, 8, 64, id="32-128-128-16-8-64-GQA"),
+ pytest.param(4, 2048, 2048, 12, 6, 64, id="4-2048-2048-12-6-64-GQA"),
+ ],
+)
+@pytest.mark.parametrize(
+ "dropout_prob",
+ [
+ pytest.param(0.0, id="DROP_0.0"),
+ pytest.param(0.1, id="DROP_0.1"),
+ ],
+)
class TestFusedAttn:
"""
Fused attention tester
"""
@staticmethod
- @pytest.mark.parametrize('is_training', [
- pytest.param(True, id='TRAINING'),
- pytest.param(False, id='INFERENCE'),
- ])
- def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
- dtype, is_training, qkv_layout, bias_shape):
+ @pytest.mark.parametrize(
+ "is_training",
+ [
+ pytest.param(True, id="TRAINING"),
+ pytest.param(False, id="INFERENCE"),
+ ],
+ )
+ def test_forward(
+ b,
+ s_q,
+ s_kv,
+ h_q,
+ h_kv,
+ d,
+ attn_bias_type,
+ attn_mask_type,
+ dropout_prob,
+ dtype,
+ is_training,
+ qkv_layout,
+ bias_shape,
+ ):
"""
Test forward with parameterized configs
"""
- runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
- dropout_prob, dtype, is_training, qkv_layout, bias_shape)
+ runner = FusedAttnRunner(
+ b,
+ s_q,
+ s_kv,
+ h_q,
+ h_kv,
+ d,
+ attn_bias_type,
+ attn_mask_type,
+ dropout_prob,
+ dtype,
+ is_training,
+ qkv_layout,
+ bias_shape,
+ )
runner.test_forward()
@staticmethod
- def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
- dtype, qkv_layout, bias_shape):
+ def test_backward(
+ b,
+ s_q,
+ s_kv,
+ h_q,
+ h_kv,
+ d,
+ attn_bias_type,
+ attn_mask_type,
+ dropout_prob,
+ dtype,
+ qkv_layout,
+ bias_shape,
+ ):
"""
Test backward with parameterized configs
"""
- runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
- dropout_prob, dtype, True, qkv_layout, bias_shape)
+ runner = FusedAttnRunner(
+ b,
+ s_q,
+ s_kv,
+ h_q,
+ h_kv,
+ d,
+ attn_bias_type,
+ attn_mask_type,
+ dropout_prob,
+ dtype,
+ True,
+ qkv_layout,
+ bias_shape,
+ )
runner.test_backward()
diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py
index 90f558744368281c5ffa6e87f1de054e951d301f..3a0add0a38cfdaf4736349963681dec8d4163b79 100644
--- a/tests/jax/test_helper.py
+++ b/tests/jax/test_helper.py
@@ -27,21 +27,28 @@ class TestFP8Helper(unittest.TestCase):
fp8_format = FP8Format.E4M3
amax_history_len = 10
- FP8Helper.initialize(margin=margin,
- fp8_format=fp8_format,
- amax_history_len=amax_history_len)
+ FP8Helper.initialize(
+ margin=margin, fp8_format=fp8_format, amax_history_len=amax_history_len
+ )
self.assertEqual(
- FP8Helper.MARGIN, margin, f"FP8Helper.MARGIN initialization failed, should be {margin}"
- f" but got {FP8Helper.MARGIN}.")
+ FP8Helper.MARGIN,
+ margin,
+ f"FP8Helper.MARGIN initialization failed, should be {margin}"
+ f" but got {FP8Helper.MARGIN}.",
+ )
self.assertEqual(
- FP8Helper.FP8_FORMAT, fp8_format,
+ FP8Helper.FP8_FORMAT,
+ fp8_format,
f"FP8Helper.FP8_FORMAT initialization failed, should be {fp8_format}"
- f" but got {FP8Helper.FP8_FORMAT}.")
+ f" but got {FP8Helper.FP8_FORMAT}.",
+ )
self.assertEqual(
- FP8Helper.AMAX_HISTORY_LEN, amax_history_len,
+ FP8Helper.AMAX_HISTORY_LEN,
+ amax_history_len,
f"FP8Helper.AMAX_HISTORY_LEN initialization failed, should be {amax_history_len}"
- f" but got {FP8Helper.AMAX_HISTORY_LEN}.")
+ f" but got {FP8Helper.AMAX_HISTORY_LEN}.",
+ )
FP8Helper.finalize()
@@ -77,7 +84,7 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast(self):
- FP8Helper.finalize() # Ensure the testing not affect by previous tests.
+ FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()):
@@ -102,21 +109,21 @@ class TestFP8Functions(unittest.TestCase):
@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_with_sharding_resource(self):
- FP8Helper.finalize() # Ensure the testing not affect by previous tests.
+ FP8Helper.finalize() # Ensure the testing not affect by previous tests.
self._check_defult_state()
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
mesh_s = (
(MeshResource(None, None)),
- (MeshResource('dp', None)),
- (MeshResource(None, 'tp')),
- (MeshResource('dp', 'tp')),
+ (MeshResource("dp", None)),
+ (MeshResource(None, "tp")),
+ (MeshResource("dp", "tp")),
)
# TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme
mesh_shape = (1, 1)
devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape)
- with jax.sharding.Mesh(devices, ('dp', 'tp')):
+ with jax.sharding.Mesh(devices, ("dp", "tp")):
for sr in mesh_s:
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr):
self.assertTrue(FP8Helper.is_fp8_enabled())
diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py
index b8f52c91a2d86fe7f16e8a375a0fd4efdb532dc5..fa04382d598c6054f07d39075658ab5d3f70faa0 100644
--- a/tests/jax/test_layer.py
+++ b/tests/jax/test_layer.py
@@ -22,7 +22,7 @@ from transformer_engine.jax.fp8 import FP8Helper, is_fp8_available
is_fp8_supported, reason = is_fp8_available()
-@pytest.fixture(autouse=True, scope='function')
+@pytest.fixture(autouse=True, scope="function")
def enable_fused_attn():
"""Enable fused attention"""
os.environ["NVTE_FUSED_ATTN"] = "1"
@@ -30,9 +30,9 @@ def enable_fused_attn():
del os.environ["NVTE_FUSED_ATTN"]
-DATA_SHAPE = [ # (batch, seqlen, emb_dim)
- pytest.param((32, 128, 1024), id='32-128-1024'),
- pytest.param((32, 512, 1024), id='32-512-1024'),
+DATA_SHAPE = [ # (batch, seqlen, emb_dim)
+ pytest.param((32, 128, 1024), id="32-128-1024"),
+ pytest.param((32, 512, 1024), id="32-512-1024"),
]
DTYPE = [jnp.float32, jnp.bfloat16]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
@@ -69,123 +69,138 @@ BASE_ATTRS = {
_KEY_OF_ATTENTION_DROPOUT: 0,
_KEY_OF_INTERMEDIATE_DROPOUT: 0,
_KEY_OF_SELF_ATTN_MASK_TYPE: "padding_causal",
- _KEY_OF_LAYERNORM_TYPE: 'layernorm',
+ _KEY_OF_LAYERNORM_TYPE: "layernorm",
}
-ATTRS = [{}, {
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
-}, {
- _KEY_OF_ZERO_CENTERED_GAMMA: True,
- _KEY_OF_LAYERNORM_EPS: 1e-2,
-}, {
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_RESIDUAL_POST_LAYERNORM: True
-}, {
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_OUTPUT_LAYERNORM: True
-}, {
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
- _KEY_OF_OUTPUT_LAYERNORM: True
-}, {
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_DROP_PATH: 0.1
-}, {
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_FUSE_QKV_PARAMS: False
-}, {
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
-}, {
- _KEY_OF_SCALE_ATTN_LOGITS: True,
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_HIDDEN_DROPOUT: 0.8,
- _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
- _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
- _KEY_OF_USE_BIAS: True,
-}, {
- _KEY_OF_TRANSPOSE_BS: False,
- _KEY_OF_SCALE_ATTN_LOGITS: True,
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_MLP_ACTIVATIONS: ('gelu', 'linear'),
-}, {
- _KEY_OF_NUM_HEADS: 8,
- _KEY_OF_NUM_GQA_GROUPS: 4,
- _KEY_OF_TRANSPOSE_BS: False,
- _KEY_OF_SCALE_ATTN_LOGITS: True,
- _KEY_OF_MLP_ACTIVATIONS: ('gelu',),
- _KEY_OF_USE_BIAS: True,
-}, {
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
-}, {
- _KEY_OF_SCALE_ATTN_LOGITS: True,
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_HIDDEN_DROPOUT: 0.8,
- _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
- _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
- _KEY_OF_USE_BIAS: True,
-}, {
- _KEY_OF_TRANSPOSE_BS: False,
- _KEY_OF_SCALE_ATTN_LOGITS: True,
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_MLP_ACTIVATIONS: (('silu', 'linear')),
-}, {
- _KEY_OF_NUM_HEADS: 8,
- _KEY_OF_NUM_GQA_GROUPS: 4,
- _KEY_OF_TRANSPOSE_BS: False,
- _KEY_OF_SCALE_ATTN_LOGITS: True,
- _KEY_OF_LAYERNORM_TYPE: 'layernorm',
- _KEY_OF_MLP_ACTIVATIONS: (('silu',)),
- _KEY_OF_USE_BIAS: True,
-}, {
- _KEY_OF_TRANSPOSE_BS: False,
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_NUM_GQA_GROUPS: 1,
- _KEY_OF_ENABLE_ROPE: True,
- _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
- _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
-}, {
- _KEY_OF_TRANSPOSE_BS: True,
- _KEY_OF_ENABLE_ROPE: True,
- _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
- _KEY_OF_USE_BIAS: True,
-}, {
- _KEY_OF_TRANSPOSE_BS: False,
- _KEY_OF_LAYERNORM_TYPE: 'layernorm',
- _KEY_OF_NUM_GQA_GROUPS: 2,
- _KEY_OF_ENABLE_ROPE: True,
- _KEY_OF_ROPE_GROUP_METHOD: "alternate",
- _KEY_OF_USE_BIAS: True,
- _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
-}, {
- _KEY_OF_TRANSPOSE_BS: True,
- _KEY_OF_LAYERNORM_TYPE: 'rmsnorm',
- _KEY_OF_ENABLE_ROPE: True,
- _KEY_OF_ROPE_GROUP_METHOD: "alternate",
- _KEY_OF_USE_BIAS: True,
-}, {
- _KEY_OF_HIDDEN_DROPOUT: 0.3,
- _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
- _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
- _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
-}, {
- _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
- _KEY_OF_USE_BIAS: True,
-}, {
- _KEY_OF_RELATIVE_EMBEDDING: False,
- _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
-}, {
- _KEY_OF_ATTENTION_DROPOUT: 0.3,
-}, {
- _KEY_OF_MLP_ACTIVATIONS: (('relu', 'relu')),
-}]
+ATTRS = [
+ {},
+ {
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ },
+ {
+ _KEY_OF_ZERO_CENTERED_GAMMA: True,
+ _KEY_OF_LAYERNORM_EPS: 1e-2,
+ },
+ {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_RESIDUAL_POST_LAYERNORM: True},
+ {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_OUTPUT_LAYERNORM: True},
+ {
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_RESIDUAL_POST_LAYERNORM: True,
+ _KEY_OF_OUTPUT_LAYERNORM: True,
+ },
+ {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_DROP_PATH: 0.1},
+ {_KEY_OF_LAYERNORM_TYPE: "rmsnorm", _KEY_OF_FUSE_QKV_PARAMS: False},
+ {
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
+ },
+ {
+ _KEY_OF_SCALE_ATTN_LOGITS: True,
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_HIDDEN_DROPOUT: 0.8,
+ _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
+ _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
+ _KEY_OF_USE_BIAS: True,
+ },
+ {
+ _KEY_OF_TRANSPOSE_BS: False,
+ _KEY_OF_SCALE_ATTN_LOGITS: True,
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_MLP_ACTIVATIONS: ("gelu", "linear"),
+ },
+ {
+ _KEY_OF_NUM_HEADS: 8,
+ _KEY_OF_NUM_GQA_GROUPS: 4,
+ _KEY_OF_TRANSPOSE_BS: False,
+ _KEY_OF_SCALE_ATTN_LOGITS: True,
+ _KEY_OF_MLP_ACTIVATIONS: ("gelu",),
+ _KEY_OF_USE_BIAS: True,
+ },
+ {
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
+ },
+ {
+ _KEY_OF_SCALE_ATTN_LOGITS: True,
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_HIDDEN_DROPOUT: 0.8,
+ _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
+ _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
+ _KEY_OF_USE_BIAS: True,
+ },
+ {
+ _KEY_OF_TRANSPOSE_BS: False,
+ _KEY_OF_SCALE_ATTN_LOGITS: True,
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_MLP_ACTIVATIONS: (("silu", "linear")),
+ },
+ {
+ _KEY_OF_NUM_HEADS: 8,
+ _KEY_OF_NUM_GQA_GROUPS: 4,
+ _KEY_OF_TRANSPOSE_BS: False,
+ _KEY_OF_SCALE_ATTN_LOGITS: True,
+ _KEY_OF_LAYERNORM_TYPE: "layernorm",
+ _KEY_OF_MLP_ACTIVATIONS: (("silu",)),
+ _KEY_OF_USE_BIAS: True,
+ },
+ {
+ _KEY_OF_TRANSPOSE_BS: False,
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_NUM_GQA_GROUPS: 1,
+ _KEY_OF_ENABLE_ROPE: True,
+ _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
+ _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
+ },
+ {
+ _KEY_OF_TRANSPOSE_BS: True,
+ _KEY_OF_ENABLE_ROPE: True,
+ _KEY_OF_ROPE_GROUP_METHOD: "consecutive",
+ _KEY_OF_USE_BIAS: True,
+ },
+ {
+ _KEY_OF_TRANSPOSE_BS: False,
+ _KEY_OF_LAYERNORM_TYPE: "layernorm",
+ _KEY_OF_NUM_GQA_GROUPS: 2,
+ _KEY_OF_ENABLE_ROPE: True,
+ _KEY_OF_ROPE_GROUP_METHOD: "alternate",
+ _KEY_OF_USE_BIAS: True,
+ _KEY_OF_FLOAT32_ATTENTION_LOGITS: True,
+ },
+ {
+ _KEY_OF_TRANSPOSE_BS: True,
+ _KEY_OF_LAYERNORM_TYPE: "rmsnorm",
+ _KEY_OF_ENABLE_ROPE: True,
+ _KEY_OF_ROPE_GROUP_METHOD: "alternate",
+ _KEY_OF_USE_BIAS: True,
+ },
+ {
+ _KEY_OF_HIDDEN_DROPOUT: 0.3,
+ _KEY_OF_HIDDEN_DROPOUT_DIMS: (0,),
+ _KEY_OF_INTERMEDIATE_DROPOUT: 0.5,
+ _KEY_OF_INTERMEDIATE_DROPOUT_DIMS: (1,),
+ },
+ {
+ _KEY_OF_SELF_ATTN_MASK_TYPE: "padding",
+ _KEY_OF_USE_BIAS: True,
+ },
+ {
+ _KEY_OF_RELATIVE_EMBEDDING: False,
+ _KEY_OF_SELF_ATTN_BIAS_TYPE: "no_bias",
+ },
+ {
+ _KEY_OF_ATTENTION_DROPOUT: 0.3,
+ },
+ {
+ _KEY_OF_MLP_ACTIVATIONS: (("relu", "relu")),
+ },
+]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
class BaseRunner:
"""Base runner to define forward and backward tests"""
+
layer_type: TransformerLayerType = None
reference_layer: flax.linen.Module = None
transformations: Dict[str, str] = None
@@ -194,24 +209,24 @@ class BaseRunner:
self.attrs = attrs
self._generate_test_rngs()
# Disable fused attention for attention dropout because the different dropout impl
- if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv('NVTE_FUSED_ATTN'):
- os.environ['NVTE_FUSED_ATTN'] = "0"
+ if attrs.get(_KEY_OF_ATTENTION_DROPOUT, False) and os.getenv("NVTE_FUSED_ATTN"):
+ os.environ["NVTE_FUSED_ATTN"] = "0"
def _generate_test_rngs(self):
root_rng = jax.random.PRNGKey(0)
params_rng, init_dropout_rng, apply_dropout_rng = jax.random.split(root_rng, 3)
- self.init_rng = {'params': params_rng, 'dropout': init_dropout_rng}
- self.apply_rng = {'dropout': apply_dropout_rng}
+ self.init_rng = {"params": params_rng, "dropout": init_dropout_rng}
+ self.apply_rng = {"dropout": apply_dropout_rng}
def _generate_layer(self, layer_cls, diff_inputs, no_diff_inputs):
layer = layer_cls()
variables = layer.init(self.init_rng, *diff_inputs, *no_diff_inputs)
- others, params = flax.core.pop(variables, 'params')
+ others, params = flax.core.pop(variables, "params")
del variables
return layer, params, others
def _loss_fn(self, diff_xs, no_diff_xs, params, others, model):
- variables = {'params': params, **others}
+ variables = {"params": params, **others}
output = model.apply(variables, *diff_xs, *no_diff_xs, rngs=self.apply_rng)
return jnp.mean(output, dtype=jnp.float32).astype(output.dtype)
@@ -259,15 +274,18 @@ class BaseRunner:
)
_, fp8_meta_grad = flax.core.pop(tmp_grad[0], FP8Helper.FP8_COLLECTION_NAME)
test_others = FP8Helper.update_collections(
- {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others)
+ {FP8Helper.FP8_COLLECTION_NAME: fp8_meta_grad}, test_others
+ )
del tmp_grad, fp8_meta_grad
grad_fn = jax.value_and_grad(self._loss_fn, argnums=(0, 2), has_aux=False)
- ref_out, (ref_dgrads, ref_wgrads) = grad_fn(inputs, ref_masks, ref_params, ref_others,
- ref_layer)
- test_out, (test_dgrads, test_wgrads) = grad_fn(inputs, test_masks, test_params, test_others,
- test_layer)
+ ref_out, (ref_dgrads, ref_wgrads) = grad_fn(
+ inputs, ref_masks, ref_params, ref_others, ref_layer
+ )
+ test_out, (test_dgrads, test_wgrads) = grad_fn(
+ inputs, test_masks, test_params, test_others, test_layer
+ )
assert_allclose(ref_out, test_out, rtol=rtol, atol=atol)
assert_tree_like_allclose(ref_dgrads, test_dgrads, rtol=rtol, atol=atol)
@@ -278,19 +296,20 @@ class BaseRunner:
class EncoderRunner(BaseRunner):
"""Encoder runner implementations"""
+
layer_type = TransformerLayerType.ENCODER
reference_layer = RefEncoderLayer
transformations = {
- 'attention/qkv/scale': 'pre_attention_layer_norm/scale',
- 'attention/qkv/ln_bias': 'pre_attention_layer_norm/ln_bias',
- 'attention/query/scale': 'pre_attention_layer_norm/scale',
- 'attention/query/ln_bias': 'pre_attention_layer_norm/ln_bias',
- 'mlp/wi_kernel': 'mlp/wi/kernel',
- 'mlp/wi_bias': 'mlp/wi/bias',
- 'mlp/wo_kernel': 'mlp/wo/kernel',
- 'mlp/wo_bias': 'mlp/wo/bias',
- 'mlp/scale': 'pre_mlp_layer_norm/scale',
- 'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
+ "attention/qkv/scale": "pre_attention_layer_norm/scale",
+ "attention/qkv/ln_bias": "pre_attention_layer_norm/ln_bias",
+ "attention/query/scale": "pre_attention_layer_norm/scale",
+ "attention/query/ln_bias": "pre_attention_layer_norm/ln_bias",
+ "mlp/wi_kernel": "mlp/wi/kernel",
+ "mlp/wi_bias": "mlp/wi/bias",
+ "mlp/wo_kernel": "mlp/wo/kernel",
+ "mlp/wo_bias": "mlp/wo/bias",
+ "mlp/scale": "pre_mlp_layer_norm/scale",
+ "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
}
def generate_inputs(self, data_shape, dtype):
@@ -307,13 +326,13 @@ class EncoderRunner(BaseRunner):
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
- if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
+ if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
mask = causal_mask
else:
mask = padded_mask
ref_masks = (1 - mask,)
- test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
+ test_masks = (None, mask) # The second arg of Transformer is encoded tokens.
return inputs, (ref_masks, test_masks)
@@ -322,23 +341,24 @@ class DecoderRunner(BaseRunner):
"""
Decoder runner implementations
"""
+
layer_type = TransformerLayerType.DECODER
reference_layer = RefDecoderLayer
transformations = {
- 'encoder_decoder_attention/qkv/scale': 'pre_cross_attention_layer_norm/scale',
- 'encoder_decoder_attention/qkv/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
- 'encoder_decoder_attention/query/scale': 'pre_cross_attention_layer_norm/scale',
- 'encoder_decoder_attention/query/ln_bias': 'pre_cross_attention_layer_norm/ln_bias',
- 'self_attention/qkv/scale': 'pre_self_attention_layer_norm/scale',
- 'self_attention/qkv/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
- 'self_attention/query/scale': 'pre_self_attention_layer_norm/scale',
- 'self_attention/query/ln_bias': 'pre_self_attention_layer_norm/ln_bias',
- 'mlp/wi_kernel': 'mlp/wi/kernel',
- 'mlp/wi_bias': 'mlp/wi/bias',
- 'mlp/wo_kernel': 'mlp/wo/kernel',
- 'mlp/wo_bias': 'mlp/wo/bias',
- 'mlp/scale': 'pre_mlp_layer_norm/scale',
- 'mlp/ln_bias': 'pre_mlp_layer_norm/ln_bias',
+ "encoder_decoder_attention/qkv/scale": "pre_cross_attention_layer_norm/scale",
+ "encoder_decoder_attention/qkv/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
+ "encoder_decoder_attention/query/scale": "pre_cross_attention_layer_norm/scale",
+ "encoder_decoder_attention/query/ln_bias": "pre_cross_attention_layer_norm/ln_bias",
+ "self_attention/qkv/scale": "pre_self_attention_layer_norm/scale",
+ "self_attention/qkv/ln_bias": "pre_self_attention_layer_norm/ln_bias",
+ "self_attention/query/scale": "pre_self_attention_layer_norm/scale",
+ "self_attention/query/ln_bias": "pre_self_attention_layer_norm/ln_bias",
+ "mlp/wi_kernel": "mlp/wi/kernel",
+ "mlp/wi_bias": "mlp/wi/bias",
+ "mlp/wo_kernel": "mlp/wo/kernel",
+ "mlp/wo_bias": "mlp/wo/bias",
+ "mlp/scale": "pre_mlp_layer_norm/scale",
+ "mlp/ln_bias": "pre_mlp_layer_norm/ln_bias",
}
def generate_inputs(self, data_shape, dtype):
@@ -352,12 +372,14 @@ class DecoderRunner(BaseRunner):
data_rng = jax.random.PRNGKey(0)
data_rng_0, data_rng_1 = jax.random.split(data_rng, 2)
- inputs = (jax.random.normal(data_rng_0, data_shape,
- dtype), jax.random.normal(data_rng_1, data_shape, dtype))
+ inputs = (
+ jax.random.normal(data_rng_0, data_shape, dtype),
+ jax.random.normal(data_rng_1, data_shape, dtype),
+ )
padded_mask = jnp.zeros((batch, 1, seqlen, seqlen), dtype=jnp.uint8)
causal_mask = jnp.triu(jnp.ones((batch, 1, seqlen, seqlen), dtype=jnp.uint8), k=1)
- if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ['casual', 'padding_causal']:
+ if self.attrs[_KEY_OF_SELF_ATTN_MASK_TYPE] in ["casual", "padding_causal"]:
self_mask = causal_mask
else:
self_mask = padded_mask
@@ -368,27 +390,28 @@ class DecoderRunner(BaseRunner):
return inputs, (ref_masks, test_masks)
-@pytest.mark.parametrize('data_shape', DATA_SHAPE)
-@pytest.mark.parametrize('dtype', DTYPE)
-@pytest.mark.parametrize('attrs', ATTRS)
-class BaseTester():
+@pytest.mark.parametrize("data_shape", DATA_SHAPE)
+@pytest.mark.parametrize("dtype", DTYPE)
+@pytest.mark.parametrize("attrs", ATTRS)
+class BaseTester:
"""
Pytest interface to invoke the runner
"""
+
runner = BaseRunner
def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
- FP8Helper.finalize() # Ensure FP8 disabled.
+ FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-5, atol=7e-5)
def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
- FP8Helper.finalize() # Ensure FP8 disabled.
+ FP8Helper.finalize() # Ensure FP8 disabled.
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-5, atol=7e-5)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
+ @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test forward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
@@ -396,7 +419,7 @@ class BaseTester():
FP8Helper.finalize()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
+ @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_format):
"""Test backward with fp8 enabled"""
FP8Helper.initialize(fp8_format=fp8_format)
@@ -408,6 +431,7 @@ class TestEncoderLayer(BaseTester):
"""
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Encoder)
"""
+
runner = EncoderRunner
@@ -415,4 +439,5 @@ class TestDecoderLayer(BaseTester):
"""
Test transformer_engine.jax.flax.TransformerLayer(layer_type=Decoder)
"""
+
runner = DecoderRunner
diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py
index 453a8393ed5323fb8f631a393d81d8eecf574ab8..92a6c80028a2be38b8cfb5bb71ce86e9160276d5 100644
--- a/tests/jax/test_praxis_layers.py
+++ b/tests/jax/test_praxis_layers.py
@@ -37,13 +37,13 @@ from transformer_engine.jax.softmax import SoftmaxType
is_fp8_supported, reason = is_fp8_available()
-DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
+DATA_SHAPE = [(32, 128, 512), (32, 512, 512)] # (B, S, H)
DTYPE = [jnp.float32, jnp.bfloat16]
ENABLE_FP8 = [False, True]
FP8_FORMATS = [Format.E4M3, Format.HYBRID]
-@pytest.fixture(autouse=True, scope='module')
+@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
"""
Enable fused attn for hopper+ arch.
@@ -58,19 +58,16 @@ def enable_fused_attn():
def compare_dict(ref_fd, test_fd, rtol=1e-05, atol=1e-08):
for key in ref_fd:
- assert key in test_fd, \
- f"{key} not found in test dict {test_fd}"
- assert isinstance(test_fd[key], type(ref_fd[key])), \
- f"The data type is not match between ref and test " \
- f" Dict on {key=}"
+ assert key in test_fd, f"{key} not found in test dict {test_fd}"
+ assert isinstance(
+ test_fd[key], type(ref_fd[key])
+ ), f"The data type is not match between ref and test Dict on {key=}"
if isinstance(ref_fd[key], Dict):
compare_dict(ref_fd[key], test_fd[key], rtol, atol)
else:
- assert_allclose(ref_fd[key],
- test_fd[key],
- rtol=rtol,
- atol=atol,
- err_msg=f"{key=} is not close")
+ assert_allclose(
+ ref_fd[key], test_fd[key], rtol=rtol, atol=atol, err_msg=f"{key=} is not close"
+ )
class TestLayer:
@@ -105,9 +102,10 @@ class TestLayer:
lyr_name = self.get_layer_name()
- if 'params' in flax_variables:
- synced_praxis_variables['params'][lyr_name]['cld'] = \
- flax.core.unfreeze(flax_variables['params'])
+ if "params" in flax_variables:
+ synced_praxis_variables["params"][lyr_name]["cld"] = flax.core.unfreeze(
+ flax_variables["params"]
+ )
return synced_praxis_variables, flax_variables
@@ -116,23 +114,19 @@ class TestLayer:
lyr_name = self.get_layer_name()
- if 'params' in synced_praxis_grads:
- synced_praxis_grads['params'] = \
- synced_praxis_grads['params'][lyr_name]['cld']
+ if "params" in synced_praxis_grads:
+ synced_praxis_grads["params"] = synced_praxis_grads["params"][lyr_name]["cld"]
if FP8Helper.is_fp8_enabled():
- synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = \
- synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME][lyr_name]['cld']
+ synced_praxis_grads[FP8Helper.FP8_COLLECTION_NAME] = synced_praxis_grads[
+ FP8Helper.FP8_COLLECTION_NAME
+ ][lyr_name]["cld"]
return synced_praxis_grads, flax.core.unfreeze(flax_wgrads)
- def forward_backward_runner(self,
- data_shape,
- dtype,
- praxis_p,
- flax_cls,
- rtol=1e-05,
- atol=1e-08):
+ def forward_backward_runner(
+ self, data_shape, dtype, praxis_p, flax_cls, rtol=1e-05, atol=1e-08
+ ):
init_key = jax.random.PRNGKey(seed=1234)
test_inputs = self.input_getter(data_shape, dtype)
@@ -148,28 +142,33 @@ class TestLayer:
if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled():
- flax_variables, _ = flax.core.pop(flax_variables,
- FP8Helper.FP8_COLLECTION_NAME + "_axes")
+ flax_variables, _ = flax.core.pop(
+ flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
+ )
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
iter_times = 5 if FP8Helper.is_fp8_enabled() else 1
for _ in range(iter_times):
- praxis_loss, praxis_wgrads, praxis_dgrad = \
- TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
- flax_loss, flax_wgrads, flax_dgrad = \
- TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)
+ praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
+ praxis_layer, praxis_variables, *test_inputs
+ )
+ flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
+ flax_layer, flax_variables, *test_inputs
+ )
if FP8Helper.is_fp8_enabled():
- praxis_wgrads.pop('params')
+ praxis_wgrads.pop("params")
praxis_variables = update_collections(praxis_wgrads, praxis_variables)
- flax_wgrads, _ = flax.core.pop(flax_wgrads, 'params')
+ flax_wgrads, _ = flax.core.pop(flax_wgrads, "params")
flax_variables = update_collections(flax_wgrads, flax_variables)
- praxis_loss, praxis_wgrads, praxis_dgrad = \
- TestLayer.loss_and_grads(praxis_layer, praxis_variables, *test_inputs)
- flax_loss, flax_wgrads, flax_dgrad = \
- TestLayer.loss_and_grads(flax_layer, flax_variables, *test_inputs)
+ praxis_loss, praxis_wgrads, praxis_dgrad = TestLayer.loss_and_grads(
+ praxis_layer, praxis_variables, *test_inputs
+ )
+ flax_loss, flax_wgrads, flax_dgrad = TestLayer.loss_and_grads(
+ flax_layer, flax_variables, *test_inputs
+ )
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
assert_allclose(praxis_dgrad, flax_dgrad, rtol=rtol, atol=atol)
@@ -179,18 +178,13 @@ class TestLayer:
class LayerNormAttr:
- LN_TYPE = 'layernorm_type'
- ZERO_CEN = 'zero_centered_gamma'
- ATTRS = [{
- LN_TYPE: "layernorm",
- ZERO_CEN: False
- }, {
- LN_TYPE: "layernorm",
- ZERO_CEN: True
- }, {
- LN_TYPE: "rmsnorm",
- ZERO_CEN: False
- }]
+ LN_TYPE = "layernorm_type"
+ ZERO_CEN = "zero_centered_gamma"
+ ATTRS = [
+ {LN_TYPE: "layernorm", ZERO_CEN: False},
+ {LN_TYPE: "layernorm", ZERO_CEN: True},
+ {LN_TYPE: "rmsnorm", ZERO_CEN: False},
+ ]
class TestLayerNorm(TestLayer):
@@ -200,7 +194,7 @@ class TestLayerNorm(TestLayer):
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
- return 'layer_norm'
+ return "layer_norm"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
layernorm_type = attrs[LayerNormAttr.LN_TYPE]
@@ -209,63 +203,59 @@ class TestLayerNorm(TestLayer):
bias_init = WeightInit.Constant(0.0)
transpose_batch_sequence = False
- praxis_p = pax_fiddle.Config(LayerNorm,
- name='layer_norm',
- dtype=dtype,
- layernorm_type=layernorm_type,
- zero_centered_gamma=zero_centered_gamma,
- scale_init=scale_init,
- bias_init=bias_init,
- transpose_batch_sequence=transpose_batch_sequence)
- flax_cls = partial(flax_LayerNorm,
- layernorm_type=layernorm_type,
- zero_centered_gamma=zero_centered_gamma,
- scale_init=scale_init,
- bias_init=TransformerEngineBaseLayer.generate_params_init(
- "ln_bias", bias_init),
- dtype=dtype,
- transpose_batch_sequence=transpose_batch_sequence)
+ praxis_p = pax_fiddle.Config(
+ LayerNorm,
+ name="layer_norm",
+ dtype=dtype,
+ layernorm_type=layernorm_type,
+ zero_centered_gamma=zero_centered_gamma,
+ scale_init=scale_init,
+ bias_init=bias_init,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
+ flax_cls = partial(
+ flax_LayerNorm,
+ layernorm_type=layernorm_type,
+ zero_centered_gamma=zero_centered_gamma,
+ scale_init=scale_init,
+ bias_init=TransformerEngineBaseLayer.generate_params_init("ln_bias", bias_init),
+ dtype=dtype,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
return praxis_p, flax_cls
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', LayerNormAttr.ATTRS)
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", LayerNormAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class FusedSoftmaxAttr:
- SCALE_FACTOR = 'scale_factor'
- ST_TYPE = 'softmax_type'
- ATTRS = [{
- SCALE_FACTOR: 0.0,
- ST_TYPE: SoftmaxType.SCALED
- }, {
- SCALE_FACTOR: 0.0,
- ST_TYPE: SoftmaxType.SCALED_MASKED
- }, {
- SCALE_FACTOR: 0.0,
- ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED
- }]
+ SCALE_FACTOR = "scale_factor"
+ ST_TYPE = "softmax_type"
+ ATTRS = [
+ {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED},
+ {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_MASKED},
+ {SCALE_FACTOR: 0.0, ST_TYPE: SoftmaxType.SCALED_UPPER_TRIANG_MASKED},
+ ]
class TestFusedSoftmax(TestLayer):
def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
- return jax.random.normal(data_key, shape, dtype), \
- jnp.ones(shape, dtype=jnp.uint8) # Masks
+ return jax.random.normal(data_key, shape, dtype), jnp.ones(shape, dtype=jnp.uint8) # Masks
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
scale_factor = attrs[FusedSoftmaxAttr.SCALE_FACTOR]
softmax_type = attrs[FusedSoftmaxAttr.ST_TYPE]
- praxis_p = pax_fiddle.Config(FusedSoftmax,
- name='fused_softmax',
- scale_factor=scale_factor,
- softmax_type=softmax_type)
+ praxis_p = pax_fiddle.Config(
+ FusedSoftmax, name="fused_softmax", scale_factor=scale_factor, softmax_type=softmax_type
+ )
flax_cls = partial(Softmax, scale_factor=scale_factor, softmax_type=softmax_type)
return praxis_p, flax_cls
@@ -276,34 +266,28 @@ class TestFusedSoftmax(TestLayer):
def sync_wgrads(self, praxis_wgrads, flax_wgrads):
return praxis_wgrads, flax_wgrads
- @pytest.mark.parametrize('data_shape', [(32, 1, 128, 128), (32, 1, 512, 128)])
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', FusedSoftmaxAttr.ATTRS)
+ @pytest.mark.parametrize("data_shape", [(32, 1, 128, 128), (32, 1, 512, 128)])
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", FusedSoftmaxAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
- if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and \
- (data_shape[-2] != data_shape[-1]):
- pass # Skip, due to not support
+ if (attrs[FusedSoftmaxAttr.ST_TYPE] == SoftmaxType.SCALED_UPPER_TRIANG_MASKED) and (
+ data_shape[-2] != data_shape[-1]
+ ):
+ pass # Skip, due to not support
else:
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
class LinearAttr:
- FEATURE = 'features'
- USE_BIAS = 'use_bias'
- ATTRS = [{
- FEATURE: 512,
- USE_BIAS: False
- }, {
- FEATURE: 512,
- USE_BIAS: True
- }, {
- FEATURE: 1024,
- USE_BIAS: False
- }, {
- FEATURE: 1024,
- USE_BIAS: True
- }]
+ FEATURE = "features"
+ USE_BIAS = "use_bias"
+ ATTRS = [
+ {FEATURE: 512, USE_BIAS: False},
+ {FEATURE: 512, USE_BIAS: True},
+ {FEATURE: 1024, USE_BIAS: False},
+ {FEATURE: 1024, USE_BIAS: True},
+ ]
class TestLinear(TestLayer):
@@ -313,7 +297,7 @@ class TestLinear(TestLayer):
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
- return 'linear'
+ return "linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LinearAttr.FEATURE]
@@ -323,15 +307,17 @@ class TestLinear(TestLayer):
axis = -1
transpose_batch_sequence = False
- praxis_p = pax_fiddle.Config(Linear,
- name='linear',
- dtype=dtype,
- out_features=out_features,
- params_init=kernel_init,
- use_bias=use_bias,
- bias_init=bias_init,
- axis=axis,
- transpose_batch_sequence=transpose_batch_sequence)
+ praxis_p = pax_fiddle.Config(
+ Linear,
+ name="linear",
+ dtype=dtype,
+ out_features=out_features,
+ params_init=kernel_init,
+ use_bias=use_bias,
+ bias_init=bias_init,
+ axis=axis,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
flax_cls = partial(
DenseGeneral,
features=out_features,
@@ -340,29 +326,26 @@ class TestLinear(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis,
dtype=dtype,
- transpose_batch_sequence=transpose_batch_sequence)
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
return praxis_p, flax_cls
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', LinearAttr.ATTRS)
- @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
- def test_forward_backward_fp8(self,
- data_shape,
- dtype,
- attrs,
- fp8_format,
- rtol=1e-05,
- atol=1e-08):
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", LinearAttr.ATTRS)
+ @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
+ def test_forward_backward_fp8(
+ self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
+ ):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
@@ -371,54 +354,20 @@ class TestLinear(TestLayer):
class LayerNormLinearAttr:
- FEATURE = 'features'
- USE_BIAS = 'use_bias'
- ENABLE_LN = 'enable_layernorm'
- LN_TYPE = 'layernorm_type'
- ZERO_CEN = 'zero_centered_gamma'
- ATTRS = [{
- FEATURE: 512,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False
- }, {
- FEATURE: 512,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False
- }, {
- FEATURE: 512,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True
- }, {
- FEATURE: 512,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True
- }, {
- FEATURE: 512,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False
- }, {
- FEATURE: 512,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False
- }, {
- FEATURE: 512,
- USE_BIAS: True,
- ENABLE_LN: False,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False
- }]
+ FEATURE = "features"
+ USE_BIAS = "use_bias"
+ ENABLE_LN = "enable_layernorm"
+ LN_TYPE = "layernorm_type"
+ ZERO_CEN = "zero_centered_gamma"
+ ATTRS = [
+ {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
+ {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: False},
+ {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
+ {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "layernorm", ZERO_CEN: True},
+ {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
+ {FEATURE: 512, USE_BIAS: True, ENABLE_LN: True, LN_TYPE: "rmsnorm", ZERO_CEN: False},
+ {FEATURE: 512, USE_BIAS: True, ENABLE_LN: False, LN_TYPE: "layernorm", ZERO_CEN: False},
+ ]
class TestLayerNormLinear(TestLayer):
@@ -428,7 +377,7 @@ class TestLayerNormLinear(TestLayer):
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
- return 'ln_linear'
+ return "ln_linear"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
out_features = attrs[LayerNormLinearAttr.FEATURE]
@@ -441,18 +390,20 @@ class TestLayerNormLinear(TestLayer):
axis = -1
transpose_batch_sequence = False
- praxis_p = pax_fiddle.Config(LayerNormLinear,
- name='ln_linear',
- dtype=dtype,
- out_features=out_features,
- enable_layernorm=enable_layernorm,
- layernorm_type=layernorm_type,
- zero_centered_gamma=zero_centered_gamma,
- params_init=kernel_init,
- use_bias=use_bias,
- bias_init=bias_init,
- axis=axis,
- transpose_batch_sequence=transpose_batch_sequence)
+ praxis_p = pax_fiddle.Config(
+ LayerNormLinear,
+ name="ln_linear",
+ dtype=dtype,
+ out_features=out_features,
+ enable_layernorm=enable_layernorm,
+ layernorm_type=layernorm_type,
+ zero_centered_gamma=zero_centered_gamma,
+ params_init=kernel_init,
+ use_bias=use_bias,
+ bias_init=bias_init,
+ axis=axis,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
flax_cls = partial(
LayerNormDenseGeneral,
features=out_features,
@@ -464,29 +415,26 @@ class TestLayerNormLinear(TestLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
axis=axis,
dtype=dtype,
- transpose_batch_sequence=transpose_batch_sequence)
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
return praxis_p, flax_cls
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', LayerNormLinearAttr.ATTRS)
- @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
- def test_forward_backward_fp8(self,
- data_shape,
- dtype,
- attrs,
- fp8_format,
- rtol=1e-05,
- atol=1e-08):
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", LayerNormLinearAttr.ATTRS)
+ @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
+ def test_forward_backward_fp8(
+ self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
+ ):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
@@ -495,62 +443,70 @@ class TestLayerNormLinear(TestLayer):
class LayerNormMLPAttr:
- INTERMEDIATE_DIM = 'intermediate_dim'
- USE_BIAS = 'use_bias'
- ENABLE_LN = 'enable_layernorm'
- LN_TYPE = 'layernorm_type'
- ZERO_CEN = 'zero_centered_gamma'
- ACTIVATION = 'activations'
- ATTRS = [{
- INTERMEDIATE_DIM: 2048,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',)
- }, {
- INTERMEDIATE_DIM: 2048,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('relu',)
- }, {
- INTERMEDIATE_DIM: 2048,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',)
- }, {
- INTERMEDIATE_DIM: 2048,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear')
- }, {
- INTERMEDIATE_DIM: 2048,
- USE_BIAS: False,
- ENABLE_LN: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear')
- }, {
- INTERMEDIATE_DIM: 2048,
- USE_BIAS: True,
- ENABLE_LN: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('silu', 'linear')
- }, {
- INTERMEDIATE_DIM: 2048,
- USE_BIAS: False,
- ENABLE_LN: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('silu', 'linear')
- }]
+ INTERMEDIATE_DIM = "intermediate_dim"
+ USE_BIAS = "use_bias"
+ ENABLE_LN = "enable_layernorm"
+ LN_TYPE = "layernorm_type"
+ ZERO_CEN = "zero_centered_gamma"
+ ACTIVATION = "activations"
+ ATTRS = [
+ {
+ INTERMEDIATE_DIM: 2048,
+ USE_BIAS: True,
+ ENABLE_LN: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ },
+ {
+ INTERMEDIATE_DIM: 2048,
+ USE_BIAS: True,
+ ENABLE_LN: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("relu",),
+ },
+ {
+ INTERMEDIATE_DIM: 2048,
+ USE_BIAS: True,
+ ENABLE_LN: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ },
+ {
+ INTERMEDIATE_DIM: 2048,
+ USE_BIAS: True,
+ ENABLE_LN: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ },
+ {
+ INTERMEDIATE_DIM: 2048,
+ USE_BIAS: False,
+ ENABLE_LN: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ },
+ {
+ INTERMEDIATE_DIM: 2048,
+ USE_BIAS: True,
+ ENABLE_LN: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("silu", "linear"),
+ },
+ {
+ INTERMEDIATE_DIM: 2048,
+ USE_BIAS: False,
+ ENABLE_LN: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("silu", "linear"),
+ },
+ ]
class TestLayerNormMLP(TestLayer):
@@ -560,7 +516,7 @@ class TestLayerNormMLP(TestLayer):
return (jax.random.normal(data_key, shape, dtype),)
def get_layer_name(self):
- return 'ln_mlp'
+ return "ln_mlp"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
intermediate_dim = attrs[LayerNormMLPAttr.INTERMEDIATE_DIM]
@@ -574,20 +530,22 @@ class TestLayerNormMLP(TestLayer):
axis = -1
transpose_batch_sequence = False
- praxis_p = pax_fiddle.Config(LayerNormMLP,
- name='ln_mlp',
- dtype=dtype,
- intermediate_dim=intermediate_dim,
- enable_layernorm=enable_layernorm,
- layernorm_type=layernorm_type,
- zero_centered_gamma=zero_centered_gamma,
- params_init=kernel_init,
- use_bias=use_bias,
- bias_init=bias_init,
- activations=activations,
- intermediate_dropout_rate=0.0,
- axis=axis,
- transpose_batch_sequence=transpose_batch_sequence)
+ praxis_p = pax_fiddle.Config(
+ LayerNormMLP,
+ name="ln_mlp",
+ dtype=dtype,
+ intermediate_dim=intermediate_dim,
+ enable_layernorm=enable_layernorm,
+ layernorm_type=layernorm_type,
+ zero_centered_gamma=zero_centered_gamma,
+ params_init=kernel_init,
+ use_bias=use_bias,
+ bias_init=bias_init,
+ activations=activations,
+ intermediate_dropout_rate=0.0,
+ axis=axis,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
flax_cls = partial(
flax_LayerNormMLP,
intermediate_dim=intermediate_dim,
@@ -601,29 +559,26 @@ class TestLayerNormMLP(TestLayer):
intermediate_dropout_rate=0.0,
axis=axis,
dtype=dtype,
- transpose_batch_sequence=transpose_batch_sequence)
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
return praxis_p, flax_cls
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', LayerNormMLPAttr.ATTRS)
- @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
- def test_forward_backward_fp8(self,
- data_shape,
- dtype,
- attrs,
- fp8_format,
- rtol=1e-05,
- atol=1e-08):
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", LayerNormMLPAttr.ATTRS)
+ @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
+ def test_forward_backward_fp8(
+ self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
+ ):
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
@@ -634,35 +589,40 @@ class TestLayerNormMLP(TestLayer):
class TestRelativePositionBias(TestLayer):
def get_layer_name(self):
- return 'relative_position_bias'
+ return "relative_position_bias"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
num_buckets = 32
max_distance = 128
num_attention_heads = 64
- rb_stddev = (num_attention_heads * num_buckets)**-0.5
+ rb_stddev = (num_attention_heads * num_buckets) ** -0.5
embedding_init = WeightInit.Gaussian(rb_stddev)
- praxis_p = pax_fiddle.Config(RelativePositionBiases,
- name='relative_position_bias',
- dtype=dtype,
- num_buckets=num_buckets,
- max_distance=max_distance,
- num_attention_heads=num_attention_heads,
- embedding_init=embedding_init)
- flax_cls = partial(flax_RelativePositionBiases,
- num_buckets=num_buckets,
- max_distance=max_distance,
- num_attention_heads=num_attention_heads,
- embedding_init=TransformerEngineBaseLayer.generate_params_init(
- "rel_embedding", embedding_init),
- dtype=dtype)
+ praxis_p = pax_fiddle.Config(
+ RelativePositionBiases,
+ name="relative_position_bias",
+ dtype=dtype,
+ num_buckets=num_buckets,
+ max_distance=max_distance,
+ num_attention_heads=num_attention_heads,
+ embedding_init=embedding_init,
+ )
+ flax_cls = partial(
+ flax_RelativePositionBiases,
+ num_buckets=num_buckets,
+ max_distance=max_distance,
+ num_attention_heads=num_attention_heads,
+ embedding_init=TransformerEngineBaseLayer.generate_params_init(
+ "rel_embedding", embedding_init
+ ),
+ dtype=dtype,
+ )
return praxis_p, flax_cls
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', [{}])
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", [{}])
def test_forward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
@@ -678,53 +638,64 @@ class TestRelativePositionBias(TestLayer):
if "params_axes" in flax_variables:
flax_variables, _ = flax.core.pop(flax_variables, "params_axes")
if FP8Helper.is_fp8_enabled():
- flax_variables, _ = flax.core.pop(flax_variables,
- FP8Helper.FP8_COLLECTION_NAME + "_axes")
+ flax_variables, _ = flax.core.pop(
+ flax_variables, FP8Helper.FP8_COLLECTION_NAME + "_axes"
+ )
praxis_variables, flax_variables = self.sync_variables(praxis_variables, flax_variables)
- praxis_loss= \
- TestLayer.loss(praxis_variables, *test_input, module=praxis_layer, mean_out=False)
- flax_loss = \
- TestLayer.loss(flax_variables, *test_input, module=flax_layer, mean_out=False)
+ praxis_loss = TestLayer.loss(
+ praxis_variables, *test_input, module=praxis_layer, mean_out=False
+ )
+ flax_loss = TestLayer.loss(
+ flax_variables, *test_input, module=flax_layer, mean_out=False
+ )
assert_allclose(praxis_loss, flax_loss, rtol=rtol, atol=atol)
class DotProductAttnAttr:
- ATTN_MASK_TYPE = 'attn_mask_type'
- NUM_GQA_GROUPS = 'num_gqa_groups'
- TRANSPOSE_BS = 'transpose_batch_sequence'
- SCALE_FACTOR = 'scale_factor'
- ATTRS = [{
- ATTN_MASK_TYPE: 'padding',
- TRANSPOSE_BS: True,
- SCALE_FACTOR: 0.125,
- }, {
- ATTN_MASK_TYPE: 'padding_causal',
- TRANSPOSE_BS: True,
- SCALE_FACTOR: 0.125,
- }, {
- ATTN_MASK_TYPE: 'causal',
- TRANSPOSE_BS: True,
- SCALE_FACTOR: 0.125,
- }, {
- ATTN_MASK_TYPE: 'padding',
- TRANSPOSE_BS: False,
- SCALE_FACTOR: 0.125,
- }, {
- ATTN_MASK_TYPE: 'padding_causal',
- TRANSPOSE_BS: False,
- SCALE_FACTOR: 2.,
- }, {
- ATTN_MASK_TYPE: 'causal',
- TRANSPOSE_BS: False,
- SCALE_FACTOR: 1.,
- }, {
- ATTN_MASK_TYPE: 'no_mask',
- TRANSPOSE_BS: False,
- SCALE_FACTOR: 1.,
- }]
+ ATTN_MASK_TYPE = "attn_mask_type"
+ NUM_GQA_GROUPS = "num_gqa_groups"
+ TRANSPOSE_BS = "transpose_batch_sequence"
+ SCALE_FACTOR = "scale_factor"
+ ATTRS = [
+ {
+ ATTN_MASK_TYPE: "padding",
+ TRANSPOSE_BS: True,
+ SCALE_FACTOR: 0.125,
+ },
+ {
+ ATTN_MASK_TYPE: "padding_causal",
+ TRANSPOSE_BS: True,
+ SCALE_FACTOR: 0.125,
+ },
+ {
+ ATTN_MASK_TYPE: "causal",
+ TRANSPOSE_BS: True,
+ SCALE_FACTOR: 0.125,
+ },
+ {
+ ATTN_MASK_TYPE: "padding",
+ TRANSPOSE_BS: False,
+ SCALE_FACTOR: 0.125,
+ },
+ {
+ ATTN_MASK_TYPE: "padding_causal",
+ TRANSPOSE_BS: False,
+ SCALE_FACTOR: 2.0,
+ },
+ {
+ ATTN_MASK_TYPE: "causal",
+ TRANSPOSE_BS: False,
+ SCALE_FACTOR: 1.0,
+ },
+ {
+ ATTN_MASK_TYPE: "no_mask",
+ TRANSPOSE_BS: False,
+ SCALE_FACTOR: 1.0,
+ },
+ ]
class TestDotProductAttn(TestLayer):
@@ -737,11 +708,12 @@ class TestDotProductAttn(TestLayer):
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
- *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
+ *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]),
+ mask,
]
def get_layer_name(self):
- return 'dot_product_attn'
+ return "dot_product_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
@@ -750,27 +722,31 @@ class TestDotProductAttn(TestLayer):
attn_mask_type = attrs[DotProductAttnAttr.ATTN_MASK_TYPE]
transpose_batch_sequence = attrs[DotProductAttnAttr.TRANSPOSE_BS]
- praxis_p = pax_fiddle.Config(DotProductAttention,
- name='mha',
- dtype=dtype,
- head_dim=head_dim,
- num_attention_heads=num_attention_heads,
- num_gqa_groups=num_gqa_groups,
- attn_mask_type=attn_mask_type,
- transpose_batch_sequence=transpose_batch_sequence)
- flax_cls = partial(flax_DotProductAttention,
- dtype=dtype,
- head_dim=head_dim,
- num_attention_heads=num_attention_heads,
- num_gqa_groups=num_gqa_groups,
- attn_mask_type=attn_mask_type,
- transpose_batch_sequence=transpose_batch_sequence)
+ praxis_p = pax_fiddle.Config(
+ DotProductAttention,
+ name="mha",
+ dtype=dtype,
+ head_dim=head_dim,
+ num_attention_heads=num_attention_heads,
+ num_gqa_groups=num_gqa_groups,
+ attn_mask_type=attn_mask_type,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
+ flax_cls = partial(
+ flax_DotProductAttention,
+ dtype=dtype,
+ head_dim=head_dim,
+ num_attention_heads=num_attention_heads,
+ num_gqa_groups=num_gqa_groups,
+ attn_mask_type=attn_mask_type,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
return praxis_p, flax_cls
- @pytest.mark.parametrize('data_shape', [(32, 128, 16, 64)])
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
+ @pytest.mark.parametrize("data_shape", [(32, 128, 16, 64)])
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
@@ -778,113 +754,125 @@ class TestDotProductAttn(TestLayer):
class MultiHeadAttnAttr:
- USE_BIAS = 'use_bias'
- LN_TYPE = 'layernorm_type'
- ATTN_MASK_TYPE = 'attn_mask_type'
- ZERO_CEN = 'zero_centered_gamma'
- NUM_ATTN_HEADS = 'num_attention_heads'
- NUM_GQA_GROUPS = 'num_gqa_groups'
- TRANSPOSE_BS = 'transpose_batch_sequence'
- ENABLE_ROPE = 'enable_rotary_pos_emb'
- ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
- LORA_SCOPE = 'low_rank_adaptation_scope'
- ATTRS = [{
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- ATTN_MASK_TYPE: 'padding',
- TRANSPOSE_BS: True,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- ATTN_MASK_TYPE: 'padding',
- TRANSPOSE_BS: False,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- ATTN_MASK_TYPE: 'padding',
- TRANSPOSE_BS: True,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- ATTN_MASK_TYPE: 'causal',
- TRANSPOSE_BS: False,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- ATTN_MASK_TYPE: 'causal',
- TRANSPOSE_BS: True,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- ATTN_MASK_TYPE: 'causal',
- TRANSPOSE_BS: False,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- NUM_ATTN_HEADS: 8,
- NUM_GQA_GROUPS: 4,
- ATTN_MASK_TYPE: 'causal',
- TRANSPOSE_BS: True,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ENABLE_ROPE: True,
- ROPE_GROUP_METHOD: 'consecutive',
- NUM_ATTN_HEADS: 8,
- NUM_GQA_GROUPS: 4,
- ATTN_MASK_TYPE: 'causal',
- TRANSPOSE_BS: False,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ENABLE_ROPE: True,
- ROPE_GROUP_METHOD: 'alternate',
- NUM_ATTN_HEADS: 8,
- NUM_GQA_GROUPS: 4,
- ATTN_MASK_TYPE: 'causal',
- TRANSPOSE_BS: True,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- ATTN_MASK_TYPE: 'padding',
- LORA_SCOPE: 'all',
- TRANSPOSE_BS: False,
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- ATTN_MASK_TYPE: 'causal',
- LORA_SCOPE: 'all',
- TRANSPOSE_BS: True,
- }]
+ USE_BIAS = "use_bias"
+ LN_TYPE = "layernorm_type"
+ ATTN_MASK_TYPE = "attn_mask_type"
+ ZERO_CEN = "zero_centered_gamma"
+ NUM_ATTN_HEADS = "num_attention_heads"
+ NUM_GQA_GROUPS = "num_gqa_groups"
+ TRANSPOSE_BS = "transpose_batch_sequence"
+ ENABLE_ROPE = "enable_rotary_pos_emb"
+ ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
+ LORA_SCOPE = "low_rank_adaptation_scope"
+ ATTRS = [
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ ATTN_MASK_TYPE: "padding",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ ATTN_MASK_TYPE: "padding",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ ATTN_MASK_TYPE: "padding",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ ATTN_MASK_TYPE: "causal",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ ATTN_MASK_TYPE: "causal",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ ATTN_MASK_TYPE: "causal",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ NUM_ATTN_HEADS: 8,
+ NUM_GQA_GROUPS: 4,
+ ATTN_MASK_TYPE: "causal",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: True,
+ ROPE_GROUP_METHOD: "consecutive",
+ NUM_ATTN_HEADS: 8,
+ NUM_GQA_GROUPS: 4,
+ ATTN_MASK_TYPE: "causal",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: True,
+ ROPE_GROUP_METHOD: "alternate",
+ NUM_ATTN_HEADS: 8,
+ NUM_GQA_GROUPS: 4,
+ ATTN_MASK_TYPE: "causal",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ ATTN_MASK_TYPE: "padding",
+ LORA_SCOPE: "all",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ ATTN_MASK_TYPE: "causal",
+ LORA_SCOPE: "all",
+ TRANSPOSE_BS: True,
+ },
+ ]
class TestMultiHeadAttn(TestLayer):
@@ -899,13 +887,16 @@ class TestMultiHeadAttn(TestLayer):
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]
def get_layer_name(self):
- return 'multi_head_attn'
+ return "multi_head_attn"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
head_dim = 64
num_attention_heads = 16
- num_gqa_groups = attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS] \
- if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs else None
+ num_gqa_groups = (
+ attrs[MultiHeadAttnAttr.NUM_GQA_GROUPS]
+ if MultiHeadAttnAttr.NUM_GQA_GROUPS in attrs
+ else None
+ )
layernorm_type = attrs[MultiHeadAttnAttr.LN_TYPE]
zero_centered_gamma = attrs[MultiHeadAttnAttr.ZERO_CEN]
kernel_init = WeightInit.Gaussian(1.0)
@@ -916,35 +907,37 @@ class TestMultiHeadAttn(TestLayer):
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
- low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, 'none')
+ low_rank_adaptation_scope = attrs.get(MultiHeadAttnAttr.LORA_SCOPE, "none")
fuse_qkv_params = True
transpose_batch_sequence = attrs[MultiHeadAttnAttr.TRANSPOSE_BS]
scale_attn_logits = False
scaled_query_init = True
float32_logits = False
- praxis_p = pax_fiddle.Config(MultiHeadAttention,
- name='mha',
- dtype=dtype,
- head_dim=head_dim,
- num_attention_heads=num_attention_heads,
- num_gqa_groups=num_gqa_groups,
- layernorm_type=layernorm_type,
- zero_centered_gamma=zero_centered_gamma,
- params_init=kernel_init,
- use_bias=use_bias,
- bias_init=bias_init,
- return_layernorm_output=return_layernorm_output,
- input_layernorm=input_layernorm,
- attn_mask_type=attn_mask_type,
- enable_rotary_pos_emb=enable_rotary_pos_emb,
- rotary_pos_emb_group_method=rotary_pos_emb_group_method,
- low_rank_adaptation_scope=low_rank_adaptation_scope,
- fuse_qkv_params=fuse_qkv_params,
- transpose_batch_sequence=transpose_batch_sequence,
- scale_attn_logits=scale_attn_logits,
- scaled_query_init=scaled_query_init,
- float32_logits=float32_logits)
+ praxis_p = pax_fiddle.Config(
+ MultiHeadAttention,
+ name="mha",
+ dtype=dtype,
+ head_dim=head_dim,
+ num_attention_heads=num_attention_heads,
+ num_gqa_groups=num_gqa_groups,
+ layernorm_type=layernorm_type,
+ zero_centered_gamma=zero_centered_gamma,
+ params_init=kernel_init,
+ use_bias=use_bias,
+ bias_init=bias_init,
+ return_layernorm_output=return_layernorm_output,
+ input_layernorm=input_layernorm,
+ attn_mask_type=attn_mask_type,
+ enable_rotary_pos_emb=enable_rotary_pos_emb,
+ rotary_pos_emb_group_method=rotary_pos_emb_group_method,
+ low_rank_adaptation_scope=low_rank_adaptation_scope,
+ fuse_qkv_params=fuse_qkv_params,
+ transpose_batch_sequence=transpose_batch_sequence,
+ scale_attn_logits=scale_attn_logits,
+ scaled_query_init=scaled_query_init,
+ float32_logits=float32_logits,
+ )
flax_cls = partial(
flax_MultiHeadAttention,
dtype=dtype,
@@ -966,30 +959,27 @@ class TestMultiHeadAttn(TestLayer):
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
scaled_query_init=scaled_query_init,
- float32_logits=float32_logits)
+ float32_logits=float32_logits,
+ )
return praxis_p, flax_cls
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', MultiHeadAttnAttr.ATTRS)
- @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
- def test_forward_backward_fp8(self,
- data_shape,
- dtype,
- attrs,
- fp8_format,
- rtol=1e-05,
- atol=1e-08):
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", MultiHeadAttnAttr.ATTRS)
+ @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
+ def test_forward_backward_fp8(
+ self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
+ ):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
@@ -998,252 +988,279 @@ class TestMultiHeadAttn(TestLayer):
class TransformerLayerAttr:
- USE_BIAS = 'use_bias'
- LN_TYPE = 'layernorm_type'
- ACTIVATION = 'activations'
- LYR_TYPE = 'layer_type'
- ZERO_CEN = 'zero_centered_gamma'
- TRANSPOSE_BS = 'transpose_batch_sequence'
- ENABLE_ROPE = 'enable_rotary_pos_emb'
- ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
- LORA_SCOPE = 'low_rank_adaptation_scope'
- ATTRS = [{
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('relu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear'),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear'),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear'),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear'),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False,
- LORA_SCOPE: 'all'
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear'),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear'),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear'),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: True
- }, {
- USE_BIAS: True,
- LN_TYPE: 'rmsnorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu', 'linear'),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('gelu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: True,
- ROPE_GROUP_METHOD: 'alternate',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('gelu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: True,
- ROPE_GROUP_METHOD: 'alternate',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('gelu',),
- LYR_TYPE: TransformerLayerType.ENCODER,
- ENABLE_ROPE: True,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: True,
- ACTIVATION: ('gelu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: True,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False
- }, {
- USE_BIAS: True,
- LN_TYPE: 'layernorm',
- ZERO_CEN: False,
- ACTIVATION: ('gelu',),
- LYR_TYPE: TransformerLayerType.DECODER,
- ENABLE_ROPE: False,
- ROPE_GROUP_METHOD: 'consecutive',
- TRANSPOSE_BS: False,
- LORA_SCOPE: 'all'
- }]
+ USE_BIAS = "use_bias"
+ LN_TYPE = "layernorm_type"
+ ACTIVATION = "activations"
+ LYR_TYPE = "layer_type"
+ ZERO_CEN = "zero_centered_gamma"
+ TRANSPOSE_BS = "transpose_batch_sequence"
+ ENABLE_ROPE = "enable_rotary_pos_emb"
+ ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
+ LORA_SCOPE = "low_rank_adaptation_scope"
+ ATTRS = [
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("relu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ LORA_SCOPE: "all",
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: True,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "rmsnorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu", "linear"),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("gelu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: True,
+ ROPE_GROUP_METHOD: "alternate",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("gelu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: True,
+ ROPE_GROUP_METHOD: "alternate",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("gelu",),
+ LYR_TYPE: TransformerLayerType.ENCODER,
+ ENABLE_ROPE: True,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: True,
+ ACTIVATION: ("gelu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: True,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ },
+ {
+ USE_BIAS: True,
+ LN_TYPE: "layernorm",
+ ZERO_CEN: False,
+ ACTIVATION: ("gelu",),
+ LYR_TYPE: TransformerLayerType.DECODER,
+ ENABLE_ROPE: False,
+ ROPE_GROUP_METHOD: "consecutive",
+ TRANSPOSE_BS: False,
+ LORA_SCOPE: "all",
+ },
+ ]
class TestTransformer(TestLayer):
@@ -1256,11 +1273,13 @@ class TestTransformer(TestLayer):
shape = (shape[1], shape[0]) + shape[2:]
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
- *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
+ *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]),
+ mask,
+ mask,
]
def get_layer_name(self):
- return 'transformerlayer'
+ return "transformerlayer"
def generate_praxis_p_and_flax_cls(self, dtype, attrs):
hidden_size = 512
@@ -1277,97 +1296,102 @@ class TestTransformer(TestLayer):
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
- low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, 'none')
+ low_rank_adaptation_scope = attrs.get(TransformerLayerAttr.LORA_SCOPE, "none")
enable_relative_embedding = True
- relative_embedding = pax_fiddle.Config(RelativePositionBiases,
- dtype=dtype,
- num_attention_heads=num_attention_heads)
+ relative_embedding = pax_fiddle.Config(
+ RelativePositionBiases, dtype=dtype, num_attention_heads=num_attention_heads
+ )
drop_path = 0.0
transpose_batch_sequence = attrs[TransformerLayerAttr.TRANSPOSE_BS]
rel_embedding_init = RelativePositionBiases.generate_embedding_init(
- relative_embedding.embedding_init, relative_embedding.num_attention_heads,
- relative_embedding.num_buckets)
+ relative_embedding.embedding_init,
+ relative_embedding.num_attention_heads,
+ relative_embedding.num_buckets,
+ )
relative_embedding_flax_module = flax_RelativePositionBiases(
num_buckets=relative_embedding.num_buckets,
max_distance=relative_embedding.max_distance,
num_attention_heads=relative_embedding.num_attention_heads,
embedding_init=TransformerEngineBaseLayer.generate_params_init(
- "rel_embedding", rel_embedding_init),
+ "rel_embedding", rel_embedding_init
+ ),
embedding_axes=relative_embedding.embedding_axes,
- dtype=relative_embedding.dtype)
-
- praxis_p = pax_fiddle.Config(TransformerLayer,
- name='transformer_layer',
- params_init=kernel_init,
- dtype=dtype,
- hidden_size=hidden_size,
- mlp_hidden_size=mlp_hidden_size,
- num_attention_heads=num_attention_heads,
- layernorm_type=layernorm_type,
- hidden_dropout=hidden_dropout,
- attention_dropout=attention_dropout,
- intermediate_dropout=intermediate_dropout,
- mlp_activations=mlp_activations,
- use_bias=use_bias,
- bias_init=bias_init,
- layer_type=layer_type,
- enable_relative_embedding=enable_relative_embedding,
- enable_rotary_pos_emb=enable_rotary_pos_emb,
- rotary_pos_emb_group_method=rotary_pos_emb_group_method,
- low_rank_adaptation_scope=low_rank_adaptation_scope,
- relative_embedding=relative_embedding,
- drop_path=drop_path,
- transpose_batch_sequence=transpose_batch_sequence)
- flax_cls = partial(flax_TransformerLayer,
- dtype=dtype,
- hidden_size=hidden_size,
- mlp_hidden_size=mlp_hidden_size,
- num_attention_heads=num_attention_heads,
- layernorm_type=layernorm_type,
- hidden_dropout=hidden_dropout,
- attention_dropout=attention_dropout,
- intermediate_dropout=intermediate_dropout,
- mlp_activations=mlp_activations,
- mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
- "mha_kernel", kernel_init),
- mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
- "mlp_kernel", kernel_init),
- use_bias=use_bias,
- bias_init=TransformerEngineBaseLayer.generate_params_init(
- "bias", bias_init),
- layer_type=layer_type,
- enable_rotary_pos_emb=enable_rotary_pos_emb,
- rotary_pos_emb_group_method=rotary_pos_emb_group_method,
- enable_relative_embedding=enable_relative_embedding,
- relative_embedding=relative_embedding_flax_module,
- low_rank_adaptation_scope=low_rank_adaptation_scope,
- drop_path=drop_path,
- transpose_batch_sequence=transpose_batch_sequence)
+ dtype=relative_embedding.dtype,
+ )
+
+ praxis_p = pax_fiddle.Config(
+ TransformerLayer,
+ name="transformer_layer",
+ params_init=kernel_init,
+ dtype=dtype,
+ hidden_size=hidden_size,
+ mlp_hidden_size=mlp_hidden_size,
+ num_attention_heads=num_attention_heads,
+ layernorm_type=layernorm_type,
+ hidden_dropout=hidden_dropout,
+ attention_dropout=attention_dropout,
+ intermediate_dropout=intermediate_dropout,
+ mlp_activations=mlp_activations,
+ use_bias=use_bias,
+ bias_init=bias_init,
+ layer_type=layer_type,
+ enable_relative_embedding=enable_relative_embedding,
+ enable_rotary_pos_emb=enable_rotary_pos_emb,
+ rotary_pos_emb_group_method=rotary_pos_emb_group_method,
+ low_rank_adaptation_scope=low_rank_adaptation_scope,
+ relative_embedding=relative_embedding,
+ drop_path=drop_path,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
+ flax_cls = partial(
+ flax_TransformerLayer,
+ dtype=dtype,
+ hidden_size=hidden_size,
+ mlp_hidden_size=mlp_hidden_size,
+ num_attention_heads=num_attention_heads,
+ layernorm_type=layernorm_type,
+ hidden_dropout=hidden_dropout,
+ attention_dropout=attention_dropout,
+ intermediate_dropout=intermediate_dropout,
+ mlp_activations=mlp_activations,
+ mha_kernel_init=TransformerEngineBaseLayer.generate_params_init(
+ "mha_kernel", kernel_init
+ ),
+ mlp_kernel_init=TransformerEngineBaseLayer.generate_params_init(
+ "mlp_kernel", kernel_init
+ ),
+ use_bias=use_bias,
+ bias_init=TransformerEngineBaseLayer.generate_params_init("bias", bias_init),
+ layer_type=layer_type,
+ enable_rotary_pos_emb=enable_rotary_pos_emb,
+ rotary_pos_emb_group_method=rotary_pos_emb_group_method,
+ enable_relative_embedding=enable_relative_embedding,
+ relative_embedding=relative_embedding_flax_module,
+ low_rank_adaptation_scope=low_rank_adaptation_scope,
+ drop_path=drop_path,
+ transpose_batch_sequence=transpose_batch_sequence,
+ )
return praxis_p, flax_cls
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('data_shape', DATA_SHAPE)
- @pytest.mark.parametrize('dtype', DTYPE)
- @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
- @pytest.mark.parametrize('fp8_format', FP8_FORMATS)
- def test_forward_backward_fp8(self,
- data_shape,
- dtype,
- attrs,
- fp8_format,
- rtol=1e-05,
- atol=1e-08):
+ @pytest.mark.parametrize("data_shape", DATA_SHAPE)
+ @pytest.mark.parametrize("dtype", DTYPE)
+ @pytest.mark.parametrize("attrs", TransformerLayerAttr.ATTRS)
+ @pytest.mark.parametrize("fp8_format", FP8_FORMATS)
+ def test_forward_backward_fp8(
+ self, data_shape, dtype, attrs, fp8_format, rtol=1e-05, atol=1e-08
+ ):
self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
diff --git a/tests/jax/test_sanity_import.py b/tests/jax/test_sanity_import.py
index cc07e2d63cf699b3f9d0c3a661cf2fa2f634d0a1..f47c2eb411cfffd1559f835281d1e8192a5c7ae6 100644
--- a/tests/jax/test_sanity_import.py
+++ b/tests/jax/test_sanity_import.py
@@ -3,4 +3,5 @@
# See LICENSE for license information.
import transformer_engine.jax
+
print("OK")
diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py
index 484728ed71bca2333f212c271a96656e373983ec..4581cdc39ee5503fe89a50ee7e7cf72d7187f03a 100644
--- a/tests/jax/test_sharding.py
+++ b/tests/jax/test_sharding.py
@@ -8,25 +8,25 @@ from transformer_engine.jax.flax import extend_logical_axis_rules
from transformer_engine.jax.sharding import global_shard_guard, MeshResource
LOGICAL_RULES = [
- [(('a1', None), ('a2', 'ma2')), False],
- [(('a1', None), ('a2', 'ma2'), ('a3', ('ma31', 'ma32'))), True],
- [(('a1', None), ('a2', 'ma2'), ('a3', 'ma31'), ('a3', 'ma32')), False],
- [(('a1', None), ('a2', 'ma2'), ('batch', 'batch_1200234')), True],
- [(('a1', None), ('a2', 'ma2'), ('a2', 'ma1'), ('batch', 'model'), ('batch', 'data')), True],
+ [(("a1", None), ("a2", "ma2")), False],
+ [(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True],
+ [(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False],
+ [(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True],
+ [(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True],
]
MeshS = [
MeshResource(),
- MeshResource('data', None),
- MeshResource(None, 'model'),
- MeshResource('data', 'model')
+ MeshResource("data", None),
+ MeshResource(None, "model"),
+ MeshResource("data", "model"),
]
class TestShardingSideAPI:
- @pytest.mark.parametrize('base_rules,need_assert', LOGICAL_RULES)
- @pytest.mark.parametrize('sr', MeshS)
+ @pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES)
+ @pytest.mark.parametrize("sr", MeshS)
def test_extend_logical_axis_rules(self, base_rules, need_assert, sr):
with global_shard_guard(sr):
try:
diff --git a/tests/jax/test_softmax.py b/tests/jax/test_softmax.py
index 3df268cbb1c380efa1c1d502162cd2ffa68c5609..0cff5955fa4f3f71a5948379b51a2ae381d5d844 100644
--- a/tests/jax/test_softmax.py
+++ b/tests/jax/test_softmax.py
@@ -43,6 +43,7 @@ class SoftmaxRunner:
"""
Softmax runner
"""
+
batch_size: int
max_seqlen_q: int
max_seqlen_kv: int
@@ -57,14 +58,22 @@ class SoftmaxRunner:
Jax softmax as the reference
"""
if mask is not None:
- logits += lax.select(mask > 0,
- jnp.full(mask.shape, -1e10).astype(logits.dtype),
- jnp.full(mask.shape, 0.).astype(logits.dtype))
+ logits += lax.select(
+ mask > 0,
+ jnp.full(mask.shape, -1e10).astype(logits.dtype),
+ jnp.full(mask.shape, 0.0).astype(logits.dtype),
+ )
return nn.softmax(logits * scale_factor)
def _is_support(self):
- return is_softmax_kernel_available(self.softmax_type, self.batch_size, self.num_heads,
- self.max_seqlen_q, self.max_seqlen_kv, self.dtype)
+ return is_softmax_kernel_available(
+ self.softmax_type,
+ self.batch_size,
+ self.num_heads,
+ self.max_seqlen_q,
+ self.max_seqlen_kv,
+ self.dtype,
+ )
def _setup_inputs(self):
key = jax.random.PRNGKey(0)
@@ -73,7 +82,7 @@ class SoftmaxRunner:
logits_shape = (self.batch_size, self.num_heads, self.max_seqlen_q, self.max_seqlen_kv)
mask_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
- self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.)
+ self.logits = jax.random.uniform(logits_key, logits_shape, self.dtype, -1.0)
match self.softmax_type:
case SoftmaxType.SCALED:
@@ -81,7 +90,7 @@ class SoftmaxRunner:
case SoftmaxType.SCALED_MASKED:
self.mask = jax.random.bernoulli(mask_key, shape=mask_shape).astype(jnp.uint8)
case SoftmaxType.SCALED_UPPER_TRIANG_MASKED:
- self.mask = (1. - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
+ self.mask = (1.0 - jnp.tril(jnp.ones_like(self.logits))).astype(jnp.uint8)
case _:
raise ValueError(f"Unknown {self.softmax_type=}")
@@ -108,18 +117,24 @@ class SoftmaxRunner:
args = [self.logits, self.mask]
kwargs = {
- 'scale_factor': self.scale_factor,
- 'softmax_type': self.softmax_type,
+ "scale_factor": self.scale_factor,
+ "softmax_type": self.softmax_type,
}
# Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation
jitted_primitive = jit(
- value_and_grad(lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs),
- (0,)))
+ value_and_grad(
+ lambda logits, *args: grad_func(softmax, self.logits, *args, **kwargs), (0,)
+ )
+ )
jitted_reference = jit(
value_and_grad(
- lambda logits, *args: grad_func(__class__.reference_softmax, self.logits, *args, **
- kwargs), (0,)))
+ lambda logits, *args: grad_func(
+ __class__.reference_softmax, self.logits, *args, **kwargs
+ ),
+ (0,),
+ )
+ )
primitive_out, (primitive_grad_logits,) = jitted_primitive(*args)
reference_out, (reference_grad_logits,) = jitted_reference(*args)
@@ -128,21 +143,30 @@ class SoftmaxRunner:
assert_allclose(primitive_grad_logits, reference_grad_logits, dtype=self.dtype)
-@pytest.mark.parametrize('b, s_q, s_kv, h', [
- pytest.param(8, 16, 16, 16, id='8-16-16-16'),
- pytest.param(8, 512, 512, 16, id='8-512-512-16'),
- pytest.param(2, 8, 16384, 8, id='2-8-16384-8')
-])
-@pytest.mark.parametrize('scale_factor', [0.125])
-@pytest.mark.parametrize('softmax_type', [
- pytest.param(SoftmaxType.SCALED, id='SCALED'),
- pytest.param(SoftmaxType.SCALED_MASKED, id='SCALED_MASKED'),
- pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id='SCALED_UPPER_TRIANG_MASKED')
-])
-@pytest.mark.parametrize('dtype', [
- pytest.param(jnp.bfloat16, id="BF16"),
- pytest.param(jnp.float16, id="FP16"),
-])
+@pytest.mark.parametrize(
+ "b, s_q, s_kv, h",
+ [
+ pytest.param(8, 16, 16, 16, id="8-16-16-16"),
+ pytest.param(8, 512, 512, 16, id="8-512-512-16"),
+ pytest.param(2, 8, 16384, 8, id="2-8-16384-8"),
+ ],
+)
+@pytest.mark.parametrize("scale_factor", [0.125])
+@pytest.mark.parametrize(
+ "softmax_type",
+ [
+ pytest.param(SoftmaxType.SCALED, id="SCALED"),
+ pytest.param(SoftmaxType.SCALED_MASKED, id="SCALED_MASKED"),
+ pytest.param(SoftmaxType.SCALED_UPPER_TRIANG_MASKED, id="SCALED_UPPER_TRIANG_MASKED"),
+ ],
+)
+@pytest.mark.parametrize(
+ "dtype",
+ [
+ pytest.param(jnp.bfloat16, id="BF16"),
+ pytest.param(jnp.float16, id="FP16"),
+ ],
+)
class TestSoftmax:
"""
Test transformer_engine.jax.softmax.softmax
diff --git a/tests/jax/utils.py b/tests/jax/utils.py
index ee933005d05a927d3c5f53386c9e37d4cb5c476f..798c2a82bae7136952bb0233a1f716ec20676c63 100644
--- a/tests/jax/utils.py
+++ b/tests/jax/utils.py
@@ -24,8 +24,9 @@ PRNGKey = Any
Shape = Tuple[int, ...]
DType = jnp.dtype
Array = Any
-PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision,
- lax.Precision]]
+PrecisionLike = Union[
+ None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]
+]
Initializer = Callable[[PRNGKey, Shape, DType], Array]
@@ -56,7 +57,7 @@ def _canonicalize_tuple(x):
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
"""Convert a string to an activation function."""
- if fn_or_string == 'linear':
+ if fn_or_string == "linear":
return lambda x: x
if isinstance(fn_or_string, str):
return getattr(nn, fn_or_string)
@@ -68,17 +69,18 @@ def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Calla
def combine_biases(*masks: Optional[Array]):
"""Combine attention biases.
- Args:
- *masks: set of attention bias arguments to combine, some can be None.
+ Args:
+ *masks: set of attention bias arguments to combine, some can be None.
- Returns:
- Combined mask, reduced by summation, returns None if no masks given.
- """
+ Returns:
+ Combined mask, reduced by summation, returns None if no masks given.
+ """
masks = [m for m in masks if m is not None]
if not masks:
return None
- assert all(map(lambda x: x.ndim == masks[0].ndim,
- masks)), (f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}')
+ assert all(
+ map(lambda x: x.ndim == masks[0].ndim, masks)
+ ), f"masks must have same rank: {tuple(map(lambda x: x.ndim, masks))}"
mask, *other_masks = masks
for other_mask in other_masks:
mask = mask + other_mask
@@ -88,7 +90,7 @@ def combine_biases(*masks: Optional[Array]):
class DotProductAttention(nn.Module):
transpose_batch_sequence: bool = True
scale_attn_logits: bool = True
- dropout_rate: float = 0.
+ dropout_rate: float = 0.0
dtype: DType = jnp.float32
float32_logits: bool = False
"""Computes dot-product attention given query, key, and value.
@@ -105,12 +107,14 @@ class DotProductAttention(nn.Module):
"""
@nn.compact
- def __call__(self,
- query: Array,
- key: Array,
- value: Array,
- bias: Optional[Array] = None,
- deterministic: bool = False):
+ def __call__(
+ self,
+ query: Array,
+ key: Array,
+ value: Array,
+ bias: Optional[Array] = None,
+ deterministic: bool = False,
+ ):
"""
Args:
query: queries for calculating attention with shape of `[batch, q_length,
@@ -127,14 +131,15 @@ class DotProductAttention(nn.Module):
Returns:
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
"""
- assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
+ assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
batch_dim = 1 if self.transpose_batch_sequence else 0
- assert query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim], (
- 'q, k, v batch dims must match.')
+ assert (
+ query.shape[batch_dim] == key.shape[batch_dim] == value.shape[batch_dim]
+ ), "q, k, v batch dims must match."
sequence_dim = 0 if self.transpose_batch_sequence else 1
- assert key.shape[sequence_dim] == value.shape[sequence_dim], 'k, v lengths must match.'
- assert key.shape[-2] == value.shape[-2], 'k, v num_heads must match.'
- assert query.shape[-1] == key.shape[-1], 'q, k head_dim must match.'
+ assert key.shape[sequence_dim] == value.shape[sequence_dim], "k, v lengths must match."
+ assert key.shape[-2] == value.shape[-2], "k, v num_heads must match."
+ assert query.shape[-1] == key.shape[-1], "q, k head_dim must match."
if self.scale_attn_logits:
head_dim = query.shape[-1]
@@ -153,9 +158,9 @@ class DotProductAttention(nn.Module):
grouped_query = query.reshape((*query.shape[:2], h_kv, group_size, query.shape[-1]))
if self.transpose_batch_sequence:
- attn_weights = jnp.einsum('qbhgd,kbhd->bhgqk', grouped_query, key)
+ attn_weights = jnp.einsum("qbhgd,kbhd->bhgqk", grouped_query, key)
else:
- attn_weights = jnp.einsum('bqhgd,bkhd->bhgqk', grouped_query, key)
+ attn_weights = jnp.einsum("bqhgd,bkhd->bhgqk", grouped_query, key)
# reshape back to normal DPA shape for bias/softmax/dropout
b, h, g, q, k = attn_weights_with_groups_shape = attn_weights.shape
@@ -170,37 +175,37 @@ class DotProductAttention(nn.Module):
attn_weights = jax_nn.softmax(attn_weights).astype(self.dtype)
# Apply attention dropout.
- if not deterministic and self.dropout_rate > 0.:
+ if not deterministic and self.dropout_rate > 0.0:
keep_prob = 1.0 - self.dropout_rate
# T5 broadcasts along the "length" dim, but unclear which one that
# corresponds to in positional dimensions here, assuming query dim.
dropout_shape = list(attn_weights.shape)
- dropout_rng = self.make_rng('dropout')
+ dropout_rng = self.make_rng("dropout")
keep = jax_random.bernoulli(dropout_rng, keep_prob, dropout_shape)
- multiplier = (keep.astype(attn_weights.dtype) /
- jnp.asarray(keep_prob, dtype=self.dtype))
+ multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=self.dtype)
attn_weights = attn_weights * multiplier
attn_weights = attn_weights.reshape(attn_weights_with_groups_shape)
# Take the linear combination of `value`.
if self.transpose_batch_sequence:
- return jnp.einsum('bhgqk,kbhd->qbhgd', attn_weights, value).reshape(query.shape)
+ return jnp.einsum("bhgqk,kbhd->qbhgd", attn_weights, value).reshape(query.shape)
- return jnp.einsum('bhgqk,bkhd->bqhgd', attn_weights, value).reshape(query.shape)
+ return jnp.einsum("bhgqk,bkhd->bqhgd", attn_weights, value).reshape(query.shape)
class DenseGeneral(nn.Module):
"""A linear transformation with flexible axes and FP8 support.
- Attributes:
- features: tuple with numbers of output features.
- axis: tuple with axes to apply the transformation on.
- dtype: the dtype of the computation (default: float32).
- kernel_init: initializer function for the weight matrix.
- use_bias: whether to add a bias to the output (default: False).
- bias_init: initializer function for the bias vector.
+ Attributes:
+ features: tuple with numbers of output features.
+ axis: tuple with axes to apply the transformation on.
+ dtype: the dtype of the computation (default: float32).
+ kernel_init: initializer function for the weight matrix.
+ use_bias: whether to add a bias to the output (default: False).
+ bias_init: initializer function for the bias vector.
"""
+
features: Union[Iterable[int], int]
axis: Union[Iterable[int], int] = -1
dtype: DType = jnp.float32
@@ -212,7 +217,7 @@ class DenseGeneral(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
- self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
+ self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__()
@nn.compact
@@ -233,21 +238,17 @@ class DenseGeneral(nn.Module):
kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features
kernel_param_shape = (np.prod([inputs.shape[ax] for ax in axis]), np.prod(features))
- kernel = nn_partitioning.param_with_axes('kernel',
- self.kernel_init,
- kernel_param_shape,
- jnp.float32,
- axes=self.kernel_axes)
+ kernel = nn_partitioning.param_with_axes(
+ "kernel", self.kernel_init, kernel_param_shape, jnp.float32, axes=self.kernel_axes
+ )
kernel = jnp.asarray(kernel, self.dtype)
kernel = jnp.reshape(kernel, kernel_shape)
if self.use_bias:
- bias = nn_partitioning.param_with_axes('bias',
- self.bias_init,
- self.features,
- jnp.float32,
- axes=self.bias_axes)
+ bias = nn_partitioning.param_with_axes(
+ "bias", self.bias_init, self.features, jnp.float32, axes=self.bias_axes
+ )
bias = bias.astype(self.dtype)
else:
bias = None
@@ -264,18 +265,19 @@ class DenseGeneral(nn.Module):
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block.
- Attributes:
- intermediate_dim: Shared dimension of hidden layers.
- activations: Type of activations for each layer. Each element is either
- 'linear', a string function name in flax.linen, or a function.
- kernel_init: Kernel function, passed to the dense layers.
- deterministic: Whether the dropout layers should be deterministic.
- intermediate_dropout_rate: Dropout rate used after the intermediate layers.
- dtype: Type for the dense layer.
- """
+ Attributes:
+ intermediate_dim: Shared dimension of hidden layers.
+ activations: Type of activations for each layer. Each element is either
+ 'linear', a string function name in flax.linen, or a function.
+ kernel_init: Kernel function, passed to the dense layers.
+ deterministic: Whether the dropout layers should be deterministic.
+ intermediate_dropout_rate: Dropout rate used after the intermediate layers.
+ dtype: Type for the dense layer.
+ """
+
transpose_batch_sequence: bool
intermediate_dim: int = 2048
- activations: Sequence[Union[str, Callable]] = ('relu',)
+ activations: Sequence[Union[str, Callable]] = ("relu",)
kernel_init: Initializer = None
intermediate_dropout_rate: float = 0.1
intermediate_dropout_dims: Sequence[int] = ()
@@ -285,7 +287,7 @@ class MlpBlock(nn.Module):
def __post_init__(self):
if self.kernel_init is None:
- self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
+ self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "truncated_normal")
super().__post_init__()
@nn.compact
@@ -296,49 +298,57 @@ class MlpBlock(nn.Module):
activations = []
if self.fuse_wi:
- dense_name = 'wi'
+ dense_name = "wi"
num_activations = len(self.activations)
- x = DenseGeneral(self.intermediate_dim * num_activations,
- dtype=self.dtype,
- kernel_init=self.kernel_init,
- kernel_axes=('embed', 'mlp'),
- use_bias=self.use_bias,
- bias_axes=('mlp'),
- name=dense_name)(inputs)
+ x = DenseGeneral(
+ self.intermediate_dim * num_activations,
+ dtype=self.dtype,
+ kernel_init=self.kernel_init,
+ kernel_axes=("embed", "mlp"),
+ use_bias=self.use_bias,
+ bias_axes="mlp",
+ name=dense_name,
+ )(inputs)
x = jnp.split(x, num_activations, axis=-1)
for idx, act_fn in enumerate(self.activations):
x_i = _convert_to_activation_function(act_fn)(x[idx])
activations.append(x_i)
else:
for idx, act_fn in enumerate(self.activations):
- dense_name = 'wi' if len(self.activations) == 1 else f'wi_{idx}'
- x = DenseGeneral(self.intermediate_dim,
- dtype=self.dtype,
- kernel_init=self.kernel_init,
- kernel_axes=('embed', 'mlp'),
- use_bias=self.use_bias,
- bias_axes=('mlp'),
- name=dense_name)(inputs)
+ dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
+ x = DenseGeneral(
+ self.intermediate_dim,
+ dtype=self.dtype,
+ kernel_init=self.kernel_init,
+ kernel_axes=("embed", "mlp"),
+ use_bias=self.use_bias,
+ bias_axes="mlp",
+ name=dense_name,
+ )(inputs)
x = _convert_to_activation_function(act_fn)(x)
activations.append(x)
# Take elementwise product of above intermediate activations.
x = functools.reduce(operator.mul, activations)
# Apply dropout and final dense output projection.
- x = nn.Dropout(rate=self.intermediate_dropout_rate,
- broadcast_dims=self.intermediate_dropout_dims)(
- x, deterministic=deterministic) # Broadcast along length.
+ x = nn.Dropout(
+ rate=self.intermediate_dropout_rate, broadcast_dims=self.intermediate_dropout_dims
+ )(
+ x, deterministic=deterministic
+ ) # Broadcast along length.
if self.transpose_batch_sequence:
- x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'mlp'))
+ x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "mlp"))
else:
- x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'mlp'))
- output = DenseGeneral(inputs.shape[-1],
- dtype=self.dtype,
- kernel_init=self.kernel_init,
- kernel_axes=('mlp', 'embed'),
- use_bias=self.use_bias,
- bias_axes=('embed'),
- name='wo')(x)
+ x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "mlp"))
+ output = DenseGeneral(
+ inputs.shape[-1],
+ dtype=self.dtype,
+ kernel_init=self.kernel_init,
+ kernel_axes=("mlp", "embed"),
+ use_bias=self.use_bias,
+ bias_axes="embed",
+ name="wo",
+ )(x)
return output
@@ -351,7 +361,7 @@ def apply_rotary_pos_emb_alternate(
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
- timescale = min_timescale * (max_timescale / min_timescale)**fraction
+ timescale = min_timescale * (max_timescale / min_timescale) ** fraction
timescale = jnp.expand_dims(timescale, axis=tuple(range(inputs.ndim - 1)))
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
@@ -386,7 +396,7 @@ def apply_rotary_pos_emb_consecutive(
inputs_shifted_left,
)
fraction = jnp.repeat(fraction, 2)
- timescale = min_timescale * (max_timescale / min_timescale)**fraction
+ timescale = min_timescale * (max_timescale / min_timescale) ** fraction
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
@@ -415,89 +425,96 @@ class MultiHeadAttention(nn.Module):
kernel_init: initializer for the kernel of the Dense layers.
float32_logits: bool, if True then compute logits in float32 to avoid
numerical issues with bfloat16.
- """
+ """
num_heads: int = 8
num_gqa_groups: int | None = None
head_dim: int = 64
transpose_batch_sequence: bool = True
dtype: DType = jnp.float32
- dropout_rate: float = 0.
+ dropout_rate: float = 0.0
kernel_init: Initializer = None
- float32_logits: bool = False # computes logits in float32 for stability.
+ float32_logits: bool = False # computes logits in float32 for stability.
scale_attn_logits: bool = False
scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False
- rotary_pos_emb_group_method: str = 'consecutive'
+ rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv: bool = True
use_bias: bool = False
def __post_init__(self):
if self.kernel_init is None:
- self.kernel_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal')
+ self.kernel_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal")
if self.num_gqa_groups is None:
self.num_gqa_groups = self.num_attention_heads
super().__post_init__()
@nn.compact
- def __call__(self,
- inputs_q: Array,
- inputs_kv: Array,
- mask: Optional[Array] = None,
- bias: Optional[Array] = None,
- *,
- decode: bool = False,
- deterministic: bool = False) -> Array:
+ def __call__(
+ self,
+ inputs_q: Array,
+ inputs_kv: Array,
+ mask: Optional[Array] = None,
+ bias: Optional[Array] = None,
+ *,
+ decode: bool = False,
+ deterministic: bool = False,
+ ) -> Array:
"""Applies multi-head dot product attention on the input data.
- Projects the inputs into multi-headed query, key, and value vectors,
- applies dot-product attention and project the results to an output vector.
+ Projects the inputs into multi-headed query, key, and value vectors,
+ applies dot-product attention and project the results to an output vector.
- There are two modes: decoding and non-decoding (e.g., training). The mode is
- determined by `decode` argument. For decoding, this method is called twice,
- first to initialize the cache and then for an actual decoding process. The
- two calls are differentiated by the presence of 'cached_key' in the variable
- dict. In the cache initialization stage, the cache variables are initialized
- as zeros and will be filled in the subsequent decoding process.
+ There are two modes: decoding and non-decoding (e.g., training). The mode is
+ determined by `decode` argument. For decoding, this method is called twice,
+ first to initialize the cache and then for an actual decoding process. The
+ two calls are differentiated by the presence of 'cached_key' in the variable
+ dict. In the cache initialization stage, the cache variables are initialized
+ as zeros and will be filled in the subsequent decoding process.
- In the cache initialization call, `inputs_q` has a shape [batch, length,
- q_features] and `inputs_kv`: [batch, length, kv_features]. During the
- incremental decoding stage, query, key and value all have the shape [batch,
- 1, qkv_features] corresponding to a single step.
+ In the cache initialization call, `inputs_q` has a shape [batch, length,
+ q_features] and `inputs_kv`: [batch, length, kv_features]. During the
+ incremental decoding stage, query, key and value all have the shape [batch,
+ 1, qkv_features] corresponding to a single step.
- Args:
- inputs_q: input queries of shape `[batch, q_length, q_features]`.
- inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
- mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
- bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
- decode: Whether to prepare and use an autoregressive cache.
- deterministic: Disables dropout if set to True.
+ Args:
+ inputs_q: input queries of shape `[batch, q_length, q_features]`.
+ inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
+ mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
+ bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
+ decode: Whether to prepare and use an autoregressive cache.
+ deterministic: Disables dropout if set to True.
- Returns:
- output of shape `[batch, length, q_features]`.
- """
- q_projection = functools.partial(DenseGeneral,
- axis=-1,
- features=self.num_heads * self.head_dim,
- kernel_axes=('embed', 'joined_kv'),
- use_bias=self.use_bias,
- bias_axes=('joined_kv'),
- dtype=self.dtype)
-
- kv_projection = functools.partial(DenseGeneral,
- axis=-1,
- features=self.num_gqa_groups * self.head_dim,
- kernel_axes=('embed', 'joined_kv'),
- use_bias=self.use_bias,
- bias_axes=('joined_kv'),
- dtype=self.dtype)
+ Returns:
+ output of shape `[batch, length, q_features]`.
+ """
+ q_projection = functools.partial(
+ DenseGeneral,
+ axis=-1,
+ features=self.num_heads * self.head_dim,
+ kernel_axes=("embed", "joined_kv"),
+ use_bias=self.use_bias,
+ bias_axes="joined_kv",
+ dtype=self.dtype,
+ )
+
+ kv_projection = functools.partial(
+ DenseGeneral,
+ axis=-1,
+ features=self.num_gqa_groups * self.head_dim,
+ kernel_axes=("embed", "joined_kv"),
+ use_bias=self.use_bias,
+ bias_axes="joined_kv",
+ dtype=self.dtype,
+ )
# NOTE: T5 does not explicitly rescale the attention logits by
# 1/sqrt(depth_kq)! This is folded into the initializers of the
# linear transformations, which is equivalent under Adafactor
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
- query_init = lambda *args: self.kernel_init(*args) / (depth_scaling
- if self.scaled_query_init else 1.0)
+ query_init = lambda *args: self.kernel_init(*args) / (
+ depth_scaling if self.scaled_query_init else 1.0
+ )
# Project inputs_q to multi-headed q/k/v
# dimensions are then [batch, length, num_heads, head_dim]
@@ -515,39 +532,45 @@ class MultiHeadAttention(nn.Module):
return jnp.concatenate([q_kernel, k_kernel, v_kernel], axis=-1, dtype=dtype)
- is_self_attn = (inputs_q is inputs_kv)
- is_gqa = (self.num_heads != self.num_gqa_groups)
- is_qkvpack = (is_self_attn and not is_gqa)
+ is_self_attn = inputs_q is inputs_kv
+ is_gqa = self.num_heads != self.num_gqa_groups
+ is_qkvpack = is_self_attn and not is_gqa
if self.fuse_qkv:
if is_qkvpack:
- qkv_proj = DenseGeneral(axis=-1,
- features=self.num_heads * self.head_dim * 3,
- kernel_axes=('embed', 'joined_kv'),
- kernel_init=qkv_init,
- use_bias=self.use_bias,
- bias_axes=('joined_kv'),
- name='qkv',
- dtype=self.dtype)(inputs_kv)
+ qkv_proj = DenseGeneral(
+ axis=-1,
+ features=self.num_heads * self.head_dim * 3,
+ kernel_axes=("embed", "joined_kv"),
+ kernel_init=qkv_init,
+ use_bias=self.use_bias,
+ bias_axes="joined_kv",
+ name="qkv",
+ dtype=self.dtype,
+ )(inputs_kv)
query, key, value = jnp.split(
- qkv_proj, [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
- axis=-1)
+ qkv_proj,
+ [self.num_heads * self.head_dim, self.num_heads * self.head_dim * 2],
+ axis=-1,
+ )
else:
- query = q_projection(kernel_init=query_init, name='query')(inputs_q)
-
- kv_proj = DenseGeneral(axis=-1,
- features=self.num_gqa_groups * self.head_dim * 2,
- kernel_axes=('embed', 'joined_kv'),
- kernel_init=self.kernel_init,
- use_bias=self.use_bias,
- bias_axes=('joined_kv'),
- name='kv',
- dtype=self.dtype)(inputs_kv)
+ query = q_projection(kernel_init=query_init, name="query")(inputs_q)
+
+ kv_proj = DenseGeneral(
+ axis=-1,
+ features=self.num_gqa_groups * self.head_dim * 2,
+ kernel_axes=("embed", "joined_kv"),
+ kernel_init=self.kernel_init,
+ use_bias=self.use_bias,
+ bias_axes="joined_kv",
+ name="kv",
+ dtype=self.dtype,
+ )(inputs_kv)
key, value = jnp.split(kv_proj, [self.num_gqa_groups * self.head_dim], axis=-1)
else:
- query = q_projection(kernel_init=query_init, name='query')(inputs_q)
- key = kv_projection(kernel_init=self.kernel_init, name='key')(inputs_kv)
- value = kv_projection(kernel_init=self.kernel_init, name='value')(inputs_kv)
+ query = q_projection(kernel_init=query_init, name="query")(inputs_q)
+ key = kv_projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
+ value = kv_projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
if self.enable_rotary_pos_emb:
batch_dim = 1 if self.transpose_batch_sequence else 0
@@ -556,7 +579,7 @@ class MultiHeadAttention(nn.Module):
q_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
k_position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
- if self.rotary_pos_emb_group_method == 'alternate':
+ if self.rotary_pos_emb_group_method == "alternate":
apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
else:
apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive
@@ -571,33 +594,40 @@ class MultiHeadAttention(nn.Module):
value = value.reshape((*value.shape[:2], self.num_gqa_groups, self.head_dim))
if self.transpose_batch_sequence:
- query = nn_partitioning.with_sharding_constraint(query,
- ('length', 'batch', 'heads', 'kv'))
- key = nn_partitioning.with_sharding_constraint(key, ('length', 'batch', 'heads', 'kv'))
- value = nn_partitioning.with_sharding_constraint(value,
- ('length', 'batch', 'heads', 'kv'))
+ query = nn_partitioning.with_sharding_constraint(
+ query, ("length", "batch", "heads", "kv")
+ )
+ key = nn_partitioning.with_sharding_constraint(key, ("length", "batch", "heads", "kv"))
+ value = nn_partitioning.with_sharding_constraint(
+ value, ("length", "batch", "heads", "kv")
+ )
else:
- query = nn_partitioning.with_sharding_constraint(query,
- ('batch', 'length', 'heads', 'kv'))
- key = nn_partitioning.with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv'))
- value = nn_partitioning.with_sharding_constraint(value,
- ('batch', 'length', 'heads', 'kv'))
+ query = nn_partitioning.with_sharding_constraint(
+ query, ("batch", "length", "heads", "kv")
+ )
+ key = nn_partitioning.with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
+ value = nn_partitioning.with_sharding_constraint(
+ value, ("batch", "length", "heads", "kv")
+ )
if decode:
# Detect if we're initializing by absence of existing cache data.
- is_initialized = self.has_variable('cache', 'cached_key')
+ is_initialized = self.has_variable("cache", "cached_key")
# The key and value have dimension [batch, length, num_heads, head_dim],
# but we cache them as [batch, num_heads, head_dim, length] as a TPU
# fusion optimization. This also enables the "scatter via one-hot
# broadcast" trick, which means we do a one-hot broadcast instead of a
# scatter/gather operations, resulting in a 3-4x speedup in practice.
swap_dims = lambda x: x[:-3] + tuple(x[i] for i in [-2, -1, -3])
- cached_key = self.variable('cache', 'cached_key', jnp.zeros, swap_dims(key.shape),
- key.dtype)
- cached_value = self.variable('cache', 'cached_value', jnp.zeros, swap_dims(value.shape),
- value.dtype)
- cache_index = self.variable('cache', 'cache_index',
- lambda: jnp.array(0, dtype=jnp.int32))
+ cached_key = self.variable(
+ "cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype
+ )
+ cached_value = self.variable(
+ "cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype
+ )
+ cache_index = self.variable(
+ "cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)
+ )
if is_initialized:
batch, num_heads, head_dim, length = cached_key.value.shape
# During fast autoregressive decoding, we feed one position at a time,
@@ -606,8 +636,9 @@ class MultiHeadAttention(nn.Module):
expected_shape = (batch, 1, num_heads, head_dim)
if expected_shape != query.shape:
raise ValueError(
- 'Autoregressive cache shape error, '
- f"expected query shape {expected_shape} instead got {query.shape}.")
+ "Autoregressive cache shape error, "
+ f"expected query shape {expected_shape} instead got {query.shape}."
+ )
# Create a OHE of the current index. NOTE: the index is increased below.
cur_index = cache_index.value
@@ -638,11 +669,13 @@ class MultiHeadAttention(nn.Module):
jnp.logical_not(mask),
jnp.broadcast_to(
jnp.arange(length) <= cur_index,
- # (1, 1, length) represent (head dim, query length, key length)
- # query length is 1 because during decoding we deal with one
- # index.
- # The same mask is applied to all batch elements and heads.
- (batch, 1, 1, length)))
+ # (1, 1, length) represent (head dim, query length, key length)
+ # query length is 1 because during decoding we deal with one
+ # index.
+ # The same mask is applied to all batch elements and heads.
+ (batch, 1, 1, length),
+ ),
+ )
# Grab the correct relative attention bias during decoding. This is
# only required during single step decoding.
@@ -650,15 +683,18 @@ class MultiHeadAttention(nn.Module):
# The bias is a full attention matrix, but during decoding we only
# have to take a slice of it.
# This is equivalent to bias[..., cur_index:cur_index+1, :].
- bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0),
- jnp.reshape(cur_index, (-1)), 1, -2)
+ bias = dynamic_vector_slice_in_dim(
+ jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2
+ )
# Convert the boolean attention mask to an attention bias.
if mask is not None:
# attention mask in the form of attention bias
- attention_bias = lax.select(mask > 0,
- jnp.full(mask.shape, 0.).astype(self.dtype),
- jnp.full(mask.shape, -1e10).astype(self.dtype))
+ attention_bias = lax.select(
+ mask > 0,
+ jnp.full(mask.shape, 0.0).astype(self.dtype),
+ jnp.full(mask.shape, -1e10).astype(self.dtype),
+ )
else:
attention_bias = None
@@ -667,41 +703,41 @@ class MultiHeadAttention(nn.Module):
attention_bias = combine_biases(attention_bias, bias)
# Apply attention.
- x = DotProductAttention(transpose_batch_sequence=self.transpose_batch_sequence,
- scale_attn_logits=self.scale_attn_logits,
- dropout_rate=self.dropout_rate,
- dtype=self.dtype,
- float32_logits=self.float32_logits)(query,
- key,
- value,
- bias=attention_bias,
- deterministic=deterministic)
+ x = DotProductAttention(
+ transpose_batch_sequence=self.transpose_batch_sequence,
+ scale_attn_logits=self.scale_attn_logits,
+ dropout_rate=self.dropout_rate,
+ dtype=self.dtype,
+ float32_logits=self.float32_logits,
+ )(query, key, value, bias=attention_bias, deterministic=deterministic)
x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))
if self.transpose_batch_sequence:
- x = nn_partitioning.with_sharding_constraint(x, ('length', 'batch', 'joined_kv'))
+ x = nn_partitioning.with_sharding_constraint(x, ("length", "batch", "joined_kv"))
else:
- x = nn_partitioning.with_sharding_constraint(x, ('batch', 'length', 'joined_kv'))
+ x = nn_partitioning.with_sharding_constraint(x, ("batch", "length", "joined_kv"))
# Back to the original inputs dimensions.
out = DenseGeneral(
- features=inputs_q.shape[-1], # output dim is set to the input dim.
+ features=inputs_q.shape[-1], # output dim is set to the input dim.
axis=-1,
kernel_init=self.kernel_init,
- kernel_axes=('joined_kv', 'embed'),
+ kernel_axes=("joined_kv", "embed"),
use_bias=self.use_bias,
- bias_axes=('embed'),
+ bias_axes="embed",
dtype=self.dtype,
- name='out')(x)
+ name="out",
+ )(x)
return out
class LayerNorm(nn.Module):
"""T5 Layer normalization operating on the last axis of the input data."""
+
epsilon: float = 1e-6
dtype: Any = jnp.float32
- layernorm_type: str = 'layernorm'
+ layernorm_type: str = "layernorm"
zero_centered_gamma: bool = False
scale_init: Initializer = None
bias_init: Initializer = nn.initializers.zeros
@@ -721,29 +757,27 @@ class LayerNorm(nn.Module):
x = jnp.asarray(x, jnp.float32)
features = x.shape[-1]
- scale = nn_partitioning.param_with_axes('scale',
- self.scale_init, (features,),
- jnp.float32,
- axes=('embed',))
+ scale = nn_partitioning.param_with_axes(
+ "scale", self.scale_init, (features,), jnp.float32, axes=("embed",)
+ )
scale = jnp.asarray(scale, self.dtype)
- if self.layernorm_type == 'layernorm':
+ if self.layernorm_type == "layernorm":
mean = jnp.mean(x, axis=-1, keepdims=True)
var = jnp.mean(jnp.square(x - mean), axis=-1, keepdims=True)
y = (x - mean) * lax.rsqrt(var + self.epsilon)
- bias = nn_partitioning.param_with_axes('ln_bias',
- self.bias_init, (features,),
- jnp.float32,
- axes=('embed',))
+ bias = nn_partitioning.param_with_axes(
+ "ln_bias", self.bias_init, (features,), jnp.float32, axes=("embed",)
+ )
bias = jnp.asarray(bias, self.dtype)
if not self.zero_centered_gamma:
z = y * scale + bias
else:
- z = y * (scale + 1.) + bias
+ z = y * (scale + 1.0) + bias
else:
- assert self.layernorm_type == 'rmsnorm'
+ assert self.layernorm_type == "rmsnorm"
assert not self.zero_centered_gamma
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
y = x * lax.rsqrt(mean2 + self.epsilon)
@@ -755,16 +789,17 @@ class LayerNorm(nn.Module):
class RelativePositionBiases(nn.Module):
"""Adds T5-style relative positional embeddings to the attention logits.
- Attributes:
- num_buckets: Number of buckets to bucket distances between key and query
- positions into.
- max_distance: Maximum distance before everything is lumped into the last
- distance bucket.
- num_heads: Number of heads in the attention layer. Each head will get a
- different relative position weighting.
- dtype: Type of arrays through this module.
- embedding_init: initializer for relative embedding table.
- """
+ Attributes:
+ num_buckets: Number of buckets to bucket distances between key and query
+ positions into.
+ max_distance: Maximum distance before everything is lumped into the last
+ distance bucket.
+ num_heads: Number of heads in the attention layer. Each head will get a
+ different relative position weighting.
+ dtype: Type of arrays through this module.
+ embedding_init: initializer for relative embedding table.
+ """
+
num_buckets: int
max_distance: int
num_heads: int
@@ -772,33 +807,32 @@ class RelativePositionBiases(nn.Module):
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
@staticmethod
- def _relative_position_bucket(relative_position,
- bidirectional=True,
- num_buckets=32,
- max_distance=128):
+ def _relative_position_bucket(
+ relative_position, bidirectional=True, num_buckets=32, max_distance=128
+ ):
"""Translate relative position to a bucket number for relative attention.
- The relative position is defined as memory_position - query_position, i.e.
- the distance in tokens from the attending position to the attended-to
- position. If bidirectional=False, then positive relative positions are
- invalid.
- We use smaller buckets for small absolute relative_position and larger
- buckets for larger absolute relative_positions. All relative
- positions >=max_distance map to the same bucket. All relative
- positions <=-max_distance map to the same bucket. This should allow for
- more graceful generalization to longer sequences than the model has been
- trained on.
+ The relative position is defined as memory_position - query_position, i.e.
+ the distance in tokens from the attending position to the attended-to
+ position. If bidirectional=False, then positive relative positions are
+ invalid.
+ We use smaller buckets for small absolute relative_position and larger
+ buckets for larger absolute relative_positions. All relative
+ positions >=max_distance map to the same bucket. All relative
+ positions <=-max_distance map to the same bucket. This should allow for
+ more graceful generalization to longer sequences than the model has been
+ trained on.
- Args:
- relative_position: an int32 array
- bidirectional: a boolean - whether the attention is bidirectional
- num_buckets: an integer
- max_distance: an integer
+ Args:
+ relative_position: an int32 array
+ bidirectional: a boolean - whether the attention is bidirectional
+ num_buckets: an integer
+ max_distance: an integer
- Returns:
- a Tensor with the same shape as relative_position, containing int32
- values in the range [0, num_buckets)
- """
+ Returns:
+ a Tensor with the same shape as relative_position, containing int32
+ values in the range [0, num_buckets)
+ """
ret = 0
n = -relative_position
if bidirectional:
@@ -811,8 +845,10 @@ class RelativePositionBiases(nn.Module):
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (
- np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps) /
- np.log(max_distance / max_exact) * (num_buckets - max_exact)).astype(np.int32)
+ np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
+ / np.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).astype(np.int32)
val_if_large = np.minimum(val_if_large, num_buckets - 1)
ret += np.where(is_small, n, val_if_large)
return ret
@@ -821,27 +857,31 @@ class RelativePositionBiases(nn.Module):
def __call__(self, qlen, klen, bidirectional=True):
"""Produce relative position embedding attention biases.
- Args:
- qlen: attention query length.
- klen: attention key length.
- bidirectional: whether to allow positive memory-query relative position
- embeddings.
+ Args:
+ qlen: attention query length.
+ klen: attention key length.
+ bidirectional: whether to allow positive memory-query relative position
+ embeddings.
- Returns:
- output: `(1, len, q_len, k_len)` attention bias
- """
+ Returns:
+ output: `(1, len, q_len, k_len)` attention bias
+ """
context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
- relative_position = memory_position - context_position # shape (qlen, klen)
- rp_bucket = self._relative_position_bucket(relative_position,
- bidirectional=bidirectional,
- num_buckets=self.num_buckets,
- max_distance=self.max_distance)
+ relative_position = memory_position - context_position # shape (qlen, klen)
+ rp_bucket = self._relative_position_bucket(
+ relative_position,
+ bidirectional=bidirectional,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
relative_attention_bias = nn_partitioning.param_with_axes(
- 'rel_embedding',
- self.embedding_init, (self.num_heads, self.num_buckets),
+ "rel_embedding",
+ self.embedding_init,
+ (self.num_heads, self.num_buckets),
jnp.float32,
- axes=('heads', 'relpos_buckets'))
+ axes=("heads", "relpos_buckets"),
+ )
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
# Instead of using a slow gather, we create a leading-dimension one-hot
@@ -855,9 +895,8 @@ class RelativePositionBiases(nn.Module):
values = lax.dot_general(
relative_attention_bias,
rp_bucket_one_hot,
- (
- ((1,), (0,)), # rhs, lhs contracting dims
- ((), ()))) # no batched dims
+ (((1,), (0,)), ((), ())), # rhs, lhs contracting dims
+ ) # no batched dims
# Add a singleton batch dimension.
# --> shape (1, num_heads, qlen, klen)
return values[jnp.newaxis, ...]
@@ -865,6 +904,7 @@ class RelativePositionBiases(nn.Module):
class EncoderLayer(nn.Module):
"""Transformer encoder layer."""
+
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
num_attention_heads: int = 8
@@ -880,17 +920,17 @@ class EncoderLayer(nn.Module):
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
- mlp_activations: Sequence[str] = ('relu',)
+ mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
- layernorm_type: str = 'layernorm'
+ layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
output_layernorm: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
- rotary_pos_emb_group_method: str = 'consecutive'
+ rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
@@ -903,20 +943,21 @@ class EncoderLayer(nn.Module):
@nn.compact
def __call__(self, inputs, encoder_mask=None, deterministic=False):
- del self.self_attn_mask_type # dummy, just align to TE's impl
+ del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
if self.enable_relative_embedding:
if self.relative_embedding is None:
- rel_emb = RelativePositionBiases(num_buckets=32,
- max_distance=128,
- num_heads=self.num_attention_heads,
- dtype=self.dtype,
- embedding_init=nn.initializers.variance_scaling(
- 1.0, 'fan_avg', 'uniform'),
- name='relpos_bias')
+ rel_emb = RelativePositionBiases(
+ num_buckets=32,
+ max_distance=128,
+ num_heads=self.num_attention_heads,
+ dtype=self.dtype,
+ embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
+ name="relpos_bias",
+ )
else:
rel_emb = self.relative_embedding
encoder_bias = rel_emb(inputs.shape[sequence_dim], inputs.shape[sequence_dim], True)
@@ -928,11 +969,13 @@ class EncoderLayer(nn.Module):
if not self.output_layernorm:
# Attention block.
- x = LayerNorm(layernorm_type=self.layernorm_type,
- epsilon=self.layernorm_epsilon,
- zero_centered_gamma=self.zero_centered_gamma,
- dtype=self.dtype,
- name="pre_attention_layer_norm")(inputs)
+ x = LayerNorm(
+ layernorm_type=self.layernorm_type,
+ epsilon=self.layernorm_epsilon,
+ zero_centered_gamma=self.zero_centered_gamma,
+ dtype=self.dtype,
+ name="pre_attention_layer_norm",
+ )(inputs)
if self.apply_residual_connection_post_layernorm:
residual = x
@@ -940,39 +983,41 @@ class EncoderLayer(nn.Module):
x = inputs
# [batch, length, emb_dim] -> [batch, length, emb_dim]
- x = MultiHeadAttention(num_heads=self.num_attention_heads,
- num_gqa_groups=self.num_gqa_groups,
- dtype=self.dtype,
- head_dim=self.head_dim,
- transpose_batch_sequence=self.transpose_batch_sequence,
- dropout_rate=self.attention_dropout,
- float32_logits=self.float32_attention_logits,
- scale_attn_logits=self.scale_attn_logits,
- scaled_query_init=self.scaled_query_init,
- fuse_qkv=self.fuse_qkv_params,
- enable_rotary_pos_emb=self.enable_rotary_pos_emb,
- rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
- use_bias=self.use_bias,
- name='attention')(x,
- x,
- encoder_mask,
- encoder_bias,
- deterministic=deterministic)
- x = nn.Dropout(rate=self.hidden_dropout,
- broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
+ x = MultiHeadAttention(
+ num_heads=self.num_attention_heads,
+ num_gqa_groups=self.num_gqa_groups,
+ dtype=self.dtype,
+ head_dim=self.head_dim,
+ transpose_batch_sequence=self.transpose_batch_sequence,
+ dropout_rate=self.attention_dropout,
+ float32_logits=self.float32_attention_logits,
+ scale_attn_logits=self.scale_attn_logits,
+ scaled_query_init=self.scaled_query_init,
+ fuse_qkv=self.fuse_qkv_params,
+ enable_rotary_pos_emb=self.enable_rotary_pos_emb,
+ rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
+ use_bias=self.use_bias,
+ name="attention",
+ )(x, x, encoder_mask, encoder_bias, deterministic=deterministic)
+ x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
+ x, deterministic=deterministic
+ )
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
- x = nn.Dropout(rate=self.drop_path,
- broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
+ x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
+ x, deterministic=deterministic
+ )
x = x + residual
# MLP block.
residual = x
- y = LayerNorm(layernorm_type=self.layernorm_type,
- epsilon=self.layernorm_epsilon,
- zero_centered_gamma=self.zero_centered_gamma,
- dtype=self.dtype,
- name='pre_mlp_layer_norm')(x)
+ y = LayerNorm(
+ layernorm_type=self.layernorm_type,
+ epsilon=self.layernorm_epsilon,
+ zero_centered_gamma=self.zero_centered_gamma,
+ dtype=self.dtype,
+ name="pre_mlp_layer_norm",
+ )(x)
if self.apply_residual_connection_post_layernorm:
residual = y
@@ -987,27 +1032,32 @@ class EncoderLayer(nn.Module):
use_bias=self.use_bias,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
- name='mlp',
+ name="mlp",
)(y, deterministic=deterministic)
- y = nn.Dropout(rate=self.hidden_dropout,
- broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic)
+ y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
+ y, deterministic=deterministic
+ )
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(y.shape, batch_dim)
- y = nn.Dropout(rate=self.drop_path,
- broadcast_dims=drop_path_shape)(y, deterministic=deterministic)
+ y = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
+ y, deterministic=deterministic
+ )
y = y + residual
if self.output_layernorm:
- y = LayerNorm(layernorm_type=self.layernorm_type,
- epsilon=self.layernorm_epsilon,
- zero_centered_gamma=self.zero_centered_gamma,
- dtype=self.dtype,
- name="output_layernorm")(y)
+ y = LayerNorm(
+ layernorm_type=self.layernorm_type,
+ epsilon=self.layernorm_epsilon,
+ zero_centered_gamma=self.zero_centered_gamma,
+ dtype=self.dtype,
+ name="output_layernorm",
+ )(y)
return y
class DecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder."""
+
enable_relative_embedding: bool = True
relative_embedding: nn.Module = None
num_attention_heads: int = 8
@@ -1023,17 +1073,17 @@ class DecoderLayer(nn.Module):
scale_attn_logits: bool = False
scaled_query_init: bool = True
mlp_dim: int = 2048
- mlp_activations: Sequence[str] = ('relu',)
+ mlp_activations: Sequence[str] = ("relu",)
use_bias: bool = False
dtype: Any = jnp.float32
apply_residual_connection_post_layernorm: bool = False
output_layernorm: bool = False
- layernorm_type: str = 'layernorm'
+ layernorm_type: str = "layernorm"
layernorm_epsilon: float = 1e-6
zero_centered_gamma: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
- rotary_pos_emb_group_method: str = 'consecutive'
+ rotary_pos_emb_group_method: str = "consecutive"
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = True
self_attn_bias_type: Any = None
@@ -1045,15 +1095,17 @@ class DecoderLayer(nn.Module):
super().__post_init__()
@nn.compact
- def __call__(self,
- inputs,
- encoded,
- decoder_mask=None,
- encoder_decoder_mask=None,
- deterministic=False,
- decode=False,
- max_decode_length=None):
- del self.self_attn_mask_type # dummy, just align to TE's impl
+ def __call__(
+ self,
+ inputs,
+ encoded,
+ decoder_mask=None,
+ encoder_decoder_mask=None,
+ deterministic=False,
+ decode=False,
+ max_decode_length=None,
+ ):
+ del self.self_attn_mask_type # dummy, just align to TE's impl
# Relative position embedding as attention biases.
sequence_dim = 0 if self.transpose_batch_sequence else 1
batch_dim = 1 - sequence_dim
@@ -1061,13 +1113,14 @@ class DecoderLayer(nn.Module):
if self.enable_relative_embedding:
l = max_decode_length if decode and max_decode_length else inputs.shape[sequence_dim]
if self.relative_embedding is None:
- rel_emb = RelativePositionBiases(num_buckets=32,
- max_distance=128,
- num_heads=self.num_attention_heads,
- dtype=self.dtype,
- embedding_init=nn.initializers.variance_scaling(
- 1.0, 'fan_avg', 'uniform'),
- name='relpos_bias')
+ rel_emb = RelativePositionBiases(
+ num_buckets=32,
+ max_distance=128,
+ num_heads=self.num_attention_heads,
+ dtype=self.dtype,
+ embedding_init=nn.initializers.variance_scaling(1.0, "fan_avg", "uniform"),
+ name="relpos_bias",
+ )
else:
rel_emb = self.relative_embedding
decoder_bias = rel_emb(l, l, False)
@@ -1079,11 +1132,13 @@ class DecoderLayer(nn.Module):
if not self.output_layernorm:
# Attention block.
- x = LayerNorm(layernorm_type=self.layernorm_type,
- epsilon=self.layernorm_epsilon,
- zero_centered_gamma=self.zero_centered_gamma,
- dtype=self.dtype,
- name="pre_self_attention_layer_norm")(inputs)
+ x = LayerNorm(
+ layernorm_type=self.layernorm_type,
+ epsilon=self.layernorm_epsilon,
+ zero_centered_gamma=self.zero_centered_gamma,
+ dtype=self.dtype,
+ name="pre_self_attention_layer_norm",
+ )(inputs)
if self.apply_residual_connection_post_layernorm:
residual = x
@@ -1091,71 +1146,74 @@ class DecoderLayer(nn.Module):
x = inputs
# Self-attention block
- x = MultiHeadAttention(num_heads=self.num_attention_heads,
- num_gqa_groups=self.num_gqa_groups,
- dtype=self.dtype,
- head_dim=self.head_dim,
- transpose_batch_sequence=self.transpose_batch_sequence,
- dropout_rate=self.attention_dropout,
- float32_logits=self.float32_attention_logits,
- scale_attn_logits=self.scale_attn_logits,
- scaled_query_init=self.scaled_query_init,
- enable_rotary_pos_emb=self.enable_rotary_pos_emb,
- rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
- fuse_qkv=self.fuse_qkv_params,
- use_bias=self.use_bias,
- name='self_attention')(x,
- x,
- decoder_mask,
- decoder_bias,
- deterministic=deterministic,
- decode=decode)
- x = nn.Dropout(rate=self.hidden_dropout,
- broadcast_dims=self.hidden_dropout_dims)(x, deterministic=deterministic)
+ x = MultiHeadAttention(
+ num_heads=self.num_attention_heads,
+ num_gqa_groups=self.num_gqa_groups,
+ dtype=self.dtype,
+ head_dim=self.head_dim,
+ transpose_batch_sequence=self.transpose_batch_sequence,
+ dropout_rate=self.attention_dropout,
+ float32_logits=self.float32_attention_logits,
+ scale_attn_logits=self.scale_attn_logits,
+ scaled_query_init=self.scaled_query_init,
+ enable_rotary_pos_emb=self.enable_rotary_pos_emb,
+ rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
+ fuse_qkv=self.fuse_qkv_params,
+ use_bias=self.use_bias,
+ name="self_attention",
+ )(x, x, decoder_mask, decoder_bias, deterministic=deterministic, decode=decode)
+ x = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
+ x, deterministic=deterministic
+ )
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(x.shape, batch_dim)
- x = nn.Dropout(rate=self.drop_path,
- broadcast_dims=drop_path_shape)(x, deterministic=deterministic)
+ x = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
+ x, deterministic=deterministic
+ )
x = x + residual
# Encoder-Decoder block.
residual = x
- y = LayerNorm(layernorm_type=self.layernorm_type,
- epsilon=self.layernorm_epsilon,
- zero_centered_gamma=self.zero_centered_gamma,
- dtype=self.dtype,
- name='pre_cross_attention_layer_norm')(x)
+ y = LayerNorm(
+ layernorm_type=self.layernorm_type,
+ epsilon=self.layernorm_epsilon,
+ zero_centered_gamma=self.zero_centered_gamma,
+ dtype=self.dtype,
+ name="pre_cross_attention_layer_norm",
+ )(x)
if self.apply_residual_connection_post_layernorm:
residual = y
- y = MultiHeadAttention(num_heads=self.num_attention_heads,
- num_gqa_groups=self.num_gqa_groups,
- dtype=self.dtype,
- head_dim=self.head_dim,
- transpose_batch_sequence=self.transpose_batch_sequence,
- dropout_rate=self.attention_dropout,
- float32_logits=self.float32_attention_logits,
- scale_attn_logits=self.scale_attn_logits,
- scaled_query_init=self.scaled_query_init,
- enable_rotary_pos_emb=self.enable_rotary_pos_emb,
- rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
- fuse_qkv=self.fuse_qkv_params,
- use_bias=self.use_bias,
- name='encoder_decoder_attention')(y,
- encoded,
- encoder_decoder_mask,
- deterministic=deterministic)
- y = nn.Dropout(rate=self.hidden_dropout,
- broadcast_dims=self.hidden_dropout_dims)(y, deterministic=deterministic)
+ y = MultiHeadAttention(
+ num_heads=self.num_attention_heads,
+ num_gqa_groups=self.num_gqa_groups,
+ dtype=self.dtype,
+ head_dim=self.head_dim,
+ transpose_batch_sequence=self.transpose_batch_sequence,
+ dropout_rate=self.attention_dropout,
+ float32_logits=self.float32_attention_logits,
+ scale_attn_logits=self.scale_attn_logits,
+ scaled_query_init=self.scaled_query_init,
+ enable_rotary_pos_emb=self.enable_rotary_pos_emb,
+ rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
+ fuse_qkv=self.fuse_qkv_params,
+ use_bias=self.use_bias,
+ name="encoder_decoder_attention",
+ )(y, encoded, encoder_decoder_mask, deterministic=deterministic)
+ y = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
+ y, deterministic=deterministic
+ )
y = y + residual
# MLP block.
residual = y
- z = LayerNorm(layernorm_type=self.layernorm_type,
- epsilon=self.layernorm_epsilon,
- zero_centered_gamma=self.zero_centered_gamma,
- dtype=self.dtype,
- name='pre_mlp_layer_norm')(y)
+ z = LayerNorm(
+ layernorm_type=self.layernorm_type,
+ epsilon=self.layernorm_epsilon,
+ zero_centered_gamma=self.zero_centered_gamma,
+ dtype=self.dtype,
+ name="pre_mlp_layer_norm",
+ )(y)
if self.apply_residual_connection_post_layernorm:
residual = z
z = MlpBlock(
@@ -1167,22 +1225,26 @@ class DecoderLayer(nn.Module):
use_bias=self.use_bias,
dtype=self.dtype,
fuse_wi=self.fuse_mlp_wi,
- name='mlp',
+ name="mlp",
)(z, deterministic=deterministic)
- z = nn.Dropout(rate=self.hidden_dropout,
- broadcast_dims=self.hidden_dropout_dims)(z, deterministic=deterministic)
+ z = nn.Dropout(rate=self.hidden_dropout, broadcast_dims=self.hidden_dropout_dims)(
+ z, deterministic=deterministic
+ )
if self.drop_path > 0.0:
drop_path_shape = _generate_drop_path_shape(z.shape, batch_dim)
- z = nn.Dropout(rate=self.drop_path,
- broadcast_dims=drop_path_shape)(z, deterministic=deterministic)
+ z = nn.Dropout(rate=self.drop_path, broadcast_dims=drop_path_shape)(
+ z, deterministic=deterministic
+ )
z = z + residual
if self.output_layernorm:
- z = LayerNorm(layernorm_type=self.layernorm_type,
- epsilon=self.layernorm_epsilon,
- zero_centered_gamma=self.zero_centered_gamma,
- dtype=self.dtype,
- name="output_layernorm")(z)
+ z = LayerNorm(
+ layernorm_type=self.layernorm_type,
+ epsilon=self.layernorm_epsilon,
+ zero_centered_gamma=self.zero_centered_gamma,
+ dtype=self.dtype,
+ name="output_layernorm",
+ )(z)
return z
@@ -1261,15 +1323,18 @@ def assert_tree_like_allclose(expected, actual, rtol=1e-05, atol=1e-08):
flatten_expected, _ = jax.tree_util.tree_flatten_with_path(expected)
flatten_actual, _ = jax.tree_util.tree_flatten_with_path(actual)
- for (expected_path, expected_value), (actual_path,
- actual_value) in zip(flatten_expected, flatten_actual):
+ for (expected_path, expected_value), (actual_path, actual_value) in zip(
+ flatten_expected, flatten_actual
+ ):
assert expected_path == actual_path
key_str = jax.tree_util.keystr(expected_path)
- assert_allclose(expected_value,
- actual_value,
- rtol=rtol,
- atol=atol,
- err_msg=f'Value of expected{key_str} and actual{key_str} is not close')
+ assert_allclose(
+ expected_value,
+ actual_value,
+ rtol=rtol,
+ atol=atol,
+ err_msg=f"Value of expected{key_str} and actual{key_str} is not close",
+ )
def dtype_tols(
@@ -1323,7 +1388,7 @@ def dtype_tols(
)
-def sync_params_values(dst, src, transformations, sep='/'):
+def sync_params_values(dst, src, transformations, sep="/"):
"""
This function will reconstuct a tree with dst's tree_def/shape and src's value.
transformations is a map that records the key mappings between dst and src.
diff --git a/tests/paddle/dist_launcher.py b/tests/paddle/dist_launcher.py
index eb8e8cb6627812a7e7ab2638f3f8f9b5f2de5bda..8c417b19303b6deb43bd7d6a4bc2a7ba80af962a 100644
--- a/tests/paddle/dist_launcher.py
+++ b/tests/paddle/dist_launcher.py
@@ -21,15 +21,15 @@ from paddle.distributed.utils.launch_utils import (
watch_local_trainers,
)
-__all__ = ['TestDistributed']
+__all__ = ["TestDistributed"]
def get_cluster_from_args(selected_gpus):
"""Get node information from selected GPUs"""
- cluster_node_ips = '127.0.0.1'
- node_ip = '127.0.0.1'
+ cluster_node_ips = "127.0.0.1"
+ node_ip = "127.0.0.1"
- node_ips = [x.strip() for x in cluster_node_ips.split(',')]
+ node_ips = [x.strip() for x in cluster_node_ips.split(",")]
node_ips.index(node_ip)
@@ -47,7 +47,7 @@ def get_cluster_from_args(selected_gpus):
def get_gpus(selected_gpus):
"""Get selected GPU string"""
- selected_gpus = [x.strip() for x in selected_gpus.split(',')]
+ selected_gpus = [x.strip() for x in selected_gpus.split(",")]
return selected_gpus
@@ -86,7 +86,7 @@ def start_local_trainers(
print(f"trainer proc env:{current_env}")
- if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
+ if os.getenv("WITH_COVERAGE", "OFF") == "ON":
cmd = "python -m coverage run --branch -p " + training_script
else:
cmd = "python -u " + training_script
@@ -95,7 +95,9 @@ def start_local_trainers(
fn = None
- proc = subprocess.Popen(cmd.split(" ") + training_script_args, env=current_env) # pylint: disable=consider-using-with
+ proc = subprocess.Popen(
+ cmd.split(" ") + training_script_args, env=current_env
+ ) # pylint: disable=consider-using-with
tp = TrainerProc()
tp.proc = proc
@@ -117,10 +119,10 @@ class TestDistributed(unittest.TestCase):
allocator_strategy="auto_growth",
):
"""Run target file in subprocesses"""
- if (not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0):
+ if not core.is_compiled_with_cuda() or core.get_cuda_device_count() == 0:
return
- selected_gpus = get_gpus('0,1')
+ selected_gpus = get_gpus("0,1")
cluster = None
pod = None
diff --git a/tests/paddle/parallel_tests/amax_reduction.py b/tests/paddle/parallel_tests/amax_reduction.py
index 093eb50263e76b6e40671125ad9fc4bb5d9ca387..c4605f121e89002e1c69a1872883c062f0949b18 100644
--- a/tests/paddle/parallel_tests/amax_reduction.py
+++ b/tests/paddle/parallel_tests/amax_reduction.py
@@ -27,7 +27,7 @@ class TestAmaxReduction(unittest.TestCase):
def setUp(self):
self.data_parallel_size = 2
self.init_dist_env()
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
paddle.set_default_dtype(self.global_dtype)
def init_dist_env(self):
@@ -83,5 +83,5 @@ class TestAmaxReduction(unittest.TestCase):
assert_allclose_across_ranks(layer2.fp8_meta["scaling_bwd"].scale_inv)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/parallel_tests/attention_tp.py b/tests/paddle/parallel_tests/attention_tp.py
index 31ce75f4c9429a3e5a7b87155ad193272b766234..e145f20b3901c24a1ff88b1191eb40c2ad120c9b 100644
--- a/tests/paddle/parallel_tests/attention_tp.py
+++ b/tests/paddle/parallel_tests/attention_tp.py
@@ -44,8 +44,8 @@ class TestAttentionTp(unittest.TestCase):
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
- self.mask_type = 'padding'
- self.global_dtype = 'bfloat16'
+ self.mask_type = "padding"
+ self.global_dtype = "bfloat16"
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
@@ -56,7 +56,7 @@ class TestAttentionTp(unittest.TestCase):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
- input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
+ input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
@@ -80,18 +80,20 @@ class TestAttentionTp(unittest.TestCase):
self.num_heads,
)
common_kwargs = {
- 'layernorm_epsilon': self.eps,
- 'attention_dropout': 0.0,
- 'attn_mask_type': self.mask_type,
- 'attention_type': 'self',
+ "layernorm_epsilon": self.eps,
+ "attention_dropout": 0.0,
+ "attn_mask_type": self.mask_type,
+ "attention_type": "self",
"tp_group": self.tp_group,
"input_layernorm": True,
}
- layer_tp = te.MultiHeadAttention(*common_args,
- **common_kwargs,
- set_parallel_mode=True,
- sequence_parallel=self.sequence_parallel)
+ layer_tp = te.MultiHeadAttention(
+ *common_args,
+ **common_kwargs,
+ set_parallel_mode=True,
+ sequence_parallel=self.sequence_parallel,
+ )
layer_single = te.MultiHeadAttention(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
@@ -102,12 +104,15 @@ class TestAttentionTp(unittest.TestCase):
# Due to the interleaved qkv layout, need to concat on num_head
# dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0
- assert [3 * self.hidden_size // self.world_size,
- self.hidden_size] == partial_weight.shape
+ assert [
+ 3 * self.hidden_size // self.world_size,
+ self.hidden_size,
+ ] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
- [3, local_num_head, -1, self.hidden_size])
+ [3, local_num_head, -1, self.hidden_size]
+ )
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
@@ -123,42 +128,47 @@ class TestAttentionTp(unittest.TestCase):
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
- elif partition_mode == 'column':
- total_weight = _get_total_weight(weight_src,
- tp_group=self.tp_group,
- axis=0,
- interleave=interleave)
- elif partition_mode == 'row':
+ elif partition_mode == "column":
+ total_weight = _get_total_weight(
+ weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
+ )
+ elif partition_mode == "row":
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
- assert weight_dst.shape == total_weight.shape, \
- f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
+ assert (
+ weight_dst.shape == total_weight.shape
+ ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
- copy_weight(layer_tp, layer_single, None, ['layernorm_qkv', 'ln_weight'])
- copy_weight(layer_tp, layer_single, 'column', ['layernorm_qkv', 'weight'], interleave=True)
- copy_weight(layer_tp, layer_single, 'row', ['proj', 'weight'])
+ copy_weight(layer_tp, layer_single, None, ["layernorm_qkv", "ln_weight"])
+ copy_weight(layer_tp, layer_single, "column", ["layernorm_qkv", "weight"], interleave=True)
+ copy_weight(layer_tp, layer_single, "row", ["proj", "weight"])
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
- optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
- parameters=layer_single.parameters())
+ optimizer_single = paddle.optimizer.SGD(
+ learning_rate=0.01, parameters=layer_single.parameters()
+ )
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5):
- inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
- self.global_dtype)
- mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
- dtype='bool')
- loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
- self.sequence_parallel)
- loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
- optimizer_single, self.fp8)
+ inp = paddle.uniform(
+ [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
+ )
+ mask = paddle.zeros(
+ shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
+ )
+ loss_tp, out_tp = self._train_one_step(
+ layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
+ )
+ loss_single, out_single = self._train_one_step(
+ layer_single, [inp, mask], optimizer_single, self.fp8
+ )
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
@@ -173,8 +183,8 @@ class TestAttentionTpFp8(TestAttentionTp):
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
- self.mask_type = 'padding'
- self.global_dtype = 'bfloat16'
+ self.mask_type = "padding"
+ self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
@@ -192,8 +202,8 @@ class TestAttentionSp(TestAttentionTp):
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
- self.mask_type = 'padding'
- self.global_dtype = 'bfloat16'
+ self.mask_type = "padding"
+ self.global_dtype = "bfloat16"
self.rtol = 5e-3
self.atol = 5e-3
self.eps = 1e-3
@@ -211,8 +221,8 @@ class TestAttentionSpFp8(TestAttentionTp):
self.num_heads = 16
self.q_seqlen = 128
self.kv_seqlen = 128
- self.mask_type = 'padding'
- self.global_dtype = 'bfloat16'
+ self.mask_type = "padding"
+ self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 1e-1
self.eps = 1e-3
@@ -220,5 +230,5 @@ class TestAttentionSpFp8(TestAttentionTp):
self.sequence_parallel = True
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/parallel_tests/group_sharding.py b/tests/paddle/parallel_tests/group_sharding.py
index 3489b040a50efa7bba7ffe3a505a1a8d9a2fb94e..11060be38e0cf7c6442a80d6e35037dce9e33a7b 100644
--- a/tests/paddle/parallel_tests/group_sharding.py
+++ b/tests/paddle/parallel_tests/group_sharding.py
@@ -8,7 +8,8 @@ import unittest
import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer import (
- DygraphShardingOptimizer,)
+ DygraphShardingOptimizer,
+)
from utils import assert_allclose, set_random_seed
import transformer_engine.paddle as te
@@ -25,7 +26,7 @@ class TestGroupSharding(unittest.TestCase):
def set_attr(self):
"""Set test configs"""
self.sharding_degree = 2
- self.global_dtype = 'float32'
+ self.global_dtype = "float32"
self.rtol = 1e-5
self.atol = 1e-5
self.batch_size = 16
@@ -57,11 +58,12 @@ class TestGroupSharding(unittest.TestCase):
optimizer = paddle.optimizer.AdamW(learning_rate=0.01, parameters=model.parameters())
group = fleet.get_hybrid_communicate_group().get_sharding_parallel_group()
- class ShardingLevel: # pylint: disable=too-few-public-methods,
+ class ShardingLevel: # pylint: disable=too-few-public-methods,
"""Paddle sharding options"""
- kStage1 = 'os'
- kStage2 = 'os_g'
- kStage3 = 'p_g_os'
+
+ kStage1 = "os"
+ kStage2 = "os_g"
+ kStage3 = "p_g_os"
level = ShardingLevel.kStage3 if stage == 3 else ShardingLevel.kStage2
model, optimizer, _ = paddle.distributed.sharding.group_sharded_parallel(
@@ -104,8 +106,9 @@ class TestGroupSharding(unittest.TestCase):
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
- assert len(optimizer_te.state_dict()) == 4, \
- "Expect each rank to hold 4 optimizer state entries."
+ assert (
+ len(optimizer_te.state_dict()) == 4
+ ), "Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage2(self):
"""Tests group sharding training"""
@@ -141,8 +144,9 @@ class TestGroupSharding(unittest.TestCase):
loss_pd = train_one_step(model_pd, inp, optimizer_pd)
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
- assert len(optimizer_te.state_dict()) == 4, \
- "Expect each rank to hold 4 optimizer state entries."
+ assert (
+ len(optimizer_te.state_dict()) == 4
+ ), "Expect each rank to hold 4 optimizer state entries."
def test_group_sharding_stage3(self):
"""Tests group sharding training"""
@@ -174,11 +178,11 @@ class TestGroupSharding(unittest.TestCase):
assert_allclose(loss_te, loss_pd, rtol=self.rtol, atol=self.atol)
for name, value in optimizer_te.state_dict().items():
- if name.endswith('w_0_moment1_0'):
- assert value.numel() == \
- self.in_channels * self.out_channels // self.sharding_degree, \
- "Expect optimizer state to be sharded across trainers."
+ if name.endswith("w_0_moment1_0"):
+ assert (
+ value.numel() == self.in_channels * self.out_channels // self.sharding_degree
+ ), "Expect optimizer state to be sharded across trainers."
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/parallel_tests/layernorm_linear_tp.py b/tests/paddle/parallel_tests/layernorm_linear_tp.py
index 93142a571d5adb39e6b99286df04043c2efa0332..02295a71da46cacfb0a8bf94c826d1ae17262c39 100644
--- a/tests/paddle/parallel_tests/layernorm_linear_tp.py
+++ b/tests/paddle/parallel_tests/layernorm_linear_tp.py
@@ -42,22 +42,22 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
- def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
+ def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
- assert split_input in ['none', 'column', 'row']
- if split_input == 'column':
+ assert split_input in ["none", "column", "row"]
+ if split_input == "column":
split_size = inp.shape[1] // self.world_size
- input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
- elif split_input == 'row':
+ input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
+ elif split_input == "row":
split_size = inp.shape[0] // self.world_size
- input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
+ input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
@@ -70,12 +70,12 @@ class TestLayerNormLinearTp(unittest.TestCase):
loss.backward()
optimizer.step()
optimizer.clear_grad()
- if split_input != 'none':
+ if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
- if split_input == 'column':
+ if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
- elif split_input == 'row':
+ elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
@@ -88,14 +88,14 @@ class TestLayerNormLinearTp(unittest.TestCase):
self.in_features,
self.out_features,
eps=self.eps,
- parallel_mode='column',
+ parallel_mode="column",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.LayerNormLinear(
self.in_features,
self.out_features,
eps=self.eps,
- backend='paddle',
+ backend="paddle",
)
# Get total weight
total_weight = []
@@ -104,8 +104,9 @@ class TestLayerNormLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True)
- assert_shape(layer_te.weight,
- [self.out_features // self.model_parallel_size, self.in_features])
+ assert_shape(
+ layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
+ )
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
@@ -121,8 +122,9 @@ class TestLayerNormLinearTp(unittest.TestCase):
layer_te,
inp,
optimizer_te,
- split_input='row' if self.sequence_parallel else 'none',
- gather_output=True)
+ split_input="row" if self.sequence_parallel else "none",
+ gather_output=True,
+ )
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
@@ -136,7 +138,7 @@ class TestLayerNormLinearTpFp8(TestLayerNormLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
@@ -152,7 +154,7 @@ class TestLayerNormLinearSp(TestLayerNormLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
@@ -168,7 +170,7 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
@@ -176,5 +178,5 @@ class TestLayerNormLinearSpFp8(TestLayerNormLinearTp):
self.sequence_parallel = True
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/parallel_tests/layernorm_mlp_tp.py b/tests/paddle/parallel_tests/layernorm_mlp_tp.py
index e5b391a2f0e3fb89f6700101996ee711a99f10ed..f23cfb9e3f33f4a8c4ce589a49b82166b0520b63 100644
--- a/tests/paddle/parallel_tests/layernorm_mlp_tp.py
+++ b/tests/paddle/parallel_tests/layernorm_mlp_tp.py
@@ -42,22 +42,22 @@ class TestLayerNormMLPTp(unittest.TestCase):
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
self.fp8 = False
self.sequence_parallel = False
- def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
+ def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
- assert split_input in ['none', 'column', 'row']
- if split_input == 'column':
+ assert split_input in ["none", "column", "row"]
+ if split_input == "column":
split_size = inp.shape[1] // self.world_size
- input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
- elif split_input == 'row':
+ input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
+ elif split_input == "row":
split_size = inp.shape[0] // self.world_size
- input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
+ input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
@@ -71,12 +71,12 @@ class TestLayerNormMLPTp(unittest.TestCase):
loss.backward()
optimizer.step()
optimizer.clear_grad()
- if split_input != 'none':
+ if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
- if split_input == 'column':
+ if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
- elif split_input == 'row':
+ elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
@@ -97,7 +97,7 @@ class TestLayerNormMLPTp(unittest.TestCase):
ffn_hidden_size=self.ffn_hidden_size,
eps=self.eps,
set_parallel_mode=False,
- backend='paddle',
+ backend="paddle",
)
def _get_total_weight(local_weight, tp_group, axis):
@@ -113,11 +113,15 @@ class TestLayerNormMLPTp(unittest.TestCase):
layer_pd.fc1_weight.copy_(total_fc1_weight.T, True)
layer_pd.fc2_weight.copy_(total_fc2_weight.T, True)
- assert_shape(layer_te.fc1_weight,
- [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size])
+ assert_shape(
+ layer_te.fc1_weight,
+ [self.ffn_hidden_size // self.model_parallel_size, self.hidden_size],
+ )
assert_shape(layer_te.fc1_bias, [self.ffn_hidden_size // self.model_parallel_size])
- assert_shape(layer_te.fc2_weight,
- [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size])
+ assert_shape(
+ layer_te.fc2_weight,
+ [self.hidden_size, self.ffn_hidden_size // self.model_parallel_size],
+ )
assert_shape(layer_te.fc2_bias, [self.hidden_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
@@ -133,8 +137,9 @@ class TestLayerNormMLPTp(unittest.TestCase):
layer_te,
inp,
optimizer_te,
- split_input='row' if self.sequence_parallel else 'none',
- gather_output=self.sequence_parallel)
+ split_input="row" if self.sequence_parallel else "none",
+ gather_output=self.sequence_parallel,
+ )
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
@@ -148,7 +153,7 @@ class TestLayerNormMLPTpFp8(TestLayerNormMLPTp):
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
@@ -164,7 +169,7 @@ class TestLayerNormMLPSp(TestLayerNormMLPTp):
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.eps = 1e-3
@@ -180,7 +185,7 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
self.batch_size = 16
self.hidden_size = 32
self.ffn_hidden_size = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.eps = 1e-3
@@ -188,5 +193,5 @@ class TestLayerNormMLPSpFp8(TestLayerNormMLPTp):
self.sequence_parallel = True
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/parallel_tests/linear_pp.py b/tests/paddle/parallel_tests/linear_pp.py
index 26db9dd97701e141653d46947122b5d1bc9130cc..0e7e90611e31612bdb7cb5847a3b13c329ec48b2 100644
--- a/tests/paddle/parallel_tests/linear_pp.py
+++ b/tests/paddle/parallel_tests/linear_pp.py
@@ -23,14 +23,14 @@ class TELinear(te.Linear):
"""To pass is_first_microbatch"""
def __init__(self, *args, **kwargs):
- assert 'accumulate_steps' in kwargs
- self.accumulate_steps = kwargs['accumulate_steps']
- del kwargs['accumulate_steps']
+ assert "accumulate_steps" in kwargs
+ self.accumulate_steps = kwargs["accumulate_steps"]
+ del kwargs["accumulate_steps"]
self._micro_batch_id = 0
super().__init__(*args, **kwargs)
def forward(self, *args, **kwargs):
- kwargs['is_first_microbatch'] = (self._micro_batch_id % self.accumulate_steps) == 0
+ kwargs["is_first_microbatch"] = (self._micro_batch_id % self.accumulate_steps) == 0
if paddle.is_grad_enabled() and self.training:
self._micro_batch_id += 1
return super().forward(*args, **kwargs)
@@ -39,14 +39,16 @@ class TELinear(te.Linear):
class TEPipelineModel(PipelineLayer):
"""Model for pipeline parallel test"""
- def __init__(self,
- in_features,
- hidden_features,
- weight_attrs,
- use_te=True,
- use_fp8=False,
- accumulate_steps=1,
- **kwargs):
+ def __init__(
+ self,
+ in_features,
+ hidden_features,
+ weight_attrs,
+ use_te=True,
+ use_fp8=False,
+ accumulate_steps=1,
+ **kwargs,
+ ):
self.in_features = in_features
self.hidden_features = hidden_features
self.fp8 = use_fp8
@@ -56,19 +58,23 @@ class TEPipelineModel(PipelineLayer):
Linear = TELinear if use_te else paddle.nn.Linear
extra_kwargs = {}
if use_te:
- extra_kwargs['accumulate_steps'] = accumulate_steps
+ extra_kwargs["accumulate_steps"] = accumulate_steps
model_desc = [
- LayerDesc(Linear,
- self.in_features,
- self.hidden_features,
- weight_attr=weight_attrs[0],
- **extra_kwargs),
- LayerDesc(Linear,
- self.hidden_features,
- self.in_features,
- weight_attr=weight_attrs[1],
- **extra_kwargs),
+ LayerDesc(
+ Linear,
+ self.in_features,
+ self.hidden_features,
+ weight_attr=weight_attrs[0],
+ **extra_kwargs,
+ ),
+ LayerDesc(
+ Linear,
+ self.hidden_features,
+ self.in_features,
+ weight_attr=weight_attrs[1],
+ **extra_kwargs,
+ ),
]
super().__init__(layers=model_desc, loss_fn=paddle.nn.CrossEntropyLoss(), **kwargs)
@@ -129,7 +135,7 @@ class TestLinearPipelineParallel(unittest.TestCase):
self.micro_batch_size = 16
self.in_features = 32
self.hidden_features = 64
- self.global_dtype = 'float32'
+ self.global_dtype = "float32"
self.rtol = 1e-5
self.atol = 1e-5
self.iter = 10
@@ -164,16 +170,18 @@ class TestLinearPipelineParallel(unittest.TestCase):
# Check if model is split across ranks as expected
for name, sublayer in pipe_model.named_sublayers():
- if name in ('_loss_fn', 'shared_layers'):
+ if name in ("_loss_fn", "shared_layers"):
continue
if self.rank == 0:
- assert tuple(sublayer.weight.shape) == weight1_np.T.shape, \
- f"Shape does not match, expect: {weight1_np.T.shape}, " \
+ assert tuple(sublayer.weight.shape) == weight1_np.T.shape, (
+ f"Shape does not match, expect: {weight1_np.T.shape}, "
f"actual: {tuple(sublayer.weight.shape)}"
+ )
elif self.rank == 1:
- assert tuple(sublayer.weight.shape) == weight2_np.T.shape, \
- f"Shape does not match, expect: {weight2_np.T.shape}, " \
+ assert tuple(sublayer.weight.shape) == weight2_np.T.shape, (
+ f"Shape does not match, expect: {weight2_np.T.shape}, "
f"actual: {tuple(sublayer.weight.shape)}"
+ )
standalone_model = StandaloneModel(
self.in_features,
@@ -182,8 +190,9 @@ class TestLinearPipelineParallel(unittest.TestCase):
)
optimizer_te = paddle.optimizer.SGD(learning_rate=0.1, parameters=pipe_model.parameters())
- optimizer_pd = paddle.optimizer.SGD(learning_rate=0.1,
- parameters=standalone_model.parameters())
+ optimizer_pd = paddle.optimizer.SGD(
+ learning_rate=0.1, parameters=standalone_model.parameters()
+ )
pipe_model = fleet.distributed_model(pipe_model)
optimizer_te = fleet.distributed_optimizer(optimizer_te)
@@ -196,8 +205,9 @@ class TestLinearPipelineParallel(unittest.TestCase):
return loss
for i in range(self.iter):
- inp = paddle.to_tensor(np.random.normal(size=[self.batch_size, self.in_features]),
- dtype=self.global_dtype)
+ inp = paddle.to_tensor(
+ np.random.normal(size=[self.batch_size, self.in_features]), dtype=self.global_dtype
+ )
label = paddle.to_tensor(np.random.randint(self.in_features, size=[self.batch_size, 1]))
loss_te = pipe_model.train_batch([inp, label], optimizer_te)
loss_pd = train_one_step(standalone_model, [inp, label], optimizer_pd)
@@ -214,12 +224,12 @@ class TestLinearPipelineParallelFP8(TestLinearPipelineParallel):
self.micro_batch_size = 16
self.in_features = 32
self.hidden_features = 64
- self.global_dtype = 'float32'
+ self.global_dtype = "float32"
self.rtol = 5e-2
self.atol = 5e-2
self.iter = 10
self.fp8 = True
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/parallel_tests/linear_tp.py b/tests/paddle/parallel_tests/linear_tp.py
index ebdabd06bc0c73f5007a5a789fbcfbd318fc443d..4a49474a37f34589a6187d4032635b86fdad85b8 100644
--- a/tests/paddle/parallel_tests/linear_tp.py
+++ b/tests/paddle/parallel_tests/linear_tp.py
@@ -42,21 +42,21 @@ class TestLinearTp(unittest.TestCase):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
self.sequence_parallel = False
- def _train_one_step(self, layer, inp, optimizer, split_input='none', gather_output=False):
+ def _train_one_step(self, layer, inp, optimizer, split_input="none", gather_output=False):
inp = paddle.to_tensor(inp, stop_gradient=True)
- assert split_input in ['none', 'column', 'row']
- if split_input == 'column':
+ assert split_input in ["none", "column", "row"]
+ if split_input == "column":
split_size = inp.shape[1] // self.world_size
- input_parallel = inp[:, split_size * self.rank:split_size * (self.rank + 1)]
- elif split_input == 'row':
+ input_parallel = inp[:, split_size * self.rank : split_size * (self.rank + 1)]
+ elif split_input == "row":
split_size = inp.shape[0] // self.world_size
- input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
+ input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
input_parallel.stop_gradient = False
@@ -69,12 +69,12 @@ class TestLinearTp(unittest.TestCase):
loss.backward()
optimizer.step()
optimizer.clear_grad()
- if split_input != 'none':
+ if split_input != "none":
grad_input = []
paddle.distributed.all_gather(grad_input, input_parallel.grad, group=self.tp_group)
- if split_input == 'column':
+ if split_input == "column":
grad_input = paddle.concat(grad_input, axis=1)
- elif split_input == 'row':
+ elif split_input == "row":
grad_input = paddle.concat(grad_input, axis=0)
else:
grad_input = input_parallel.grad
@@ -86,13 +86,13 @@ class TestLinearTp(unittest.TestCase):
layer_te = te.Linear(
self.in_features,
self.out_features,
- parallel_mode='column',
+ parallel_mode="column",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.Linear(
self.in_features,
self.out_features,
- backend='paddle',
+ backend="paddle",
)
# Get total weight
total_weight = []
@@ -101,8 +101,9 @@ class TestLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=0)
layer_pd.weight.copy_(total_weight.T, True)
- assert_shape(layer_te.weight,
- [self.out_features // self.model_parallel_size, self.in_features])
+ assert_shape(
+ layer_te.weight, [self.out_features // self.model_parallel_size, self.in_features]
+ )
assert_shape(layer_te.bias, [self.out_features // self.model_parallel_size])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
@@ -118,8 +119,9 @@ class TestLinearTp(unittest.TestCase):
layer_te,
inp,
optimizer_te,
- split_input='row' if self.sequence_parallel else 'none',
- gather_output=True)
+ split_input="row" if self.sequence_parallel else "none",
+ gather_output=True,
+ )
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
@@ -130,13 +132,13 @@ class TestLinearTp(unittest.TestCase):
layer_te = te.Linear(
self.in_features,
self.out_features,
- parallel_mode='row',
+ parallel_mode="row",
sequence_parallel=self.sequence_parallel,
)
layer_pd = te.Linear(
self.in_features,
self.out_features,
- backend='paddle',
+ backend="paddle",
)
# Get total weight
total_weight = []
@@ -145,8 +147,9 @@ class TestLinearTp(unittest.TestCase):
total_weight = paddle.concat(total_weight, axis=1)
layer_pd.weight.copy_(total_weight.T, True)
- assert_shape(layer_te.weight,
- [self.out_features, self.in_features // self.model_parallel_size])
+ assert_shape(
+ layer_te.weight, [self.out_features, self.in_features // self.model_parallel_size]
+ )
assert_shape(layer_te.bias, [self.out_features])
optimizer_te = paddle.optimizer.SGD(learning_rate=0.001, parameters=layer_te.parameters())
@@ -158,11 +161,13 @@ class TestLinearTp(unittest.TestCase):
for _ in range(5):
inp = paddle.uniform([self.batch_size, self.in_features], self.global_dtype)
with te.fp8_autocast(enabled=self.fp8):
- loss_tp, grad_input = self._train_one_step(layer_te,
- inp,
- optimizer_te,
- split_input='column',
- gather_output=self.sequence_parallel)
+ loss_tp, grad_input = self._train_one_step(
+ layer_te,
+ inp,
+ optimizer_te,
+ split_input="column",
+ gather_output=self.sequence_parallel,
+ )
loss_ref, grad_input_ref = self._train_one_step(layer_pd, inp, optimizer_pd)
assert_allclose(loss_tp, loss_ref, rtol=self.rtol, atol=self.atol)
assert_allclose(grad_input, grad_input_ref, rtol=self.rtol, atol=self.atol)
@@ -176,7 +181,7 @@ class TestLinearTpFP8(TestLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
@@ -191,7 +196,7 @@ class TestLinearSp(TestLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-3
self.atol = 1e-3
self.fp8 = False
@@ -206,12 +211,12 @@ class TestLinearSpFP8(TestLinearTp):
self.batch_size = 16
self.in_features = 32
self.out_features = 64
- self.global_dtype = 'bfloat16'
+ self.global_dtype = "bfloat16"
self.rtol = 1e-2
self.atol = 1e-2
self.fp8 = True
self.sequence_parallel = True
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/parallel_tests/transformer_tp.py b/tests/paddle/parallel_tests/transformer_tp.py
index 2853eda4d7277743b419ee6ec99bc73fd233afd6..5506be042ff11e09f7706bd13879faa898ac5c0b 100644
--- a/tests/paddle/parallel_tests/transformer_tp.py
+++ b/tests/paddle/parallel_tests/transformer_tp.py
@@ -45,9 +45,9 @@ class TestTransformerTp(unittest.TestCase):
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
- self.mask_type = 'padding'
- self.layer_type = 'encoder'
- self.global_dtype = 'bfloat16'
+ self.mask_type = "padding"
+ self.layer_type = "encoder"
+ self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
@@ -58,7 +58,7 @@ class TestTransformerTp(unittest.TestCase):
inp, mask = inp_list
if sequence_parallel:
split_size = inp.shape[0] // self.world_size
- input_parallel = inp[split_size * self.rank:split_size * (self.rank + 1), :]
+ input_parallel = inp[split_size * self.rank : split_size * (self.rank + 1), :]
else:
input_parallel = inp
with te.fp8_autocast(enabled=fp8_enabled):
@@ -83,16 +83,18 @@ class TestTransformerTp(unittest.TestCase):
self.num_heads,
]
common_kwargs = {
- 'layernorm_epsilon': self.eps,
- 'hidden_dropout': 0.0,
- 'attention_dropout': 0.0,
- 'self_attn_mask_type': self.mask_type,
- 'layer_type': self.layer_type,
+ "layernorm_epsilon": self.eps,
+ "hidden_dropout": 0.0,
+ "attention_dropout": 0.0,
+ "self_attn_mask_type": self.mask_type,
+ "layer_type": self.layer_type,
}
- layer_tp = te.TransformerLayer(*common_args,
- **common_kwargs,
- set_parallel_mode=True,
- sequence_parallel=self.sequence_parallel)
+ layer_tp = te.TransformerLayer(
+ *common_args,
+ **common_kwargs,
+ set_parallel_mode=True,
+ sequence_parallel=self.sequence_parallel,
+ )
layer_single = te.TransformerLayer(*common_args, **common_kwargs, set_parallel_mode=False)
def _get_total_weight(local_weight, tp_group, axis, interleave=False):
@@ -103,12 +105,15 @@ class TestTransformerTp(unittest.TestCase):
# Due to the interleaved qkv layout, need to concat on num_head
# dimension for column parallel linear in MultiHeadAttention layer
assert axis == 0
- assert [3 * self.hidden_size // self.world_size,
- self.hidden_size] == partial_weight.shape
+ assert [
+ 3 * self.hidden_size // self.world_size,
+ self.hidden_size,
+ ] == partial_weight.shape
local_num_head = self.num_heads // self.world_size
for idx, _ in enumerate(total_weight):
total_weight[idx] = total_weight[idx].reshape(
- [3, local_num_head, -1, self.hidden_size])
+ [3, local_num_head, -1, self.hidden_size]
+ )
total_weight = paddle.concat(total_weight, axis=1).reshape([-1, self.hidden_size])
else:
total_weight = paddle.concat(total_weight, axis=axis)
@@ -124,48 +129,56 @@ class TestTransformerTp(unittest.TestCase):
weight_dst = _get_weight(layer_dst, weight_names)
if partition_mode is None:
total_weight = weight_src
- elif partition_mode == 'column':
- total_weight = _get_total_weight(weight_src,
- tp_group=self.tp_group,
- axis=0,
- interleave=interleave)
- elif partition_mode == 'row':
+ elif partition_mode == "column":
+ total_weight = _get_total_weight(
+ weight_src, tp_group=self.tp_group, axis=0, interleave=interleave
+ )
+ elif partition_mode == "row":
total_weight = _get_total_weight(weight_src, tp_group=self.tp_group, axis=1)
else:
raise ValueError(f"Partition Mode {partition_mode} is not supported.")
- assert weight_dst.shape == total_weight.shape, \
- f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
+ assert (
+ weight_dst.shape == total_weight.shape
+ ), f"Shapes of src:{total_weight.shape} and dst:{weight_dst.shape} do not match."
weight_dst.copy_(total_weight, True)
- copy_weight(layer_tp, layer_single, None, ['self_attention', 'layernorm_qkv', 'ln_weight'])
- copy_weight(layer_tp,
- layer_single,
- 'column', ['self_attention', 'layernorm_qkv', 'weight'],
- interleave=True)
- copy_weight(layer_tp, layer_single, 'row', ['self_attention', 'proj', 'weight'])
- copy_weight(layer_tp, layer_single, None, ['layernorm_mlp', 'ln_weight'])
- copy_weight(layer_tp, layer_single, 'column', ['layernorm_mlp', 'fc1_weight'])
- copy_weight(layer_tp, layer_single, 'row', ['layernorm_mlp', 'fc2_weight'])
+ copy_weight(layer_tp, layer_single, None, ["self_attention", "layernorm_qkv", "ln_weight"])
+ copy_weight(
+ layer_tp,
+ layer_single,
+ "column",
+ ["self_attention", "layernorm_qkv", "weight"],
+ interleave=True,
+ )
+ copy_weight(layer_tp, layer_single, "row", ["self_attention", "proj", "weight"])
+ copy_weight(layer_tp, layer_single, None, ["layernorm_mlp", "ln_weight"])
+ copy_weight(layer_tp, layer_single, "column", ["layernorm_mlp", "fc1_weight"])
+ copy_weight(layer_tp, layer_single, "row", ["layernorm_mlp", "fc2_weight"])
if self.sequence_parallel:
register_sequence_parallel_allreduce_hooks(layer_tp, accumulation_steps=1)
optimizer_tp = paddle.optimizer.SGD(learning_rate=0.01, parameters=layer_tp.parameters())
- optimizer_single = paddle.optimizer.SGD(learning_rate=0.01,
- parameters=layer_single.parameters())
+ optimizer_single = paddle.optimizer.SGD(
+ learning_rate=0.01, parameters=layer_single.parameters()
+ )
layer_tp = fleet.distributed_model(layer_tp)
optimizer_tp = fleet.distributed_optimizer(optimizer_tp)
for _ in range(5):
- inp = paddle.uniform([self.batch_size, self.q_seqlen, self.hidden_size],
- self.global_dtype)
- mask = paddle.zeros(shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen),
- dtype='bool')
- loss_tp, out_tp = self._train_one_step(layer_tp, [inp, mask], optimizer_tp, self.fp8,
- self.sequence_parallel)
- loss_single, out_single = self._train_one_step(layer_single, [inp, mask],
- optimizer_single, self.fp8)
+ inp = paddle.uniform(
+ [self.batch_size, self.q_seqlen, self.hidden_size], self.global_dtype
+ )
+ mask = paddle.zeros(
+ shape=(self.batch_size, 1, self.q_seqlen, self.kv_seqlen), dtype="bool"
+ )
+ loss_tp, out_tp = self._train_one_step(
+ layer_tp, [inp, mask], optimizer_tp, self.fp8, self.sequence_parallel
+ )
+ loss_single, out_single = self._train_one_step(
+ layer_single, [inp, mask], optimizer_single, self.fp8
+ )
assert_allclose(out_tp, out_single, rtol=self.rtol, atol=self.atol)
assert_allclose(loss_tp, loss_single, rtol=self.rtol, atol=self.atol)
@@ -181,9 +194,9 @@ class TestTransformerTpFp8(TestTransformerTp):
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
- self.mask_type = 'padding'
- self.layer_type = 'encoder'
- self.global_dtype = 'bfloat16'
+ self.mask_type = "padding"
+ self.layer_type = "encoder"
+ self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
@@ -202,9 +215,9 @@ class TestTransformerSp(TestTransformerTp):
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
- self.mask_type = 'padding'
- self.layer_type = 'encoder'
- self.global_dtype = 'bfloat16'
+ self.mask_type = "padding"
+ self.layer_type = "encoder"
+ self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 5e-2
self.eps = 1e-3
@@ -223,9 +236,9 @@ class TestTransformerSpFp8(TestTransformerSp):
self.ffn_hidden_size = 4096
self.q_seqlen = 128
self.kv_seqlen = 128
- self.mask_type = 'padding'
- self.layer_type = 'encoder'
- self.global_dtype = 'bfloat16'
+ self.mask_type = "padding"
+ self.layer_type = "encoder"
+ self.global_dtype = "bfloat16"
self.rtol = 5e-2
self.atol = 0.5
self.eps = 1e-3
@@ -233,5 +246,5 @@ class TestTransformerSpFp8(TestTransformerSp):
self.sequence_parallel = True
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/recompute_tests/recompute_transformer_encoder.py b/tests/paddle/recompute_tests/recompute_transformer_encoder.py
index 8acf23befd609706b48c52c6148acf16528705f5..56d0c2453558025e81f9e7cb7ba1083be25b18e9 100644
--- a/tests/paddle/recompute_tests/recompute_transformer_encoder.py
+++ b/tests/paddle/recompute_tests/recompute_transformer_encoder.py
@@ -37,14 +37,17 @@ def main():
enable_recompute = int(sys.argv[1])
use_reentrant = int(sys.argv[2])
- layers = paddle.nn.LayerList([
- te.TransformerLayer(
- hidden_size,
- ffn_hidden_size,
- num_heads,
- layer_type='encoder',
- ) for _ in range(num_layers)
- ])
+ layers = paddle.nn.LayerList(
+ [
+ te.TransformerLayer(
+ hidden_size,
+ ffn_hidden_size,
+ num_heads,
+ layer_type="encoder",
+ )
+ for _ in range(num_layers)
+ ]
+ )
model = Net(layers)
optimizer = paddle.optimizer.AdamW(learning_rate=0.001, parameters=model.parameters())
@@ -52,7 +55,7 @@ def main():
for _ in range(10):
inp = paddle.uniform([batch_size, q_seqlen, hidden_size])
inp.stop_gradient = False
- mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype='bool')
+ mask = paddle.zeros(shape=(batch_size, 1, q_seqlen, kv_seqlen), dtype="bool")
with te.fp8_autocast(enabled=True):
out = model(inp, mask, enable_recompute, use_reentrant)
loss = out.mean()
diff --git a/tests/paddle/test_install.py b/tests/paddle/test_install.py
index b6e1e4673f36b6b40934e143d26958bab636962e..686771ec09e8e7a80fc5b22d4373389429d957c9 100644
--- a/tests/paddle/test_install.py
+++ b/tests/paddle/test_install.py
@@ -8,4 +8,4 @@ def test_import():
"""
Test if Paddle extension can be imported normally
"""
- import transformer_engine.paddle # pylint: disable=unused-import
+ import transformer_engine.paddle # pylint: disable=unused-import
diff --git a/tests/paddle/test_layers.py b/tests/paddle/test_layers.py
index 1aeafe030ace36981a8c8b5f4520a26d83ef5557..6a985d7e86675476266f1466164199d784b02f54 100644
--- a/tests/paddle/test_layers.py
+++ b/tests/paddle/test_layers.py
@@ -26,14 +26,14 @@ def setup():
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
-@pytest.mark.parametrize('use_fp8', [True, False])
+@pytest.mark.parametrize("use_fp8", [True, False])
def test_checkpoint(use_fp8):
"""Test checkpoint save / load"""
bs = 16
in_features = 16
out_features = 32
file_name = "model.pdparams"
- input_tensor = paddle.uniform(shape=(bs, in_features), dtype='float32')
+ input_tensor = paddle.uniform(shape=(bs, in_features), dtype="float32")
model = te.Linear(in_features, out_features)
model_loaded = te.Linear(in_features, out_features)
# Populate amax_history
@@ -91,15 +91,18 @@ class TestLinear:
"""
@staticmethod
- @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
- reason="BF16 Linear requires Ampere+ GPU")
- @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
- @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
- @pytest.mark.parametrize('no_dgrad', [True, False])
- @pytest.mark.parametrize('no_wgrad', [True, False])
- @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
- def test_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad,
- activation_dtype):
+ @pytest.mark.skipif(
+ paddle.device.cuda.get_device_capability() < (8, 0),
+ reason="BF16 Linear requires Ampere+ GPU",
+ )
+ @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
+ @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
+ @pytest.mark.parametrize("no_dgrad", [True, False])
+ @pytest.mark.parametrize("no_wgrad", [True, False])
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
+ def test_linear_bf16(
+ bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype
+ ):
"""
Test BF16 Linear
"""
@@ -112,10 +115,9 @@ class TestLinear:
paddle.set_default_dtype(activation_dtype)
layer_te = te.Linear(in_features, out_features, bias_attr=None if has_bias else False)
- layer_pd = te.Linear(in_features,
- out_features,
- bias_attr=None if has_bias else False,
- backend='paddle')
+ layer_pd = te.Linear(
+ in_features, out_features, bias_attr=None if has_bias else False, backend="paddle"
+ )
layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
@@ -139,15 +141,25 @@ class TestLinear:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
- @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
- @pytest.mark.parametrize('no_dgrad', [True, False])
- @pytest.mark.parametrize('no_wgrad', [True, False])
- @pytest.mark.parametrize('fp8_wgrad', [True, False])
- @pytest.mark.parametrize('do_calibration', [True, False])
- @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
- def test_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad, no_wgrad,
- fp8_wgrad, do_calibration, activation_dtype):
+ @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
+ @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
+ @pytest.mark.parametrize("no_dgrad", [True, False])
+ @pytest.mark.parametrize("no_wgrad", [True, False])
+ @pytest.mark.parametrize("fp8_wgrad", [True, False])
+ @pytest.mark.parametrize("do_calibration", [True, False])
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
+ def test_linear_fp8(
+ bs,
+ in_features,
+ out_features,
+ has_bias,
+ no_dbias,
+ no_dgrad,
+ no_wgrad,
+ fp8_wgrad,
+ do_calibration,
+ activation_dtype,
+ ):
"""
Test FP8 Linear
"""
@@ -170,7 +182,7 @@ class TestLinear:
in_features=in_features,
out_features=out_features,
bias_attr=None if has_bias else False,
- backend='paddle',
+ backend="paddle",
)
layer_pd.weight.copy_(layer_te.weight.T, True)
if has_bias:
@@ -182,8 +194,9 @@ class TestLinear:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
- with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
- fp8_recipe=recipe):
+ with fp8_autocast(
+ enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
+ ):
out_ref, grad_input_ref = calc_output_and_grad(layer_pd, input_tensor, grad_out)
out, grad_input = calc_output_and_grad(layer_te, input_tensor, grad_out)
@@ -199,9 +212,9 @@ class TestLinear:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
- @pytest.mark.parametrize('activation_dtype', ['bfloat16'])
- @pytest.mark.parametrize('num_microbatch', [8])
+ @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
+ @pytest.mark.parametrize("num_microbatch", [8])
def test_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype, num_microbatch):
"""
Test FP8 Linear
@@ -236,17 +249,16 @@ class TestLinear:
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
- assert_allclose(layer_cached.weight.grad,
- layer_normal.weight.grad,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
+ )
-@pytest.mark.parametrize('bs,hidden_size', NORM_CASES)
-@pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
-@pytest.mark.parametrize('no_dgrad', [True, False])
-@pytest.mark.parametrize('no_wgrad', [True, False])
-@pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
+@pytest.mark.parametrize("bs,hidden_size", NORM_CASES)
+@pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
+@pytest.mark.parametrize("no_dgrad", [True, False])
+@pytest.mark.parametrize("no_wgrad", [True, False])
+@pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad, activation_dtype):
"""
Test BF16 LayerNorm
@@ -261,10 +273,9 @@ def test_layernorm_bf16(bs, hidden_size, has_bias, no_dbias, no_dgrad, no_wgrad,
paddle.set_default_dtype(activation_dtype)
layer_te = te.LayerNorm(hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False)
- layer_pd = te.LayerNorm(hidden_size=hidden_size,
- eps=eps,
- bias_attr=None if has_bias else False,
- backend='paddle')
+ layer_pd = te.LayerNorm(
+ hidden_size=hidden_size, eps=eps, bias_attr=None if has_bias else False, backend="paddle"
+ )
layer_pd.weight.copy_(layer_te.weight, True)
if has_bias:
layer_pd.bias.copy_(layer_te.bias, True)
@@ -293,17 +304,29 @@ class TestLayerNormLinear:
"""
@staticmethod
- @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
- reason="BF16 Linear requires Ampere+ GPU")
- @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
- @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
- @pytest.mark.parametrize('no_dgrad', [True, False])
- @pytest.mark.parametrize('no_wgrad', [True, False])
- @pytest.mark.parametrize('return_ln_out', [True, False])
- @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
- @pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
- def test_layernorm_linear_bf16(bs, in_features, out_features, has_bias, no_dbias, no_dgrad,
- no_wgrad, return_ln_out, activation_dtype, normalization):
+ @pytest.mark.skipif(
+ paddle.device.cuda.get_device_capability() < (8, 0),
+ reason="BF16 Linear requires Ampere+ GPU",
+ )
+ @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
+ @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
+ @pytest.mark.parametrize("no_dgrad", [True, False])
+ @pytest.mark.parametrize("no_wgrad", [True, False])
+ @pytest.mark.parametrize("return_ln_out", [True, False])
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
+ @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
+ def test_layernorm_linear_bf16(
+ bs,
+ in_features,
+ out_features,
+ has_bias,
+ no_dbias,
+ no_dgrad,
+ no_wgrad,
+ return_ln_out,
+ activation_dtype,
+ normalization,
+ ):
"""
Test BF16 LayerNormLinear Layer
"""
@@ -315,7 +338,7 @@ class TestLayerNormLinear:
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3
- has_ln_bias = normalization == 'LayerNorm'
+ has_ln_bias = normalization == "LayerNorm"
layer_te = te.LayerNormLinear(
in_features=in_features,
@@ -333,7 +356,7 @@ class TestLayerNormLinear:
normalization=normalization,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
- backend='paddle',
+ backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
@@ -355,11 +378,11 @@ class TestLayerNormLinear:
layer_pd.bias.stop_gradient = no_dbias
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
- layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
- out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
- input_tensor,
- grad_out,
- return_ln_out=return_ln_out)
+ layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
+ )
+ out, ln_out, grad_input = calc_output_and_grad_ln_out(
+ layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
+ )
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
@@ -377,18 +400,29 @@ class TestLayerNormLinear:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
- @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
- @pytest.mark.parametrize('no_dgrad', [True, False])
- @pytest.mark.parametrize('no_wgrad', [True, False])
- @pytest.mark.parametrize('fp8_wgrad', [True, False])
- @pytest.mark.parametrize('do_calibration', [True, False])
- @pytest.mark.parametrize('return_ln_out', [True, False])
- @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
- @pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
- def test_layernorm_linear_fp8(bs, in_features, out_features, has_bias, no_dbias, no_dgrad,
- no_wgrad, fp8_wgrad, do_calibration, return_ln_out,
- activation_dtype, normalization):
+ @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
+ @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
+ @pytest.mark.parametrize("no_dgrad", [True, False])
+ @pytest.mark.parametrize("no_wgrad", [True, False])
+ @pytest.mark.parametrize("fp8_wgrad", [True, False])
+ @pytest.mark.parametrize("do_calibration", [True, False])
+ @pytest.mark.parametrize("return_ln_out", [True, False])
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
+ @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
+ def test_layernorm_linear_fp8(
+ bs,
+ in_features,
+ out_features,
+ has_bias,
+ no_dbias,
+ no_dgrad,
+ no_wgrad,
+ fp8_wgrad,
+ do_calibration,
+ return_ln_out,
+ activation_dtype,
+ normalization,
+ ):
"""
Test FP8 LayerNormLinear Layer
"""
@@ -400,7 +434,7 @@ class TestLayerNormLinear:
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, out_features), dtype=activation_dtype)
eps = 1e-3
- has_ln_bias = normalization == 'LayerNorm'
+ has_ln_bias = normalization == "LayerNorm"
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
@@ -420,7 +454,7 @@ class TestLayerNormLinear:
normalization=normalization,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
- backend='paddle',
+ backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
@@ -441,14 +475,15 @@ class TestLayerNormLinear:
layer_te.bias.stop_gradient = no_dbias
layer_pd.bias.stop_gradient = no_dbias
- with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
- fp8_recipe=recipe):
+ with fp8_autocast(
+ enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
+ ):
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
- layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
- out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
- input_tensor,
- grad_out,
- return_ln_out=return_ln_out)
+ layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
+ )
+ out, ln_out, grad_input = calc_output_and_grad_ln_out(
+ layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
+ )
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
@@ -468,11 +503,12 @@ class TestLayerNormLinear:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('bs,in_features,out_features', LINEAR_CASES)
- @pytest.mark.parametrize('activation_dtype', ['bfloat16'])
- @pytest.mark.parametrize('num_microbatch', [8])
- def test_layernorm_linear_fp8_microbatch(bs, in_features, out_features, activation_dtype,
- num_microbatch):
+ @pytest.mark.parametrize("bs,in_features,out_features", LINEAR_CASES)
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
+ @pytest.mark.parametrize("num_microbatch", [8])
+ def test_layernorm_linear_fp8_microbatch(
+ bs, in_features, out_features, activation_dtype, num_microbatch
+ ):
"""
Test FP8 LayerNormLinear Layer
"""
@@ -513,14 +549,12 @@ class TestLayerNormLinear:
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
- assert_allclose(layer_cached.weight.grad,
- layer_normal.weight.grad,
- rtol=rtol,
- atol=atol)
- assert_allclose(layer_cached.ln_weight.grad,
- layer_normal.ln_weight.grad,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_cached.weight.grad, layer_normal.weight.grad, rtol=rtol, atol=atol
+ )
+ assert_allclose(
+ layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
+ )
class TestLayerNormMLP:
@@ -529,19 +563,31 @@ class TestLayerNormMLP:
"""
@staticmethod
- @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
- reason="BF16 Linear requires Ampere+ GPU")
- @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
- @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
- @pytest.mark.parametrize('no_dgrad', [True, False])
- @pytest.mark.parametrize('no_wgrad', [True, False])
- @pytest.mark.parametrize('return_ln_out', [True, False])
- @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
- @pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
- @pytest.mark.parametrize('activation', ['gelu', 'swiglu'])
- def test_layernorm_mlp_bf16(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad,
- no_wgrad, return_ln_out, activation_dtype, normalization,
- activation):
+ @pytest.mark.skipif(
+ paddle.device.cuda.get_device_capability() < (8, 0),
+ reason="BF16 Linear requires Ampere+ GPU",
+ )
+ @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
+ @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
+ @pytest.mark.parametrize("no_dgrad", [True, False])
+ @pytest.mark.parametrize("no_wgrad", [True, False])
+ @pytest.mark.parametrize("return_ln_out", [True, False])
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
+ @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
+ @pytest.mark.parametrize("activation", ["gelu", "swiglu"])
+ def test_layernorm_mlp_bf16(
+ bs,
+ hidden_size,
+ ffn_hidden_size,
+ has_bias,
+ no_dbias,
+ no_dgrad,
+ no_wgrad,
+ return_ln_out,
+ activation_dtype,
+ normalization,
+ activation,
+ ):
"""
Tests for TestLayerNormMLP layer
"""
@@ -553,7 +599,7 @@ class TestLayerNormMLP:
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3
- has_ln_bias = normalization == 'LayerNorm'
+ has_ln_bias = normalization == "LayerNorm"
layer_te = te.LayerNormMLP(
hidden_size=hidden_size,
@@ -572,7 +618,7 @@ class TestLayerNormMLP:
activation=activation,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
- backend='paddle',
+ backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
@@ -599,55 +645,63 @@ class TestLayerNormMLP:
layer_pd.fc2_bias.stop_gradient = no_dbias
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
- layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
- out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
- input_tensor,
- grad_out,
- return_ln_out=return_ln_out)
+ layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
+ )
+ out, ln_out, grad_input = calc_output_and_grad_ln_out(
+ layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
+ )
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
- assert_allclose(layer_te.fc1_weight.grad,
- layer_pd.fc1_weight.grad.T,
- rtol=rtol,
- atol=atol)
- assert_allclose(layer_te.fc2_weight.grad,
- layer_pd.fc2_weight.grad.T,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
+ )
+ assert_allclose(
+ layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
+ )
if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias:
- assert_allclose(layer_te.fc1_bias.grad,
- layer_pd.fc1_bias.grad,
- rtol=rtol,
- atol=atol)
- assert_allclose(layer_te.fc2_bias.grad,
- layer_pd.fc2_bias.grad,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
+ )
+ assert_allclose(
+ layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
+ )
if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
- @pytest.mark.parametrize('has_bias,no_dbias', [[True, False], [True, True], [False, False]])
- @pytest.mark.parametrize('no_dgrad', [True, False])
- @pytest.mark.parametrize('no_wgrad', [True, False])
- @pytest.mark.parametrize('fp8_wgrad', [True, False])
- @pytest.mark.parametrize('do_calibration', [True, False])
- @pytest.mark.parametrize('return_ln_out', [True, False])
- @pytest.mark.parametrize('activation_dtype', ['bfloat16', 'float32'])
- @pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
- @pytest.mark.parametrize('activation', ['gelu', 'swiglu'])
- def test_layernorm_mlp_fp8(bs, hidden_size, ffn_hidden_size, has_bias, no_dbias, no_dgrad,
- no_wgrad, fp8_wgrad, do_calibration, return_ln_out, activation_dtype,
- normalization, activation):
+ @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
+ @pytest.mark.parametrize("has_bias,no_dbias", [[True, False], [True, True], [False, False]])
+ @pytest.mark.parametrize("no_dgrad", [True, False])
+ @pytest.mark.parametrize("no_wgrad", [True, False])
+ @pytest.mark.parametrize("fp8_wgrad", [True, False])
+ @pytest.mark.parametrize("do_calibration", [True, False])
+ @pytest.mark.parametrize("return_ln_out", [True, False])
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16", "float32"])
+ @pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
+ @pytest.mark.parametrize("activation", ["gelu", "swiglu"])
+ def test_layernorm_mlp_fp8(
+ bs,
+ hidden_size,
+ ffn_hidden_size,
+ has_bias,
+ no_dbias,
+ no_dgrad,
+ no_wgrad,
+ fp8_wgrad,
+ do_calibration,
+ return_ln_out,
+ activation_dtype,
+ normalization,
+ activation,
+ ):
"""
Test FP8 LayerNormMLP Layer
"""
@@ -659,7 +713,7 @@ class TestLayerNormMLP:
input_tensor.stop_gradient = no_dgrad
grad_out = paddle.uniform(shape=(bs, hidden_size), dtype=activation_dtype)
eps = 1e-3
- has_ln_bias = normalization == 'LayerNorm'
+ has_ln_bias = normalization == "LayerNorm"
recipe = DelayedScaling(override_linear_precision=(False, False, not fp8_wgrad))
@@ -681,7 +735,7 @@ class TestLayerNormMLP:
activation=activation,
bias_attr=None if has_bias else False,
return_layernorm_output=return_ln_out,
- backend='paddle',
+ backend="paddle",
)
layer_pd.ln_weight.copy_(layer_te.ln_weight, True)
if has_ln_bias:
@@ -707,40 +761,37 @@ class TestLayerNormMLP:
layer_pd.fc1_bias.stop_gradient = no_dbias
layer_pd.fc2_bias.stop_gradient = no_dbias
- with fp8_autocast(enabled=not do_calibration, calibrating=do_calibration,
- fp8_recipe=recipe):
+ with fp8_autocast(
+ enabled=not do_calibration, calibrating=do_calibration, fp8_recipe=recipe
+ ):
out_ref, ln_out_ref, grad_input_ref = calc_output_and_grad_ln_out(
- layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out)
- out, ln_out, grad_input = calc_output_and_grad_ln_out(layer_te,
- input_tensor,
- grad_out,
- return_ln_out=return_ln_out)
+ layer_pd, input_tensor, grad_out, return_ln_out=return_ln_out
+ )
+ out, ln_out, grad_input = calc_output_and_grad_ln_out(
+ layer_te, input_tensor, grad_out, return_ln_out=return_ln_out
+ )
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
if not no_dgrad:
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
assert_allclose(layer_te.ln_weight.grad, layer_pd.ln_weight.grad, rtol=rtol, atol=atol)
- assert_allclose(layer_te.fc1_weight.grad,
- layer_pd.fc1_weight.grad.T,
- rtol=rtol,
- atol=atol)
- assert_allclose(layer_te.fc2_weight.grad,
- layer_pd.fc2_weight.grad.T,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.fc1_weight.grad, layer_pd.fc1_weight.grad.T, rtol=rtol, atol=atol
+ )
+ assert_allclose(
+ layer_te.fc2_weight.grad, layer_pd.fc2_weight.grad.T, rtol=rtol, atol=atol
+ )
if not no_dbias:
if has_ln_bias:
assert_allclose(layer_te.ln_bias.grad, layer_pd.ln_bias.grad, rtol=rtol, atol=atol)
if has_bias:
- assert_allclose(layer_te.fc1_bias.grad,
- layer_pd.fc1_bias.grad,
- rtol=rtol,
- atol=atol)
- assert_allclose(layer_te.fc2_bias.grad,
- layer_pd.fc2_bias.grad,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.fc1_bias.grad, layer_pd.fc1_bias.grad, rtol=rtol, atol=atol
+ )
+ assert_allclose(
+ layer_te.fc2_bias.grad, layer_pd.fc2_bias.grad, rtol=rtol, atol=atol
+ )
if return_ln_out:
assert_allclose(ln_out, ln_out_ref, rtol=rtol, atol=atol)
@@ -749,11 +800,12 @@ class TestLayerNormMLP:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('bs,hidden_size,ffn_hidden_size', LINEAR_CASES)
- @pytest.mark.parametrize('activation_dtype', ['bfloat16'])
- @pytest.mark.parametrize('num_microbatch', [8])
- def test_layernorm_mlp_fp8_microbatch(bs, hidden_size, ffn_hidden_size, activation_dtype,
- num_microbatch):
+ @pytest.mark.parametrize("bs,hidden_size,ffn_hidden_size", LINEAR_CASES)
+ @pytest.mark.parametrize("activation_dtype", ["bfloat16"])
+ @pytest.mark.parametrize("num_microbatch", [8])
+ def test_layernorm_mlp_fp8_microbatch(
+ bs, hidden_size, ffn_hidden_size, activation_dtype, num_microbatch
+ ):
"""
Test FP8 LayerNormMLP Layer
"""
@@ -803,28 +855,26 @@ class TestLayerNormMLP:
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
- assert_allclose(layer_cached.ln_weight.grad,
- layer_normal.ln_weight.grad,
- rtol=rtol,
- atol=atol)
- assert_allclose(layer_cached.fc1_weight.grad,
- layer_normal.fc1_weight.grad,
- rtol=rtol,
- atol=atol)
- assert_allclose(layer_cached.fc2_weight.grad,
- layer_normal.fc2_weight.grad,
- rtol=rtol,
- atol=atol)
-
-
-@pytest.mark.parametrize('bs', [1, 2])
-@pytest.mark.parametrize('hidden_size, num_heads', [[1024, 16]])
-@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
-@pytest.mark.parametrize('attn_type', ['self', 'cross'])
-@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
-@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
-def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type,
- mask_type, math_dtype):
+ assert_allclose(
+ layer_cached.ln_weight.grad, layer_normal.ln_weight.grad, rtol=rtol, atol=atol
+ )
+ assert_allclose(
+ layer_cached.fc1_weight.grad, layer_normal.fc1_weight.grad, rtol=rtol, atol=atol
+ )
+ assert_allclose(
+ layer_cached.fc2_weight.grad, layer_normal.fc2_weight.grad, rtol=rtol, atol=atol
+ )
+
+
+@pytest.mark.parametrize("bs", [1, 2])
+@pytest.mark.parametrize("hidden_size, num_heads", [[1024, 16]])
+@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
+@pytest.mark.parametrize("attn_type", ["self", "cross"])
+@pytest.mark.parametrize("mask_type", ["causal", "padding"])
+@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
+def test_dot_product_attention(
+ bs, hidden_size, num_heads, q_seqlen, kv_seqlen, attn_type, mask_type, math_dtype
+):
"""
Test DotProductAttention Layer
"""
@@ -835,53 +885,64 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
- num_heads=num_heads,
- num_gqa_groups=num_heads,
- q_seqlen=q_seqlen,
- kv_seqlen=kv_seqlen,
- head_size=head_size,
- dtype=math_dtype,
- dropout=0.0,
- qkv_layout="bshd_bshd_bshd",
- bias_type="no_bias",
- mask_type=mask_type,
+ num_heads=num_heads,
+ num_gqa_groups=num_heads,
+ q_seqlen=q_seqlen,
+ kv_seqlen=kv_seqlen,
+ head_size=head_size,
+ dtype=math_dtype,
+ dropout=0.0,
+ qkv_layout="bshd_bshd_bshd",
+ bias_type="no_bias",
+ mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
- attn_q_input = paddle.normal(mean=0.0, std=0.02,
- shape=(bs, q_seqlen, num_heads, head_size)).astype(math_dtype)
- attn_k_input = paddle.normal(mean=0.0, std=0.02,
- shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
- attn_v_input = paddle.normal(mean=0.0, std=0.02,
- shape=(bs, kv_seqlen, num_heads, head_size)).astype(math_dtype)
-
- q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype='int32')
- kv_actual_seqlen = paddle.randint(low=20, high=kv_seqlen, shape=(bs,),
- dtype='int32') if attn_type == 'cross' else q_actual_seqlen
- attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
-
- grad_out = paddle.normal(mean=0.0, std=0.02,
- shape=(bs, q_seqlen, num_heads, head_size)).astype('float32')
+ attn_q_input = paddle.normal(
+ mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)
+ ).astype(math_dtype)
+ attn_k_input = paddle.normal(
+ mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
+ ).astype(math_dtype)
+ attn_v_input = paddle.normal(
+ mean=0.0, std=0.02, shape=(bs, kv_seqlen, num_heads, head_size)
+ ).astype(math_dtype)
+
+ q_actual_seqlen = paddle.randint(low=20, high=q_seqlen, shape=(bs,), dtype="int32")
+ kv_actual_seqlen = (
+ paddle.randint(low=20, high=kv_seqlen, shape=(bs,), dtype="int32")
+ if attn_type == "cross"
+ else q_actual_seqlen
+ )
+ attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
+
+ grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, num_heads, head_size)).astype(
+ "float32"
+ )
for i in range(0, bs):
- grad_out[i, q_actual_seqlen[i]:, :, :] = 0
+ grad_out[i, q_actual_seqlen[i] :, :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
- attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
+ attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
head_size = hidden_size // num_heads
- layer_te = te.DotProductAttention(num_heads,
- head_size,
- attention_dropout=0.0,
- attn_mask_type=mask_type,
- attention_type=attn_type,
- backend='transformer_engine')
- layer_pd = te.DotProductAttention(num_heads,
- head_size,
- attention_dropout=0.0,
- attn_mask_type=mask_type,
- attention_type=attn_type,
- backend='paddle')
+ layer_te = te.DotProductAttention(
+ num_heads,
+ head_size,
+ attention_dropout=0.0,
+ attn_mask_type=mask_type,
+ attention_type=attn_type,
+ backend="transformer_engine",
+ )
+ layer_pd = te.DotProductAttention(
+ num_heads,
+ head_size,
+ attention_dropout=0.0,
+ attn_mask_type=mask_type,
+ attention_type=attn_type,
+ backend="paddle",
+ )
def calc_attn_output_and_grad(layer, q, k, v, mask, dout):
_q = paddle.to_tensor(q, stop_gradient=False)
@@ -892,23 +953,29 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
out.backward(dout)
return out, _q.grad, _k.grad, _v.grad
- out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(layer_te, attn_q_input, attn_k_input,
- attn_v_input, attn_mask, grad_out)
+ out, q_grad, k_grad, v_grad = calc_attn_output_and_grad(
+ layer_te, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
+ )
out_ref, q_grad_ref, k_grad_ref, v_grad_ref = calc_attn_output_and_grad(
- layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out)
+ layer_pd, attn_q_input, attn_k_input, attn_v_input, attn_mask, grad_out
+ )
valid_out_ref = paddle.full_like(out_ref, 0)
for i in range(0, bs):
- valid_out_ref[i, 0:q_actual_seqlen[i], :, :] = out_ref[i, 0:q_actual_seqlen[i], :, :]
+ valid_out_ref[i, 0 : q_actual_seqlen[i], :, :] = out_ref[i, 0 : q_actual_seqlen[i], :, :]
valid_q_grad_ref = paddle.full_like(q_grad_ref, 0)
valid_k_grad_ref = paddle.full_like(k_grad_ref, 0)
valid_v_grad_ref = paddle.full_like(v_grad_ref, 0)
for i in range(0, bs):
- valid_q_grad_ref[i, 0:q_actual_seqlen[i], :, :] = q_grad_ref[i, 0:q_actual_seqlen[i], :, :]
- valid_k_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = k_grad_ref[i,
- 0:kv_actual_seqlen[i], :, :]
- valid_v_grad_ref[i, 0:kv_actual_seqlen[i], :, :] = v_grad_ref[i,
- 0:kv_actual_seqlen[i], :, :]
+ valid_q_grad_ref[i, 0 : q_actual_seqlen[i], :, :] = q_grad_ref[
+ i, 0 : q_actual_seqlen[i], :, :
+ ]
+ valid_k_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = k_grad_ref[
+ i, 0 : kv_actual_seqlen[i], :, :
+ ]
+ valid_v_grad_ref[i, 0 : kv_actual_seqlen[i], :, :] = v_grad_ref[
+ i, 0 : kv_actual_seqlen[i], :, :
+ ]
assert_allclose(out, valid_out_ref, rtol=rtol, atol=atol)
assert_allclose(q_grad, valid_q_grad_ref, rtol=rtol, atol=atol)
@@ -916,21 +983,34 @@ def test_dot_product_attention(bs, hidden_size, num_heads, q_seqlen, kv_seqlen,
assert_allclose(v_grad, valid_v_grad_ref, rtol=rtol, atol=atol)
-@pytest.mark.parametrize('bs', [1, 2])
-@pytest.mark.parametrize('num_gqa_groups', [1, 2, 4])
-@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]])
-@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
-@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
-@pytest.mark.parametrize('no_wgrad', [True, False])
-@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
-@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
-@pytest.mark.parametrize('output_layernorm', [True, False])
-@pytest.mark.parametrize('return_layernorm_output', [True, False])
-@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
-def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
- has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
- math_dtype, output_layernorm, return_layernorm_output,
- normalization):
+@pytest.mark.parametrize("bs", [1, 2])
+@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
+@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
+@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
+@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
+@pytest.mark.parametrize("no_wgrad", [True, False])
+@pytest.mark.parametrize("mask_type", ["causal", "padding"])
+@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
+@pytest.mark.parametrize("output_layernorm", [True, False])
+@pytest.mark.parametrize("return_layernorm_output", [True, False])
+@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
+def test_transformer_encoder_layer(
+ bs,
+ hidden_size,
+ num_heads,
+ num_gqa_groups,
+ ffn_hidden_size,
+ has_bias,
+ no_dbias,
+ no_wgrad,
+ q_seqlen,
+ kv_seqlen,
+ mask_type,
+ math_dtype,
+ output_layernorm,
+ return_layernorm_output,
+ normalization,
+):
"""
Test Transformer Encoder Layer
"""
@@ -938,68 +1018,73 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
rtol = 5e-2
atol = 5e-2
eps = 1e-3
- has_ln_bias = normalization == 'LayerNorm'
+ has_ln_bias = normalization == "LayerNorm"
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
- num_heads=num_heads,
- num_gqa_groups=num_gqa_groups,
- q_seqlen=q_seqlen,
- kv_seqlen=kv_seqlen,
- head_size=hidden_size // num_heads,
- dtype=math_dtype,
- dropout=0.0,
- qkv_layout="bshd_bshd_bshd",
- bias_type="no_bias",
- mask_type=mask_type,
+ num_heads=num_heads,
+ num_gqa_groups=num_gqa_groups,
+ q_seqlen=q_seqlen,
+ kv_seqlen=kv_seqlen,
+ head_size=hidden_size // num_heads,
+ dtype=math_dtype,
+ dropout=0.0,
+ qkv_layout="bshd_bshd_bshd",
+ bias_type="no_bias",
+ mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
- q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
+ q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
- attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
+ attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
- grad_out = paddle.normal(mean=0.0, std=0.02,
- shape=(bs, q_seqlen, hidden_size)).astype('float32')
+ grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
+ "float32"
+ )
for i in range(0, bs):
- grad_out[i, q_actual_seqlen[i]:, :] = 0
+ grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
- attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
-
- layer_te = te.TransformerLayer(hidden_size,
- ffn_hidden_size,
- num_heads,
- num_gqa_groups=num_gqa_groups,
- layernorm_epsilon=eps,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- weight_attr=None,
- bias_attr=None if has_bias else False,
- self_attn_mask_type=mask_type,
- apply_residual_connection_post_layernorm=return_layernorm_output,
- output_layernorm=output_layernorm,
- layer_type='encoder',
- normalization=normalization,
- backend='transformer_engine')
- layer_pd = te.TransformerLayer(hidden_size,
- ffn_hidden_size,
- num_heads,
- num_gqa_groups=num_gqa_groups,
- layernorm_epsilon=eps,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- weight_attr=None,
- bias_attr=None if has_bias else False,
- self_attn_mask_type=mask_type,
- apply_residual_connection_post_layernorm=return_layernorm_output,
- output_layernorm=output_layernorm,
- layer_type='encoder',
- normalization=normalization,
- backend='paddle')
+ attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
+
+ layer_te = te.TransformerLayer(
+ hidden_size,
+ ffn_hidden_size,
+ num_heads,
+ num_gqa_groups=num_gqa_groups,
+ layernorm_epsilon=eps,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ weight_attr=None,
+ bias_attr=None if has_bias else False,
+ self_attn_mask_type=mask_type,
+ apply_residual_connection_post_layernorm=return_layernorm_output,
+ output_layernorm=output_layernorm,
+ layer_type="encoder",
+ normalization=normalization,
+ backend="transformer_engine",
+ )
+ layer_pd = te.TransformerLayer(
+ hidden_size,
+ ffn_hidden_size,
+ num_heads,
+ num_gqa_groups=num_gqa_groups,
+ layernorm_epsilon=eps,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ weight_attr=None,
+ bias_attr=None if has_bias else False,
+ self_attn_mask_type=mask_type,
+ apply_residual_connection_post_layernorm=return_layernorm_output,
+ output_layernorm=output_layernorm,
+ layer_type="encoder",
+ normalization=normalization,
+ backend="paddle",
+ )
# MultiHeadAttention params
if output_layernorm:
@@ -1012,21 +1097,25 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
- layer_te.self_attention.layernorm_qkv.ln_weight, True)
+ layer_te.self_attention.layernorm_qkv.ln_weight, True
+ )
layer_pd.self_attention.layernorm_qkv.weight.copy_(
- layer_te.self_attention.layernorm_qkv.weight.T, True)
+ layer_te.self_attention.layernorm_qkv.weight.T, True
+ )
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
- layer_te.self_attention.layernorm_qkv.ln_bias, True)
+ layer_te.self_attention.layernorm_qkv.ln_bias, True
+ )
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_(
- layer_te.self_attention.layernorm_qkv.bias, True)
+ layer_te.self_attention.layernorm_qkv.bias, True
+ )
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
@@ -1074,52 +1163,75 @@ def test_transformer_encoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
out.backward(dout)
return out, _encoder_input.grad
- out_ref, grad_input_ref = calc_transformer_output_and_grad(layer_pd, encoder_input, attn_mask,
- grad_out)
+ out_ref, grad_input_ref = calc_transformer_output_and_grad(
+ layer_pd, encoder_input, attn_mask, grad_out
+ )
out, grad_input = calc_transformer_output_and_grad(layer_te, encoder_input, attn_mask, grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_input, grad_input_ref, rtol=rtol, atol=atol)
if not no_wgrad:
if output_layernorm:
- assert_allclose(layer_te.self_attention.qkv.weight.grad,
- layer_pd.self_attention.qkv.weight.grad.T,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.self_attention.qkv.weight.grad,
+ layer_pd.self_attention.qkv.weight.grad.T,
+ rtol=rtol,
+ atol=atol,
+ )
else:
- assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
- layer_pd.self_attention.layernorm_qkv.weight.grad.T,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.self_attention.layernorm_qkv.weight.grad,
+ layer_pd.self_attention.layernorm_qkv.weight.grad.T,
+ rtol=rtol,
+ atol=atol,
+ )
if not no_dbias:
if output_layernorm:
- assert_allclose(layer_te.self_attention.qkv.bias.grad,
- layer_pd.self_attention.qkv.bias.grad,
- rtol=0.01,
- atol=0.5)
+ assert_allclose(
+ layer_te.self_attention.qkv.bias.grad,
+ layer_pd.self_attention.qkv.bias.grad,
+ rtol=0.01,
+ atol=0.5,
+ )
else:
- assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
- layer_pd.self_attention.layernorm_qkv.bias.grad,
- rtol=0.01,
- atol=0.5)
-
-
-@pytest.mark.parametrize('bs', [1, 2])
-@pytest.mark.parametrize('num_gqa_groups', [1, 2, 4])
-@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[256, 4, 1024]])
-@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[1024, 1024]])
-@pytest.mark.parametrize('has_bias, no_dbias', [[False, True], [True, True], [True, False]])
-@pytest.mark.parametrize('no_wgrad', [True, False])
-@pytest.mark.parametrize('mask_type', ['causal', 'padding'])
-@pytest.mark.parametrize('math_dtype', ['bfloat16', 'float16'])
-@pytest.mark.parametrize('output_layernorm', [True, False])
-@pytest.mark.parametrize('return_layernorm_output', [True, False])
-@pytest.mark.parametrize('recompute_core_attention', [True, False])
-@pytest.mark.parametrize('normalization', ['RMSNorm', 'LayerNorm'])
-def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, ffn_hidden_size,
- has_bias, no_dbias, no_wgrad, q_seqlen, kv_seqlen, mask_type,
- math_dtype, output_layernorm, return_layernorm_output,
- recompute_core_attention, normalization):
+ assert_allclose(
+ layer_te.self_attention.layernorm_qkv.bias.grad,
+ layer_pd.self_attention.layernorm_qkv.bias.grad,
+ rtol=0.01,
+ atol=0.5,
+ )
+
+
+@pytest.mark.parametrize("bs", [1, 2])
+@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4])
+@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[256, 4, 1024]])
+@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[1024, 1024]])
+@pytest.mark.parametrize("has_bias, no_dbias", [[False, True], [True, True], [True, False]])
+@pytest.mark.parametrize("no_wgrad", [True, False])
+@pytest.mark.parametrize("mask_type", ["causal", "padding"])
+@pytest.mark.parametrize("math_dtype", ["bfloat16", "float16"])
+@pytest.mark.parametrize("output_layernorm", [True, False])
+@pytest.mark.parametrize("return_layernorm_output", [True, False])
+@pytest.mark.parametrize("recompute_core_attention", [True, False])
+@pytest.mark.parametrize("normalization", ["RMSNorm", "LayerNorm"])
+def test_transformer_decoder_layer(
+ bs,
+ hidden_size,
+ num_heads,
+ num_gqa_groups,
+ ffn_hidden_size,
+ has_bias,
+ no_dbias,
+ no_wgrad,
+ q_seqlen,
+ kv_seqlen,
+ mask_type,
+ math_dtype,
+ output_layernorm,
+ return_layernorm_output,
+ recompute_core_attention,
+ normalization,
+):
"""
Test Transformer Decoder Layer
"""
@@ -1127,34 +1239,37 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
rtol = 5e-2
atol = 6e-2
eps = 1e-3
- has_ln_bias = normalization == 'LayerNorm'
+ has_ln_bias = normalization == "LayerNorm"
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
- num_heads=num_heads,
- num_gqa_groups=num_gqa_groups,
- q_seqlen=q_seqlen,
- kv_seqlen=kv_seqlen,
- head_size=hidden_size // num_heads,
- dtype=math_dtype,
- dropout=0.0,
- qkv_layout="bshd_bshd_bshd",
- bias_type="no_bias",
- mask_type=mask_type,
+ num_heads=num_heads,
+ num_gqa_groups=num_gqa_groups,
+ q_seqlen=q_seqlen,
+ kv_seqlen=kv_seqlen,
+ head_size=hidden_size // num_heads,
+ dtype=math_dtype,
+ dropout=0.0,
+ qkv_layout="bshd_bshd_bshd",
+ bias_type="no_bias",
+ mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
- encoder_input = paddle.normal(mean=0.0, std=0.1,
- shape=(bs, q_seqlen, hidden_size)).astype(math_dtype)
- encoder_output = paddle.normal(mean=0.0, std=0.1,
- shape=(bs, kv_seqlen, hidden_size)).astype(math_dtype)
+ encoder_input = paddle.normal(mean=0.0, std=0.1, shape=(bs, q_seqlen, hidden_size)).astype(
+ math_dtype
+ )
+ encoder_output = paddle.normal(mean=0.0, std=0.1, shape=(bs, kv_seqlen, hidden_size)).astype(
+ math_dtype
+ )
- q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
+ q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
- attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
+ attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
- grad_out = paddle.normal(mean=0.0, std=0.01,
- shape=(bs, q_seqlen, hidden_size)).astype('float32')
+ grad_out = paddle.normal(mean=0.0, std=0.01, shape=(bs, q_seqlen, hidden_size)).astype(
+ "float32"
+ )
# rounding to avoid numerical issues
encoder_input = paddle.round(encoder_input * 1000) / 1000
@@ -1162,42 +1277,46 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
grad_out = paddle.round(grad_out * 1000) / 1000
for i in range(0, bs):
- grad_out[i, q_actual_seqlen[i]:, :] = 0
+ grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
- attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
-
- layer_te = te.TransformerLayer(hidden_size,
- ffn_hidden_size,
- num_heads,
- num_gqa_groups=num_gqa_groups,
- layernorm_epsilon=eps,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- weight_attr=None,
- bias_attr=None if has_bias else False,
- self_attn_mask_type=mask_type,
- apply_residual_connection_post_layernorm=return_layernorm_output,
- output_layernorm=output_layernorm,
- layer_type='decoder',
- normalization=normalization,
- backend='transformer_engine')
- layer_pd = te.TransformerLayer(hidden_size,
- ffn_hidden_size,
- num_heads,
- num_gqa_groups=num_gqa_groups,
- layernorm_epsilon=eps,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- weight_attr=None,
- bias_attr=None if has_bias else False,
- self_attn_mask_type=mask_type,
- apply_residual_connection_post_layernorm=return_layernorm_output,
- output_layernorm=output_layernorm,
- layer_type='decoder',
- normalization=normalization,
- backend='paddle')
+ attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
+
+ layer_te = te.TransformerLayer(
+ hidden_size,
+ ffn_hidden_size,
+ num_heads,
+ num_gqa_groups=num_gqa_groups,
+ layernorm_epsilon=eps,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ weight_attr=None,
+ bias_attr=None if has_bias else False,
+ self_attn_mask_type=mask_type,
+ apply_residual_connection_post_layernorm=return_layernorm_output,
+ output_layernorm=output_layernorm,
+ layer_type="decoder",
+ normalization=normalization,
+ backend="transformer_engine",
+ )
+ layer_pd = te.TransformerLayer(
+ hidden_size,
+ ffn_hidden_size,
+ num_heads,
+ num_gqa_groups=num_gqa_groups,
+ layernorm_epsilon=eps,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ weight_attr=None,
+ bias_attr=None if has_bias else False,
+ self_attn_mask_type=mask_type,
+ apply_residual_connection_post_layernorm=return_layernorm_output,
+ output_layernorm=output_layernorm,
+ layer_type="decoder",
+ normalization=normalization,
+ backend="paddle",
+ )
# MultiHeadAttention params - self attn
if output_layernorm:
@@ -1210,21 +1329,25 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.self_attention.qkv.bias.stop_gradient = no_dbias
else:
layer_pd.self_attention.layernorm_qkv.ln_weight.copy_(
- layer_te.self_attention.layernorm_qkv.ln_weight, True)
+ layer_te.self_attention.layernorm_qkv.ln_weight, True
+ )
layer_pd.self_attention.layernorm_qkv.weight.copy_(
- layer_te.self_attention.layernorm_qkv.weight.T, True)
+ layer_te.self_attention.layernorm_qkv.weight.T, True
+ )
layer_pd.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_pd.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.ln_weight.stop_gradient = no_wgrad
layer_te.self_attention.layernorm_qkv.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.self_attention.layernorm_qkv.ln_bias.copy_(
- layer_te.self_attention.layernorm_qkv.ln_bias, True)
+ layer_te.self_attention.layernorm_qkv.ln_bias, True
+ )
layer_pd.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.self_attention.layernorm_qkv.bias.copy_(
- layer_te.self_attention.layernorm_qkv.bias, True)
+ layer_te.self_attention.layernorm_qkv.bias, True
+ )
layer_pd.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
layer_te.self_attention.layernorm_qkv.bias.stop_gradient = no_dbias
@@ -1238,26 +1361,31 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
# MultiHeadAttention params - cross attn
layer_pd.inter_attention.layernorm_query.ln_weight.copy_(
- layer_te.inter_attention.layernorm_query.ln_weight, True)
+ layer_te.inter_attention.layernorm_query.ln_weight, True
+ )
layer_pd.inter_attention.layernorm_query.weight.copy_(
- layer_te.inter_attention.layernorm_query.weight.T, True)
+ layer_te.inter_attention.layernorm_query.weight.T, True
+ )
layer_pd.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_pd.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.ln_weight.stop_gradient = no_wgrad
layer_te.inter_attention.layernorm_query.weight.stop_gradient = no_wgrad
if has_ln_bias:
layer_pd.inter_attention.layernorm_query.ln_bias.copy_(
- layer_te.inter_attention.layernorm_query.ln_bias, True)
+ layer_te.inter_attention.layernorm_query.ln_bias, True
+ )
layer_pd.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.ln_bias.stop_gradient = no_dbias
if has_bias:
layer_pd.inter_attention.layernorm_query.bias.copy_(
- layer_te.inter_attention.layernorm_query.bias, True)
+ layer_te.inter_attention.layernorm_query.bias, True
+ )
layer_pd.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
layer_te.inter_attention.layernorm_query.bias.stop_gradient = no_dbias
- layer_pd.inter_attention.key_value.weight.copy_(layer_te.inter_attention.key_value.weight.T,
- True)
+ layer_pd.inter_attention.key_value.weight.copy_(
+ layer_te.inter_attention.key_value.weight.T, True
+ )
layer_pd.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_te.inter_attention.key_value.weight.stop_gradient = no_wgrad
layer_pd.inter_attention.proj.weight.copy_(layer_te.inter_attention.proj.weight.T, True)
@@ -1301,25 +1429,30 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
layer_te.layernorm.weight.stop_gradient = no_wgrad
layer_te.layernorm.bias.stop_gradient = no_dbias
- def calc_transformer_output_and_grad(layer,
- encoder_input,
- mask,
- encoder_output,
- enc_dec_attn_mask,
- dout,
- recompute_core_attention=False):
+ def calc_transformer_output_and_grad(
+ layer,
+ encoder_input,
+ mask,
+ encoder_output,
+ enc_dec_attn_mask,
+ dout,
+ recompute_core_attention=False,
+ ):
_encoder_input = paddle.to_tensor(encoder_input, stop_gradient=False)
_encoder_output = paddle.to_tensor(encoder_output, stop_gradient=False)
- out = layer(_encoder_input,
- mask,
- _encoder_output,
- enc_dec_attn_mask,
- recompute_core_attention=recompute_core_attention)
+ out = layer(
+ _encoder_input,
+ mask,
+ _encoder_output,
+ enc_dec_attn_mask,
+ recompute_core_attention=recompute_core_attention,
+ )
out.backward(dout)
return out, _encoder_input.grad, _encoder_output.grad
out_ref, grad_encoder_input_ref, grad_encoder_output_ref = calc_transformer_output_and_grad(
- layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out)
+ layer_pd, encoder_input, attn_mask, encoder_output, attn_mask, grad_out
+ )
out, grad_encoder_input, grad_encoder_output = calc_transformer_output_and_grad(
layer_te,
encoder_input,
@@ -1327,52 +1460,74 @@ def test_transformer_decoder_layer(bs, hidden_size, num_heads, num_gqa_groups, f
encoder_output,
attn_mask,
grad_out,
- recompute_core_attention=recompute_core_attention)
+ recompute_core_attention=recompute_core_attention,
+ )
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_input, grad_encoder_input_ref, rtol=rtol, atol=atol)
assert_allclose(grad_encoder_output, grad_encoder_output_ref, rtol=rtol, atol=atol)
if not no_wgrad:
if output_layernorm:
- assert_allclose(layer_te.self_attention.qkv.weight.grad,
- layer_pd.self_attention.qkv.weight.grad.T,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.self_attention.qkv.weight.grad,
+ layer_pd.self_attention.qkv.weight.grad.T,
+ rtol=rtol,
+ atol=atol,
+ )
else:
- assert_allclose(layer_te.self_attention.layernorm_qkv.weight.grad,
- layer_pd.self_attention.layernorm_qkv.weight.grad.T,
- rtol=rtol,
- atol=atol)
- assert_allclose(layer_te.inter_attention.layernorm_query.weight.grad,
- layer_pd.inter_attention.layernorm_query.weight.grad.T,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.self_attention.layernorm_qkv.weight.grad,
+ layer_pd.self_attention.layernorm_qkv.weight.grad.T,
+ rtol=rtol,
+ atol=atol,
+ )
+ assert_allclose(
+ layer_te.inter_attention.layernorm_query.weight.grad,
+ layer_pd.inter_attention.layernorm_query.weight.grad.T,
+ rtol=rtol,
+ atol=atol,
+ )
if not no_dbias:
if output_layernorm:
- assert_allclose(layer_te.self_attention.qkv.bias.grad,
- layer_pd.self_attention.qkv.bias.grad,
- rtol=0.5,
- atol=0.6)
+ assert_allclose(
+ layer_te.self_attention.qkv.bias.grad,
+ layer_pd.self_attention.qkv.bias.grad,
+ rtol=0.5,
+ atol=0.6,
+ )
else:
- assert_allclose(layer_te.self_attention.layernorm_qkv.bias.grad,
- layer_pd.self_attention.layernorm_qkv.bias.grad,
- rtol=0.01,
- atol=0.5)
- assert_allclose(layer_te.inter_attention.layernorm_query.bias.grad,
- layer_pd.inter_attention.layernorm_query.bias.grad,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_te.self_attention.layernorm_qkv.bias.grad,
+ layer_pd.self_attention.layernorm_qkv.bias.grad,
+ rtol=0.01,
+ atol=0.5,
+ )
+ assert_allclose(
+ layer_te.inter_attention.layernorm_query.bias.grad,
+ layer_pd.inter_attention.layernorm_query.bias.grad,
+ rtol=rtol,
+ atol=atol,
+ )
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
-@pytest.mark.parametrize('bs', [8])
-@pytest.mark.parametrize('hidden_size, num_heads, ffn_hidden_size', [[1024, 16, 4096]])
-@pytest.mark.parametrize('q_seqlen, kv_seqlen', [[128, 128]])
-@pytest.mark.parametrize('mask_type', ['causal'])
-@pytest.mark.parametrize('math_dtype', ['bfloat16'])
-@pytest.mark.parametrize('num_microbatch', [8])
-def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hidden_size, q_seqlen,
- kv_seqlen, mask_type, math_dtype, num_microbatch):
+@pytest.mark.parametrize("bs", [8])
+@pytest.mark.parametrize("hidden_size, num_heads, ffn_hidden_size", [[1024, 16, 4096]])
+@pytest.mark.parametrize("q_seqlen, kv_seqlen", [[128, 128]])
+@pytest.mark.parametrize("mask_type", ["causal"])
+@pytest.mark.parametrize("math_dtype", ["bfloat16"])
+@pytest.mark.parametrize("num_microbatch", [8])
+def test_transformer_encoder_layer_microbatch(
+ bs,
+ hidden_size,
+ num_heads,
+ ffn_hidden_size,
+ q_seqlen,
+ kv_seqlen,
+ mask_type,
+ math_dtype,
+ num_microbatch,
+):
"""
Test Transformer Encoder Layer with FP8 weight caching
"""
@@ -1383,48 +1538,56 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
# Skip if cuDNN fused attention is not supported
if not is_fused_attention_supported(
- num_heads=num_heads,
- num_gqa_groups=num_heads,
- q_seqlen=q_seqlen,
- kv_seqlen=kv_seqlen,
- head_size=hidden_size // num_heads,
- dtype=math_dtype,
- dropout=0.0,
- qkv_layout="bs3hd",
- bias_type="no_bias",
- mask_type=mask_type,
+ num_heads=num_heads,
+ num_gqa_groups=num_heads,
+ q_seqlen=q_seqlen,
+ kv_seqlen=kv_seqlen,
+ head_size=hidden_size // num_heads,
+ dtype=math_dtype,
+ dropout=0.0,
+ qkv_layout="bs3hd",
+ bias_type="no_bias",
+ mask_type=mask_type,
):
pytest.skip("cuDNN fused attention is not supported")
- layer_cached = te.TransformerLayer(hidden_size,
- ffn_hidden_size,
- num_heads,
- layernorm_epsilon=eps,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- weight_attr=None,
- bias_attr=None,
- self_attn_mask_type=mask_type,
- layer_type='encoder')
- layer_normal = te.TransformerLayer(hidden_size,
- ffn_hidden_size,
- num_heads,
- layernorm_epsilon=eps,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- weight_attr=None,
- bias_attr=None,
- self_attn_mask_type=mask_type,
- layer_type='encoder')
+ layer_cached = te.TransformerLayer(
+ hidden_size,
+ ffn_hidden_size,
+ num_heads,
+ layernorm_epsilon=eps,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ weight_attr=None,
+ bias_attr=None,
+ self_attn_mask_type=mask_type,
+ layer_type="encoder",
+ )
+ layer_normal = te.TransformerLayer(
+ hidden_size,
+ ffn_hidden_size,
+ num_heads,
+ layernorm_epsilon=eps,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ weight_attr=None,
+ bias_attr=None,
+ self_attn_mask_type=mask_type,
+ layer_type="encoder",
+ )
layer_normal.self_attention.layernorm_qkv.ln_weight.copy_(
- layer_cached.self_attention.layernorm_qkv.ln_weight, True)
+ layer_cached.self_attention.layernorm_qkv.ln_weight, True
+ )
layer_normal.self_attention.layernorm_qkv.ln_bias.copy_(
- layer_cached.self_attention.layernorm_qkv.ln_bias, True)
+ layer_cached.self_attention.layernorm_qkv.ln_bias, True
+ )
layer_normal.self_attention.layernorm_qkv.weight.copy_(
- layer_cached.self_attention.layernorm_qkv.weight, True)
+ layer_cached.self_attention.layernorm_qkv.weight, True
+ )
layer_normal.self_attention.layernorm_qkv.bias.copy_(
- layer_cached.self_attention.layernorm_qkv.bias, True)
+ layer_cached.self_attention.layernorm_qkv.bias, True
+ )
layer_normal.self_attention.proj.weight.copy_(layer_cached.self_attention.proj.weight, True)
layer_normal.self_attention.proj.bias.copy_(layer_cached.self_attention.proj.bias, True)
@@ -1442,18 +1605,19 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
def generate_input():
encoder_input = paddle.uniform(shape=(bs, q_seqlen, hidden_size), dtype=math_dtype)
- q_actual_seqlen = paddle.ones(shape=(bs,), dtype='int32') * q_seqlen
+ q_actual_seqlen = paddle.ones(shape=(bs,), dtype="int32") * q_seqlen
kv_actual_seqlen = q_actual_seqlen
- attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype='bool')
+ attn_mask = paddle.ones(shape=(bs, 1, q_seqlen, kv_seqlen), dtype="bool")
- grad_out = paddle.normal(mean=0.0, std=0.02,
- shape=(bs, q_seqlen, hidden_size)).astype('float32')
+ grad_out = paddle.normal(mean=0.0, std=0.02, shape=(bs, q_seqlen, hidden_size)).astype(
+ "float32"
+ )
for i in range(0, bs):
- grad_out[i, q_actual_seqlen[i]:, :] = 0
+ grad_out[i, q_actual_seqlen[i] :, :] = 0
grad_out = grad_out.astype(math_dtype)
for i in range(0, bs):
- attn_mask[i, 0, 0:q_actual_seqlen[i], 0:kv_actual_seqlen[i]] = False
+ attn_mask[i, 0, 0 : q_actual_seqlen[i], 0 : kv_actual_seqlen[i]] = False
return encoder_input, attn_mask, grad_out
@@ -1477,7 +1641,9 @@ def test_transformer_encoder_layer_microbatch(bs, hidden_size, num_heads, ffn_hi
out_ref.backward(grad_out)
assert_allclose(out, out_ref, rtol=rtol, atol=atol)
- assert_allclose(layer_cached.self_attention.layernorm_qkv.weight.grad,
- layer_normal.self_attention.layernorm_qkv.weight.grad,
- rtol=rtol,
- atol=atol)
+ assert_allclose(
+ layer_cached.self_attention.layernorm_qkv.weight.grad,
+ layer_normal.self_attention.layernorm_qkv.weight.grad,
+ rtol=rtol,
+ atol=atol,
+ )
diff --git a/tests/paddle/test_master_grad.py b/tests/paddle/test_master_grad.py
index f5f67cba4e9fb647ac85d15938b49fbc70fc829e..4e029cf8ddd9a9fea1a6b03a06ffd2709587d653 100644
--- a/tests/paddle/test_master_grad.py
+++ b/tests/paddle/test_master_grad.py
@@ -16,7 +16,7 @@ is_fp8_supported, reason = is_fp8_available()
def create_optimizer(model, use_pure_bf16, use_main_grad):
- '''Create optimizer'''
+ """Create optimizer"""
if use_main_grad:
assert use_pure_bf16
model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16")
@@ -32,7 +32,7 @@ def create_optimizer(model, use_pure_bf16, use_main_grad):
class Net(paddle.nn.Layer):
- '''Network use for main_grad testing'''
+ """Network use for main_grad testing"""
def __init__(self, fuse_wgrad_accumulation):
super().__init__()
@@ -40,7 +40,7 @@ class Net(paddle.nn.Layer):
4096,
16384,
32,
- layer_type='encoder',
+ layer_type="encoder",
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
)
@@ -50,7 +50,7 @@ class Net(paddle.nn.Layer):
def train(enable_master_grad, fuse_wgrad_accumulation=False):
- '''Train function'''
+ """Train function"""
paddle.seed(10)
accumulate_steps = 4
@@ -64,7 +64,7 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False):
loss_list = []
for step_id in range(16):
- inp = paddle.uniform([2, 1024, 4096], dtype='float32')
+ inp = paddle.uniform([2, 1024, 4096], dtype="float32")
inp.stop_gradient = False
with te.fp8_autocast(enabled=True):
out = model(inp)
@@ -82,8 +82,8 @@ def train(enable_master_grad, fuse_wgrad_accumulation=False):
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
def test_master_grad():
- '''Test main_grad'''
- paddle.set_default_dtype('float32')
+ """Test main_grad"""
+ paddle.set_default_dtype("float32")
loss1 = train(enable_master_grad=False)
loss2 = train(enable_master_grad=True)
loss3 = train(enable_master_grad=True, fuse_wgrad_accumulation=True)
diff --git a/tests/paddle/test_operators.py b/tests/paddle/test_operators.py
index bb6a73eaecf66a0ab1595b4c5506ca5b9cd563ce..b3b856077513d2a39cc05133e4f4049678e5f70f 100644
--- a/tests/paddle/test_operators.py
+++ b/tests/paddle/test_operators.py
@@ -56,8 +56,13 @@ from transformer_engine.paddle.fp8 import is_fp8_available
from transformer_engine.paddle.constants import FP8FwdTensors
from transformer_engine.common.recipe import DelayedScaling
-GEMM_CASES = [(256, 256, 512), (32, 32, 32), (16384, 1024, 2816), (16384, 2816, 1024),
- (16384, 1024, 1024)]
+GEMM_CASES = [
+ (256, 256, 512),
+ (32, 32, 32),
+ (16384, 1024, 2816),
+ (16384, 2816, 1024),
+ (16384, 1024, 1024),
+]
is_fp8_supported, reason = is_fp8_available()
SELF_ATTN_CASES = [(2, 512, 12, 64)]
@@ -74,13 +79,13 @@ def setup():
yield
-@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
-@pytest.mark.parametrize('inplace', [True, False])
+@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
+@pytest.mark.parametrize("inplace", [True, False])
def test_quantize_dequantize(fp8_dtype, inplace):
"""
Test cast_to_fp8 and cast_from_fp8
"""
- a = paddle.rand(shape=(32, 32), dtype='float32')
+ a = paddle.rand(shape=(32, 32), dtype="float32")
# Init fp8_meta
fp8_meta = create_fp8_meta()
a_fp8 = paddle.zeros(shape=a.shape, dtype=paddle.uint8) if inplace else None
@@ -99,7 +104,7 @@ def copy_bits_from_float_to_uint16(f):
"""
Copy bits
"""
- return struct.unpack('> 16
+ return struct.unpack("> 16
def convert_float_to_uint16(float_list):
@@ -124,95 +129,106 @@ class TestTranspose:
"""
Test BF16 transpose
"""
- a = paddle.rand(shape=(16, 32), dtype='bfloat16')
+ a = paddle.rand(shape=(16, 32), dtype="bfloat16")
a_transposed = transpose(a, otype=tex.DType.kBFloat16)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
+ @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_transpose_fp8(fp8_dtype):
"""
Test FP8 transpose
"""
min_val = -8
max_val = 8
- a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
+ a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
a_fp8 = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
a_fp8_transposed = transpose(a_fp8, otype=fp8_dtype)
- a_transposed = cast_from_fp8(a_fp8_transposed,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
+ a_transposed = cast_from_fp8(
+ a_fp8_transposed,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
- @pytest.mark.parametrize('inplace', [True, False])
+ @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
+ @pytest.mark.parametrize("inplace", [True, False])
def test_cast_transpose(fp8_dtype, inplace):
"""
Test cast_transpose
"""
min_val = -8
max_val = 8
- a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
+ a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
a_fp8_casted, a_fp8_transposed = None, None
if inplace:
a_fp8_casted = paddle.zeros(shape=a.shape, dtype=paddle.uint8)
a_fp8_transposed = paddle.zeros(shape=a.T.shape, dtype=paddle.uint8)
- a_fp8_casted, a_fp8_transposed = cast_transpose(a,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- otype=fp8_dtype,
- cast_out=a_fp8_casted,
- transpose_out=a_fp8_transposed)
-
- a_transposed = cast_from_fp8(a_fp8_transposed,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
-
- a_casted = cast_from_fp8(a_fp8_casted,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
+ a_fp8_casted, a_fp8_transposed = cast_transpose(
+ a,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ otype=fp8_dtype,
+ cast_out=a_fp8_casted,
+ transpose_out=a_fp8_transposed,
+ )
+
+ a_transposed = cast_from_fp8(
+ a_fp8_transposed,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
+
+ a_casted = cast_from_fp8(
+ a_fp8_casted,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
+ @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_cast_transpose_bgrad(fp8_dtype):
"""
Test cast_transpose_bgrad
"""
min_val = -8
max_val = 8
- a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), 'float32')
+ a = paddle.cast(paddle.randint(min_val, max_val, shape=(16, 32)), "float32")
fp8_meta = create_fp8_meta()
- bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(a,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- otype=fp8_dtype)
-
- a_transposed = cast_from_fp8(a_fp8_transposed,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
-
- a_casted = cast_from_fp8(a_fp8_casted,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
+ bgrad, a_fp8_casted, a_fp8_transposed = cast_transpose_bgrad(
+ a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
+ )
+
+ a_transposed = cast_from_fp8(
+ a_fp8_transposed,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
+
+ a_casted = cast_from_fp8(
+ a_fp8_casted,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
assert_allclose(a_casted, a)
assert_allclose(a_transposed, a.T)
@@ -229,7 +245,7 @@ class TestActivation:
"""
Test BF16 GELU Forward
"""
- a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
+ a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
gelu_out = te_gelu(a, otype=tex.DType.kBFloat16)
gelu_ref = paddle.nn.GELU()(a)
@@ -237,21 +253,23 @@ class TestActivation:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
+ @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_fp8(fp8_dtype):
"""
Test FP8 GELU Forward
"""
- a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
+ a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
fp8_meta = create_fp8_meta()
gelu_out_fp8 = gelu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
- gelu_out = cast_from_fp8(gelu_out_fp8,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
+ gelu_out = cast_from_fp8(
+ gelu_out_fp8,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
gelu_ref = paddle.nn.GELU()(a)
@@ -259,36 +277,38 @@ class TestActivation:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
+ @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_gelu_bwd_fp8(fp8_dtype):
"""
Test FP8 GELU Backward
"""
# y = GELU(x), calculate ref
- x = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
+ x = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
x.stop_gradient = False
y = paddle.nn.GELU()(x)
- y_grad = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
+ y_grad = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
fp8_meta = create_fp8_meta()
- x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(y_grad,
- x,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- otype=fp8_dtype)
-
- x_grad = cast_from_fp8(x_grad_fp8,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
-
- x_grad_t = cast_from_fp8(x_grad_t_fp8,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
+ x_grad_fp8, x_grad_t_fp8, dbias = dgelu_cast_transpose_bgrad_fp8(
+ y_grad, x, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype
+ )
+
+ x_grad = cast_from_fp8(
+ x_grad_fp8,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
+
+ x_grad_t = cast_from_fp8(
+ x_grad_t_fp8,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)
assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
@@ -299,7 +319,7 @@ class TestActivation:
"""
Test BF16 SwiGLU Forward
"""
- a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
+ a = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
swiglu_out = swiglu(a, otype=tex.DType.kBFloat16)
swiglu_ref = swiglu_pd(a)
@@ -307,21 +327,23 @@ class TestActivation:
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
+ @pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_swiglu_fp8(fp8_dtype):
"""
Test FP8 SwiGLU Forward
"""
- a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
+ a = paddle.rand(shape=(16, 32), dtype="float32") * 2 - 1
fp8_meta = create_fp8_meta()
swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
- swiglu_out = cast_from_fp8(swiglu_out_fp8,
- fp8_meta,
- FP8FwdTensors.GEMM1_INPUT,
- itype=fp8_dtype,
- otype=tex.DType.kFloat32)
+ swiglu_out = cast_from_fp8(
+ swiglu_out_fp8,
+ fp8_meta,
+ FP8FwdTensors.GEMM1_INPUT,
+ itype=fp8_dtype,
+ otype=tex.DType.kFloat32,
+ )
swiglu_ref = swiglu_pd(a)
@@ -333,10 +355,10 @@ class TestActivation:
Test SwiGLU Backward
"""
# y = SwiGLU(x), calculate ref
- x = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
+ x = paddle.rand(shape=(16, 32), dtype="bfloat16") * 2 - 1
x.stop_gradient = False
y = swiglu_pd(x)
- y_grad = paddle.rand(shape=(16, 16), dtype='bfloat16') * 2 - 1
+ y_grad = paddle.rand(shape=(16, 16), dtype="bfloat16") * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16)
@@ -350,17 +372,18 @@ class TestGemm:
"""
@staticmethod
- @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
- reason="BF16 GEMM requires Ampere+ GPU")
- @pytest.mark.parametrize('m,n,k', GEMM_CASES)
+ @pytest.mark.skipif(
+ paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
+ )
+ @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_bf16(m, n, k):
"""
Test "TN" BF16 GEMM
"""
- a = paddle.rand(shape=(m, k), dtype='bfloat16')
- b = paddle.rand(shape=(n, k), dtype='bfloat16')
+ a = paddle.rand(shape=(m, k), dtype="bfloat16")
+ b = paddle.rand(shape=(n, k), dtype="bfloat16")
- workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
+ workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = paddle.matmul(a, b.T)
# CublasLt inside tex.te_gemm assumes inputs are column major.
@@ -368,37 +391,51 @@ class TestGemm:
# transpose of X.
# Here we perform "TN" GEMM in column major, i.e., b@a^T = C^T,
# which is equivalent to a@b^T = C in row major.
- actual_out, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, False, "TN",
- None, None, False)
+ actual_out, _, _ = gemm(
+ b, a, paddle.bfloat16, workspace, False, None, False, False, "TN", None, None, False
+ )
assert_allclose(actual_out, ref_out, rtol=1.6e-2, atol=1e-5)
@staticmethod
- @pytest.mark.skipif(paddle.device.cuda.get_device_capability() < (8, 0),
- reason="BF16 GEMM requires Ampere+ GPU")
- @pytest.mark.parametrize('m,n,k', GEMM_CASES)
+ @pytest.mark.skipif(
+ paddle.device.cuda.get_device_capability() < (8, 0), reason="BF16 GEMM requires Ampere+ GPU"
+ )
+ @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_bf16_inplace(m, n, k):
"""
Test "TN" BF16 GEMM, with accumulate=True
"""
min_val = -16
max_val = 16
- a = paddle.rand(shape=(m, k), dtype='bfloat16')
- b = paddle.rand(shape=(n, k), dtype='bfloat16')
- c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), 'bfloat16')
- workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
+ a = paddle.rand(shape=(m, k), dtype="bfloat16")
+ b = paddle.rand(shape=(n, k), dtype="bfloat16")
+ c = paddle.cast(paddle.randint(min_val, max_val, shape=(m, n)), "bfloat16")
+ workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = c + paddle.matmul(a, b.T)
actual_out = paddle.clone(c)
- _, _, _ = gemm(b, a, paddle.bfloat16, workspace, False, None, False, True, "TN", actual_out,
- None, False)
+ _, _, _ = gemm(
+ b,
+ a,
+ paddle.bfloat16,
+ workspace,
+ False,
+ None,
+ False,
+ True,
+ "TN",
+ actual_out,
+ None,
+ False,
+ )
assert_allclose(actual_out, ref_out, rtol=5e-2, atol=5e-2)
@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
- @pytest.mark.parametrize('m,n,k', GEMM_CASES)
+ @pytest.mark.parametrize("m,n,k", GEMM_CASES)
def test_fp8_randint(m, n, k):
"""
Test "TN" FP8 GEMM
@@ -409,17 +446,26 @@ class TestGemm:
out_dtype = paddle.float32
fp8_meta = create_fp8_meta(num_gemms=1)
- a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), 'float32')
+ a = paddle.cast(paddle.randint(min_val, max_val, shape=(m, k)), "float32")
a_casted = cast_to_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)
- b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), 'float32')
+ b = paddle.cast(paddle.randint(min_val, max_val, shape=(n, k)), "float32")
b_casted = cast_to_fp8(b, fp8_meta, FP8FwdTensors.GEMM1_WEIGHT, otype=fp8_dtype)
- workspace = paddle.zeros(shape=[33_554_432], dtype='uint8')
+ workspace = paddle.zeros(shape=[33_554_432], dtype="uint8")
ref_out = paddle.matmul(a, b.T)
- actual_out, _ = fp8_gemm(b_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_WEIGHT,
- fp8_dtype, a_casted, fp8_meta.scale_inv, FP8FwdTensors.GEMM1_INPUT,
- fp8_dtype, out_dtype, workspace)
+ actual_out, _ = fp8_gemm(
+ b_casted,
+ fp8_meta.scale_inv,
+ FP8FwdTensors.GEMM1_WEIGHT,
+ fp8_dtype,
+ a_casted,
+ fp8_meta.scale_inv,
+ FP8FwdTensors.GEMM1_INPUT,
+ fp8_dtype,
+ out_dtype,
+ workspace,
+ )
assert_allclose(actual_out, ref_out)
@@ -434,14 +480,12 @@ class TestLayerNorm:
"""
Calculate reference using paddle layer_norm op
"""
- y = paddle.nn.functional.layer_norm(x=x,
- normalized_shape=x.shape[1:],
- weight=gamma,
- bias=beta,
- epsilon=eps)
+ y = paddle.nn.functional.layer_norm(
+ x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
+ )
mean = paddle.mean(x, axis=-1)
var = paddle.var(x, axis=-1)
- inv_var = paddle.sqrt(1. / var)
+ inv_var = paddle.sqrt(1.0 / var)
return y, mean, inv_var
@staticmethod
@@ -453,11 +497,9 @@ class TestLayerNorm:
gamma.stop_gradient = False
beta.stop_gradient = False
- y = paddle.nn.functional.layer_norm(x=x,
- normalized_shape=x.shape[1:],
- weight=gamma,
- bias=beta,
- epsilon=eps)
+ y = paddle.nn.functional.layer_norm(
+ x=x, normalized_shape=x.shape[1:], weight=gamma, bias=beta, epsilon=eps
+ )
paddle.autograd.backward([y], [dy], True)
@@ -469,9 +511,9 @@ class TestLayerNorm:
"""
N, H = (16, 32)
eps = 1e-3
- x = paddle.uniform(shape=(N, H), dtype='bfloat16')
- gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
- beta = paddle.uniform(shape=(H,), dtype='bfloat16')
+ x = paddle.uniform(shape=(N, H), dtype="bfloat16")
+ gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
+ beta = paddle.uniform(shape=(H,), dtype="bfloat16")
y, mu, rsigma = layernorm_fwd(x, gamma, beta, eps, tex.DType.kBFloat16)
@@ -490,9 +532,9 @@ class TestLayerNorm:
N, H = (16, 32)
eps = 1e-3
- x = paddle.uniform(shape=(N, H), dtype='float32')
- gamma = paddle.uniform(shape=(H,), dtype='float32')
- beta = paddle.uniform(shape=(H,), dtype='float32')
+ x = paddle.uniform(shape=(N, H), dtype="float32")
+ gamma = paddle.uniform(shape=(H,), dtype="float32")
+ beta = paddle.uniform(shape=(H,), dtype="float32")
fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta()
@@ -513,10 +555,10 @@ class TestLayerNorm:
"""
N, H = (16, 32)
eps = 1e-3
- x = paddle.uniform(shape=(N, H), dtype='bfloat16')
- dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
- gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
- beta = paddle.uniform(shape=(H,), dtype='bfloat16')
+ x = paddle.uniform(shape=(N, H), dtype="bfloat16")
+ dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
+ gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
+ beta = paddle.uniform(shape=(H,), dtype="bfloat16")
dx_ref, dgamma_ref, dbeta_ref = self.calc_bwd_ref(x, eps, gamma, beta, dy)
@@ -563,8 +605,8 @@ class TestRMSNorm:
"""
N, H = (16, 32)
eps = 1e-3
- x = paddle.uniform(shape=(N, H), dtype='bfloat16')
- gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
+ x = paddle.uniform(shape=(N, H), dtype="bfloat16")
+ gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
y, _ = rmsnorm_fwd(x, gamma, eps, tex.DType.kBFloat16)
@@ -581,8 +623,8 @@ class TestRMSNorm:
N, H = (16, 32)
eps = 1e-3
- x = paddle.uniform(shape=(N, H), dtype='float32')
- gamma = paddle.uniform(shape=(H,), dtype='float32')
+ x = paddle.uniform(shape=(N, H), dtype="float32")
+ gamma = paddle.uniform(shape=(H,), dtype="float32")
fp8_tensor = FP8FwdTensors.GEMM1_INPUT
fp8_meta = create_fp8_meta()
@@ -602,9 +644,9 @@ class TestRMSNorm:
"""
N, H = (16, 32)
eps = 1e-3
- x = paddle.uniform(shape=(N, H), dtype='bfloat16')
- dy = paddle.uniform(shape=(N, H), dtype='bfloat16')
- gamma = paddle.uniform(shape=(H,), dtype='bfloat16')
+ x = paddle.uniform(shape=(N, H), dtype="bfloat16")
+ dy = paddle.uniform(shape=(N, H), dtype="bfloat16")
+ gamma = paddle.uniform(shape=(H,), dtype="bfloat16")
dx_ref, dgamma_ref = self.calc_bwd_ref(x, eps, gamma, dy)
@@ -620,7 +662,7 @@ class TestFusedAttn:
Test fused attention operators
"""
- def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode='self_attn', is_causal_masking=False):
+ def set_input(self, b, s_q, s_kv, h, d, dtype, attn_mode="self_attn", is_causal_masking=False):
"""
set test input
"""
@@ -682,10 +724,10 @@ class TestFusedAttn:
assert attn_mode == "self_attn", "only support causal masking for self attention"
for i in range(0, self.batch_size):
for j in range(self.q_actual_seqlen[i]):
- self.attn_mask[i, :, j, :j + 1] = 0
+ self.attn_mask[i, :, j, : j + 1] = 0
else:
for i in range(0, self.batch_size):
- self.attn_mask[i, :, :self.q_actual_seqlen[i], :self.kv_actual_seqlen[i]] = 0
+ self.attn_mask[i, :, : self.q_actual_seqlen[i], : self.kv_actual_seqlen[i]] = 0
dout = _random((self.batch_size, self.q_seqlen, self.num_heads, self.head_size))
self.dout = paddle.to_tensor(dout, dtype=self.dtype)
@@ -696,9 +738,9 @@ class TestFusedAttn:
k_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
v_tensor = paddle.to_tensor(self.kv, stop_gradient=False)
- q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
- k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
- v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
+ q_out = paddle.transpose(x=q_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
+ k_out = paddle.transpose(x=k_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
+ v_out = paddle.transpose(x=v_tensor, perm=[0, 2, 1, 3]) # [b, s, h, d] -> [b, h, s, d]
qk_out = paddle.matmul(
x=q_out * self.scaling_factor,
@@ -707,10 +749,10 @@ class TestFusedAttn:
transpose_y=True,
)
- attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast('bool')
+ attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=True).cast("bool")
attn_mask_vals = paddle.full(qk_out.shape, -1e4, qk_out.dtype)
attn_mask_out = paddle.where(attn_mask, attn_mask_vals, qk_out)
- attn_mask_out = paddle.cast(attn_mask_out, 'float32')
+ attn_mask_out = paddle.cast(attn_mask_out, "float32")
softmax_out = F.softmax(attn_mask_out)
softmax_out = paddle.cast(softmax_out, self.dtype)
@@ -725,7 +767,7 @@ class TestFusedAttn:
else:
qkv_out = paddle.matmul(softmax_out, v_out)
- out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d]
+ out = paddle.transpose(qkv_out, perm=[0, 2, 1, 3]) # [b, h, s, d] -> [b, s, h, d]
paddle.autograd.backward(
[out],
@@ -738,17 +780,17 @@ class TestFusedAttn:
paddle.disable_static(place=paddle.CUDAPlace(0))
if self.attn_mode == "self_attn":
- qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d]
+ qkv = np.stack([self.q, self.kv, self.kv], axis=2) # [b, s, 3, h, d]
qkv_tensor = paddle.to_tensor(qkv, stop_gradient=False)
else:
q_tensor = paddle.to_tensor(self.q, stop_gradient=False)
- kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d]
+ kv = np.stack([self.kv, self.kv], axis=2) # [b, s, 2, h, d]
kv_tensor = paddle.to_tensor(kv, stop_gradient=False)
q_cu_seqlen_tensor = paddle.to_tensor(self.q_cu_seqlen, dtype="int32", stop_gradient=True)
kv_cu_seqlen_tensor = paddle.to_tensor(self.kv_cu_seqlen, dtype="int32", stop_gradient=True)
- qkv_layout = ("bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd")
+ qkv_layout = "bs3hd" if self.attn_mode == "self_attn" else "bshd_bs2hd"
fused_attention_backend = get_fused_attention_backend(
num_heads=self.num_heads,
num_gqa_groups=self.num_heads,
@@ -764,7 +806,7 @@ class TestFusedAttn:
qkv_dtype = tex.DType.kBFloat16 if self.dtype == "bfloat16" else tex.DType.kFloat16
out, softmax_aux_tensor, q_grad, k_grad, v_grad = None, None, None, None, None
- if self.attn_mode == 'self_attn':
+ if self.attn_mode == "self_attn":
out, softmax_aux_tensor, rng_state = fused_attn_fwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
@@ -776,7 +818,8 @@ class TestFusedAttn:
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
- attn_mask_type="causal" if self.is_causal_masking else "padding")
+ attn_mask_type="causal" if self.is_causal_masking else "padding",
+ )
dqkv, _ = fused_attn_bwd_qkvpacked(
qkv_tensor,
q_cu_seqlen_tensor,
@@ -790,11 +833,12 @@ class TestFusedAttn:
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
set_zero=False,
- attn_mask_type="causal" if self.is_causal_masking else "padding")
+ attn_mask_type="causal" if self.is_causal_masking else "padding",
+ )
q_grad = dqkv[:, :, 0, :, :]
k_grad = dqkv[:, :, 1, :, :]
v_grad = dqkv[:, :, 2, :, :]
- else: # attn_mode == 'cross_attn'
+ else: # attn_mode == 'cross_attn'
out, softmax_aux_tensor, rng_state = fused_attn_fwd_kvpacked(
q_tensor,
kv_tensor,
@@ -808,22 +852,25 @@ class TestFusedAttn:
Bias=None,
attn_scale=self.scaling_factor,
dropout=self.dropout_prob,
- set_zero=False)
- dq, dkv, _ = fused_attn_bwd_kvpacked(q_tensor,
- kv_tensor,
- q_cu_seqlen_tensor,
- kv_cu_seqlen_tensor,
- rng_state,
- out,
- self.dout,
- softmax_aux_tensor,
- fused_attention_backend=fused_attention_backend,
- max_seqlen_q=self.q_seqlen,
- max_seqlen_kv=self.kv_seqlen,
- qkv_dtype=qkv_dtype,
- attn_scale=self.scaling_factor,
- dropout=self.dropout_prob,
- set_zero=False)
+ set_zero=False,
+ )
+ dq, dkv, _ = fused_attn_bwd_kvpacked(
+ q_tensor,
+ kv_tensor,
+ q_cu_seqlen_tensor,
+ kv_cu_seqlen_tensor,
+ rng_state,
+ out,
+ self.dout,
+ softmax_aux_tensor,
+ fused_attention_backend=fused_attention_backend,
+ max_seqlen_q=self.q_seqlen,
+ max_seqlen_kv=self.kv_seqlen,
+ qkv_dtype=qkv_dtype,
+ attn_scale=self.scaling_factor,
+ dropout=self.dropout_prob,
+ set_zero=False,
+ )
q_grad = dq
k_grad = dkv[:, :, 0, :, :]
v_grad = dkv[:, :, 1, :, :]
@@ -871,7 +918,8 @@ class TestFusedAttn:
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
- attn_mask_type="causal" if self.is_causal_masking else "padding")
+ attn_mask_type="causal" if self.is_causal_masking else "padding",
+ )
dq, dk, dv, _ = fused_attn_bwd(
q_tensor,
k_tensor,
@@ -890,28 +938,29 @@ class TestFusedAttn:
dropout=self.dropout_prob,
set_zero=False,
qkv_layout=qkv_layout,
- attn_mask_type="causal" if self.is_causal_masking else "padding")
+ attn_mask_type="causal" if self.is_causal_masking else "padding",
+ )
return out, dq, dk, dv
- @pytest.mark.parametrize('b, s, h, d', SELF_ATTN_CASES)
- @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
- @pytest.mark.parametrize('is_causal_masking', [True, False])
+ @pytest.mark.parametrize("b, s, h, d", SELF_ATTN_CASES)
+ @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
+ @pytest.mark.parametrize("is_causal_masking", [True, False])
def test_self_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test self attention forward + backward
"""
if not is_fused_attention_supported(
- num_heads=h,
- num_gqa_groups=h,
- q_seqlen=s,
- kv_seqlen=s,
- head_size=d,
- dtype=dtype,
- dropout=0.0,
- qkv_layout="bs3hd",
- bias_type="no_bias",
- mask_type="causal" if is_causal_masking else "padding",
+ num_heads=h,
+ num_gqa_groups=h,
+ q_seqlen=s,
+ kv_seqlen=s,
+ head_size=d,
+ dtype=dtype,
+ dropout=0.0,
+ qkv_layout="bs3hd",
+ bias_type="no_bias",
+ mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
@@ -922,23 +971,23 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
- @pytest.mark.parametrize('b, s_q, s_kv, h, d', CROSS_ATTN_CASES)
- @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
+ @pytest.mark.parametrize("b, s_q, s_kv, h, d", CROSS_ATTN_CASES)
+ @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_cross_attn_forward_backward(self, b, s_q, s_kv, h, d, dtype):
"""
test cross attention forward + backward
"""
if not is_fused_attention_supported(
- num_heads=h,
- num_gqa_groups=h,
- q_seqlen=s_q,
- kv_seqlen=s_kv,
- head_size=d,
- dtype=dtype,
- dropout=0.0,
- qkv_layout="bshd_bs2hd",
- bias_type="no_bias",
- mask_type="padding",
+ num_heads=h,
+ num_gqa_groups=h,
+ q_seqlen=s_q,
+ kv_seqlen=s_kv,
+ head_size=d,
+ dtype=dtype,
+ dropout=0.0,
+ qkv_layout="bshd_bs2hd",
+ bias_type="no_bias",
+ mask_type="padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s_q, s_kv, h, d, dtype, "cross_attn")
@@ -949,24 +998,24 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
- @pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
- @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
- @pytest.mark.parametrize('is_causal_masking', [True])
+ @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
+ @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
+ @pytest.mark.parametrize("is_causal_masking", [True])
def test_flash_attn_forward_backward(self, b, s, h, d, dtype, is_causal_masking):
"""
test flash attention forward + backward
"""
if not is_fused_attention_supported(
- num_heads=h,
- num_gqa_groups=h,
- q_seqlen=s,
- kv_seqlen=s,
- head_size=d,
- dtype=dtype,
- dropout=0.0,
- qkv_layout="bs3hd",
- bias_type="no_bias",
- mask_type="causal" if is_causal_masking else "padding",
+ num_heads=h,
+ num_gqa_groups=h,
+ q_seqlen=s,
+ kv_seqlen=s,
+ head_size=d,
+ dtype=dtype,
+ dropout=0.0,
+ qkv_layout="bs3hd",
+ bias_type="no_bias",
+ mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
@@ -977,25 +1026,26 @@ class TestFusedAttn:
assert_allclose(k_grad_ref, k_grad, rtol=1e-3, atol=1e-2)
assert_allclose(v_grad_ref, v_grad, rtol=1e-3, atol=1e-2)
- @pytest.mark.parametrize('b, s, h, d', FLASH_ATTN_CASES)
- @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
- @pytest.mark.parametrize('is_causal_masking', [False, True])
- def test_fused_attn_with_separate_qkv_forward_backward(self, b, s, h, d, dtype,
- is_causal_masking):
+ @pytest.mark.parametrize("b, s, h, d", FLASH_ATTN_CASES)
+ @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
+ @pytest.mark.parametrize("is_causal_masking", [False, True])
+ def test_fused_attn_with_separate_qkv_forward_backward(
+ self, b, s, h, d, dtype, is_causal_masking
+ ):
"""
test flash attention forward + backward with separate qkv inputs
"""
if not is_fused_attention_supported(
- num_heads=h,
- num_gqa_groups=h,
- q_seqlen=s,
- kv_seqlen=s,
- head_size=d,
- dtype=dtype,
- dropout=0.0,
- qkv_layout="bshd_bshd_bshd",
- bias_type="no_bias",
- mask_type="causal" if is_causal_masking else "padding",
+ num_heads=h,
+ num_gqa_groups=h,
+ q_seqlen=s,
+ kv_seqlen=s,
+ head_size=d,
+ dtype=dtype,
+ dropout=0.0,
+ qkv_layout="bshd_bshd_bshd",
+ bias_type="no_bias",
+ mask_type="causal" if is_causal_masking else "padding",
):
pytest.skip("cuDNN fused attention is not supported")
self.set_input(b, s, s, h, d, dtype, "self_attn", is_causal_masking)
@@ -1013,7 +1063,7 @@ class TestSoftmax:
"""
@staticmethod
- @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
+ @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_softmax_fwd_bwd(dtype):
"""test scaled softmax"""
B, H, S = (16, 4, 32)
@@ -1034,7 +1084,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod
- @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
+ @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_masked_softmax_fwd_bwd(dtype):
"""test scaled masked softmax"""
B, H, S = (16, 4, 32)
@@ -1058,7 +1108,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=1e-3)
@staticmethod
- @pytest.mark.parametrize('dtype', ['float16', 'bfloat16'])
+ @pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
def test_scaled_upper_triang_masked_softmax_fwd_bwd(dtype):
"""test scaled upper triang masked softmax"""
B, S = (16, 32)
@@ -1068,7 +1118,7 @@ class TestSoftmax:
x.stop_gradient = False
dy = paddle.uniform(shape=(B, S, S), dtype=dtype)
- mask = paddle.ones((S, S), dtype='int32')
+ mask = paddle.ones((S, S), dtype="int32")
col_beg, col_end = 1, S
for row in range(0, S):
mask[row, col_beg:col_end] = 0
@@ -1087,7 +1137,7 @@ class TestSoftmax:
assert_allclose(dx_ref, dx, rtol=1e-4, atol=5e-3)
-@pytest.mark.parametrize('update_weight_scale_inv', [True, False])
+@pytest.mark.parametrize("update_weight_scale_inv", [True, False])
def test_amax_and_scale_update(update_weight_scale_inv):
"""Test update_scale"""
num_gemm = 6
@@ -1097,11 +1147,11 @@ def test_amax_and_scale_update(update_weight_scale_inv):
fp8_max = recipe.fp8_format.value.max_fwd
non_weight_mask = paddle.to_tensor([True, False] * (num_gemm // 2))
- amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
+ amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
rolled_history_ref = paddle.roll(amax_history_tensor, -1, axis=0)
rolled_history_ref[0] = 0.0
amax_tensor = paddle.max(amax_history_tensor, axis=0)
- scale_tensor = paddle.ones(shape=[num_gemm], dtype='float32')
+ scale_tensor = paddle.ones(shape=[num_gemm], dtype="float32")
def calc_ref(amax, scale, fp8_max, margin=0):
"""Calculate reference scale"""
@@ -1110,12 +1160,12 @@ def test_amax_and_scale_update(update_weight_scale_inv):
sf = paddle.where(paddle.isfinite(amax), sf, scale)
return sf
- scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.)
+ scale_ref = calc_ref(amax_tensor, scale_tensor, fp8_max, 0.0)
if update_weight_scale_inv:
- scale_inv_ref = 1. / scale_ref
+ scale_inv_ref = 1.0 / scale_ref
else:
scale_inv_ref = paddle.zeros_like(scale_tensor)
- scale_inv_ref = paddle.where(non_weight_mask, 1. / scale_ref, scale_inv_ref)
+ scale_inv_ref = paddle.where(non_weight_mask, 1.0 / scale_ref, scale_inv_ref)
# Placeholder
scale_actual = paddle.zeros_like(scale_tensor)
@@ -1123,13 +1173,15 @@ def test_amax_and_scale_update(update_weight_scale_inv):
if update_weight_scale_inv:
non_weight_mask = paddle.empty([0])
- tex.amax_and_scale_update_inplace(_amax_history=amax_history_tensor,
- _scale=scale_actual,
- _scale_inv=scale_inv_actual,
- non_weight_mask=non_weight_mask,
- fp8_dtype=int(fp8_dtype),
- margin=0.,
- amax_compute="max")
+ tex.amax_and_scale_update_inplace(
+ _amax_history=amax_history_tensor,
+ _scale=scale_actual,
+ _scale_inv=scale_inv_actual,
+ non_weight_mask=non_weight_mask,
+ fp8_dtype=int(fp8_dtype),
+ margin=0.0,
+ amax_compute="max",
+ )
assert_allclose(scale_actual, scale_ref, rtol=1e-7, atol=1e-7)
assert_allclose(scale_inv_actual, scale_inv_ref, rtol=1e-7, atol=1e-7)
@@ -1141,8 +1193,8 @@ def test_update_latest_history():
num_gemm = 6
history_len = 1024
- amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype='float32')
- amax = paddle.rand(shape=[num_gemm], dtype='float32')
+ amax_history_tensor = paddle.rand(shape=[history_len, num_gemm], dtype="float32")
+ amax = paddle.rand(shape=[num_gemm], dtype="float32")
tex.update_latest_amax_history_inplace(_history=amax_history_tensor, amax=amax)
diff --git a/tests/paddle/test_parallel.py b/tests/paddle/test_parallel.py
index 9c8f399e3e01283b64b070a382b7cc3b739bbfaa..f07d56d44bd9539ac228fe2046483b5bc158b002 100644
--- a/tests/paddle/test_parallel.py
+++ b/tests/paddle/test_parallel.py
@@ -22,7 +22,7 @@ class TestParallelLinear(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_linear_tp(self):
"""Tests linear with tensor parallel in BF16"""
- self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_tp.py'))
+ self.run_2gpu(str(test_root / "parallel_tests" / "linear_tp.py"))
class TestParallelLayerNormLinear(TestDistributed):
@@ -32,7 +32,7 @@ class TestParallelLayerNormLinear(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_linear_tp(self):
"""Tests layernorm_linear with tensor parallel in BF16"""
- self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_linear_tp.py'))
+ self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_linear_tp.py"))
class TestParallelLayerNormMLP(TestDistributed):
@@ -42,7 +42,7 @@ class TestParallelLayerNormMLP(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_layernorm_mlp_tp(self):
"""Tests layernorm_mlp with tensor parallel in BF16"""
- self.run_2gpu(str(test_root / 'parallel_tests' / 'layernorm_mlp_tp.py'))
+ self.run_2gpu(str(test_root / "parallel_tests" / "layernorm_mlp_tp.py"))
class TestAmaxReduction(TestDistributed):
@@ -52,7 +52,7 @@ class TestAmaxReduction(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_amax_reduction(self):
"""Tests amax reduction"""
- self.run_2gpu(str(test_root / 'parallel_tests' / 'amax_reduction.py'))
+ self.run_2gpu(str(test_root / "parallel_tests" / "amax_reduction.py"))
class TestPipelineParallel(TestDistributed):
@@ -62,7 +62,7 @@ class TestPipelineParallel(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_pipeline_parallel(self):
"""Tests pipeline parallel"""
- self.run_2gpu(str(test_root / 'parallel_tests' / 'linear_pp.py'))
+ self.run_2gpu(str(test_root / "parallel_tests" / "linear_pp.py"))
class TestGroupSharding(TestDistributed):
@@ -72,7 +72,7 @@ class TestGroupSharding(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_group_sharding(self):
"""Tests group sharding"""
- self.run_2gpu(str(test_root / 'parallel_tests' / 'group_sharding.py'))
+ self.run_2gpu(str(test_root / "parallel_tests" / "group_sharding.py"))
class TestParallelAttention(TestDistributed):
@@ -82,7 +82,7 @@ class TestParallelAttention(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_attention_tp(self):
"""Tests TransMultiHeadAttentionformer Layer with tensor parallel in BF16"""
- self.run_2gpu(str(test_root / 'parallel_tests' / 'attention_tp.py'))
+ self.run_2gpu(str(test_root / "parallel_tests" / "attention_tp.py"))
class TestParallelTransformerLayer(TestDistributed):
@@ -92,8 +92,8 @@ class TestParallelTransformerLayer(TestDistributed):
@unittest.skipIf(not gpu_has_fp8, reason)
def test_transformer_tp(self):
"""Tests Transformer Layer with tensor parallel in BF16"""
- self.run_2gpu(str(test_root / 'parallel_tests' / 'transformer_tp.py'))
+ self.run_2gpu(str(test_root / "parallel_tests" / "transformer_tp.py"))
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/tests/paddle/test_recompute.py b/tests/paddle/test_recompute.py
index 769921c6f57045d82c63e34e26c1c5b7f048f3fb..02dddad2107483c8528bca6327fc0dcde0de2691 100644
--- a/tests/paddle/test_recompute.py
+++ b/tests/paddle/test_recompute.py
@@ -17,7 +17,7 @@ is_fp8_supported, reason = is_fp8_available()
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
-@pytest.mark.parametrize('use_reentrant', [False, True])
+@pytest.mark.parametrize("use_reentrant", [False, True])
def test_transformer_encoder_recompute(use_reentrant):
"""
Test TransformerLayer encoder recompute
@@ -29,17 +29,17 @@ def test_transformer_encoder_recompute(use_reentrant):
"""Launch training in subprocess and check output"""
try:
cmd = [
- 'python',
- str(test_root / 'recompute_tests' / 'recompute_transformer_encoder.py'),
+ "python",
+ str(test_root / "recompute_tests" / "recompute_transformer_encoder.py"),
str(int(enable_recompute)),
- str(int(use_reentrant))
+ str(int(use_reentrant)),
]
result = subprocess.check_output(cmd, stderr=subprocess.STDOUT, universal_newlines=True)
print(result)
- loss_match = re.search(r'Loss:\s+(-?\d+\.\d+)', result)
- memory_match = re.search(r'Peak memory:\s+(\d+)', result)
+ loss_match = re.search(r"Loss:\s+(-?\d+\.\d+)", result)
+ memory_match = re.search(r"Peak memory:\s+(\d+)", result)
loss_value = float(loss_match.group(1))
memory_value = int(memory_match.group(1))
diff --git a/tests/paddle/test_sanity_import.py b/tests/paddle/test_sanity_import.py
index 8245de4bcf0c028aa233c0f10ff261146334984c..9b38d543da7a687b8c952b110953686025b39065 100644
--- a/tests/paddle/test_sanity_import.py
+++ b/tests/paddle/test_sanity_import.py
@@ -3,4 +3,5 @@
# See LICENSE for license information.
import transformer_engine.paddle
+
print("OK")
diff --git a/tests/paddle/utils.py b/tests/paddle/utils.py
index 8d65711939ab532da5dd35b91027e4faaad32177..572af66ff9979b9dc5a0a66c111a8411532c32d0 100644
--- a/tests/paddle/utils.py
+++ b/tests/paddle/utils.py
@@ -11,7 +11,7 @@ import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker
-import transformer_engine # pylint: disable=unused-import
+import transformer_engine # pylint: disable=unused-import
from transformer_engine.paddle.constants import (
TE_DType,
AttnBiasType,
@@ -19,7 +19,9 @@ from transformer_engine.paddle.constants import (
FusedAttnBackend,
)
from transformer_engine.paddle.fp8 import FP8TensorMeta
-from transformer_engine import transformer_engine_paddle as tex # pylint: disable=wrong-import-order
+from transformer_engine import (
+ transformer_engine_paddle as tex,
+) # pylint: disable=wrong-import-order
def create_fp8_meta(num_gemms=1, amax_history_len=10):
@@ -31,18 +33,14 @@ def create_fp8_meta(num_gemms=1, amax_history_len=10):
return fp8_meta
-def assert_allclose(actual,
- desired,
- rtol=1e-05,
- atol=1e-08,
- equal_nan=True,
- err_msg='',
- verbose=True):
+def assert_allclose(
+ actual, desired, rtol=1e-05, atol=1e-08, equal_nan=True, err_msg="", verbose=True
+):
"""Compare two input paddle tensors"""
if isinstance(actual, paddle.Tensor):
- actual = paddle.cast(actual, 'float32')
+ actual = paddle.cast(actual, "float32")
if isinstance(desired, paddle.Tensor):
- desired = paddle.cast(desired, 'float32')
+ desired = paddle.cast(desired, "float32")
if len(actual.shape) == 0:
actual = actual.item()
desired = desired.item()
@@ -54,8 +52,9 @@ def assert_allclose(actual,
def assert_shape(inp, expected_shape):
"""Assert the shape of input tensor equals to expected shape"""
- assert inp.shape == expected_shape, f"Expected tensor shape: {expected_shape} != " \
- f"actual tensor shape: {inp.shape}"
+ assert (
+ inp.shape == expected_shape
+ ), f"Expected tensor shape: {expected_shape} != actual tensor shape: {inp.shape}"
def is_devices_enough(required):
@@ -91,12 +90,21 @@ def set_random_seed(seed):
np.random.seed(seed + 100 * pp_rank)
seed_offset = seed + 1024 + paddle.distributed.get_world_size()
- global_seed = (seed_offset + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
- sharding_rank * (mp_size * pp_size * dp_size))
+ global_seed = (
+ seed_offset
+ + pp_rank * (mp_size)
+ + dp_rank * (mp_size * pp_size)
+ + sharding_rank * (mp_size * pp_size * dp_size)
+ )
seed_offset += paddle.distributed.get_world_size()
- local_seed = (seed_offset + mp_rank + pp_rank * (mp_size) + dp_rank * (mp_size * pp_size) +
- sharding_rank * (mp_size * pp_size * dp_size))
+ local_seed = (
+ seed_offset
+ + mp_rank
+ + pp_rank * (mp_size)
+ + dp_rank * (mp_size * pp_size)
+ + sharding_rank * (mp_size * pp_size * dp_size)
+ )
tracker = get_rng_state_tracker()
# tracker.reset()
diff --git a/tests/pytorch/distributed/print_logs.py b/tests/pytorch/distributed/print_logs.py
index 20062e064fbad142d239cd2453dd31f190686a53..6c25db49452ba5c74668a613784119f401499a8b 100644
--- a/tests/pytorch/distributed/print_logs.py
+++ b/tests/pytorch/distributed/print_logs.py
@@ -112,13 +112,17 @@ def perf_and_loss_plots():
lm_loss_data.append(lm_data["loss"])
lm_perf_data.append(lm_data["perf"])
save_plot(
- model_config + " loss", legend,
- lm_loss_data, model_config + "_loss.png",
+ model_config + " loss",
+ legend,
+ lm_loss_data,
+ model_config + "_loss.png",
"LM-Loss",
)
save_plot(
model_config + " perf",
- legend, lm_perf_data, model_config + "_perf.png",
+ legend,
+ lm_perf_data,
+ model_config + "_perf.png",
"Time per step (ms)",
)
diff --git a/tests/pytorch/distributed/test_convergence.py b/tests/pytorch/distributed/test_convergence.py
index 2a4c6a7282480be15ba0fdbb324cb424d259305b..fa3ba1e3f38830ec96b56392170d897f5c56da32 100644
--- a/tests/pytorch/distributed/test_convergence.py
+++ b/tests/pytorch/distributed/test_convergence.py
@@ -68,7 +68,9 @@ def get_filename(
config = f"gpt3_{model}_dp{dp}_tp{tp}_pp{pp}_sp{sp}"
config_dir = os.path.join(mlm_log_dir, config)
os.makedirs(config_dir, exist_ok=True)
- fname = f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt"
+ fname = (
+ f"{'te' if use_te else 'megatron'}" + (f"_fp8_{fp8_recipe}" if fp8_recipe else "") + ".txt"
+ )
return os.path.join(config_dir, fname)
@@ -106,4 +108,5 @@ def test_distributed(dtype, fp8_recipe, dp, tp, pp, sp, use_te, model):
TRANSFORMER_IMPL="transformer_engine" if use_te else "local",
**asdict(model_configs[model]),
),
- check=True)
+ check=True,
+ )
diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py
index c35fc339318c8b1a8ec205170e8d9cf54522e86d..9b9b7686c250019a4b4ac74e1c0c9dd40bbc3a42 100644
--- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py
+++ b/tests/pytorch/fused_attn/run_fused_attn_with_cp.py
@@ -9,9 +9,10 @@ from transformer_engine.pytorch.attention import DotProductAttention
import transformer_engine_torch as tex
from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn
-dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}
+dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16}
-def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend='FlashAttention'):
+
+def run_dpa_with_cp(dtype="bf16", model=None, qkv_format="bshd", kernel_backend="FlashAttention"):
"""Test DotProductAttention module with context parallelism"""
os.environ["NVTE_FLASH_ATTN"] = "0"
@@ -22,11 +23,13 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
config = model_configs_fused_attn[model]
- if qkv_format == 'thd' and (config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"):
+ if qkv_format == "thd" and (
+ config.num_heads != config.num_gqa_groups or config.attn_bias_type == "post_scale_bias"
+ ):
return
- rank = int(os.getenv('RANK', '0'))
- world_size = int(os.getenv('WORLD_SIZE', '1'))
+ rank = int(os.getenv("RANK", "0"))
+ world_size = int(os.getenv("WORLD_SIZE", "1"))
if dist.is_initialized():
world_size = dist.get_world_size()
@@ -38,51 +41,76 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
print(f"[INFO] world_size:{world_size}, rank:{rank}")
- dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)
+ dist.init_process_group(backend="nccl", world_size=world_size, rank=rank)
# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
- assert(rank in cp_comm_ranks)
- cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl')
+ assert rank in cp_comm_ranks
+ cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
- assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"
+ assert config.attn_mask_type in [
+ "causal",
+ "no_mask",
+ ], f"{config.attn_mask_type} is an unsupported attention mask type!"
- if kernel_backend == 'FusedAttention' and qkv_format == 'thd':
- if 'causal' in config.attn_mask_type:
- config.attn_mask_type = 'padding_causal'
+ if kernel_backend == "FusedAttention" and qkv_format == "thd":
+ if "causal" in config.attn_mask_type:
+ config.attn_mask_type = "padding_causal"
else:
- config.attn_mask_type = 'padding'
+ config.attn_mask_type = "padding"
# instantiate core attn module
- core_attn = DotProductAttention(config.num_heads,
- config.head_dim,
- num_gqa_groups=config.num_gqa_groups,
- attention_dropout=config.dropout_p,
- qkv_format=qkv_format,
- attn_mask_type=config.attn_mask_type)
+ core_attn = DotProductAttention(
+ config.num_heads,
+ config.head_dim,
+ num_gqa_groups=config.num_gqa_groups,
+ attention_dropout=config.dropout_p,
+ qkv_format=qkv_format,
+ attn_mask_type=config.attn_mask_type,
+ )
core_attn = core_attn.cuda()
# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
- kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
- attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
+ kv_input_shape = (
+ config.batch_size,
+ config.max_seqlen_kv,
+ config.num_gqa_groups,
+ config.head_dim,
+ )
+ attn_output_shape = (
+ config.batch_size,
+ config.max_seqlen_q,
+ config.num_heads * config.head_dim,
+ )
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
- kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
- attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
+ kv_input_shape = (
+ config.max_seqlen_kv,
+ config.batch_size,
+ config.num_gqa_groups,
+ config.head_dim,
+ )
+ attn_output_shape = (
+ config.max_seqlen_q,
+ config.batch_size,
+ config.num_heads * config.head_dim,
+ )
cu_seqlens_q = None
cu_seqlens_kv = None
elif qkv_format == "thd":
- seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(torch.int32)
+ seqlens_q = torch.randint(world_size * 2, config.max_seqlen_q + 1, [config.batch_size]).to(
+ torch.int32
+ )
seqlens_q = seqlens_q - seqlens_q % (world_size * 2)
cu_seqlens_q = torch.cat([torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0)])
cu_seqlens_kv = cu_seqlens_q
q_input_shape = (cu_seqlens_q[-1], config.num_heads, config.head_dim)
kv_input_shape = (cu_seqlens_kv[-1], config.num_gqa_groups, config.head_dim)
- attn_output_shape = (cu_seqlens_q[-1], config.num_heads*config.head_dim)
+ attn_output_shape = (cu_seqlens_q[-1], config.num_heads * config.head_dim)
cu_seqlens_q = cu_seqlens_q.to(torch.int32).cuda()
cu_seqlens_kv = cu_seqlens_kv.to(torch.int32).cuda()
else:
@@ -111,7 +139,9 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(
- q, k, v,
+ q,
+ k,
+ v,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
@@ -120,17 +150,28 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out.backward(dout)
# run core_attn wit CP
- q_, k_, v_, dout_, *rest = [x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])]
+ q_, k_, v_, dout_, *rest = [
+ x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias])
+ ]
bias_ = rest[0] if len(rest) else None
if qkv_format == "bshd" or qkv_format == "sbhd":
- seq_dim = qkv_format.index('s')
- q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
- for x in [q_, k_, v_, dout_]]
- seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
+ seq_dim = qkv_format.index("s")
+ q_, k_, v_, dout_ = [
+ x.view(
+ *x.shape[:seq_dim],
+ 2 * world_size,
+ x.shape[seq_dim] // (2 * world_size),
+ *x.shape[(seq_dim + 1) :],
+ )
+ for x in [q_, k_, v_, dout_]
+ ]
+ seq_idx = torch.tensor([rank, 2 * world_size - rank - 1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
- q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
+ q_, k_, v_, dout_ = [
+ x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) for x in [q_, k_, v_, dout_]
+ ]
elif qkv_format == "thd":
- seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
+ seq_idx_q = tex.thd_get_partitioned_indices(cu_seqlens_q, q_.size(0), world_size, rank)
seq_idx_kv = tex.thd_get_partitioned_indices(cu_seqlens_kv, k_.size(0), world_size, rank)
q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]]
k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]]
@@ -140,14 +181,18 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
assert False, f"{qkv_format} is an unsupported qkv_format!"
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
if bias_ is not None:
- bias_ = bias_.view(*bias_.shape[:-2], 2*world_size, bias_.shape[-2]//(2*world_size), bias_.shape[-1])
+ bias_ = bias_.view(
+ *bias_.shape[:-2], 2 * world_size, bias_.shape[-2] // (2 * world_size), bias_.shape[-1]
+ )
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
- max_seqlen_q = config.max_seqlen_q
+ max_seqlen_q = config.max_seqlen_q
max_seqlen_kv = config.max_seqlen_kv
out_ = core_attn(
- q_, k_, v_,
+ q_,
+ k_,
+ v_,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
@@ -158,23 +203,32 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
out_.backward(dout_)
for x in [out_, q_.grad, k_.grad, v_.grad]:
- assert(torch.all(~torch.isnan(x)))
- assert(torch.all(~torch.isinf(x)))
+ assert torch.all(~torch.isnan(x))
+ assert torch.all(~torch.isinf(x))
# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
- if dtype == 'bf16':
+ if dtype == "bf16":
if config.num_heads == config.num_gqa_groups:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=3.5e-2, rtol=3.5e-2)
if qkv_format == "bshd" or qkv_format == "sbhd":
- dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
- for x in [q.grad, k.grad, v.grad, out]]
+ dq, dk, dv, out = [
+ x.view(
+ *x.shape[:seq_dim],
+ 2 * world_size,
+ x.shape[seq_dim] // (2 * world_size),
+ *x.shape[(seq_dim + 1) :],
+ )
+ for x in [q.grad, k.grad, v.grad, out]
+ ]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
- dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
- for x in [q_.grad, k_.grad, v_.grad, out_]]
+ dq_, dk_, dv_, out_ = [
+ x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :])
+ for x in [q_.grad, k_.grad, v_.grad, out_]
+ ]
elif qkv_format == "thd":
dq, out = [x.index_select(0, seq_idx_q).contiguous().view(-1) for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous().view(-1) for x in [k.grad, v.grad]]
@@ -208,9 +262,11 @@ def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend=
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
+
def main(**kwargs):
run_dpa_with_cp(**kwargs)
+
if __name__ == "__main__":
- kwargs = dict(arg.split('=') for arg in sys.argv[2:])
+ kwargs = dict(arg.split("=") for arg in sys.argv[2:])
main(**kwargs)
diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/fused_attn/test_fused_attn.py
index ccd82c1d067e404eee79b83b4654275597857284..cca515b63d0b7ba70f737a320b8b65488d9169dc 100644
--- a/tests/pytorch/fused_attn/test_fused_attn.py
+++ b/tests/pytorch/fused_attn/test_fused_attn.py
@@ -91,10 +91,10 @@ class ModelConfig:
self.max_seqlen_q = max_seqlen_q
self.max_seqlen_kv = max_seqlen_kv
self.dropout_p = dropout_p
- self.attn_mask_type = attn_mask_type
- self.attn_bias_type = attn_bias_type
- self.alibi_type = alibi_type
- self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
+ self.attn_mask_type = attn_mask_type
+ self.attn_bias_type = attn_bias_type
+ self.alibi_type = alibi_type
+ self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross"
self.num_layers = num_layers
self.bias_shape = bias_shape
@@ -184,28 +184,29 @@ def _is_flash_attention_supported(config: ModelConfig) -> bool:
return False
return True
+
def _is_unfused_attention_supported(
config: ModelConfig,
qkv_format: str,
- ) -> bool:
+) -> bool:
"""Check if UnfusedDotProductAttention supports a model configuration"""
- if ("padding" in config.attn_mask_type):
+ if "padding" in config.attn_mask_type:
return False
- if ("causal" in config.attn_mask_type and config.attn_type == 'cross'):
+ if "causal" in config.attn_mask_type and config.attn_type == "cross":
return False
- if qkv_format == 'thd':
+ if qkv_format == "thd":
return False
return True
model_configs_base = {
# test: b, h, hg, d, sq, skv, p, mask, bias # attn , backend
- "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
- "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
- "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
- "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
- "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
- "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
+ "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), # self , 0
+ "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), # cross, 0
+ "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), # self , 1
+ "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), # cross, 1
+ "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
+ "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), # inference
}
@@ -221,13 +222,13 @@ def get_swa(seq_q, seq_kv, w=None):
if w is None:
w = torch.randint(0, seq_kv, [2], dtype=torch.int32, device="cuda")
m = torch.ones(seq_q, seq_kv, dtype=torch.bool, device="cuda")
- mu = torch.triu(m, diagonal=seq_kv-seq_q-w[0])
- ml = torch.tril(mu, diagonal=seq_kv-seq_q+w[1])
- ml = ~ ml
+ mu = torch.triu(m, diagonal=seq_kv - seq_q - w[0])
+ ml = torch.tril(mu, diagonal=seq_kv - seq_q + w[1])
+ ml = ~ml
return w, ml
-@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", model_configs_base.keys())
@@ -236,8 +237,9 @@ def get_swa(seq_q, seq_kv, w=None):
@pytest.mark.parametrize("qkv_layout", [None])
@pytest.mark.parametrize("swa", [False])
@pytest.mark.parametrize("pad_between_seqs", [False])
-def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
- workspace_opt, qkv_layout, swa, pad_between_seqs):
+def test_dot_product_attention(
+ dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
+):
"""Test DotProductAttention module"""
# Get configs
@@ -251,36 +253,43 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
else:
qkv_layout = "sbhd_sb2hd"
if "3" in qkv_layout and config.attn_type == "cross":
- pytest.skip(
- "No need to test this layout for cross attention"
- )
+ pytest.skip("No need to test this layout for cross attention")
# Skip if only unfused backend is supported
- qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
+ qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
unfused_attn_supported = _is_unfused_attention_supported(config, qkv_format)
if config.max_seqlen_q <= 512 and config.max_seqlen_kv <= 512:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_supported, fused_attn_backend = _is_fused_attention_supported(
- config, dtype, qkv_layout=qkv_layout,
+ config,
+ dtype,
+ qkv_layout=qkv_layout,
)
if swa:
fused_attn_supported = False
flash_attn_supported = _is_flash_attention_supported(config)
if (len(fused_attn_backend) + flash_attn_supported + unfused_attn_supported) < 2:
pytest.skip("Less than two backends to compare.")
- if (qkv_format == 'thd' and 'padding' not in config.attn_mask_type):
+ if qkv_format == "thd" and "padding" not in config.attn_mask_type:
pytest.skip("THD layout requires padding/padding_causal mask type.")
# d=256 is supported by cuDNN 9.0+ for inference but not training
- is_training = (config.head_dim <= 128)
+ is_training = config.head_dim <= 128
# UnfusedDotProductAttention backend
if unfused_attn_supported:
if swa:
attn_mask_type = config.attn_mask_type
config.attn_mask_type = "arbitrary"
unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention(
- dtype, config, "UnfusedDotProductAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "UnfusedDotProductAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
if swa:
config.attn_mask_type = attn_mask_type
@@ -289,51 +298,79 @@ def test_dot_product_attention(dtype, model_configs, model, ckpt_attn,
if fused_attn_supported:
if len(fused_attn_backend) == 1:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
- dtype, config, "FusedAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FusedAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
if len(fused_attn_backend) == 2:
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0"
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
- dtype, config, "FusedAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FusedAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1"
fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention(
- dtype, config, "FusedAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FusedAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
# FlashAttention backend
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
- dtype, config, "FlashAttention",
- ckpt_attn, qkv_layout, workspace_opt, swa, pad_between_seqs, is_training,
+ dtype,
+ config,
+ "FlashAttention",
+ ckpt_attn,
+ qkv_layout,
+ workspace_opt,
+ swa,
+ pad_between_seqs,
+ is_training,
)
if unfused_attn_supported and fused_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs fused attn")
torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols)
- for i,_ in enumerate(unfused_attn_bwd):
+ for i, _ in enumerate(unfused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols)
if unfused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: unfused attn vs flash attn")
torch.testing.assert_close(flash_attn_fwd, unfused_attn_fwd, **tols)
- for i,_ in enumerate(flash_attn_bwd):
+ for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(unfused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and flash_attn_supported:
logging.info("[test_dot_product_attention]: fused attn vs flash attn")
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
- for i,_ in enumerate(flash_attn_bwd):
+ for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
if fused_attn_supported and len(fused_attn_backend) == 2:
logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1")
torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols)
- for i,_ in enumerate(fused_attn_bwd):
+ for i, _ in enumerate(fused_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols)
-@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_base])
@pytest.mark.parametrize("model", ["base_1_1", "base_2_1"])
@@ -344,22 +381,22 @@ def test_dpa_checkpoint(dtype, model_configs, model):
model_configs_mask = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
- "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
- "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
- "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
- "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
- "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
- "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
- "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
- "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
- "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
+ "mask_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "causal", "no_bias"),
+ "mask_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "causal", "no_bias"),
+ "mask_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
+ "mask_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"),
+ "mask_3_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
+ "mask_3_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
+ "mask_4_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding", "no_bias"),
+ "mask_4_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
+ "mask_5_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
+ "mask_5_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "padding_causal", "no_bias"),
"mask_6_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
"mask_6_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
-@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_mask])
@pytest.mark.parametrize("model", model_configs_mask.keys())
@@ -370,34 +407,48 @@ def test_dpa_mask(dtype, model_configs, model):
model_configs_bias = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
- "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
- "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
- "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
- "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
- "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
- "bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
- "bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
- "bias_2_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"), # skipped
- "bias_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"), # skipped
- "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
- "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
- "bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
- "bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
- "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
- "bias_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"), # skipped
- "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
- "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
- "bias_4_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"), # skipped
- "bias_4_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), # skipped
- "bias_4_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"), # skipped
- "bias_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), # skipped
- "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
- "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
+ "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
+ "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"),
+ "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"),
+ "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"),
+ "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped
+ "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped
+ "bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped
+ "bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped
+ "bias_2_2": ModelConfig(
+ 4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias"
+ ), # skipped
+ "bias_2_3": ModelConfig(
+ 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias"
+ ), # skipped
+ "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped
+ "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped
+ "bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
+ "bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"),
+ "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
+ "bias_3_3": ModelConfig(
+ 2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias"
+ ), # skipped
+ "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"),
+ "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped
+ "bias_4_0": ModelConfig(
+ 4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias"
+ ), # skipped
+ "bias_4_1": ModelConfig(
+ 2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"
+ ), # skipped
+ "bias_4_2": ModelConfig(
+ 4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias"
+ ), # skipped
+ "bias_4_3": ModelConfig(
+ 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"
+ ), # skipped
+ "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped
+ "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped
}
-@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias])
@pytest.mark.parametrize("model", model_configs_bias.keys())
@@ -408,23 +459,38 @@ def test_dpa_bias(dtype, model_configs, model):
model_configs_bias_shapes = {
# test: b, h, hg, d, sq, skv, p,
- "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0,
- # mask, bias, bias_shape,
- "no_mask", "post_scale_bias", bias_shape='11ss'),
- "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0,
- "no_mask", "post_scale_bias", bias_shape='1hss'),
- "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
- "no_mask", "post_scale_bias", bias_shape='b1ss'),
- "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
- "no_mask", "post_scale_bias", bias_shape='bhss'),
- "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0,
- "causal", "alibi", bias_shape='1hss', alibi_type='custom'),
- "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0,
- "causal", "alibi", bias_shape='bhss', alibi_type='custom'),
+ "bias_1_0": ModelConfig(
+ 4,
+ 16,
+ 16,
+ 64,
+ 128,
+ 128,
+ 0.0,
+ # mask, bias, bias_shape,
+ "no_mask",
+ "post_scale_bias",
+ bias_shape="11ss",
+ ),
+ "bias_1_1": ModelConfig(
+ 2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss"
+ ),
+ "bias_1_2": ModelConfig(
+ 4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss"
+ ),
+ "bias_1_3": ModelConfig(
+ 2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss"
+ ),
+ "bias_1_4": ModelConfig(
+ 4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom"
+ ),
+ "bias_1_5": ModelConfig(
+ 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom"
+ ),
}
-@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_bias_shapes])
@pytest.mark.parametrize("model", model_configs_bias_shapes.keys())
@@ -435,10 +501,10 @@ def test_dpa_bias_shapes(dtype, model_configs, model):
model_configs_swa = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
- "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
- "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
- "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
+ "swa_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
+ "swa_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"),
+ "swa_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
+ "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"),
}
@@ -453,10 +519,14 @@ def test_dpa_sliding_window(dtype, model_configs, model):
model_configs_alibi_slopes = {
# test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type
- "alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
- "alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
- "alibi_2_0": ModelConfig(2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type= "custom"),
- "alibi_2_1": ModelConfig(1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type= "custom"),
+ "alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"),
+ "alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"),
+ "alibi_2_0": ModelConfig(
+ 2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom"
+ ),
+ "alibi_2_1": ModelConfig(
+ 1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom"
+ ),
}
@@ -470,27 +540,35 @@ def test_dpa_alibi_slopes(dtype, model_configs, model):
qkv_layouts = [
- 'sb3hd', 'sbh3d', 'sbhd_sb2hd', 'sbhd_sbh2d', 'sbhd_sbhd_sbhd',
- 'bs3hd', 'bsh3d', 'bshd_bs2hd', 'bshd_bsh2d', 'bshd_bshd_bshd',
- ]
+ "sb3hd",
+ "sbh3d",
+ "sbhd_sb2hd",
+ "sbhd_sbh2d",
+ "sbhd_sbhd_sbhd",
+ "bs3hd",
+ "bsh3d",
+ "bshd_bs2hd",
+ "bshd_bsh2d",
+ "bshd_bshd_bshd",
+]
model_configs_layout = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
- "layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
- "layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
- "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
- "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
- "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
- "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
+ "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"),
+ "layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
+ "layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
+ "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"),
+ "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
+ "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"),
+ "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"),
"layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"),
- "layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
- "layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
+ "layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"),
+ "layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"),
}
-@pytest.mark.skipif(get_cudnn_version() < (8,9,5), reason="cuDNN 8.9.5+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 5), reason="cuDNN 8.9.5+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout])
@pytest.mark.parametrize("model", model_configs_layout.keys())
@@ -500,26 +578,28 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout):
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
-qkv_layouts_thd = ['t3hd', 'th3d', 'thd_t2hd', 'thd_th2d', 'thd_thd_thd']
+qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"]
model_configs_layout_thd = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
- "layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
- "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
- "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
- "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
- "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
- "layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
- "layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
- "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
- "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
- "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
- "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
+ "layout_0_1": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
+ "layout_0_2": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding", "no_bias"),
+ "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
+ "layout_0_4": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "padding_causal", "no_bias"),
+ "layout_1_1": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
+ "layout_1_2": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
+ "layout_1_3": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
+ "layout_1_4": ModelConfig(8, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"),
+ "layout_2_1": ModelConfig(1, 16, 16, 128, 128, 128, 0.0, "padding", "no_bias"),
+ "layout_2_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"),
+ "layout_2_3": ModelConfig(1, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"),
+ "layout_2_4": ModelConfig(8, 16, 16, 64, 2048, 4096, 0.0, "padding_causal", "no_bias"),
}
-@pytest.mark.skipif(get_cudnn_version() < (9,0,0), reason="cuDNN 9.0.0+ is required.")
-@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+.")
+@pytest.mark.skipif(get_cudnn_version() < (9, 0, 0), reason="cuDNN 9.0.0+ is required.")
+@pytest.mark.skipif(
+ get_device_compute_capability() < (9, 0), reason="THD is only supported on Hopper+."
+)
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_layout_thd])
@pytest.mark.parametrize("model", model_configs_layout_thd.keys())
@@ -527,24 +607,26 @@ model_configs_layout_thd = {
def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout):
"""Test DotProductAttention module with different QKV layouts"""
pad_between_seqs = False
- test_dot_product_attention(dtype, model_configs, model, False, True,
- qkv_layout, False, pad_between_seqs)
+ test_dot_product_attention(
+ dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
+ )
pad_between_seqs = True
- test_dot_product_attention(dtype, model_configs, model, False, True,
- qkv_layout, False, pad_between_seqs)
+ test_dot_product_attention(
+ dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs
+ )
def _run_dot_product_attention(
- dtype: torch.dtype,
- config: ModelConfig,
- backend: str,
- ckpt_attn: bool,
- qkv_layout: str,
- workspace_opt: bool,
- swa: bool,
- pad_between_seqs: bool,
- is_training: bool,
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
+ dtype: torch.dtype,
+ config: ModelConfig,
+ backend: str,
+ ckpt_attn: bool,
+ qkv_layout: str,
+ workspace_opt: bool,
+ swa: bool,
+ pad_between_seqs: bool,
+ is_training: bool,
+) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run DotProductAttention module with one forward pass and one backward pass"""
# Set RNG and environment varables
@@ -558,22 +640,27 @@ def _run_dot_product_attention(
os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] = "1" if workspace_opt else "0"
# Create seqlens
- qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
- if "padding" in config.attn_mask_type or qkv_format == 'thd':
- if config.attn_type == 'self':
- seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
- dtype=torch.int32, device="cuda")
+ qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
+ if "padding" in config.attn_mask_type or qkv_format == "thd":
+ if config.attn_type == "self":
+ seqlens_q = torch.randint(
+ 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
+ )
seqlens_kv = seqlens_q
- if config.attn_type == 'cross':
- seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
- dtype=torch.int32, device="cuda")
- seqlens_kv = torch.randint(1, config.max_seqlen_kv, [config.batch_size],
- dtype=torch.int32, device="cuda")
+ if config.attn_type == "cross":
+ seqlens_q = torch.randint(
+ 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
+ )
+ seqlens_kv = torch.randint(
+ 1, config.max_seqlen_kv, [config.batch_size], dtype=torch.int32, device="cuda"
+ )
else:
- seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
- dtype=torch.int32, device="cuda")
- seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
- dtype=torch.int32, device="cuda")
+ seqlens_q = torch.full(
+ [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
+ )
+ seqlens_kv = torch.full(
+ [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
+ )
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
@@ -586,7 +673,7 @@ def _run_dot_product_attention(
pad_len = [0] * config.batch_size
if pad_between_seqs:
max_pad_len = 3
- pad_len = torch.randint(0, max_pad_len+1, [config.batch_size], device="cuda") #3
+ pad_len = torch.randint(0, max_pad_len + 1, [config.batch_size], device="cuda") # 3
seqlens_q_after_pad = seqlens_q + pad_len
seqlens_kv_after_pad = seqlens_kv + pad_len
cu_seqlens_q_after_pad[1:] = torch.cumsum(seqlens_q_after_pad, dim=0)
@@ -595,25 +682,58 @@ def _run_dot_product_attention(
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
- if config.attn_type == 'self':
+ if config.attn_type == "self":
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
- attention_mask_q = torch.cat([attention_mask_q,
- torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
- .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
+ attention_mask_q = torch.cat(
+ [
+ attention_mask_q,
+ torch.Tensor(
+ [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
+ )
+ .to(dtype=torch.bool)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .unsqueeze(0),
+ ],
+ dim=0,
+ )
attention_mask = attention_mask_q.to(device="cuda")
- if config.attn_type == 'cross':
+ if config.attn_type == "cross":
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
attention_mask_kv = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
- attention_mask_q = torch.cat([attention_mask_q,
- torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
- .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
- attention_mask_kv = torch.cat([attention_mask_kv, torch.Tensor(
- [False]*seqlens_kv[i] + [True]*(config.max_seqlen_kv-seqlens_kv[i]))
- .to(dtype=torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
+ attention_mask_q = torch.cat(
+ [
+ attention_mask_q,
+ torch.Tensor(
+ [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
+ )
+ .to(dtype=torch.bool)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .unsqueeze(0),
+ ],
+ dim=0,
+ )
+ attention_mask_kv = torch.cat(
+ [
+ attention_mask_kv,
+ torch.Tensor(
+ [False] * seqlens_kv[i]
+ + [True] * (config.max_seqlen_kv - seqlens_kv[i])
+ )
+ .to(dtype=torch.bool)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .unsqueeze(0),
+ ],
+ dim=0,
+ )
attention_mask = (
- attention_mask_q.to(device="cuda"), attention_mask_kv.to(device="cuda"))
+ attention_mask_q.to(device="cuda"),
+ attention_mask_kv.to(device="cuda"),
+ )
window_size = None
if swa:
window_size, attention_mask = get_swa(config.max_seqlen_q, config.max_seqlen_kv)
@@ -623,62 +743,84 @@ def _run_dot_product_attention(
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
if config.bias_shape == "1hss":
- alibi_slopes = torch.randn(
- config.num_heads).abs().to(dtype=torch.float32, device="cuda")
+ alibi_slopes = (
+ torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
+ )
if config.bias_shape == "bhss":
- alibi_slopes = torch.randn(
- config.batch_size, config.num_heads).abs().to(dtype=torch.float32, device="cuda")
+ alibi_slopes = (
+ torch.randn(config.batch_size, config.num_heads)
+ .abs()
+ .to(dtype=torch.float32, device="cuda")
+ )
# Create input tensors
dim_to_num = {
- 'b' : config.batch_size,
- 'sq' : config.max_seqlen_q,
- 'skv': config.max_seqlen_kv,
- 'h' : config.num_heads,
- 'hg' : config.num_gqa_groups,
- 'd' : config.head_dim,
- 't' : cu_seqlens_q_after_pad[-1],
- 'tg' : cu_seqlens_kv_after_pad[-1],
- '3' : 3,
- '2' : 2,
- '1' : 1,
- }
+ "b": config.batch_size,
+ "sq": config.max_seqlen_q,
+ "skv": config.max_seqlen_kv,
+ "h": config.num_heads,
+ "hg": config.num_gqa_groups,
+ "d": config.head_dim,
+ "t": cu_seqlens_q_after_pad[-1],
+ "tg": cu_seqlens_kv_after_pad[-1],
+ "3": 3,
+ "2": 2,
+ "1": 1,
+ }
inp = []
inp_orig = []
- for i,layout in enumerate(qkv_layout.split('_')):
- layout = '_'.join(layout)
+ for i, layout in enumerate(qkv_layout.split("_")):
+ layout = "_".join(layout)
if i == 0:
- layout = layout.replace('s', 'sq')
+ layout = layout.replace("s", "sq")
else:
- layout = layout.replace('s', 'skv')
- layout = layout.replace('h', 'hg')
- layout = layout.replace('t', 'tg')
- tensor_shape = [dim_to_num[j] for j in layout.split('_')]
+ layout = layout.replace("s", "skv")
+ layout = layout.replace("h", "hg")
+ layout = layout.replace("t", "tg")
+ tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_orig = tensor
- if qkv_format == 'thd' and pad_between_seqs:
- tensor_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
- if layout in ['t_h_d', 't_3_h_d', 't_h_3_d']:
- for i in range(1, config.batch_size+1):
- valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
- pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i])
- tensor[pad_range[0]:pad_range[1]] = 0.0
- tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0)
- if layout in ['tg_hg_d', 'tg_2_hg_d', 'tg_hg_2_d']:
- for i in range(1, config.batch_size+1):
- valid_range = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1])
- pad_range = (cu_seqlens_kv_after_pad[i] - pad_len[i-1], cu_seqlens_kv_after_pad[i])
- tensor[pad_range[0]:pad_range[1]] = 0.0
- tensor_orig = torch.cat([tensor_orig, tensor[valid_range[0]:valid_range[1]]], dim=0)
+ if qkv_format == "thd" and pad_between_seqs:
+ tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
+ if layout in ["t_h_d", "t_3_h_d", "t_h_3_d"]:
+ for i in range(1, config.batch_size + 1):
+ valid_range = (
+ cu_seqlens_q_after_pad[i - 1],
+ cu_seqlens_q_after_pad[i] - pad_len[i - 1],
+ )
+ pad_range = (
+ cu_seqlens_q_after_pad[i] - pad_len[i - 1],
+ cu_seqlens_q_after_pad[i],
+ )
+ tensor[pad_range[0] : pad_range[1]] = 0.0
+ tensor_orig = torch.cat(
+ [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
+ )
+ if layout in ["tg_hg_d", "tg_2_hg_d", "tg_hg_2_d"]:
+ for i in range(1, config.batch_size + 1):
+ valid_range = (
+ cu_seqlens_kv_after_pad[i - 1],
+ cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
+ )
+ pad_range = (
+ cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
+ cu_seqlens_kv_after_pad[i],
+ )
+ tensor[pad_range[0] : pad_range[1]] = 0.0
+ tensor_orig = torch.cat(
+ [tensor_orig, tensor[valid_range[0] : valid_range[1]]], dim=0
+ )
tensor_count = 1
split_dim = 0
- for dim, l in enumerate(layout.split('_')):
+ for dim, l in enumerate(layout.split("_")):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
break
tensors = torch.split(tensor, 1, dim=split_dim) if split_dim != 0 else [tensor]
- tensors_orig = torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
+ tensors_orig = (
+ torch.split(tensor_orig, 1, dim=split_dim) if split_dim != 0 else [tensor_orig]
+ )
for j in range(tensor_count):
if split_dim != 0:
inp.append(tensors[j].squeeze(split_dim))
@@ -692,73 +834,77 @@ def _run_dot_product_attention(
# Create ragged offsets for q/k/v
seq_offsets_q, seq_offsets_k, seq_offsets_v, seq_offsets_o = None, None, None, None
- qkv_group = ''.join([x for x in qkv_layout if x not in 'bst'])
- if qkv_format == 'thd':
+ qkv_group = "".join([x for x in qkv_layout if x not in "bst"])
+ if qkv_format == "thd":
seq_offsets_o = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
- if qkv_group == 'hd_hd_hd':
+ if qkv_group == "hd_hd_hd":
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * cu_seqlens_kv_after_pad
- if qkv_group in ['3hd', 'h3d']:
+ if qkv_group in ["3hd", "h3d"]:
seq_offsets_q = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_k = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
seq_offsets_v = config.num_heads * config.head_dim * 3 * cu_seqlens_q_after_pad
- if qkv_group in ['hd_2hd', 'hd_h2d']:
+ if qkv_group in ["hd_2hd", "hd_h2d"]:
seq_offsets_q = config.num_heads * config.head_dim * cu_seqlens_q_after_pad
seq_offsets_k = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
seq_offsets_v = config.num_gqa_groups * config.head_dim * 2 * cu_seqlens_kv_after_pad
# Create output gradient
- qkv_format_kv = '_'.join(qkv_format)
- qkv_format_kv = qkv_format_kv.replace('s', 'sq')
- out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
+ qkv_format_kv = "_".join(qkv_format)
+ qkv_format_kv = qkv_format_kv.replace("s", "sq")
+ out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = 0.001 * torch.randint(0, 200, out_grad_shape_new, dtype=dtype, device="cuda")
out_grad_orig = out_grad
- if qkv_format == 'thd' and pad_between_seqs:
- out_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
- if qkv_format_kv == 't_h_d':
- for i in range(1, config.batch_size+1):
- valid_range = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
- pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i-1], cu_seqlens_q_after_pad[i])
- out_grad[pad_range[0]:pad_range[1]] = 0.0
- out_grad_orig = torch.cat([out_grad_orig, out_grad[valid_range[0]:valid_range[1]]], dim=0)
+ if qkv_format == "thd" and pad_between_seqs:
+ out_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
+ if qkv_format_kv == "t_h_d":
+ for i in range(1, config.batch_size + 1):
+ valid_range = (
+ cu_seqlens_q_after_pad[i - 1],
+ cu_seqlens_q_after_pad[i] - pad_len[i - 1],
+ )
+ pad_range = (cu_seqlens_q_after_pad[i] - pad_len[i - 1], cu_seqlens_q_after_pad[i])
+ out_grad[pad_range[0] : pad_range[1]] = 0.0
+ out_grad_orig = torch.cat(
+ [out_grad_orig, out_grad[valid_range[0] : valid_range[1]]], dim=0
+ )
# Create bias
- if config.attn_bias_type in ['no_bias', 'alibi']:
+ if config.attn_bias_type in ["no_bias", "alibi"]:
bias = None
- if config.attn_bias_type == 'post_scale_bias':
- shape = '_'.join(config.bias_shape)
- shape = shape.replace('_s_s', '_sq_skv')
- tensor_shape = [dim_to_num[j] for j in shape.split('_')]
+ if config.attn_bias_type == "post_scale_bias":
+ shape = "_".join(config.bias_shape)
+ shape = shape.replace("_s_s", "_sq_skv")
+ tensor_shape = [dim_to_num[j] for j in shape.split("_")]
bias = torch.randn(tensor_shape, dtype=dtype, device="cuda")
- if config.bias_shape != '1hss':
+ if config.bias_shape != "1hss":
bias.requires_grad = False
# Create RNG
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
+
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
# Set up model
- block = (
- DotProductAttention(
- config.num_heads,
- config.head_dim,
- num_gqa_groups=config.num_gqa_groups,
- attention_dropout=config.dropout_p,
- qkv_format=qkv_format,
- attn_mask_type=config.attn_mask_type,
- sequence_parallel=False,
- tp_size=1,
- get_rng_state_tracker=get_dummy_cuda_rng_tracker,
- tp_group=None,
- layer_number=1,
- attention_type=config.attn_type,
- ).to(dtype=dtype, device="cuda")
- )
+ block = DotProductAttention(
+ config.num_heads,
+ config.head_dim,
+ num_gqa_groups=config.num_gqa_groups,
+ attention_dropout=config.dropout_p,
+ qkv_format=qkv_format,
+ attn_mask_type=config.attn_mask_type,
+ sequence_parallel=False,
+ tp_size=1,
+ get_rng_state_tracker=get_dummy_cuda_rng_tracker,
+ tp_group=None,
+ layer_number=1,
+ attention_type=config.attn_type,
+ ).to(dtype=dtype, device="cuda")
# Run a forward and backward pass
if backend in ["FlashAttention", "UnfusedDotProductAttention"]:
@@ -771,24 +917,28 @@ def _run_dot_product_attention(
k = inp[1]
v = inp[2]
d_out = out_grad
- out = block(q, k, v,
- window_size=window_size,
- attention_mask=attention_mask,
- qkv_format=qkv_format,
- max_seqlen_q=config.max_seqlen_q,
- max_seqlen_kv=config.max_seqlen_kv,
- cu_seqlens_q=cu_seqlens_q,
- cu_seqlens_kv=cu_seqlens_kv,
- seq_offsets_q=seq_offsets_q,
- seq_offsets_k=seq_offsets_k,
- seq_offsets_v=seq_offsets_v,
- seq_offsets_o=seq_offsets_o,
- attn_mask_type=config.attn_mask_type,
- checkpoint_core_attention=ckpt_attn,
- core_attention_bias_type=config.attn_bias_type,
- core_attention_bias=bias,
- alibi_slopes=alibi_slopes,
- fast_zero_fill=True)
+ out = block(
+ q,
+ k,
+ v,
+ window_size=window_size,
+ attention_mask=attention_mask,
+ qkv_format=qkv_format,
+ max_seqlen_q=config.max_seqlen_q,
+ max_seqlen_kv=config.max_seqlen_kv,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_kv=cu_seqlens_kv,
+ seq_offsets_q=seq_offsets_q,
+ seq_offsets_k=seq_offsets_k,
+ seq_offsets_v=seq_offsets_v,
+ seq_offsets_o=seq_offsets_o,
+ attn_mask_type=config.attn_mask_type,
+ checkpoint_core_attention=ckpt_attn,
+ core_attention_bias_type=config.attn_bias_type,
+ core_attention_bias=bias,
+ alibi_slopes=alibi_slopes,
+ fast_zero_fill=True,
+ )
if is_training:
out.backward(d_out)
@@ -798,18 +948,30 @@ def _run_dot_product_attention(
else:
return out, (None, None, None)
if backend == "FusedAttention":
- if qkv_format == 'thd' and pad_between_seqs:
- out_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
- q_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
- k_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
- v_grad_orig = torch.Tensor([]).to(device="cuda",dtype=dtype)
- for i in range(1, config.batch_size+1):
- valid_range_q = (cu_seqlens_q_after_pad[i-1], cu_seqlens_q_after_pad[i] - pad_len[i-1])
- valid_range_kv = (cu_seqlens_kv_after_pad[i-1], cu_seqlens_kv_after_pad[i] - pad_len[i-1])
- out_orig = torch.cat([out_orig, out[valid_range_q[0]:valid_range_q[1]]], dim=0)
- q_grad_orig = torch.cat([q_grad_orig, q.grad[valid_range_q[0]:valid_range_q[1]]], dim=0)
- k_grad_orig = torch.cat([k_grad_orig, k.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0)
- v_grad_orig = torch.cat([v_grad_orig, v.grad[valid_range_kv[0]:valid_range_kv[1]]], dim=0)
+ if qkv_format == "thd" and pad_between_seqs:
+ out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
+ q_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
+ k_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
+ v_grad_orig = torch.Tensor([]).to(device="cuda", dtype=dtype)
+ for i in range(1, config.batch_size + 1):
+ valid_range_q = (
+ cu_seqlens_q_after_pad[i - 1],
+ cu_seqlens_q_after_pad[i] - pad_len[i - 1],
+ )
+ valid_range_kv = (
+ cu_seqlens_kv_after_pad[i - 1],
+ cu_seqlens_kv_after_pad[i] - pad_len[i - 1],
+ )
+ out_orig = torch.cat([out_orig, out[valid_range_q[0] : valid_range_q[1]]], dim=0)
+ q_grad_orig = torch.cat(
+ [q_grad_orig, q.grad[valid_range_q[0] : valid_range_q[1]]], dim=0
+ )
+ k_grad_orig = torch.cat(
+ [k_grad_orig, k.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
+ )
+ v_grad_orig = torch.cat(
+ [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0
+ )
if is_training:
return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig)
else:
@@ -823,18 +985,18 @@ def _run_dot_product_attention(
model_configs_te_layer = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
- "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
- "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
- "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
- "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
- "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
- "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
- "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
+ "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"),
+ "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"),
+ "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"),
+ "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"),
+ "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"),
+ "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"),
+ "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"),
+ "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"),
}
-@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", model_configs_te_layer.keys())
@@ -842,7 +1004,9 @@ model_configs_te_layer = {
@pytest.mark.parametrize("qkv_format", ["sbhd"])
@pytest.mark.parametrize("fused_qkv_params", [False])
@pytest.mark.parametrize("RoPE", [False])
-def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE):
+def test_transformer_layer(
+ dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
+):
"""Test TransformerLayer module"""
# Get configs
@@ -916,7 +1080,7 @@ def test_transformer_layer(dtype, model_configs, model, ckpt_attn, qkv_format, f
torch.testing.assert_close(fused_attn_bwd, flash_attn_bwd, **tols)
-@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_1_2", "te_2_0"])
@@ -926,22 +1090,24 @@ def test_te_layer_misc(dtype, model_configs, model, qkv_format):
ckpt_attn = True
fused_qkv_params = True
RoPE = True
- test_transformer_layer(dtype, model_configs, model,
- ckpt_attn, qkv_format, fused_qkv_params, RoPE)
+ test_transformer_layer(
+ dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
+ )
-@pytest.mark.skipif(get_cudnn_version() < (8,9,1), reason="cuDNN 8.9.1+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
@pytest.mark.parametrize("dtype", param_types_lean)
@pytest.mark.parametrize("model_configs", [model_configs_te_layer])
@pytest.mark.parametrize("model", ["te_2_0", "te_2_1", "te_2_2"])
def test_te_layer_mqa_gqa(dtype, model_configs, model):
"""Test TransformerLayer module with MQA/GQA"""
+
def find_factors(x):
- f = []
- for i in range(2, x + 1):
- if x % i == 0:
- f.append(i)
- return f
+ f = []
+ for i in range(2, x + 1):
+ if x % i == 0:
+ f.append(i)
+ return f
ckpt_attn = True
qkv_format = "bshd"
@@ -951,21 +1117,22 @@ def test_te_layer_mqa_gqa(dtype, model_configs, model):
num_querys_per_gqa_group = find_factors(config.num_heads)
for num_q_per_gqa_group in num_querys_per_gqa_group:
- config.num_gqa_groups=config.num_heads // num_q_per_gqa_group
- test_transformer_layer(dtype, model_configs, model,
- ckpt_attn, qkv_format, fused_qkv_params, RoPE)
+ config.num_gqa_groups = config.num_heads // num_q_per_gqa_group
+ test_transformer_layer(
+ dtype, model_configs, model, ckpt_attn, qkv_format, fused_qkv_params, RoPE
+ )
def _run_transformer_layer(
- dtype: torch.dtype,
- config: ModelConfig,
- backend: str,
- ckpt_attn: bool,
- qkv_format: str,
- workspace_opt: bool,
- fused_qkv_params: bool,
- RoPE: bool,
- ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
+ dtype: torch.dtype,
+ config: ModelConfig,
+ backend: str,
+ ckpt_attn: bool,
+ qkv_format: str,
+ workspace_opt: bool,
+ fused_qkv_params: bool,
+ RoPE: bool,
+) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
"""Run TransformerLayer module with one forward pass and one backward pass"""
# Set RNG and environment variables
@@ -978,29 +1145,47 @@ def _run_transformer_layer(
os.environ["NVTE_FUSED_ATTN"] = "1"
# Create input tensor
- inp = torch.randn(config.max_seqlen_q, config.batch_size, config.hidden_size,
- dtype=dtype, device="cuda", requires_grad = True)
+ inp = torch.randn(
+ config.max_seqlen_q,
+ config.batch_size,
+ config.hidden_size,
+ dtype=dtype,
+ device="cuda",
+ requires_grad=True,
+ )
# In case the format to be tested is batch-first, need to transpose the
# input tensor.
if qkv_format == "bshd":
- inp = inp.transpose(0,1)
+ inp = inp.transpose(0, 1)
# Create seqlens
if "padding" in config.attn_mask_type:
- seqlens_q = torch.randint(1, config.max_seqlen_q, [config.batch_size],
- dtype=torch.int32, device="cuda")
+ seqlens_q = torch.randint(
+ 1, config.max_seqlen_q, [config.batch_size], dtype=torch.int32, device="cuda"
+ )
else:
- seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
- dtype=torch.int32, device="cuda")
+ seqlens_q = torch.full(
+ [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
+ )
# Create attention mask if padding
attention_mask = None
if "padding" in config.attn_mask_type:
attention_mask_q = torch.Tensor([]).to(dtype=torch.bool)
for i in range(config.batch_size):
- attention_mask_q = torch.cat([attention_mask_q,
- torch.Tensor([False]*seqlens_q[i] + [True]*(config.max_seqlen_q-seqlens_q[i]))
- .to(torch.bool).unsqueeze(0).unsqueeze(0).unsqueeze(0)], dim=0)
+ attention_mask_q = torch.cat(
+ [
+ attention_mask_q,
+ torch.Tensor(
+ [False] * seqlens_q[i] + [True] * (config.max_seqlen_q - seqlens_q[i])
+ )
+ .to(torch.bool)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ .unsqueeze(0),
+ ],
+ dim=0,
+ )
attention_mask = attention_mask_q.to(device="cuda")
sigma = 0.02
@@ -1009,14 +1194,19 @@ def _run_transformer_layer(
layer_number = 1
drop_path_rate = 0.0
- drop_path_rates = [
- rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
+ drop_path_rates = [rate.item() for rate in torch.linspace(0, drop_path_rate, config.num_layers)]
# Create bias
bias = None
- if config.attn_bias_type == 'post_scale_bias':
- bias = torch.randn(1, config.num_heads, config.max_seqlen_q, config.max_seqlen_kv,
- dtype=dtype, device="cuda")
+ if config.attn_bias_type == "post_scale_bias":
+ bias = torch.randn(
+ 1,
+ config.num_heads,
+ config.max_seqlen_q,
+ config.max_seqlen_kv,
+ dtype=dtype,
+ device="cuda",
+ )
# Create RoPE
rotary_pos_emb = None
@@ -1025,58 +1215,56 @@ def _run_transformer_layer(
rotary_pos_emb = PE(config.max_seqlen_q).to(device="cuda")
# Set up model
- block = (
- TransformerLayer(
- config.hidden_size,
- 4 * config.hidden_size,
- config.num_heads,
- num_gqa_groups=config.num_gqa_groups,
- layernorm_epsilon=1e-5,
- hidden_dropout=0.0,
- attention_dropout=config.dropout_p,
- init_method=init_method,
- output_layer_init_method=output_layer_init_method,
- layer_number=layer_number,
- kv_channels=config.head_dim,
- self_attn_mask_type=config.attn_mask_type,
- tp_group=None,
- tp_size=1,
- params_dtype=dtype,
- get_rng_state_tracker=None,
- fuse_wgrad_accumulation=False,
- seq_length=config.max_seqlen_q,
- micro_batch_size=config.batch_size,
- sequence_parallel=False,
- apply_residual_connection_post_layernorm=False,
- output_layernorm=False,
- layer_type="encoder",
- drop_path_rate=drop_path_rates[layer_number - 1],
- set_parallel_mode=True,
- fuse_qkv_params=fused_qkv_params,
- zero_centered_gamma=False,
- qkv_weight_interleaved=False,
- ub_tp_comm_overlap=False,
- bias=True,
- attn_input_format=qkv_format,
- )
- .to(dtype=dtype, device="cuda")
- )
+ block = TransformerLayer(
+ config.hidden_size,
+ 4 * config.hidden_size,
+ config.num_heads,
+ num_gqa_groups=config.num_gqa_groups,
+ layernorm_epsilon=1e-5,
+ hidden_dropout=0.0,
+ attention_dropout=config.dropout_p,
+ init_method=init_method,
+ output_layer_init_method=output_layer_init_method,
+ layer_number=layer_number,
+ kv_channels=config.head_dim,
+ self_attn_mask_type=config.attn_mask_type,
+ tp_group=None,
+ tp_size=1,
+ params_dtype=dtype,
+ get_rng_state_tracker=None,
+ fuse_wgrad_accumulation=False,
+ seq_length=config.max_seqlen_q,
+ micro_batch_size=config.batch_size,
+ sequence_parallel=False,
+ apply_residual_connection_post_layernorm=False,
+ output_layernorm=False,
+ layer_type="encoder",
+ drop_path_rate=drop_path_rates[layer_number - 1],
+ set_parallel_mode=True,
+ fuse_qkv_params=fused_qkv_params,
+ zero_centered_gamma=False,
+ qkv_weight_interleaved=False,
+ ub_tp_comm_overlap=False,
+ bias=True,
+ attn_input_format=qkv_format,
+ ).to(dtype=dtype, device="cuda")
# Create ALiBi slopes
alibi_slopes = None
if config.attn_bias_type == "alibi" and config.alibi_type == "custom":
- alibi_slopes = torch.randn(
- config.num_heads).abs().to(dtype=torch.float32, device="cuda")
+ alibi_slopes = torch.randn(config.num_heads).abs().to(dtype=torch.float32, device="cuda")
# Run a forward and backward pass
- out = block(inp,
+ out = block(
+ inp,
attention_mask=attention_mask,
self_attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
rotary_pos_emb=rotary_pos_emb,
core_attention_bias_type=config.attn_bias_type,
core_attention_bias=bias,
- alibi_slopes=alibi_slopes)
+ alibi_slopes=alibi_slopes,
+ )
loss = out.sum()
loss.backward()
@@ -1085,23 +1273,24 @@ def _run_transformer_layer(
model_configs_fp8_vs_f16 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "fp8_9" : ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
- "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
+ "fp8_9": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
+ "fp8_10": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
"fp8_11": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
- "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
- "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
- "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
+ "fp8_12": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"),
+ "fp8_13": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"),
+ "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"),
}
param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16]
-qkv_layout_fp8_vs_f16 = ['sbh3d', 'bshd_bshd_bshd', 'sbhd_sbhd_sbhd']
-qkv_format_fp8_vs_f16 = ['bshd', 'sbhd']
+qkv_layout_fp8_vs_f16 = ["sbh3d", "bshd_bshd_bshd", "sbhd_sbhd_sbhd"]
+qkv_format_fp8_vs_f16 = ["bshd", "sbhd"]
+
def _rmse(a, b):
- return math.sqrt((torch.pow((a-b), 2)/a.numel()).sum())
+ return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum())
-@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@@ -1118,58 +1307,78 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd):
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
- dtype, config, True, qkv_format, input_layernorm)
+ dtype, config, True, qkv_format, input_layernorm
+ )
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
- dtype, config, False, qkv_format, input_layernorm)
+ dtype, config, False, qkv_format, input_layernorm
+ )
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
- fwd_range = max(fused_attn_fwd_fp8.max().item(),
- fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
- fused_attn_fwd_f16.min().item())
-
- logging.debug('========== {:^25s} =========='.format('forward output'))
- logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
- fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
- logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
- fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
- logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
+ fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
+ fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
+ )
+
+ logging.debug("========== {:^25s} ==========".format("forward output"))
+ logging.debug(
+ "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
+ fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
+ )
+ )
+ logging.debug(
+ "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
+ fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
+ )
+ )
+ logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
- assert(fwd_rmse < rmse_tol * fwd_range
- ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
+ assert (
+ fwd_rmse < rmse_tol * fwd_range
+ ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
+ fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
+ )
for i in range(len(param_names[:1])):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
- bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
- fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
- fused_attn_bwd_f16[i].min().item())
-
- logging.debug('========== {:^25s} =========='.format(param_names[i]))
- logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
- fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
- logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
- fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
- logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
+ bwd_range = max(
+ fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
+ ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
+
+ logging.debug("========== {:^25s} ==========".format(param_names[i]))
+ logging.debug(
+ "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
+ i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
+ )
+ )
+ logging.debug(
+ "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
+ i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
+ )
+ )
+ logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
- assert(bwd_rmse < rmse_tol * bwd_range
- ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
+ assert (
+ bwd_rmse < rmse_tol * bwd_range
+ ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
+ bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
+ )
+
def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
+
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
@@ -1184,7 +1393,7 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
)
with fp8_model_init(enabled=fp8_mha):
- mha = (MultiheadAttention(
+ mha = MultiheadAttention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_heads,
kv_channels=config.head_dim,
@@ -1199,34 +1408,35 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
attention_type="self",
qkv_weight_interleaved=True,
qkv_format=qkv_format,
- ).to(dtype=dtype, device="cuda")
- )
+ ).to(dtype=dtype, device="cuda")
- seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
- dtype=torch.int32, device="cuda")
- seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
- dtype=torch.int32, device="cuda")
+ seqlens_q = torch.full(
+ [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
+ )
+ seqlens_kv = torch.full(
+ [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
+ )
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = {
- 'b' : config.batch_size,
- 'sq' : config.max_seqlen_q,
- 'skv': config.max_seqlen_kv,
- 'h' : config.num_heads,
- 'hg' : config.num_gqa_groups,
- 'd' : config.head_dim,
- 't' : cu_seqlens_q[-1],
- 'tg' : cu_seqlens_kv[-1],
- '3' : 3,
- '2' : 2,
- '1' : 1,
- }
- layout = '_'.join(qkv_format)
- layout = layout.replace('s', 'sq')
- tensor_shape = [dim_to_num[j] for j in layout.split('_')]
+ "b": config.batch_size,
+ "sq": config.max_seqlen_q,
+ "skv": config.max_seqlen_kv,
+ "h": config.num_heads,
+ "hg": config.num_gqa_groups,
+ "d": config.head_dim,
+ "t": cu_seqlens_q[-1],
+ "tg": cu_seqlens_kv[-1],
+ "3": 3,
+ "2": 2,
+ "1": 1,
+ }
+ layout = "_".join(qkv_format)
+ layout = layout.replace("s", "sq")
+ tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = 0.01 * torch.randint(-100, 100, tensor_shape, dtype=dtype, device="cuda")
hidden_states = tensor.view(*tensor.shape[:-2], -1)
hidden_states.requires_grad = True
@@ -1234,27 +1444,28 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
out_grad = tensor.view(*tensor.shape[:-2], -1)
with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe):
- out = mha(hidden_states,
+ out = mha(
+ hidden_states,
attn_mask_type=config.attn_mask_type,
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=None,
- )
+ )
out.backward(out_grad)
param_names = []
- param_names.append('hidden_states.grad')
+ param_names.append("hidden_states.grad")
params = []
params.append(hidden_states)
for name, param in mha.named_parameters():
if param.requires_grad:
- param_names.append(name+'.grad')
+ param_names.append(name + ".grad")
params.append(param)
return out, param_names, tuple(x.grad for x in params)
-@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8_vs_f16)
@@ -1264,62 +1475,75 @@ def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm):
def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd):
config = model_configs_fp8_vs_f16[model]
- if (config.num_heads != config.num_gqa_groups and '3' in qkv_layout):
- pytest.skip("qkv_layout not applicable for MQA/GQA");
+ if config.num_heads != config.num_gqa_groups and "3" in qkv_layout:
+ pytest.skip("qkv_layout not applicable for MQA/GQA")
os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0"
-
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True")
- fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(
- dtype, config, True, qkv_layout)
+ fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16(dtype, config, True, qkv_layout)
logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False")
- fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(
- dtype, config, False, qkv_layout)
+ fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16(dtype, config, False, qkv_layout)
tols = dict(atol=5e-1, rtol=5e-2)
rmse_tol = 0.1
- bwd_names = ['dq', 'dk', 'dv']
+ bwd_names = ["dq", "dk", "dv"]
fwd_rmse = _rmse(fused_attn_fwd_fp8, fused_attn_fwd_f16)
- fwd_range = max(fused_attn_fwd_fp8.max().item(),
- fused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
- fused_attn_fwd_f16.min().item())
-
- logging.debug('========== {:^25s} =========='.format('forward output'))
- logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
- fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
- logging.debug('fused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
- fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()))
- logging.debug('fused_attn_fwd RMSE: {:.6f}'.format(fwd_rmse))
+ fwd_range = max(fused_attn_fwd_fp8.max().item(), fused_attn_fwd_f16.max().item()) - min(
+ fused_attn_fwd_fp8.min().item(), fused_attn_fwd_f16.min().item()
+ )
+
+ logging.debug("========== {:^25s} ==========".format("forward output"))
+ logging.debug(
+ "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
+ fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
+ )
+ )
+ logging.debug(
+ "fused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
+ fused_attn_fwd_f16.min().item(), fused_attn_fwd_f16.max().item()
+ )
+ )
+ logging.debug("fused_attn_fwd RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, fused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
- assert(fwd_rmse < rmse_tol * fwd_range
- ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
- for i,_ in enumerate(fused_attn_bwd_f16):
+ assert (
+ fwd_rmse < rmse_tol * fwd_range
+ ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
+ fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
+ )
+ for i, _ in enumerate(fused_attn_bwd_f16):
bwd_rmse = _rmse(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i])
- bwd_range = max(fused_attn_bwd_fp8[i].max().item(),
- fused_attn_bwd_f16[i].max().item()) - min(fused_attn_bwd_fp8[i].min().item(),
- fused_attn_bwd_f16[i].min().item())
-
- logging.debug('========== {:^25s} =========='.format(bwd_names[i]))
- logging.debug('fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}'.format(i,
- fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()))
- logging.debug('fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}'.format(i,
- fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()))
- logging.debug('fused_attn_bwd RMSE[{}]: {:.6f}'.format(i, bwd_rmse))
+ bwd_range = max(
+ fused_attn_bwd_fp8[i].max().item(), fused_attn_bwd_f16[i].max().item()
+ ) - min(fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_f16[i].min().item())
+
+ logging.debug("========== {:^25s} ==========".format(bwd_names[i]))
+ logging.debug(
+ "fused_attn_bwd_fp8[{}] min {:.6f} max {:.6f}".format(
+ i, fused_attn_bwd_fp8[i].min().item(), fused_attn_bwd_fp8[i].max().item()
+ )
+ )
+ logging.debug(
+ "fused_attn_bwd_f16[{}] min {:.6f} max {:.6f}".format(
+ i, fused_attn_bwd_f16[i].min().item(), fused_attn_bwd_f16[i].max().item()
+ )
+ )
+ logging.debug("fused_attn_bwd RMSE[{}]: {:.6f}".format(i, bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], **tols)
except Exception as e:
logging.debug(e)
- assert(bwd_rmse < rmse_tol * bwd_range
- ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
+ assert (
+ bwd_rmse < rmse_tol * bwd_range
+ ), "BWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
+ bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
+ )
def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
@@ -1327,6 +1551,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
reset_rng_states()
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
+
def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
@@ -1339,60 +1564,60 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
fp8_dpa=fp8_dpa,
)
- qkv_format = ''.join([i for i in qkv_layout.split('_')[0] if i.isalpha()])
+ qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()])
with fp8_model_init(enabled=fp8_dpa):
- dpa = (
- DotProductAttention(
- config.num_heads,
- config.head_dim,
- num_gqa_groups=config.num_gqa_groups,
- attention_dropout=config.dropout_p,
- sequence_parallel=False,
- tp_size=1,
- get_rng_state_tracker=get_dummy_cuda_rng_tracker,
- tp_group=None,
- layer_number=1,
- attention_type="self",
- qkv_format=qkv_format,
- ).to(dtype=dtype, device="cuda")
- )
+ dpa = DotProductAttention(
+ config.num_heads,
+ config.head_dim,
+ num_gqa_groups=config.num_gqa_groups,
+ attention_dropout=config.dropout_p,
+ sequence_parallel=False,
+ tp_size=1,
+ get_rng_state_tracker=get_dummy_cuda_rng_tracker,
+ tp_group=None,
+ layer_number=1,
+ attention_type="self",
+ qkv_format=qkv_format,
+ ).to(dtype=dtype, device="cuda")
- seqlens_q = torch.full([config.batch_size], config.max_seqlen_q,
- dtype=torch.int32, device="cuda")
- seqlens_kv = torch.full([config.batch_size], config.max_seqlen_kv,
- dtype=torch.int32, device="cuda")
+ seqlens_q = torch.full(
+ [config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda"
+ )
+ seqlens_kv = torch.full(
+ [config.batch_size], config.max_seqlen_kv, dtype=torch.int32, device="cuda"
+ )
cu_seqlens_q = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_kv = torch.zeros(config.batch_size + 1, dtype=torch.int32, device="cuda")
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_kv[1:] = torch.cumsum(seqlens_kv, dim=0)
dim_to_num = {
- 'b' : config.batch_size,
- 'sq' : config.max_seqlen_q,
- 'skv': config.max_seqlen_kv,
- 'h' : config.num_heads,
- 'hg' : config.num_gqa_groups,
- 'd' : config.head_dim,
- 't' : cu_seqlens_q[-1],
- 'tg' : cu_seqlens_kv[-1],
- '3' : 3,
- '2' : 2,
- '1' : 1,
- }
+ "b": config.batch_size,
+ "sq": config.max_seqlen_q,
+ "skv": config.max_seqlen_kv,
+ "h": config.num_heads,
+ "hg": config.num_gqa_groups,
+ "d": config.head_dim,
+ "t": cu_seqlens_q[-1],
+ "tg": cu_seqlens_kv[-1],
+ "3": 3,
+ "2": 2,
+ "1": 1,
+ }
inp = []
- for i,layout in enumerate(qkv_layout.split('_')):
- layout = '_'.join(layout)
+ for i, layout in enumerate(qkv_layout.split("_")):
+ layout = "_".join(layout)
if i == 0:
- layout = layout.replace('s', 'sq')
+ layout = layout.replace("s", "sq")
else:
- layout = layout.replace('s', 'skv')
- layout = layout.replace('h', 'hg')
- layout = layout.replace('t', 'tg')
- tensor_shape = [dim_to_num[j] for j in layout.split('_')]
+ layout = layout.replace("s", "skv")
+ layout = layout.replace("h", "hg")
+ layout = layout.replace("t", "tg")
+ tensor_shape = [dim_to_num[j] for j in layout.split("_")]
tensor = torch.randn(tensor_shape, dtype=dtype, device="cuda")
tensor_count = 1
split_dim = 0
- for dim, l in enumerate(layout.split('_')):
+ for dim, l in enumerate(layout.split("_")):
if l.isdigit():
tensor_count = int(l)
split_dim = dim
@@ -1406,14 +1631,17 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
for i in range(3):
inp[i].requires_grad = True
- qkv_format_kv = '_'.join(qkv_format)
- qkv_format_kv = qkv_format_kv.replace('s', 'sq')
- out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split('_')]
+ qkv_format_kv = "_".join(qkv_format)
+ qkv_format_kv = qkv_format_kv.replace("s", "sq")
+ out_grad_shape = [dim_to_num[i] for i in qkv_format_kv.split("_")]
out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]]
out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda")
with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe):
- out = dpa(inp[0], inp[1], inp[2],
+ out = dpa(
+ inp[0],
+ inp[1],
+ inp[2],
qkv_format=qkv_format,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
@@ -1423,7 +1651,7 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
checkpoint_core_attention=False,
core_attention_bias_type=config.attn_bias_type,
is_first_microbatch=True,
- )
+ )
out.backward(out_grad)
return out, (inp[0].grad, inp[1].grad, inp[2].grad)
@@ -1431,22 +1659,22 @@ def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout):
model_configs_fp8 = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
- "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
- "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
+ "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"),
+ "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"),
+ "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
"fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"),
- "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
- "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
- "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
- "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
+ "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"),
+ "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"),
+ "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"),
+ "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"),
}
param_types_fp8 = [torch.float16, torch.bfloat16]
-cudnn_frontend_version = int(os.getenv('NVTE_FUSED_ATTN_FE_VER','1'))
-models_v0 = ['fp8_1', 'fp8_2', 'fp8_5', 'fp8_6']
-models_v1 = ['fp8_3', 'fp8_4', 'fp8_7', 'fp8_8']
+cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1"))
+models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"]
+models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"]
-@pytest.mark.skipif(get_cudnn_version() < (8,9,3), reason="cuDNN 8.9.3+ is required.")
+@pytest.mark.skipif(get_cudnn_version() < (8, 9, 3), reason="cuDNN 8.9.3+ is required.")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.")
@pytest.mark.parametrize("dtype", param_types_fp8)
@@ -1460,50 +1688,62 @@ def test_custom_mha_fp8_vs_f16(dtype, model):
config = model_configs_fp8[model]
- fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(
- dtype, config, "FusedAttention")
- unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(
- dtype, config, "UnfusedAttention")
+ fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention")
+ unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention")
tols = dict(atol=5e-1, rtol=5e-1)
rmse_tol = 0.1
fwd_rmse = _rmse(fused_attn_fwd_fp8, unfused_attn_fwd_f16)
- fwd_range = max(fused_attn_fwd_fp8.max().item(),
- unfused_attn_fwd_f16.max().item()) - min(fused_attn_fwd_fp8.min().item(),
- unfused_attn_fwd_f16.min().item())
+ fwd_range = max(fused_attn_fwd_fp8.max().item(), unfused_attn_fwd_f16.max().item()) - min(
+ fused_attn_fwd_fp8.min().item(), unfused_attn_fwd_f16.min().item()
+ )
bwd_rmse = _rmse(fused_attn_bwd_fp8, unfused_attn_bwd_f16)
- bwd_range = max(fused_attn_bwd_fp8.max().item(),
- unfused_attn_bwd_f16.max().item()) - min(fused_attn_bwd_fp8.min().item(),
- unfused_attn_bwd_f16.min().item())
-
- logging.debug('fused_attn_fwd_fp8 min {:.6f} max {:.6f}'.format(
- fused_attn_fwd_fp8.min().item(),fused_attn_fwd_fp8.max().item()))
- logging.debug('unfused_attn_fwd_f16 min {:.6f} max {:.6f}'.format(
- unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()))
- logging.debug('fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}'.format(
- fwd_rmse))
+ bwd_range = max(fused_attn_bwd_fp8.max().item(), unfused_attn_bwd_f16.max().item()) - min(
+ fused_attn_bwd_fp8.min().item(), unfused_attn_bwd_f16.min().item()
+ )
+
+ logging.debug(
+ "fused_attn_fwd_fp8 min {:.6f} max {:.6f}".format(
+ fused_attn_fwd_fp8.min().item(), fused_attn_fwd_fp8.max().item()
+ )
+ )
+ logging.debug(
+ "unfused_attn_fwd_f16 min {:.6f} max {:.6f}".format(
+ unfused_attn_fwd_f16.min().item(), unfused_attn_fwd_f16.max().item()
+ )
+ )
+ logging.debug("fused_attn_fwd_fp8 vs unfused_attn_fwd_f16 RMSE: {:.6f}".format(fwd_rmse))
try:
torch.testing.assert_close(fused_attn_fwd_fp8, unfused_attn_fwd_f16, **tols)
except Exception as e:
logging.debug(e)
- logging.debug('fused_attn_bwd_fp8 min {:.6f} max {:.6f}'.format(
- fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()))
- logging.debug('unfused_attn_bwd_f16 min {:.6f} max {:.6f}'.format(
- unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()))
- logging.debug('fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}'.format(
- bwd_rmse))
+ logging.debug(
+ "fused_attn_bwd_fp8 min {:.6f} max {:.6f}".format(
+ fused_attn_bwd_fp8.min().item(), fused_attn_bwd_fp8.max().item()
+ )
+ )
+ logging.debug(
+ "unfused_attn_bwd_f16 min {:.6f} max {:.6f}".format(
+ unfused_attn_bwd_f16.min().item(), unfused_attn_bwd_f16.max().item()
+ )
+ )
+ logging.debug("fused_attn_bwd_fp8 vs unfused_attn_bwd_f16 RMSE: {:.6f}".format(bwd_rmse))
try:
torch.testing.assert_close(fused_attn_bwd_fp8, unfused_attn_bwd_f16, **tols)
except Exception as e:
logging.debug(e)
- assert(fwd_rmse < rmse_tol * fwd_range
- ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range)
- assert(bwd_rmse < rmse_tol * bwd_range
- ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
- bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range)
+ assert (
+ fwd_rmse < rmse_tol * fwd_range
+ ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
+ fwd_rmse, rmse_tol * fwd_range, rmse_tol, fwd_range
+ )
+ assert (
+ bwd_rmse < rmse_tol * bwd_range
+ ), "FWD RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format(
+ bwd_rmse, rmse_tol * bwd_range, rmse_tol, bwd_range
+ )
def _run_custom_mha_fp8(dtype, config, backend):
@@ -1517,18 +1757,25 @@ def _run_custom_mha_fp8(dtype, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
- inp = 0.0001 * torch.randint(-100, 100,
- (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
- dtype=dtype, device="cuda", requires_grad=True)
- seqlens = torch.full([config.batch_size], config.max_seqlen_q,
- dtype=torch.int32, device="cuda")
+ inp = 0.0001 * torch.randint(
+ -100,
+ 100,
+ (config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim),
+ dtype=dtype,
+ device="cuda",
+ requires_grad=True,
+ )
+ seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
out_grad = 0.01 * torch.randn(
- config.batch_size * config.max_seqlen_q, config.num_heads * config.head_dim,
- dtype=dtype, device="cuda")
- torch.save(out_grad, 'out_grad.pt')
+ config.batch_size * config.max_seqlen_q,
+ config.num_heads * config.head_dim,
+ dtype=dtype,
+ device="cuda",
+ )
+ torch.save(out_grad, "out_grad.pt")
fp8_recipe = recipe.DelayedScaling(
margin=0,
@@ -1543,10 +1790,13 @@ def _run_custom_mha_fp8(dtype, config, backend):
out.backward(out_grad)
out = torch.load("out.pt")
- dqkv = torch.load('dqkv.pt')
- return (out.view(config.batch_size, config.max_seqlen_q, -1),
- dqkv.view(config.batch_size, config.max_seqlen_q, 3,
- config.num_heads, config.head_dim).contiguous())
+ dqkv = torch.load("dqkv.pt")
+ return (
+ out.view(config.batch_size, config.max_seqlen_q, -1),
+ dqkv.view(
+ config.batch_size, config.max_seqlen_q, 3, config.num_heads, config.head_dim
+ ).contiguous(),
+ )
def _run_ref_mha_f16(dtype, config, backend):
@@ -1560,13 +1810,14 @@ def _run_ref_mha_f16(dtype, config, backend):
if backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"
- inp = torch.load('qkv.pt').to(device="cuda")
+ inp = torch.load("qkv.pt").to(device="cuda")
inp.requires_grad = True
seqlens = torch.full([config.batch_size], config.max_seqlen_q, dtype=torch.int32, device="cuda")
cu_seqlens = torch.zeros(config.batch_size + 1, device="cuda", dtype=torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim=0)
- out_grad = torch.load('out_grad.pt').to(device="cuda").view(
- config.batch_size, config.max_seqlen_q, -1)
+ out_grad = (
+ torch.load("out_grad.pt").to(device="cuda").view(config.batch_size, config.max_seqlen_q, -1)
+ )
_DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
_DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed)
@@ -1582,24 +1833,22 @@ def _run_ref_mha_f16(dtype, config, backend):
"""Get cuda rng tracker."""
return _DUMMY_CUDA_RNG_STATE_TRACKER
- block = (
- DotProductAttention(
- config.num_heads,
- config.head_dim,
- attention_dropout=config.dropout_p,
- sequence_parallel=False,
- tp_size=1,
- get_rng_state_tracker=get_dummy_cuda_rng_tracker,
- tp_group=None,
- layer_number=1,
- attention_type="self",
- qkv_format="bshd",
- ).to(dtype=dtype, device="cuda")
- )
-
- q = inp[:,:,0,:,:]
- k = inp[:,:,1,:,:]
- v = inp[:,:,2,:,:]
+ block = DotProductAttention(
+ config.num_heads,
+ config.head_dim,
+ attention_dropout=config.dropout_p,
+ sequence_parallel=False,
+ tp_size=1,
+ get_rng_state_tracker=get_dummy_cuda_rng_tracker,
+ tp_group=None,
+ layer_number=1,
+ attention_type="self",
+ qkv_format="bshd",
+ ).to(dtype=dtype, device="cuda")
+
+ q = inp[:, :, 0, :, :]
+ k = inp[:, :, 1, :, :]
+ v = inp[:, :, 2, :, :]
out = block(q, k, v, attn_mask_type=config.attn_mask_type)
out.backward(out_grad)
@@ -1611,12 +1860,12 @@ _2X_ACC_FPROP = False
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False
-META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
+META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
-META_O = tex.FP8FwdTensors.GEMM2_INPUT
-META_DO = tex.FP8BwdTensors.GRAD_INPUT2
-META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
-META_DP = tex.FP8BwdTensors.GRAD_INPUT3
+META_O = tex.FP8FwdTensors.GEMM2_INPUT
+META_DO = tex.FP8BwdTensors.GRAD_INPUT2
+META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
+META_DP = tex.FP8BwdTensors.GRAD_INPUT3
class _custom_mha_fp8(torch.autograd.Function):
@@ -1682,46 +1931,57 @@ class _custom_mha_fp8(torch.autograd.Function):
D_dtype=fp8_dtype_forward,
)
qkv = qkv.view(-1, 3, h, d)
- qkv_fp16 = ext.cast_from_fp8(qkv, fp8_meta["scaling_fwd"],
- META_QKV, fp8_dtype_forward,
- tex.DType.kFloat16).view(b, max_s, 3, h, d).contiguous()
- torch.save(qkv_fp16, 'qkv.pt')
+ qkv_fp16 = (
+ ext.cast_from_fp8(
+ qkv, fp8_meta["scaling_fwd"], META_QKV, fp8_dtype_forward, tex.DType.kFloat16
+ )
+ .view(b, max_s, 3, h, d)
+ .contiguous()
+ )
+ torch.save(qkv_fp16, "qkv.pt")
if cudnn_frontend_version == 1:
- qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
+ qkv = qkv.view(b, max_s, 3, h, d) # bs3hd
# FMHA
out, aux_ctx_tensors, *rest = fused_attn_fwd(
- is_training,
- max_s,
- max_s,
- cu_seqlens,
- cu_seqlens,
- qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:],
- qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:],
- qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
- fp8_dtype_forward,
- FusedAttnBackend["FP8"],
- None, None, None, None, None,
- fp8_meta["scaling_fwd"].scale_inv[META_QKV],
- fp8_meta["scaling_fwd"].scale_inv[META_S],
- fp8_meta["scaling_fwd"].scale[META_S],
- fp8_meta["scaling_fwd"].scale[META_O],
- fp8_meta["scaling_fwd"].amax_history[0][META_S],
- fp8_meta["scaling_fwd"].amax_history[0][META_O],
- attn_scale=None,
- dropout=p_dropout,
- fast_zero_fill=fast_zero_fill,
- qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
- attn_bias_type="no_bias",
- attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
- rng_gen=None,
- )
+ is_training,
+ max_s,
+ max_s,
+ cu_seqlens,
+ cu_seqlens,
+ qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :],
+ qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :],
+ qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :],
+ fp8_dtype_forward,
+ FusedAttnBackend["FP8"],
+ None,
+ None,
+ None,
+ None,
+ None,
+ fp8_meta["scaling_fwd"].scale_inv[META_QKV],
+ fp8_meta["scaling_fwd"].scale_inv[META_S],
+ fp8_meta["scaling_fwd"].scale[META_S],
+ fp8_meta["scaling_fwd"].scale[META_O],
+ fp8_meta["scaling_fwd"].amax_history[0][META_S],
+ fp8_meta["scaling_fwd"].amax_history[0][META_O],
+ attn_scale=None,
+ dropout=p_dropout,
+ fast_zero_fill=fast_zero_fill,
+ qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
+ attn_bias_type="no_bias",
+ attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding",
+ rng_gen=None,
+ )
M, ZInv, philox_unpacked = aux_ctx_tensors
ctx.save_for_backward(
- inp_t_fp8, qkv_weight_t_fp8, workspace,
- qkv, out,
+ inp_t_fp8,
+ qkv_weight_t_fp8,
+ workspace,
+ qkv,
+ out,
fp8_meta["scaling_fwd"].scale,
fp8_meta["scaling_fwd"].scale_inv,
)
@@ -1736,82 +1996,84 @@ class _custom_mha_fp8(torch.autograd.Function):
ctx.mask_type = mask_type
ctx.dtype = inp.dtype
- out = out.view(-1, in_features) # (bs)(hd)
- out_fp16 = ext.cast_from_fp8(out, fp8_meta["scaling_fwd"],
- META_O, fp8_dtype_forward, tex.DType.kFloat16)
- torch.save(out_fp16, 'out.pt') # (bs)(hd)
+ out = out.view(-1, in_features) # (bs)(hd)
+ out_fp16 = ext.cast_from_fp8(
+ out, fp8_meta["scaling_fwd"], META_O, fp8_dtype_forward, tex.DType.kFloat16
+ )
+ torch.save(out_fp16, "out.pt") # (bs)(hd)
return out_fp16
-
@staticmethod
- def backward(
- ctx, grad_output: torch.Tensor
- ) -> Tuple[Union[torch.Tensor, None], ...]:
+ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]:
with torch.cuda.nvtx.range("_DPA"):
(
inp_t_fp8,
qkv_weight_t_fp8,
workspace,
- qkv, out,
+ qkv,
+ out,
fwd_scales,
fwd_scale_inverses,
) = ctx.saved_tensors
- fp8_dtype_forward = fp8.get_fp8_te_dtype(
- ctx.fp8_meta["recipe"], fprop_tensor=True
- )
- fp8_dtype_backward = fp8.get_fp8_te_dtype(
- ctx.fp8_meta["recipe"], fprop_tensor=False
- )
+ fp8_dtype_forward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=True)
+ fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False)
proj_dgrad = ext.cast_to_fp8(
grad_output, ctx.fp8_meta["scaling_bwd"], META_DO, fp8_dtype_backward
- ) # (bs)(hd)
+ ) # (bs)(hd)
dq, dk, dv, *rest = fused_attn_bwd(
- ctx.max_s,
- ctx.max_s,
- ctx.cu_seqlens,
- ctx.cu_seqlens,
- qkv[:,:,0,:,:] if cudnn_frontend_version == 1 else qkv[:,0,:,:],
- qkv[:,:,1,:,:] if cudnn_frontend_version == 1 else qkv[:,1,:,:],
- qkv[:,:,2,:,:] if cudnn_frontend_version == 1 else qkv[:,2,:,:],
- out,
- proj_dgrad.view_as(out),
- fp8_dtype_forward,
- fp8_dtype_backward,
- ctx.aux_ctx_tensors,
- FusedAttnBackend["FP8"],
- None, None, None, None,
- fwd_scale_inverses[META_QKV], # d_scale_qkv,
- fwd_scale_inverses[META_S], # d_scale_s,
- fwd_scale_inverses[META_O], # d_scale_o,
- ctx.fp8_meta['scaling_bwd'].scale_inv[META_DO], # d_scale_do
- ctx.fp8_meta['scaling_bwd'].scale_inv[META_DP], # d_scale_dp
- fwd_scales[META_S], # q_scale_s
- ctx.fp8_meta['scaling_bwd'].scale[META_DP], # q_scale_dp
- ctx.fp8_meta['scaling_bwd'].scale[META_DQKV], # q_scale_dqkv
- ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DP], # amax_dp
- ctx.fp8_meta['scaling_bwd'].amax_history[0][META_DQKV], # amax_dqkv
- attn_scale=None,
- dropout=ctx.p_dropout,
- fast_zero_fill=ctx.fast_zero_fill,
- qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
- attn_bias_type="no_bias",
- attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
- )
+ ctx.max_s,
+ ctx.max_s,
+ ctx.cu_seqlens,
+ ctx.cu_seqlens,
+ qkv[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv[:, 0, :, :],
+ qkv[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv[:, 1, :, :],
+ qkv[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv[:, 2, :, :],
+ out,
+ proj_dgrad.view_as(out),
+ fp8_dtype_forward,
+ fp8_dtype_backward,
+ ctx.aux_ctx_tensors,
+ FusedAttnBackend["FP8"],
+ None,
+ None,
+ None,
+ None,
+ fwd_scale_inverses[META_QKV], # d_scale_qkv,
+ fwd_scale_inverses[META_S], # d_scale_s,
+ fwd_scale_inverses[META_O], # d_scale_o,
+ ctx.fp8_meta["scaling_bwd"].scale_inv[META_DO], # d_scale_do
+ ctx.fp8_meta["scaling_bwd"].scale_inv[META_DP], # d_scale_dp
+ fwd_scales[META_S], # q_scale_s
+ ctx.fp8_meta["scaling_bwd"].scale[META_DP], # q_scale_dp
+ ctx.fp8_meta["scaling_bwd"].scale[META_DQKV], # q_scale_dqkv
+ ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DP], # amax_dp
+ ctx.fp8_meta["scaling_bwd"].amax_history[0][META_DQKV], # amax_dqkv
+ attn_scale=None,
+ dropout=ctx.p_dropout,
+ fast_zero_fill=ctx.fast_zero_fill,
+ qkv_layout="bs3hd" if cudnn_frontend_version == 1 else "t3hd",
+ attn_bias_type="no_bias",
+ attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding",
+ )
dim = 2 if cudnn_frontend_version == 1 else 1
dqkv = torch.Tensor().to(device=dq.device, dtype=dq.dtype)
dqkv_shape = list(dq.shape)
dqkv_shape.insert(dim, 3)
dqkv_stride = list(dq.stride())
- dqkv_stride.insert(dim, int(dqkv_stride[-3]/3))
- dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd
+ dqkv_stride.insert(dim, int(dqkv_stride[-3] / 3))
+ dqkv.set_(dq.untyped_storage(), dq.storage_offset(), dqkv_shape, dqkv_stride) # bs3hd
- dqkv_c = dqkv.view(-1, 3*ctx.hidden_size)
- dqkv_c_fp16 = ext.cast_from_fp8(dqkv_c,
- ctx.fp8_meta["scaling_bwd"], META_DQKV,
- fp8_dtype_backward, tex.DType.kFloat16)
- torch.save(dqkv_c_fp16, 'dqkv.pt')
+ dqkv_c = dqkv.view(-1, 3 * ctx.hidden_size)
+ dqkv_c_fp16 = ext.cast_from_fp8(
+ dqkv_c,
+ ctx.fp8_meta["scaling_bwd"],
+ META_DQKV,
+ fp8_dtype_backward,
+ tex.DType.kFloat16,
+ )
+ torch.save(dqkv_c_fp16, "dqkv.pt")
qkv_bgrad, dqkv_t = ext.fp8_transpose_bgrad_fused(
dqkv_c,
@@ -1850,7 +2112,8 @@ class _custom_mha_fp8(torch.autograd.Function):
use_split_accumulator=_2X_ACC_WGRAD,
)
- return (qkv_dgrad,
+ return (
+ qkv_dgrad,
qkv_wgrad,
qkv_bgrad,
None,
@@ -1862,14 +2125,12 @@ class _custom_mha_fp8(torch.autograd.Function):
None,
None,
None,
- None)
+ None,
+ )
class Custom_MHA_FP8(TransformerEngineBaseModule):
- def __init__(
- self,
- config,
- params_dtype: torch.dtype = torch.float32):
+ def __init__(self, config, params_dtype: torch.dtype = torch.float32):
super().__init__()
self.p_dropout = config.dropout_p
self.h = config.num_heads
@@ -1901,8 +2162,10 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
)
def forward(
- self, inp: torch.Tensor,
- cu_seqlens, max_s,
+ self,
+ inp: torch.Tensor,
+ cu_seqlens,
+ max_s,
) -> torch.Tensor:
with self.prepare_forward(inp, None, num_gemms=3) as inp:
out = _custom_mha_fp8.apply(
@@ -1917,5 +2180,6 @@ class Custom_MHA_FP8(TransformerEngineBaseModule):
self.fp8_meta,
self.workspace,
self.training,
- self.mask_type)
+ self.mask_type,
+ )
return out
diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
index 928ca27578beb3b0c170ac49fe1c270972e92f12..754416c837029ce774df2b9071bfcfbb06a5cec7 100644
--- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
+++ b/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
@@ -14,12 +14,13 @@ from transformer_engine.pytorch.utils import get_device_compute_capability
model_configs_flash_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
- "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
- "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
- "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
+ "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
+ "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
+ "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
+ "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
}
+
def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
@@ -29,46 +30,43 @@ def get_bash_arguments(**kwargs):
args.append(f"{k}={v}")
return args
+
@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
-@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
+@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
-@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
+@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
- dtype=dtype,
- model=model,
- qkv_format=qkv_format,
- kernel_backend='FlashAttention'
+ dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FlashAttention"
),
- check=True
+ check=True,
)
+
model_configs_fused_attn = {
# test: b, h, hg, d, sq, skv, p, mask, bias
- "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
- "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
- "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
- "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
- "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
- "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
- "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
- "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
+ "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA
+ "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA
+ "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA
+ "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA
+ "cp_2_0": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA
+ "cp_2_1": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA
+ "cp_2_2": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA
+ "cp_2_3": ModelConfig(2, 12, 1, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA
}
-@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
+
+@pytest.mark.skipif(_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.")
-@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
+@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
-@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd', 'thd'])
+@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
- dtype=dtype,
- model=model,
- qkv_format=qkv_format,
- kernel_backend='FusedAttention'
+ dtype=dtype, model=model, qkv_format=qkv_format, kernel_backend="FusedAttention"
),
- check=True
+ check=True,
)
diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py
index 71023c32f9ab042a76761c4c5615b20927d5e687..8d3a9dca4f228a64e3b653cd7233f3ff9aa59c23 100644
--- a/tests/pytorch/test_cuda_graphs.py
+++ b/tests/pytorch/test_cuda_graphs.py
@@ -8,8 +8,15 @@ import pytest
import torch
from transformer_engine.pytorch import (
- DotProductAttention, LayerNormLinear, LayerNormMLP, Linear, make_graphed_callables,
- MultiheadAttention, TransformerLayer, fp8_autocast, fp8_model_init,
+ DotProductAttention,
+ LayerNormLinear,
+ LayerNormMLP,
+ Linear,
+ make_graphed_callables,
+ MultiheadAttention,
+ TransformerLayer,
+ fp8_autocast,
+ fp8_model_init,
)
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import is_bf16_compatible
@@ -26,15 +33,18 @@ torch.cuda.manual_seed(seed)
_cpu_rng_state = torch.get_rng_state()
_cuda_rng_state = torch.cuda.get_rng_state()
+
@dataclass
class ModelConfig:
"""Data tensor dimensions within Transformer model"""
+
sequence_length: int
batch_size: int
hidden_size: int
num_heads: int
kv_channels: int
+
model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)}
modules = ["transformer", "layernorm_mlp", "layernorm_linear", "linear", "mha", "dpa"]
@@ -66,7 +76,9 @@ def assert_all_equal(l1: List[torch.Tensor], l2: List[torch.Tensor], names=None)
for i, (t1, t2) in enumerate(zip(l1, l2)):
if not torch.equal(t1, t2):
failed = True
- failed_tensors += f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
+ failed_tensors += (
+ f" {names[i]}\n" if names is not None else f" tensor at idx={i}\n"
+ )
assert not failed, "Output mismatches in:\n" + failed_tensors
@@ -157,41 +169,51 @@ def _test_cuda_graphs(
with fp8_model_init(enabled=fp8_params):
# Create modules.
if module == "transformer":
- modules = [TransformerLayer(
- config.hidden_size,
- config.hidden_size,
- config.num_heads,
- hidden_dropout=0.0,
- attention_dropout=0.0,
- fuse_qkv_params=True,
- params_dtype=dtype,
- ) for _ in range(num_layers)]
+ modules = [
+ TransformerLayer(
+ config.hidden_size,
+ config.hidden_size,
+ config.num_heads,
+ hidden_dropout=0.0,
+ attention_dropout=0.0,
+ fuse_qkv_params=True,
+ params_dtype=dtype,
+ )
+ for _ in range(num_layers)
+ ]
elif module == "layernorm_mlp":
- modules = [LayerNormMLP(
- config.hidden_size, config.hidden_size, params_dtype=dtype
- ) for _ in range(num_layers)]
+ modules = [
+ LayerNormMLP(config.hidden_size, config.hidden_size, params_dtype=dtype)
+ for _ in range(num_layers)
+ ]
elif module == "layernorm_linear":
- modules = [LayerNormLinear(
- config.hidden_size, config.hidden_size, params_dtype=dtype
- ) for _ in range(num_layers)]
+ modules = [
+ LayerNormLinear(config.hidden_size, config.hidden_size, params_dtype=dtype)
+ for _ in range(num_layers)
+ ]
elif module == "mha":
- modules = [MultiheadAttention(
- config.hidden_size,
- config.num_heads,
- attention_dropout=0.0,
- params_dtype=dtype,
- fuse_qkv_params=True,
- ) for _ in range(num_layers)]
+ modules = [
+ MultiheadAttention(
+ config.hidden_size,
+ config.num_heads,
+ attention_dropout=0.0,
+ params_dtype=dtype,
+ fuse_qkv_params=True,
+ )
+ for _ in range(num_layers)
+ ]
elif dpa:
assert config.hidden_size % config.num_heads == 0, "Err."
assert num_layers == 1, "Err."
- modules = [DotProductAttention(
- config.num_heads, config.kv_channels, attention_dropout=0.0
- ) for _ in range(num_layers)]
+ modules = [
+ DotProductAttention(config.num_heads, config.kv_channels, attention_dropout=0.0)
+ for _ in range(num_layers)
+ ]
else:
- modules = [Linear(
- config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype
- ) for _ in range(num_layers)]
+ modules = [
+ Linear(config.hidden_size, config.hidden_size, device="cuda", params_dtype=dtype)
+ for _ in range(num_layers)
+ ]
# Initialize gradient buffers.
for module in modules:
@@ -238,7 +260,7 @@ def _test_cuda_graphs(
with fp8_autocast(enabled=fp8):
kwargs = {}
if fp8_weight_caching:
- kwargs["is_first_microbatch"] = (grad_accumulation_step == 0)
+ kwargs["is_first_microbatch"] = grad_accumulation_step == 0
output = model(*inputs, **kwargs)
output.backward(grad_output)
if not dpa:
diff --git a/tests/pytorch/test_deferred_init.py b/tests/pytorch/test_deferred_init.py
index cbc761a27caa084a06b238ec11798d95969ec917..0469a01c5f109ddf6fd4a2cdc8bd93340418f709 100644
--- a/tests/pytorch/test_deferred_init.py
+++ b/tests/pytorch/test_deferred_init.py
@@ -27,33 +27,31 @@ num_heads = 16
head_dim = 64
dtype = torch.bfloat16
+
class TestDeferredInit:
@staticmethod
def get_module_args(module):
hidden_size = num_heads * head_dim
args = (hidden_size,)
- kwargs = {
- 'params_dtype': dtype,
- 'device': 'meta'
- }
+ kwargs = {"params_dtype": dtype, "device": "meta"}
if module in [te.Linear, te.LayerNormLinear, te.LayerNormMLP]:
ffn_hidden_size = 2 * hidden_size
- args += (ffn_hidden_size, )
- kwargs['bias'] = True
+ args += (ffn_hidden_size,)
+ kwargs["bias"] = True
if module == te.LayerNormMLP:
- kwargs['seq_length'] = seq_length
+ kwargs["seq_length"] = seq_length
elif module == te.MultiheadAttention:
- args += (num_heads, )
- kwargs['fuse_qkv_params'] = True
+ args += (num_heads,)
+ kwargs["fuse_qkv_params"] = True
elif module == te.TransformerLayer:
args += (3 * hidden_size, num_heads)
- kwargs['fuse_qkv_params'] = True
- kwargs['seq_length'] = seq_length
+ kwargs["fuse_qkv_params"] = True
+ kwargs["seq_length"] = seq_length
return args, kwargs
- @pytest.mark.parametrize("module_type", _core_modules+_composed_modules)
+ @pytest.mark.parametrize("module_type", _core_modules + _composed_modules)
def test_zero_memory_init(
self,
module_type: torch.nn.Module,
diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py
index cead7767bb2b4d9d21aa05fe712635c1d43d4383..0ea03197714843e78dc03e4a48d64644fc408d36 100644
--- a/tests/pytorch/test_float8tensor.py
+++ b/tests/pytorch/test_float8tensor.py
@@ -26,6 +26,7 @@ _tols: Dict[tex.DType, Dict[str, float]] = {
tex.DType.kFloat8E5M2: dict(rtol=0.25, atol=0.125), # epsilon = 0.125
}
+
def _to_list(x: Union[Iterable, Any]) -> List:
"""Convert to list if iterable, otherwise put in singleton list"""
if isinstance(x, Iterable):
@@ -33,12 +34,14 @@ def _to_list(x: Union[Iterable, Any]) -> List:
else:
return [x]
+
# Types that can be interpreted as tensor dims
DimsType = Union[Iterable[int], int]
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
+
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFloat8Tensor:
@@ -108,7 +111,7 @@ class TestFloat8Tensor:
def test_quantize_dequantize_scales(self, scale: float) -> None:
self._test_quantize_dequantize(scale=scale)
- @pytest.mark.parametrize("dims", [[], 1, 311, [7,11], [7,5,3], [2,3,5,3]])
+ @pytest.mark.parametrize("dims", [[], 1, 311, [7, 11], [7, 5, 3], [2, 3, 5, 3]])
def test_quantize_dequantize_dims(self, dims: DimsType) -> None:
self._test_quantize_dequantize(dims=dims)
@@ -310,7 +313,7 @@ class TestFloat8Tensor:
def test_serialization(
self,
- dims: DimsType = [2,3,5],
+ dims: DimsType = [2, 3, 5],
fp8_dtype: tex.DType = tex.DType.kFloat8E4M3,
scale: float = 0.5,
dtype: torch.dtype = torch.float32,
diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py
index 63ef4bdc0b236ccf6126f35fb845e901ec471e09..8a50648391b3cfde2daad72bbbe19e6ccb480811 100644
--- a/tests/pytorch/test_fused_optimizer.py
+++ b/tests/pytorch/test_fused_optimizer.py
@@ -117,9 +117,7 @@ class TestFusedAdam(TestFusedOptimizer):
tensors = []
for size in sizes:
tensors.append(torch.rand(size, dtype=torch.float, device="cuda"))
- ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
- tensors, self.options
- )
+ ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(tensors, self.options)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
@@ -139,9 +137,7 @@ class TestFusedAdam(TestFusedOptimizer):
}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
- ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
- [tensor], adam_option
- )
+ ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
for i in range(self.iters):
self.gen_grad(ref_param, tst_param)
@@ -161,9 +157,7 @@ class TestFusedAdam(TestFusedOptimizer):
}
tensor = torch.rand(nelem, dtype=torch.float, device="cuda")
- ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim(
- [tensor], adam_option
- )
+ ref_param, tst_param, ref_optim, tst_optim = self.gen_param_optim([tensor], adam_option)
# Add an empty param group which may occur for pipeline parallel p-tuning
tst_optim.add_param_group({"params": []})
@@ -175,10 +169,11 @@ class TestFusedAdam(TestFusedOptimizer):
torch.testing.assert_close(ref_param, tst_param)
+
class TestFusedSGD(TestFusedOptimizer):
def __init__(self, *args, **kwargs):
super(TestFusedSGD, self).__init__(*args, **kwargs)
- self.options = {"lr": .25, "momentum": .125}
+ self.options = {"lr": 0.25, "momentum": 0.125}
self.ref_optim = torch.optim.SGD
self.fused_optim = te.optimizers.FusedSGD
@@ -188,7 +183,7 @@ class TestFusedSGD(TestFusedOptimizer):
def test_half(self):
self.gen_single_type_test(param_type=torch.float16)
- @unittest.skipIf(torch.cuda.device_count()<2, "more than 1 GPU required")
+ @unittest.skipIf(torch.cuda.device_count() < 2, "more than 1 GPU required")
def test_multi_device(self):
devices = ("cuda:0", "cuda:1")
for current_dev, tensor_dev in product(devices, devices):
@@ -452,8 +447,8 @@ class AdamTest(unittest.TestCase):
@largeTensorTest("60GB", "cuda")
def testLargeTensor(self):
- t = torch.zeros(2359332864, dtype=torch.half, device='cuda')
- t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda')
+ t = torch.zeros(2359332864, dtype=torch.half, device="cuda")
+ t2 = torch.zeros(2359332864, dtype=torch.half, device="cuda")
grad = torch.randn_like(t)
t.grad = grad
t2.grad = grad
diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py
index 0fb7597246239e6bbcd98db88d1db92b05a81709..d6ba66cbbcc20578018f7b6fb2323eaa54362613 100644
--- a/tests/pytorch/test_fused_rope.py
+++ b/tests/pytorch/test_fused_rope.py
@@ -26,10 +26,7 @@ def apply_rotary_pos_emb_thd(
"""
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
return torch.cat(
- [
- apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)])
- for x in torch.split(t, seqlens)
- ]
+ [apply_rotary_pos_emb(x.unsqueeze(1), freqs[: x.size(0)]) for x in torch.split(t, seqlens)]
).squeeze(1)
@@ -45,6 +42,7 @@ def get_tol(dtype: torch.dtype) -> Dict:
def _overlapping_grad(output: torch.Tensor) -> torch.Tensor:
return output.sum() * 2
+
# Gradient is a full tensor
def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor:
t = torch.ones_like(output)
@@ -86,9 +84,7 @@ def test_fused_rope(
emb = rotary_pos_emb(seq_length)
# unfused
- output_unfused = apply_rotary_pos_emb(
- t, emb, tensor_format=tensor_format, fused=False
- )
+ output_unfused = apply_rotary_pos_emb(t, emb, tensor_format=tensor_format, fused=False)
loss_unfused = loss_func(output_unfused)
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
diff --git a/tests/pytorch/test_gqa.py b/tests/pytorch/test_gqa.py
index 91b29c6e266f675756605b8ee9d0c9d573de026d..9f9098891f8dc40924bc2ac938f0c1dbab4aa639 100644
--- a/tests/pytorch/test_gqa.py
+++ b/tests/pytorch/test_gqa.py
@@ -13,23 +13,16 @@ num_heads = 16
head_dim = 64
dtype = torch.bfloat16
num_attn_head = 16
-ffn_hidden_size=1024
+ffn_hidden_size = 1024
+
@pytest.mark.parametrize("kv_channels", [128, 256])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("num_gqa_groups", [1, 2, 4, 8, 16])
-def test_gqa(
- kv_channels,
- hidden_size,
- num_gqa_groups
-) -> None:
-
+def test_gqa(kv_channels, hidden_size, num_gqa_groups) -> None:
+
model = te.TransformerLayer(
- hidden_size,
- ffn_hidden_size,
- num_attn_head,
- num_gqa_groups,
- kv_channels=kv_channels
+ hidden_size, ffn_hidden_size, num_attn_head, num_gqa_groups, kv_channels=kv_channels
)
# Run forward pass
@@ -42,10 +35,9 @@ def test_gqa(
assert model.self_attention.layernorm_qkv.query_weight.shape[0] == kv_channels * num_attn_head
assert model.self_attention.layernorm_qkv.query_weight.shape[1] == hidden_size
-
+
assert model.self_attention.layernorm_qkv.value_weight.shape[0] == kv_channels * num_gqa_groups
assert model.self_attention.layernorm_qkv.value_weight.shape[1] == hidden_size
-
+
assert model.self_attention.proj.weight.shape[0] == hidden_size
assert model.self_attention.proj.weight.shape[1] == kv_channels * num_attn_head
-
diff --git a/tests/pytorch/test_jit.py b/tests/pytorch/test_jit.py
index 24288b5ff485aab9a7a88ee837419e2e185f9308..7d69e0371272c4ee6bf64c17211bde428c580012 100644
--- a/tests/pytorch/test_jit.py
+++ b/tests/pytorch/test_jit.py
@@ -11,11 +11,11 @@ import transformer_engine.pytorch as te
# Model names for test_torch_dynamo
_model_factory = {
- "Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
- "LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
- "LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
- "LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
- "TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
+ "Linear": [(lambda: te.Linear(16, 16)), [16, 16]],
+ "LayerNorm": [(lambda: te.LayerNorm(16)), [16, 16]],
+ "LayerNormLinear": [(lambda: te.LayerNormLinear(16, 16)), [16, 16]],
+ "LayerNormMLP": [(lambda: te.LayerNormMLP(16, 16)), [16, 16]],
+ "TransformerLayer": [(lambda: te.TransformerLayer(128, 128, 2)), [4, 1, 128]],
}
@@ -31,11 +31,11 @@ def test_torch_dynamo(model_name: str):
# Helper function to construct tensor with default options
def make_tensor(
- dims: Tuple[int],
- dtype: torch.dtype = torch.float32,
- device: torch.device = "cuda",
- requires_grad: bool = True,
- **kwargs,
+ dims: Tuple[int],
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = "cuda",
+ requires_grad: bool = True,
+ **kwargs,
):
return torch.zeros(
dims,
diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py
index 5b528b67274ee285348c78f579d1bb4902154007..216b200e092945c6a7534c4958bcbd26c3ca8697 100644
--- a/tests/pytorch/test_multi_tensor.py
+++ b/tests/pytorch/test_multi_tensor.py
@@ -28,9 +28,7 @@ appliers = [MultiTensorApply(2048 * 32), MultiTensorApply(333), MultiTensorApply
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("out_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("inplace", [False, True])
-def test_multi_tensor_scale(
- input_size_pair, applier, repeat, in_type, out_type, inplace
-):
+def test_multi_tensor_scale(input_size_pair, applier, repeat, in_type, out_type, inplace):
if inplace is True and (out_type is not in_type):
pytest.skip("inplace=True and out_type != in_type is not supported.")
elif (in_type == torch.float16 and out_type == torch.bfloat16) or (
@@ -154,9 +152,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
in_list += [a.clone().to(in_type), b.clone().to(in_type)]
if per_tensor:
- norm, norm_per_tensor = applier(
- tex.multi_tensor_l2norm, overflow_buf, [in_list], True
- )
+ norm, norm_per_tensor = applier(tex.multi_tensor_l2norm, overflow_buf, [in_list], True)
normab = torch.cat((a.norm().view(1), b.norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
@@ -168,9 +164,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor:
- torch.testing.assert_close(
- norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)
- )
+ torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
assert overflow_buf.item() == 0
@@ -179,9 +173,7 @@ def test_multi_tensor_l2norm(input_size_pair, applier, repeat, in_type, per_tens
@pytest.mark.parametrize("repeat", [1, 55])
@pytest.mark.parametrize("in_type", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("per_tensor", [False, True])
-def test_multi_tensor_unscale_l2norm(
- input_size_pair, applier, repeat, in_type, per_tensor
-):
+def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type, per_tensor):
sizea, sizeb = input_size_pair
device = torch.device("cuda")
val = 4.0
@@ -205,9 +197,7 @@ def test_multi_tensor_unscale_l2norm(
inv_scale_cuda,
True,
)
- normab = torch.cat(
- ((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1))
- )
+ normab = torch.cat(((a * inv_scale).norm().view(1), (b * inv_scale).norm().view(1)))
norm_per_tensor = norm_per_tensor.view(-1, 2)
else:
norm, _ = applier(
@@ -224,7 +214,5 @@ def test_multi_tensor_unscale_l2norm(
torch.testing.assert_close(norm, reference.broadcast_to(norm.shape))
if per_tensor:
- torch.testing.assert_close(
- norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape)
- )
+ torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
assert overflow_buf.item() == 0
diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py
index 2c9d61b845e3604ce3b3650662a8af03ba5e3cbc..4a3baf8823ca278493b287ad2ad27d9fd95db4b0 100644
--- a/tests/pytorch/test_numerics.py
+++ b/tests/pytorch/test_numerics.py
@@ -20,8 +20,15 @@ from transformer_engine.pytorch.utils import (
is_bf16_compatible,
)
from transformer_engine.pytorch import (
- DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
- MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
+ DotProductAttention,
+ LayerNormLinear,
+ LayerNormMLP,
+ Linear,
+ MultiheadAttention,
+ RMSNorm,
+ TransformerLayer,
+ LayerNorm,
+ InferenceParams,
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
@@ -106,10 +113,11 @@ def assert_allclose(
if not result:
diff = torch.abs(t1 - t2).flatten()
m = torch.argmax(diff)
- msg = (f"Outputs not close enough in tensor at idx={i}. "
- f"Location of the maximum difference: {m.item()} "
- f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
- f"(diff {diff[m].item()})."
+ msg = (
+ f"Outputs not close enough in tensor at idx={i}. "
+ f"Location of the maximum difference: {m.item()} "
+ f"with {t1.flatten()[m].item()} vs {t2.flatten()[m].item()} "
+ f"(diff {diff[m].item()})."
)
raise AssertionError(msg)
@@ -175,9 +183,7 @@ class TorchDotProductAttention(torch.nn.Module):
)
# [sq, b, np, hn] -> [sq, b * np, hn]
- query_layer = query_layer.reshape(
- output_size[2], output_size[0] * output_size[1], -1
- )
+ query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)
@@ -216,14 +222,10 @@ class TorchDotProductAttention(torch.nn.Module):
)
# change view [sk, b * np, hn]
- value_layer = value_layer.reshape(
- value_layer.size(0), output_size[0] * output_size[1], -1
- )
+ value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)
# change view [b * np, sq, sk]
- attention_probs = attention_probs.view(
- output_size[0] * output_size[1], output_size[2], -1
- )
+ attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
# matmul: [b * np, sq, hn]
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
@@ -241,9 +243,7 @@ class TorchDotProductAttention(torch.nn.Module):
class TorchLayerNorm(nn.Module):
- def __init__(self, in_features: int,
- eps: float,
- zero_centered_gamma: bool):
+ def __init__(self, in_features: int, eps: float, zero_centered_gamma: bool):
super().__init__()
self.eps = eps
self.in_features = in_features
@@ -260,10 +260,12 @@ class TorchLayerNorm(nn.Module):
w = w.to(torch.float32)
b = self.bias.to(torch.float32)
inp = x.to(torch.float32)
- out = torch.nn.functional.layer_norm(inp, (self.in_features,), weight=w,
- bias=b, eps=self.eps)
+ out = torch.nn.functional.layer_norm(
+ inp, (self.in_features,), weight=w, bias=b, eps=self.eps
+ )
return out.to(x.dtype)
+
# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
@@ -278,11 +280,11 @@ class TorchRMSNorm(nn.Module):
self.register_parameter("weight", self.weight)
def forward(self, x):
- norm_x2 = torch.sum(x.float()**2, dim=-1, keepdim=True)
+ norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
d_x = self.in_features
rms_x2 = norm_x2 / d_x + self.eps
- r_rms_x = rms_x2 ** (-1. / 2)
+ r_rms_x = rms_x2 ** (-1.0 / 2)
x_normed = x * r_rms_x
w = self.weight.float()
@@ -292,17 +294,24 @@ class TorchRMSNorm(nn.Module):
class TorchLayerNormLinear(nn.Module):
- def __init__(self, in_features: int, out_features: int,
- eps: float, bias: bool = True,
- normalization: str = "LayerNorm",
- zero_centered_gamma: bool = False):
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ eps: float,
+ bias: bool = True,
+ normalization: str = "LayerNorm",
+ zero_centered_gamma: bool = False,
+ ):
super().__init__()
if normalization == "LayerNorm":
- self.layernorm = TorchLayerNorm(in_features, eps=eps,
- zero_centered_gamma=zero_centered_gamma)
+ self.layernorm = TorchLayerNorm(
+ in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
+ )
elif normalization == "RMSNorm":
- self.layernorm = TorchRMSNorm(in_features, eps=eps,
- zero_centered_gamma=zero_centered_gamma)
+ self.layernorm = TorchRMSNorm(
+ in_features, eps=eps, zero_centered_gamma=zero_centered_gamma
+ )
else:
raise RuntimeError("Unsupported normalization")
@@ -329,21 +338,26 @@ class TorchMHA(nn.Module):
output = output[0]
return output
+
class TorchQuickGELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input)
+
class TorchSquaredRELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return (input > 0) * input * input
-_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
- 'gelu' : nn.GELU(approximate="tanh"),
- 'reglu' : nn.ReLU(),
- 'relu' : nn.ReLU(),
- 'swiglu' : nn.SiLU(),
- 'qgelu' : TorchQuickGELU(),
- 'srelu' : TorchSquaredRELU()}
+
+_supported_act = {
+ "geglu": nn.GELU(approximate="tanh"),
+ "gelu": nn.GELU(approximate="tanh"),
+ "reglu": nn.ReLU(),
+ "relu": nn.ReLU(),
+ "swiglu": nn.SiLU(),
+ "qgelu": TorchQuickGELU(),
+ "srelu": TorchSquaredRELU(),
+}
class TorchGLU(nn.Module):
@@ -353,26 +367,29 @@ class TorchGLU(nn.Module):
def forward(self, x):
shape = x.size(-1)
- a = x[..., :shape // 2]
- b = x[..., (shape // 2):]
+ a = x[..., : shape // 2]
+ b = x[..., (shape // 2) :]
a = self.act(a)
return a * b
class TorchLayerNormMLP(nn.Module):
- def __init__(self, hidden_size: int, ffn_hidden_size: int,
- eps: float = 1e-5, activation = 'gelu',
- normalization: str = "LayerNorm"):
+ def __init__(
+ self,
+ hidden_size: int,
+ ffn_hidden_size: int,
+ eps: float = 1e-5,
+ activation="gelu",
+ normalization: str = "LayerNorm",
+ ):
super().__init__()
if normalization == "LayerNorm":
- self.ln = TorchLayerNorm(hidden_size, eps=eps,
- zero_centered_gamma=False)
+ self.ln = TorchLayerNorm(hidden_size, eps=eps, zero_centered_gamma=False)
elif normalization == "RMSNorm":
- self.ln = TorchRMSNorm(hidden_size, eps=eps,
- zero_centered_gamma=False)
+ self.ln = TorchRMSNorm(hidden_size, eps=eps, zero_centered_gamma=False)
else:
raise RuntimeError("Unsupported normalization")
- if 'glu' in activation:
+ if "glu" in activation:
fc1_output_features = 2 * ffn_hidden_size
self.gelu = TorchGLU(activation)
else:
@@ -387,7 +404,9 @@ class TorchLayerNormMLP(nn.Module):
class TorchGPT(nn.Module):
- def __init__(self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool):
+ def __init__(
+ self, hidden_size: int, eps: float, num_attention_heads: int, parallel_attention_mlp: bool
+ ):
super().__init__()
self.ln = nn.LayerNorm(hidden_size, eps=eps)
self.causal_attn = TorchMHA(hidden_size, num_attention_heads)
@@ -411,7 +430,6 @@ class TorchGPT(nn.Module):
return x
-
def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False, recompute=False):
reset_rng_states()
FP8GlobalStateManager.reset()
@@ -421,23 +439,21 @@ def _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params=False
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
with fp8_model_init(enabled=fp8 and fp8_model_params):
- block = (
- TransformerLayer(
- config.hidden_size,
- 4 * config.hidden_size,
- config.num_attention_heads,
- layernorm_epsilon=config.eps,
- init_method=init_method,
- output_layer_init_method=output_layer_init_method,
- hidden_dropout=0.1,
- attention_dropout=0.1,
- kv_channels=config.embed,
- apply_residual_connection_post_layernorm=False,
- output_layernorm=False,
- params_dtype=dtype,
- fuse_qkv_params=True,
- device="cuda",
- )
+ block = TransformerLayer(
+ config.hidden_size,
+ 4 * config.hidden_size,
+ config.num_attention_heads,
+ layernorm_epsilon=config.eps,
+ init_method=init_method,
+ output_layer_init_method=output_layer_init_method,
+ hidden_dropout=0.1,
+ attention_dropout=0.1,
+ kv_channels=config.embed,
+ apply_residual_connection_post_layernorm=False,
+ output_layernorm=False,
+ params_dtype=dtype,
+ fuse_qkv_params=True,
+ device="cuda",
)
te_inp_hidden_states = torch.randn(
@@ -477,8 +493,12 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
config = model_configs[model]
- outputs = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=False)
- outputs_recompute = _test_e2e_selective_recompute(bs, dtype, config, fp8, fp8_model_params, recompute=True)
+ outputs = _test_e2e_selective_recompute(
+ bs, dtype, config, fp8, fp8_model_params, recompute=False
+ )
+ outputs_recompute = _test_e2e_selective_recompute(
+ bs, dtype, config, fp8, fp8_model_params, recompute=True
+ )
# Check that results match
tols = dtype_tols(dtype)
@@ -496,10 +516,7 @@ def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, fp8_model_par
def _test_e2e_full_recompute(
- bs, dtype, config, fp8,
- fp8_model_params=False,
- recompute=False,
- use_reentrant=True
+ bs, dtype, config, fp8, fp8_model_params=False, recompute=False, use_reentrant=True
):
reset_rng_states()
FP8GlobalStateManager.reset()
@@ -586,10 +603,12 @@ def test_gpt_full_activation_recompute(dtype, bs, model, fp8, fp8_model_params,
# Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion
os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0"
- outputs, names = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
- recompute=False, use_reentrant=use_reentrant)
- outputs_recompute, _ = _test_e2e_full_recompute(bs, dtype, config, fp8, fp8_model_params,
- recompute=True, use_reentrant=use_reentrant)
+ outputs, names = _test_e2e_full_recompute(
+ bs, dtype, config, fp8, fp8_model_params, recompute=False, use_reentrant=use_reentrant
+ )
+ outputs_recompute, _ = _test_e2e_full_recompute(
+ bs, dtype, config, fp8, fp8_model_params, recompute=True, use_reentrant=use_reentrant
+ )
if not use_reentrant:
# Reset bias+GELU fusion flag to avoid contaminating other tests
@@ -753,22 +772,19 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config):
def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp):
config = model_configs[model]
- te_gpt = (
- TransformerLayer(
- hidden_size=config.hidden_size,
- ffn_hidden_size=4 * config.hidden_size,
- num_attention_heads=config.num_attention_heads,
- layernorm_epsilon=config.eps,
- attention_dropout=0.1,
- hidden_dropout=0.1,
- params_dtype=dtype,
- fuse_qkv_params=True,
- qkv_weight_interleaved=False,
- parallel_attention_mlp=parallel_attention_mlp,
- device="cuda",
- )
- .eval()
- )
+ te_gpt = TransformerLayer(
+ hidden_size=config.hidden_size,
+ ffn_hidden_size=4 * config.hidden_size,
+ num_attention_heads=config.num_attention_heads,
+ layernorm_epsilon=config.eps,
+ attention_dropout=0.1,
+ hidden_dropout=0.1,
+ params_dtype=dtype,
+ fuse_qkv_params=True,
+ qkv_weight_interleaved=False,
+ parallel_attention_mlp=parallel_attention_mlp,
+ device="cuda",
+ ).eval()
torch_gpt = (
TorchGPT(
@@ -853,18 +869,15 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True):
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]
- te_mha = (
- MultiheadAttention(
- config.hidden_size,
- config.num_attention_heads,
- fuse_qkv_params=True,
- params_dtype=dtype,
- qkv_weight_interleaved=False,
- input_layernorm=False,
- device="cuda",
- )
- .eval()
- )
+ te_mha = MultiheadAttention(
+ config.hidden_size,
+ config.num_attention_heads,
+ fuse_qkv_params=True,
+ params_dtype=dtype,
+ qkv_weight_interleaved=False,
+ input_layernorm=False,
+ device="cuda",
+ ).eval()
torch_mha = (
TorchMHA(
@@ -919,7 +932,9 @@ def _test_granular_accuracy(block, bs, dtype, config):
def _test_dpa_accuracy(block, bs, dtype, config):
reset_rng_states()
- mask = torch.triu(torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1)
+ mask = torch.triu(
+ torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1
+ )
query, key, value = [
torch.randn(
(config.seq_len, bs, config.num_attention_heads, config.embed),
@@ -953,7 +968,7 @@ def test_dpa_accuracy(dtype, bs, model):
DotProductAttention(
config.num_attention_heads,
config.embed,
- attention_dropout=0.0, # disable dropout, FU uses rng differently
+ attention_dropout=0.0, # disable dropout, FU uses rng differently
)
.to(dtype=dtype)
.cuda()
@@ -962,7 +977,7 @@ def test_dpa_accuracy(dtype, bs, model):
torch_dpa = (
TorchDotProductAttention(
config.embed,
- 0.0, # dropout
+ 0.0, # dropout
)
.to(dtype=dtype)
.cuda()
@@ -984,27 +999,21 @@ def test_dpa_accuracy(dtype, bs, model):
def test_linear_accuracy(dtype, bs, model):
config = model_configs[model]
- te_linear = (
- Linear(
- config.hidden_size,
- 4 * config.hidden_size,
- bias=True,
- params_dtype=dtype,
- device="cuda",
- )
- .eval()
- )
+ te_linear = Linear(
+ config.hidden_size,
+ 4 * config.hidden_size,
+ bias=True,
+ params_dtype=dtype,
+ device="cuda",
+ ).eval()
- torch_linear = (
- torch.nn.Linear(
- config.hidden_size,
- 4 * config.hidden_size,
- bias=True,
- device="cuda",
- dtype=dtype,
- )
- .eval()
- )
+ torch_linear = torch.nn.Linear(
+ config.hidden_size,
+ 4 * config.hidden_size,
+ bias=True,
+ device="cuda",
+ dtype=dtype,
+ ).eval()
# Share params
with torch.no_grad():
@@ -1029,23 +1038,16 @@ def test_linear_accuracy(dtype, bs, model):
def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
- te_rmsnorm = (
- RMSNorm(
- config.hidden_size,
- eps=eps,
- params_dtype=dtype,
- zero_centered_gamma=zero_centered_gamma,
- device="cuda",
- )
- .eval()
- )
+ te_rmsnorm = RMSNorm(
+ config.hidden_size,
+ eps=eps,
+ params_dtype=dtype,
+ zero_centered_gamma=zero_centered_gamma,
+ device="cuda",
+ ).eval()
torch_rmsnorm = (
- TorchRMSNorm(
- config.hidden_size,
- eps=eps,
- zero_centered_gamma=zero_centered_gamma
- )
+ TorchRMSNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
.to(dtype=dtype)
.cuda()
.eval()
@@ -1059,12 +1061,14 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
torch_outputs = _test_granular_accuracy(torch_rmsnorm, bs, dtype, config)
# Check output.
- atol = {torch.float32 : 1e-7,
- torch.half : 2e-3,
- torch.bfloat16: 2e-2,
+ atol = {
+ torch.float32: 1e-7,
+ torch.half: 2e-3,
+ torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
+
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@@ -1073,23 +1077,16 @@ def test_rmsnorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
config = model_configs[model]
- te_layernorm = (
- LayerNorm(
- config.hidden_size,
- eps=eps,
- params_dtype=dtype,
- zero_centered_gamma=zero_centered_gamma,
- device="cuda",
- )
- .eval()
- )
+ te_layernorm = LayerNorm(
+ config.hidden_size,
+ eps=eps,
+ params_dtype=dtype,
+ zero_centered_gamma=zero_centered_gamma,
+ device="cuda",
+ ).eval()
torch_layernorm = (
- TorchLayerNorm(
- config.hidden_size,
- eps=eps,
- zero_centered_gamma=zero_centered_gamma
- )
+ TorchLayerNorm(config.hidden_size, eps=eps, zero_centered_gamma=zero_centered_gamma)
.to(dtype=dtype)
.cuda()
.eval()
@@ -1104,9 +1101,10 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
torch_outputs = _test_granular_accuracy(torch_layernorm, bs, dtype, config)
# Check output.
- atol = {torch.float32 : 1e-7,
- torch.half : 2e-3,
- torch.bfloat16: 2e-2,
+ atol = {
+ torch.float32: 1e-7,
+ torch.half: 2e-3,
+ torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@@ -1119,19 +1117,16 @@ def test_layernorm_accuracy(dtype, bs, model, eps, zero_centered_gamma):
def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centered_gamma):
config = model_configs[model]
- te_ln_linear = (
- LayerNormLinear(
- config.hidden_size,
- 4 * config.hidden_size,
- config.eps,
- bias=True,
- normalization=normalization,
- params_dtype=dtype,
- zero_centered_gamma=zero_centered_gamma,
- device="cuda",
- )
- .eval()
- )
+ te_ln_linear = LayerNormLinear(
+ config.hidden_size,
+ 4 * config.hidden_size,
+ config.eps,
+ bias=True,
+ normalization=normalization,
+ params_dtype=dtype,
+ zero_centered_gamma=zero_centered_gamma,
+ device="cuda",
+ ).eval()
torch_ln_linear = (
TorchLayerNormLinear(
@@ -1159,9 +1154,10 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
torch_outputs = _test_granular_accuracy(torch_ln_linear, bs, dtype, config)
# Check output.
- atol = {torch.float32 : 2.5e-4,
- torch.half : 2e-3,
- torch.bfloat16: 2e-2,
+ atol = {
+ torch.float32: 2.5e-4,
+ torch.half: 2e-3,
+ torch.bfloat16: 2e-2,
}
assert_allclose(te_outputs[0], torch_outputs[0], atol[dtype])
@@ -1174,17 +1170,14 @@ def test_layernorm_linear_accuracy(dtype, bs, model, normalization, zero_centere
def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization):
config = model_configs[model]
- te_ln_mlp = (
- LayerNormMLP(
- config.hidden_size,
- 4 * config.hidden_size,
- activation=activation,
- normalization=normalization,
- params_dtype=dtype,
- device="cuda",
- )
- .eval()
- )
+ te_ln_mlp = LayerNormMLP(
+ config.hidden_size,
+ 4 * config.hidden_size,
+ activation=activation,
+ normalization=normalization,
+ params_dtype=dtype,
+ device="cuda",
+ ).eval()
torch_ln_mlp = (
TorchLayerNormMLP(
@@ -1226,8 +1219,10 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph):
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for graph capture.
- static_input = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
- static_target = torch.randn(config.seq_len, bs, config.hidden_size, device='cuda', dtype=dtype)
+ static_input = torch.randn(
+ config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True
+ )
+ static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype)
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
@@ -1286,22 +1281,20 @@ def test_gpt_cuda_graph(dtype, bs, model):
init_method = init_method_normal(sigma)
output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers)
- block = (
- TransformerLayer(
- config.hidden_size,
- 4 * config.hidden_size,
- config.num_attention_heads,
- layernorm_epsilon=config.eps,
- init_method=init_method,
- output_layer_init_method=output_layer_init_method,
- hidden_dropout=0.1,
- attention_dropout=0.1,
- kv_channels=config.embed,
- params_dtype=dtype,
- apply_residual_connection_post_layernorm=False,
- output_layernorm=False,
- device="cuda",
- )
+ block = TransformerLayer(
+ config.hidden_size,
+ 4 * config.hidden_size,
+ config.num_attention_heads,
+ layernorm_epsilon=config.eps,
+ init_method=init_method,
+ output_layer_init_method=output_layer_init_method,
+ hidden_dropout=0.1,
+ attention_dropout=0.1,
+ kv_channels=config.embed,
+ params_dtype=dtype,
+ apply_residual_connection_post_layernorm=False,
+ output_layernorm=False,
+ device="cuda",
)
graphed_block = copy.deepcopy(block)
@@ -1388,7 +1381,6 @@ def test_gpt_fp8_parameters(dtype, bs, model):
)
-
@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@@ -1451,7 +1443,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
requires_grad=True,
)
- x_bshd = x_sbhd.transpose(0,1).contiguous()
+ x_bshd = x_sbhd.transpose(0, 1).contiguous()
# To make sure forward is also identical (just in case some module decides
# to act fancy)
@@ -1466,7 +1458,7 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
# Check that results match
torch.testing.assert_close(
y_bshd,
- y_sbhd.transpose(0,1).contiguous(),
+ y_sbhd.transpose(0, 1).contiguous(),
)
@@ -1500,19 +1492,16 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
S_max = S + 2
if module == "TransformerLayer":
- model = (
- TransformerLayer(
- hidden_size=D,
- ffn_hidden_size= 4 * D,
- num_attention_heads=H,
- attn_input_format=input_format,
- layer_number=layer_number,
- attention_dropout = 0.0,
- params_dtype=dtype,
- device="cuda",
- )
- .eval()
- )
+ model = TransformerLayer(
+ hidden_size=D,
+ ffn_hidden_size=4 * D,
+ num_attention_heads=H,
+ attn_input_format=input_format,
+ layer_number=layer_number,
+ attention_dropout=0.0,
+ params_dtype=dtype,
+ device="cuda",
+ ).eval()
else:
model = (
MultiheadAttention(
@@ -1520,7 +1509,7 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
- attention_dropout = 0.0,
+ attention_dropout=0.0,
params_dtype=dtype,
)
.cuda()
@@ -1537,39 +1526,38 @@ def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module,
incremental_output = torch.zeros_like(input)
# Generate output for the entire sequence
- full_output = model(
- hidden_states=input,
- rotary_pos_emb=rotary_freqs if use_RoPE else None)
+ full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None)
# Incrementaly generate outputs using KV-cache
for i in range(S):
if input_format == "sbhd":
- incremental_input = input[i].view(1,B,D)
+ incremental_input = input[i].view(1, B, D)
else:
- incremental_input = input[:, i, :].view(B,1,D)
+ incremental_input = input[:, i, :].view(B, 1, D)
line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
- rotary_pos_emb=rotary_freqs if use_RoPE else None)
+ rotary_pos_emb=rotary_freqs if use_RoPE else None,
+ )
inference_params.sequence_len_offset += 1
if input_format == "sbhd":
- incremental_output[i] = line_output.view(B,D)
+ incremental_output[i] = line_output.view(B, D)
else:
- incremental_output[:, i, :] = line_output.view(B,D)
+ incremental_output[:, i, :] = line_output.view(B, D)
if module == "TransformerLayer":
atol = {
- torch.float32 : 5e-3,
- torch.half : 5e-3,
+ torch.float32: 5e-3,
+ torch.half: 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
- torch.float32 : 1e-3,
- torch.half : 1e-3,
+ torch.float32: 1e-3,
+ torch.half: 1e-3,
torch.bfloat16: 1e-2,
}
diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py
index 1e9bbea507d8d86dec58a8f3e8cd325d53ee4edf..bdc459cdcc979fdb5daf4391283518aaf4a2e3ca 100644
--- a/tests/pytorch/test_onnx_export.py
+++ b/tests/pytorch/test_onnx_export.py
@@ -32,7 +32,13 @@ from typing import Optional, Union, Tuple, List
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_torch as tex
-from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
+from transformer_engine.pytorch.cpp_extensions import (
+ gemm,
+ fp8_gemm,
+ gelu,
+ cast_to_fp8,
+ cast_from_fp8,
+)
from transformer_engine.pytorch.module.base import get_workspace
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine.pytorch.softmax as softmax_defs
@@ -50,8 +56,10 @@ if SAVE_TEST_IO:
from polygraphy.comparator import RunResults
# The directory where generated ONNX test models are stored.
-NVTE_TEST_ARTIFACTS_DIR = os.environ.get('NVTE_TEST_ARTIFACTS_DIR')
-NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(tempfile.gettempdir(), "./gen_onnx_models")
+NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR")
+NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
+ tempfile.gettempdir(), "./gen_onnx_models"
+)
# The directory where this file is stored.
@@ -100,23 +108,21 @@ def do_export(
model: torch.nn.Module,
inp: torch.Tensor,
fname: str,
- use_fp8: bool=True,
- opset: int=OPSET,
- input_names: List[str]=None,
- output_names: List[str]=None,
- dynamic_axes: List[str]=None
+ use_fp8: bool = True,
+ opset: int = OPSET,
+ input_names: List[str] = None,
+ output_names: List[str] = None,
+ dynamic_axes: List[str] = None,
):
"""Export to ONNX"""
fp8_recipe = create_fp8_recipe()
input_names = input_names or ["input"]
output_names = output_names or ["output"]
- with torch.inference_mode(), te.fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
- warnings.filterwarnings(
- action='ignore',
- category=torch.jit.TracerWarning,
- module=r'.*'
- )
+ with torch.inference_mode(), te.fp8_autocast(
+ enabled=use_fp8, fp8_recipe=fp8_recipe
+ ), warnings.catch_warnings():
+ warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")
model.cuda().eval()
os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
@@ -138,7 +144,8 @@ def do_export(
input_names=input_names,
output_names=output_names,
do_constant_folding=True,
- operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH)
+ operator_export_type=torch.onnx.OperatorExportTypes.ONNX_FALLTHROUGH,
+ )
def to_numpy(tensor):
@@ -154,24 +161,30 @@ def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
NB_SCALES_PER_GEMM = 3 # One scale per: input, weights, and output GEMM tensors.
nb_total_scales = num_gemms * NB_SCALES_PER_GEMM
module.init_fp8_metadata(num_gemms)
- module.fp8_meta["scaling_fwd"].scale = torch.ones(
- nb_total_scales, dtype=torch.float32, device="cuda") / scale
- module.fp8_meta["scaling_fwd"].scale_inv = torch.ones(
- nb_total_scales, dtype=torch.float32, device="cuda") * scale
+ module.fp8_meta["scaling_fwd"].scale = (
+ torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") / scale
+ )
+ module.fp8_meta["scaling_fwd"].scale_inv = (
+ torch.ones(nb_total_scales, dtype=torch.float32, device="cuda") * scale
+ )
def te_infer(model: torch.nn.Module, inps: Union[Tuple[torch.tensor], torch.tensor], is_fp8: bool):
"""Transformer Engine forward propagation."""
fp8_recipe = create_fp8_recipe()
- with torch.inference_mode(), te.fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe), warnings.catch_warnings():
+ with torch.inference_mode(), te.fp8_autocast(
+ enabled=is_fp8, fp8_recipe=fp8_recipe
+ ), warnings.catch_warnings():
te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
if not isinstance(te_outputs, tuple):
te_outputs = (te_outputs,)
return te_outputs
-def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname):
- """ Compare ORT and TE outputs."""
+def compare_outputs(
+ onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
+):
+ """Compare ORT and TE outputs."""
assert len(onnx_outputs) == len(te_outputs)
# Compare ORT and PyTorch outputs.
for onnx_output, te_output in zip(onnx_outputs, te_outputs):
@@ -192,11 +205,15 @@ def compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, al
errors = abs_err[mismatches]
for loc in mismatched_ids[:nb_vals]:
ref = te_output[loc]
- print(f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} > {atol + rtol * abs(ref)}")
+ print(
+ f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >"
+ f" {atol + rtol * abs(ref)}"
+ )
print(f"Max error: {np.max(errors)}")
if nb_errors > allow_cnt_errors:
raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")
+
def serialize_inputs_outputs(
fname: str,
inputs: Union[Tuple[torch.Tensor], torch.Tensor],
@@ -214,10 +231,10 @@ def serialize_inputs_outputs(
inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
named_inputs = zip(input_names, inputs)
input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}]
- json_fname = fname[:-len(".onnx")] + "_inputs.json"
+ json_fname = fname[: -len(".onnx")] + "_inputs.json"
save_json(input_data, json_fname, description="custom input data")
- json_fname = fname[:-len(".onnx")] + "_output.json"
+ json_fname = fname[: -len(".onnx")] + "_output.json"
named_outputs = zip(output_names, te_outputs)
output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None}
custom_outputs = RunResults()
@@ -229,14 +246,14 @@ def validate_result(
fname: str,
inps: Union[Tuple[torch.Tensor], torch.Tensor],
model: torch.nn.Module,
- atol: float=1.e-8, # np.isclose default atol
- rtol: float=1.e-5, # np.isclose default rtol
- max_errors_printed: int=10,
- is_fp8: bool=False,
- allow_cnt_errors: int=0,
- input_names: List[str]=None,
- output_names: List[str]=None,
- te_outputs: List[torch.Tensor]=None,
+ atol: float = 1.0e-8, # np.isclose default atol
+ rtol: float = 1.0e-5, # np.isclose default rtol
+ max_errors_printed: int = 10,
+ is_fp8: bool = False,
+ allow_cnt_errors: int = 0,
+ input_names: List[str] = None,
+ output_names: List[str] = None,
+ te_outputs: List[torch.Tensor] = None,
):
"""Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
representation using ONNX Runtime (ORT) and ensure they are close.
@@ -262,7 +279,7 @@ def validate_result(
print("registered custom FP8 Q/DQ ops!")
"""Create an ONNX Runtime session for validation."""
- kwargs = {"providers": ['CUDAExecutionProvider', 'CPUExecutionProvider']}
+ kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]}
if is_fp8:
sess_options = ort.SessionOptions()
load_custom_ops(sess_options)
@@ -288,10 +305,12 @@ def validate_result(
ort_s = create_ort_session(fname, is_fp8)
input_feed = create_ort_input_dict(ort_s, inps)
onnx_outputs = ort_s.run(None, input_feed=input_feed)
- compare_outputs(onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname)
+ compare_outputs(
+ onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
+ )
-def create_meta(scale_factor: float, size: int=1):
+def create_meta(scale_factor: float, size: int = 1):
meta = tex.FP8TensorMeta()
meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor
@@ -324,7 +343,9 @@ def get_attn_mask_str(use_mask, attn_mask_type):
return "_mask" if use_mask else "_no-mask"
attn_mask_str = "_arbitrary-no-mask"
attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
- attn_mask_str = "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
+ attn_mask_str = (
+ "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
+ )
return attn_mask_str
@@ -351,17 +372,11 @@ class FP8GemmModule(nn.Module):
self.outp_type = precision
def forward(self, inp, weight):
- inp_fp8 = cast_to_fp8(
- inp,
- self.meta_inp,
- self.fp8_tensor_inp,
- self.inp_type)
+ inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type)
weight_fp8 = cast_to_fp8(
- weight,
- self.meta_weight,
- self.fp8_tensor_weight,
- self.weights_type)
+ weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type
+ )
ret, _ = fp8_gemm(
weight_fp8,
@@ -376,9 +391,11 @@ class FP8GemmModule(nn.Module):
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
- use_split_accumulator=False)
+ use_split_accumulator=False,
+ )
return ret
+
"""
Tests cases begin here.
"""
@@ -387,13 +404,17 @@ Tests cases begin here.
@skip_FP8
@pytest.mark.parametrize("scale_factor", [1, 224])
@pytest.mark.parametrize(
- "precision, atol", [
- [torch.float32, 1e-7],
- [torch.float16, 1e-7],
- [torch.bfloat16, 5e-3],
- ["fake-torch.bfloat16", 5e-3],
-])
-def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype):
+ "precision, atol",
+ [
+ [torch.float32, 1e-7],
+ [torch.float16, 1e-7],
+ [torch.bfloat16, 5e-3],
+ ["fake-torch.bfloat16", 5e-3],
+ ],
+)
+def test_export_cast_ops(
+ seed_default_rng, scale_factor: float, atol: float, precision: torch.dtype
+):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
precision = torch.bfloat16 if precision == "fake-torch.bfloat16" else precision
@@ -408,18 +429,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
self.fake_bf16_io = fake_bf16_io
def forward(self, inp):
- ret = cast_to_fp8(
- inp,
- self.meta,
- self.fp8_tensor,
- self.fp8_type)
+ ret = cast_to_fp8(inp, self.meta, self.fp8_tensor, self.fp8_type)
- ret = cast_from_fp8(
- ret,
- self.meta,
- self.fp8_tensor,
- self.fp8_type,
- self.highprec_type)
+ ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret
@@ -427,8 +439,9 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
- inp = torch.randn(hidden_size, in_features, device="cuda",
- dtype=torch.float if fake_bf16_io else precision)
+ inp = torch.randn(
+ hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision
+ )
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"te.cast_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_QDQ(fake_bf16_io)
@@ -439,15 +452,18 @@ def test_export_cast_ops(seed_default_rng, scale_factor: float, atol: float, pre
if fake_bf16_io or precision != torch.bfloat16:
validate_result(fname, inp, model, atol=atol, is_fp8=True, te_outputs=te_outputs)
+
@skip_FP8
@pytest.mark.parametrize("scale_factor", [448])
@pytest.mark.parametrize(
- "precision, atol", [
- [torch.float32, 1e-5],
- [torch.float16, 1e-5],
- [torch.bfloat16, 5e-3],
- ["fake-torch.bfloat16", 5e-3]
-])
+ "precision, atol",
+ [
+ [torch.float32, 1e-5],
+ [torch.float16, 1e-5],
+ [torch.bfloat16, 5e-3],
+ ["fake-torch.bfloat16", 5e-3],
+ ],
+)
def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: float):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
@@ -463,17 +479,8 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
self.fake_bf16_io = fake_bf16_io
def forward(self, inp):
- ret = gelu(
- inp,
- self.meta,
- self.fp8_tensor,
- self.fp8_type)
- ret = cast_from_fp8(
- ret,
- self.meta,
- self.fp8_tensor,
- self.fp8_type,
- self.highprec_type)
+ ret = gelu(inp, self.meta, self.fp8_tensor, self.fp8_type)
+ ret = cast_from_fp8(ret, self.meta, self.fp8_tensor, self.fp8_type, self.highprec_type)
if self.fake_bf16_io:
ret = ret.type(torch.float32)
return ret
@@ -481,8 +488,9 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
# Set dimensions (these are arbitrary).
in_features = 64
hidden_size = 256
- inp = torch.randn(hidden_size, in_features, device="cuda",
- dtype=torch.float if fake_bf16_io else precision)
+ inp = torch.randn(
+ hidden_size, in_features, device="cuda", dtype=torch.float if fake_bf16_io else precision
+ )
high_prec_str = dtype2str(precision, fake_bf16_io=fake_bf16_io)
fname = f"te.gelu_fp8_{scale_factor}{high_prec_str}.onnx"
model = TestFP8_Gelu(fake_bf16_io)
@@ -490,39 +498,55 @@ def test_export_gelu_fp8(scale_factor: float, precision: torch.dtype, atol: floa
te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
- validate_result(fname, inp, model, rtol=0, atol=atol, is_fp8=True, allow_cnt_errors=2, te_outputs=te_outputs)
+ validate_result(
+ fname,
+ inp,
+ model,
+ rtol=0,
+ atol=atol,
+ is_fp8=True,
+ allow_cnt_errors=2,
+ te_outputs=te_outputs,
+ )
-@pytest.mark.parametrize("scale_factors",
- [(224, 224,),
-])
@pytest.mark.parametrize(
- "precision, use_fp8, use_bias, use_gelu", [
- (torch.float32, False, False, False),
- (torch.float16, False, False, False),
- (torch.bfloat16, False, False, False),
- (torch.float32, False, True, False),
- (torch.float16, False, True, False),
- (torch.bfloat16, False, True, False),
- (torch.float32, False, True, True),
- (torch.float16, False, True, True),
- (torch.bfloat16, False, True, True),
-
- # For FP8 GEMM GeLU is not used.
- (torch.float32, True, False, False),
- (torch.float16, True, False, False),
- (torch.bfloat16, True, False, False),
- # When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
- (torch.float16, True, True, False),
- (torch.bfloat16, True, True, False),
-])
+ "scale_factors",
+ [
+ (
+ 224,
+ 224,
+ ),
+ ],
+)
+@pytest.mark.parametrize(
+ "precision, use_fp8, use_bias, use_gelu",
+ [
+ (torch.float32, False, False, False),
+ (torch.float16, False, False, False),
+ (torch.bfloat16, False, False, False),
+ (torch.float32, False, True, False),
+ (torch.float16, False, True, False),
+ (torch.bfloat16, False, True, False),
+ (torch.float32, False, True, True),
+ (torch.float16, False, True, True),
+ (torch.bfloat16, False, True, True),
+ # For FP8 GEMM GeLU is not used.
+ (torch.float32, True, False, False),
+ (torch.float16, True, False, False),
+ (torch.bfloat16, True, False, False),
+ # When enabling bias we must use float16 or bfloat16 (because of kernel limitations)
+ (torch.float16, True, True, False),
+ (torch.bfloat16, True, True, False),
+ ],
+)
def test_export_gemm(
seed_default_rng,
- precision, # Precision of inputs, weights, output and bias
+ precision, # Precision of inputs, weights, output and bias
use_fp8,
use_bias,
use_gelu,
- scale_factors
+ scale_factors,
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
@@ -548,21 +572,20 @@ def test_export_gemm(
inp,
outp_type,
get_workspace(),
-
# test bias
bias=self.bias,
use_bias=self.use_bias,
-
# test gelu
gelu=self.gelu,
gelu_input=self.gelu_input,
- grad=False, # only True for backward pass
+ grad=False, # only True for backward pass
accumulate=False,
)
return ret
# If gelu is applied then bias must be added, as defined by TE kernel.
- if use_gelu: assert use_bias
+ if use_gelu:
+ assert use_bias
# Set dimensions (these are arbitrary).
out_features = 128
hidden_size = 256
@@ -574,45 +597,64 @@ def test_export_gemm(
gelu_str = "_gelu" if use_gelu else ""
high_prec_str = dtype2str(precision)
fname = f"te.gemm{fp8_str}{bias_str}{gelu_str}{high_prec_str}.onnx"
- input_names = ['input', 'weight']
+ input_names = ["input", "weight"]
if use_fp8:
- model = FP8GemmModule(precision, use_bias, use_gelu, scale_factors, hidden_size, out_features)
+ model = FP8GemmModule(
+ precision, use_bias, use_gelu, scale_factors, hidden_size, out_features
+ )
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision != torch.bfloat16:
- validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2,
- is_fp8=True, input_names=input_names, te_outputs=te_outputs)
+ validate_result(
+ fname,
+ (inp, weight),
+ model,
+ rtol=1e-2,
+ atol=2e-2,
+ is_fp8=True,
+ input_names=input_names,
+ te_outputs=te_outputs,
+ )
else:
model = Test_GEMM(precision, use_bias, use_gelu)
do_export(model, (inp, weight), fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
if precision != torch.bfloat16:
- validate_result(fname, (inp, weight), model, rtol=1e-2, atol=2e-2,
- input_names=input_names, te_outputs=te_outputs)
+ validate_result(
+ fname,
+ (inp, weight),
+ model,
+ rtol=1e-2,
+ atol=2e-2,
+ input_names=input_names,
+ te_outputs=te_outputs,
+ )
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize(
- "use_fp8, precision, atol", [
- [False, torch.float32, 1e-7],
- [False, torch.float16, 1e-7],
- [False, torch.bfloat16, 1e-7],
- [False, "fake-torch.bfloat16", 1e-7],
- [True, torch.float32, 1e-7],
- [True, torch.float16, 1e-7],
- [True, torch.bfloat16, 1e-2],
- [True, "fake-torch.bfloat16", 1e-2]
-])
+ "use_fp8, precision, atol",
+ [
+ [False, torch.float32, 1e-7],
+ [False, torch.float16, 1e-7],
+ [False, torch.bfloat16, 1e-7],
+ [False, "fake-torch.bfloat16", 1e-7],
+ [True, torch.float32, 1e-7],
+ [True, torch.float16, 1e-7],
+ [True, torch.bfloat16, 1e-2],
+ [True, "fake-torch.bfloat16", 1e-2],
+ ],
+)
def test_export_layernorm(
seed_default_rng,
use_fp8: bool,
scale_factor: float,
precision: torch.dtype,
zero_centered_gamma: bool,
- atol: float
+ atol: float,
):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
@@ -628,10 +670,15 @@ def test_export_layernorm(
class Test_Layernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
- eps = 1e-6 # An arbitrary small value
+ eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision
- self.ln = te.LayerNorm(inp_shape[1], eps, params_dtype=dtype,
- zero_centered_gamma=zero_centered_gamma).eval().cuda()
+ self.ln = (
+ te.LayerNorm(
+ inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma
+ )
+ .eval()
+ .cuda()
+ )
def forward(self, inp):
ret = self.ln(inp)
@@ -641,11 +688,13 @@ def test_export_layernorm(
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
- self.weight = torch.randn(*normalized_shape, device="cuda",
- dtype=torch.float32 if fake_bf16_io else precision)
- self.bias = torch.zeros(*normalized_shape, device="cuda",
- dtype=torch.float32 if fake_bf16_io else precision)
- self.eps = 1e-6 # An arbitrary small value
+ self.weight = torch.randn(
+ *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
+ )
+ self.bias = torch.zeros(
+ *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
+ )
+ self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor)
@@ -661,14 +710,12 @@ def test_export_layernorm(
self.fp8_tensor,
self.fp8_type,
0,
- zero_centered_gamma)
+ zero_centered_gamma,
+ )
ret = cast_from_fp8(
- ret,
- self.meta,
- self.fp8_tensor,
- self.fp8_type,
- as_te_type(precision))
+ ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision)
+ )
if fake_bf16_io:
ret = ret.type(torch.float32)
return ret
@@ -683,28 +730,32 @@ def test_export_layernorm(
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(
- fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
+ fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs
+ )
+
@pytest.mark.parametrize("scale_factor", [448, 112])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize(
- "use_fp8, precision, atol", [
- [False, torch.float32, 1e-7],
- [False, torch.float16, 1e-7],
- [False, torch.bfloat16, 1e-7],
- [False, "fake-torch.bfloat16", 1e-7],
- [True, torch.float32, 1e-7],
- [True, torch.float16, 1e-7],
- [True, torch.bfloat16, 1e-2],
- [True, "fake-torch.bfloat16", 1e-2]
-])
+ "use_fp8, precision, atol",
+ [
+ [False, torch.float32, 1e-7],
+ [False, torch.float16, 1e-7],
+ [False, torch.bfloat16, 1e-7],
+ [False, "fake-torch.bfloat16", 1e-7],
+ [True, torch.float32, 1e-7],
+ [True, torch.float16, 1e-7],
+ [True, torch.bfloat16, 1e-2],
+ [True, "fake-torch.bfloat16", 1e-2],
+ ],
+)
def test_export_rmsnorm(
seed_default_rng,
use_fp8: bool,
scale_factor: float,
precision: torch.dtype,
zero_centered_gamma: bool,
- atol: float
+ atol: float,
):
fake_bf16_io = precision == "fake-torch.bfloat16"
# reset precision to torch.bfloat16 after capturing fake BF16 mode
@@ -720,10 +771,15 @@ def test_export_rmsnorm(
class Test_RMSnorm(nn.Module):
def __init__(self) -> None:
super().__init__()
- eps = 1e-6 # An arbitrary small value
+ eps = 1e-6 # An arbitrary small value
dtype = torch.float if fake_bf16_io else precision
- self.ln = te.RMSNorm(inp_shape[1], eps, params_dtype=dtype,
- zero_centered_gamma=zero_centered_gamma).eval().cuda()
+ self.ln = (
+ te.RMSNorm(
+ inp_shape[1], eps, params_dtype=dtype, zero_centered_gamma=zero_centered_gamma
+ )
+ .eval()
+ .cuda()
+ )
def forward(self, inp):
ret = self.ln(inp)
@@ -733,9 +789,10 @@ def test_export_rmsnorm(
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
- self.weight = torch.randn(*normalized_shape, device="cuda",
- dtype=torch.float32 if fake_bf16_io else precision)
- self.eps = 1e-6 # An arbitrary small value
+ self.weight = torch.randn(
+ *normalized_shape, device="cuda", dtype=torch.float32 if fake_bf16_io else precision
+ )
+ self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(scale_factor)
@@ -750,14 +807,12 @@ def test_export_rmsnorm(
self.fp8_tensor,
self.fp8_type,
0,
- zero_centered_gamma)
+ zero_centered_gamma,
+ )
ret = cast_from_fp8(
- ret,
- self.meta,
- self.fp8_tensor,
- self.fp8_type,
- as_te_type(precision))
+ ret, self.meta, self.fp8_tensor, self.fp8_type, as_te_type(precision)
+ )
if fake_bf16_io:
ret = ret.type(torch.float32)
return ret
@@ -772,7 +827,8 @@ def test_export_rmsnorm(
serialize_inputs_outputs(fname, inp, te_outputs)
if fake_bf16_io or precision != torch.bfloat16:
validate_result(
- fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs)
+ fname, inp, model, atol=atol, is_fp8=use_fp8, allow_cnt_errors=3, te_outputs=te_outputs
+ )
@pytest.mark.parametrize("scale_factor", [1])
@@ -780,23 +836,25 @@ def test_export_rmsnorm(
# Returning the bias is a TE fusion optimization we don't care about.
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize(
- "precision, use_bias",[
- (torch.float32, False),
- (torch.float32, True),
- (torch.float16, False),
- (torch.float16, True),
- # Todo: cannot configure BF16 when bias is disabled (ORT issue?)
- (torch.bfloat16, False),
- # Todo: cannot configure BF16 when bias is enabled (ORT issue?)
- (torch.bfloat16, True),
-])
+ "precision, use_bias",
+ [
+ (torch.float32, False),
+ (torch.float32, True),
+ (torch.float16, False),
+ (torch.float16, True),
+ # Todo: cannot configure BF16 when bias is disabled (ORT issue?)
+ (torch.bfloat16, False),
+ # Todo: cannot configure BF16 when bias is enabled (ORT issue?)
+ (torch.bfloat16, True),
+ ],
+)
def test_export_linear(
seed_default_rng,
scale_factor: float,
use_fp8: bool,
use_bias: bool,
return_bias: bool,
- precision: torch.dtype
+ precision: torch.dtype,
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
@@ -808,20 +866,14 @@ def test_export_linear(
hidden_size = 256
class Test_Linear(nn.Module):
- def __init__(self,
- in_features,
- out_features,
- use_bias,
- return_bias,
- precision
- ):
+ def __init__(self, in_features, out_features, use_bias, return_bias, precision):
super().__init__()
self.linear = te.Linear(
in_features,
out_features,
bias=use_bias,
return_bias=return_bias,
- params_dtype=precision
+ params_dtype=precision,
)
def forward(self, inp):
@@ -834,20 +886,16 @@ def test_export_linear(
high_prec_str = dtype2str(precision)
fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
with te.fp8_autocast(enabled=use_fp8):
- model = Test_Linear(
- in_features,
- out_features,
- use_bias,
- return_bias,
- precision
- ).to(device='cuda')
+ model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
+ device="cuda"
+ )
if use_fp8:
set_layer_scale(model.linear, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
- if precision in (torch.bfloat16, ):
+ if precision in (torch.bfloat16,):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
@@ -861,14 +909,16 @@ def test_export_linear(
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
- "precision, use_bias",[
- (torch.float32, False),
- (torch.float32, True),
- (torch.float16, True),
- (torch.float16, False),
- (torch.bfloat16, True),
- (torch.bfloat16, False),
-])
+ "precision, use_bias",
+ [
+ (torch.float32, False),
+ (torch.float32, True),
+ (torch.float16, True),
+ (torch.float16, False),
+ (torch.bfloat16, True),
+ (torch.bfloat16, False),
+ ],
+)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_linear(
@@ -907,13 +957,13 @@ def test_export_layernorm_linear(
params_dtype=precision,
zero_centered_gamma=zero_centered_gamma,
normalization=normalization,
- ).to(device='cuda')
+ ).to(device="cuda")
if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=1)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
- if precision in (torch.bfloat16, ):
+ if precision in (torch.bfloat16,):
return
if not use_fp8:
validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
@@ -927,14 +977,16 @@ def test_export_layernorm_linear(
@pytest.mark.parametrize("return_bias", [False])
@pytest.mark.parametrize("return_layernorm_output", [False])
@pytest.mark.parametrize(
- "precision, use_bias",[
- (torch.float32, False),
- (torch.float32, True),
- (torch.float16, True),
- (torch.float16, False),
- (torch.bfloat16, True),
- (torch.bfloat16, False),
-])
+ "precision, use_bias",
+ [
+ (torch.float32, False),
+ (torch.float32, True),
+ (torch.float16, True),
+ (torch.float16, False),
+ (torch.bfloat16, True),
+ (torch.bfloat16, False),
+ ],
+)
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@pytest.mark.parametrize("activation", supported_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
@@ -954,7 +1006,6 @@ def test_export_layernorm_mlp(
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
-
# Set dimensions (these are arbitrary).
in_features = 64
out_features = 256
@@ -977,30 +1028,32 @@ def test_export_layernorm_mlp(
zero_centered_gamma=zero_centered_gamma,
activation=activation,
normalization=normalization,
- ).to(device='cuda')
+ ).to(device="cuda")
if use_fp8:
set_layer_scale(model, scale_factor, num_gemms=2)
do_export(model, inp, fname, use_fp8)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs)
- if precision in (torch.bfloat16, ):
+ if precision in (torch.bfloat16,):
return
- atol = 1e-6 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3)
+ atol = 1e-6 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3)
validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, te_outputs=te_outputs)
@skip_FP8
@pytest.mark.parametrize(
- "precision, use_mask, attn_mask_type", [
- (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
- (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
- (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
- (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
- (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
- (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
- (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
- (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
-])
+ "precision, use_mask, attn_mask_type",
+ [
+ (torch.float32, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
+ (torch.float32, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
+ (torch.float16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
+ (torch.float16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
+ (torch.float16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
+ (torch.bfloat16, False, "causal"), # calls forward_torch_softmax (apply dynamic onnx mask)
+ (torch.bfloat16, True, "arbitrary"), # calls forward_torch_softmax (apply user mask)
+ (torch.bfloat16, False, "no_mask"), # calls forward_torch_softmax (apply no mask)
+ ],
+)
def test_export_core_attention(
seed_default_rng,
set_max_seq_len,
@@ -1034,40 +1087,42 @@ def test_export_core_attention(
attention_dropout=0.5,
qkv_format=qkv_format,
attn_mask_type=attn_mask_type,
- ).to(device='cuda')
- do_export(model,
- inp,
- fname,
- input_names=input_names,
- use_fp8=True)
+ ).to(device="cuda")
+ do_export(model, inp, fname, input_names=input_names, use_fp8=True)
te_outputs = te_infer(model, inp, is_fp8=True)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
- if precision in (torch.bfloat16, ):
+ if precision in (torch.bfloat16,):
return
- validate_result(fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs)
+ validate_result(
+ fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
+ )
test_configs_multihead_attention = [
- #"use_mask, attn_mask_type"
- (False, "no_mask"), # calls ScaledSoftmax
- (True, "arbitrary"), # calls ScaledMaskedSoftmax
+ # "use_mask, attn_mask_type"
+ (False, "no_mask"), # calls ScaledSoftmax
+ (True, "arbitrary"), # calls ScaledMaskedSoftmax
]
test_configs_attention_type = [
- #"input_layernorm, attention_type, fuse_qkv_params"
- (True, "self", True),
- (False, "self", True),
- (True, "self", False),
- (False, "self", False),
- (True, "cross", True),
- (False, "cross", True),
- (True, "cross", False),
- (False, "cross", False),
+ # "input_layernorm, attention_type, fuse_qkv_params"
+ (True, "self", True),
+ (False, "self", True),
+ (True, "self", False),
+ (False, "self", False),
+ (True, "cross", True),
+ (False, "cross", True),
+ (True, "cross", False),
+ (False, "cross", False),
]
+
+
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("return_layernorm_output", [False])
-@pytest.mark.parametrize("input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type)
+@pytest.mark.parametrize(
+ "input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type
+)
def test_export_multihead_attention(
seed_default_rng,
set_max_seq_len,
@@ -1078,7 +1133,7 @@ def test_export_multihead_attention(
return_layernorm_output: bool,
input_layernorm: bool,
attention_type: str,
- fuse_qkv_params: bool
+ fuse_qkv_params: bool,
):
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
@@ -1102,17 +1157,23 @@ def test_export_multihead_attention(
output_layer_init_method,
)
- hidden_states_context = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
+ hidden_states_context = torch.randn(
+ sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
+ )
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
- probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
+ probs = 0.5 * torch.ones(
+ batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
+ )
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
encoder_output = None
if attention_type == "cross":
- encoder_output = torch.randn(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
+ encoder_output = torch.randn(
+ sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
+ )
fp8_str = "_fp8" if use_fp8 else ""
dtype_str = dtype2str(precision)
@@ -1131,49 +1192,98 @@ def test_export_multihead_attention(
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
return_bias=True,
- ).to(device='cuda')
+ ).to(device="cuda")
inp_context = (hidden_states_context, attention_mask, encoder_output)
input_names = ["hidden_states", "attention_mask", "encoder_output"]
- output_names=["attention_output", "attention_bias"]
- do_export(model, inp_context, fname, use_fp8, input_names=input_names, output_names=output_names,
- dynamic_axes={"hidden_states": {0: "seq", 1:"bs"},
- "attention_output": {0: "seq", 1:"bs"}})
+ output_names = ["attention_output", "attention_bias"]
+ do_export(
+ model,
+ inp_context,
+ fname,
+ use_fp8,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes={
+ "hidden_states": {0: "seq", 1: "bs"},
+ "attention_output": {0: "seq", 1: "bs"},
+ },
+ )
te_outputs = te_infer(model, inp_context, is_fp8=use_fp8)
- serialize_inputs_outputs(fname, inp_context, te_outputs, input_names=input_names, output_names=output_names)
- if precision in (torch.bfloat16, ):
+ serialize_inputs_outputs(
+ fname, inp_context, te_outputs, input_names=input_names, output_names=output_names
+ )
+ if precision in (torch.bfloat16,):
return
if not use_fp8:
- validate_result(fname, inp_context, model, atol=1e-3, input_names=input_names,
- output_names=output_names, te_outputs=te_outputs)
+ validate_result(
+ fname,
+ inp_context,
+ model,
+ atol=1e-3,
+ input_names=input_names,
+ output_names=output_names,
+ te_outputs=te_outputs,
+ )
else:
- validate_result(fname, inp_context, model, atol=1e-2, is_fp8=use_fp8,
- input_names=input_names, output_names=output_names, allow_cnt_errors=3,
- te_outputs=te_outputs)
+ validate_result(
+ fname,
+ inp_context,
+ model,
+ atol=1e-2,
+ is_fp8=use_fp8,
+ input_names=input_names,
+ output_names=output_names,
+ allow_cnt_errors=3,
+ te_outputs=te_outputs,
+ )
# In GPT generative phase (inference) the input sequence is smaller than the maximum
# allowed sequence length and we want to test this condition.
# Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
- is_generative_phase = (attn_mask_type == "causal" and attention_type == "self")
+ is_generative_phase = attn_mask_type == "causal" and attention_type == "self"
if is_generative_phase:
seq_len_offset = 8
- hidden_states_generative = torch.randn(sequence_length-seq_len_offset, batch_size, hidden_size, dtype=precision, device="cuda")
+ hidden_states_generative = torch.randn(
+ sequence_length - seq_len_offset,
+ batch_size,
+ hidden_size,
+ dtype=precision,
+ device="cuda",
+ )
inp_generative = (hidden_states_generative, attention_mask, encoder_output)
if not use_fp8:
- validate_result(fname, inp_generative, model, atol=1e-3, input_names=input_names, output_names=output_names)
+ validate_result(
+ fname,
+ inp_generative,
+ model,
+ atol=1e-3,
+ input_names=input_names,
+ output_names=output_names,
+ )
else:
- validate_result(fname, inp_generative, model, atol=1e-2, is_fp8=use_fp8,
- input_names=input_names, output_names=output_names, allow_cnt_errors=3)
-
+ validate_result(
+ fname,
+ inp_generative,
+ model,
+ atol=1e-2,
+ is_fp8=use_fp8,
+ input_names=input_names,
+ output_names=output_names,
+ allow_cnt_errors=3,
+ )
@pytest.mark.parametrize("use_fp8", [False, True])
@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention)
-@pytest.mark.parametrize("output_layernorm", [
- #True, # TO DO: handle this
- False
-])
+@pytest.mark.parametrize(
+ "output_layernorm",
+ [
+ # True, # TO DO: handle this
+ False
+ ],
+)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("fuse_qkv_params", [False, True])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
@@ -1201,12 +1311,16 @@ def test_export_transformer_layer(
ffn_hidden_size = 256
num_attention_heads = 4
- input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
+ input_tensor = torch.rand(
+ sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
+ )
input_names = ["input", "attention_mask"]
attention_mask = None
if use_mask and attn_mask_type != "causal":
# Generate a random mask with 50% probability for 0 or 1.
- probs = 0.5 * torch.ones(batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision)
+ probs = 0.5 * torch.ones(
+ batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
+ )
attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
inp = (input_tensor, attention_mask)
@@ -1225,19 +1339,30 @@ def test_export_transformer_layer(
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
zero_centered_gamma=zero_centered_gamma,
- activation=activation).to(device='cuda')
+ activation=activation,
+ ).to(device="cuda")
do_export(model, inp, fname, use_fp8, input_names=input_names)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
- if precision in (torch.bfloat16, ):
+ if precision in (torch.bfloat16,):
return
- atol = 5e-1 if use_fp8 else (5e-1 if activation=="swiglu" else 1e-3)
- validate_result(fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs)
+ atol = 5e-1 if use_fp8 else (5e-1 if activation == "swiglu" else 1e-3)
+ validate_result(
+ fname, inp, model, atol=atol, is_fp8=use_fp8, input_names=input_names, te_outputs=te_outputs
+ )
@pytest.mark.parametrize("use_fp8", [True])
-@pytest.mark.parametrize("ln_scale_factor", [448*2])
-@pytest.mark.parametrize("gemm_scale_factors", [(224, 224,),])
+@pytest.mark.parametrize("ln_scale_factor", [448 * 2])
+@pytest.mark.parametrize(
+ "gemm_scale_factors",
+ [
+ (
+ 224,
+ 224,
+ ),
+ ],
+)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("zero_centered_gamma", [False, True])
def test_export_gemm_layernorm(
@@ -1246,7 +1371,7 @@ def test_export_gemm_layernorm(
ln_scale_factor: float,
gemm_scale_factors: Tuple[float, float],
precision: torch.dtype,
- zero_centered_gamma: bool
+ zero_centered_gamma: bool,
):
"""This is a regression test for testing that all LN inputs have the same type.
@@ -1260,20 +1385,26 @@ def test_export_gemm_layernorm(
# Skip FP8 tests on non-hopper devices
if use_fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
+
class TestFP8_GemmLayernorm(nn.Module):
def __init__(self) -> None:
super().__init__()
normalized_shape = torch.Size(inp.shape[1:])
self.weight = torch.randn(*normalized_shape, dtype=precision, device="cuda")
self.bias = torch.zeros(*normalized_shape, dtype=precision, device="cuda")
- self.eps = 1e-6 # An arbitrary small value
+ self.eps = 1e-6 # An arbitrary small value
self.fp8_tensor = tex.FP8FwdTensors.GEMM1_INPUT
self.meta = create_meta(ln_scale_factor)
self.fp8_type = tex.DType.kFloat8E4M3
self.gemm = FP8GemmModule(
- precision, use_bias=False, gelu=False, scale_factors=gemm_scale_factors,
- hidden_size=hidden_size, out_features=out_features)
+ precision,
+ use_bias=False,
+ gelu=False,
+ scale_factors=gemm_scale_factors,
+ hidden_size=hidden_size,
+ out_features=out_features,
+ )
def forward(self, inp, weight):
x = self.gemm(inp, weight)
@@ -1286,14 +1417,16 @@ def test_export_gemm_layernorm(
self.fp8_tensor,
self.fp8_type,
0,
- zero_centered_gamma)
+ zero_centered_gamma,
+ )
x = cast_from_fp8(
x,
self.meta,
self.fp8_tensor,
self.fp8_type,
- tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16)
+ tex.DType.kFloat32 if precision == torch.float32 else tex.DType.kFloat16,
+ )
return x
inp = torch.randn(hidden_size, in_features, dtype=precision, device="cuda")
@@ -1302,14 +1435,21 @@ def test_export_gemm_layernorm(
high_prec_str = dtype2str(precision)
fp8_str = f"_fp8" if use_fp8 else ""
fname = f"te.gemm_layernorm{fp8_str}{high_prec_str}.onnx"
- input_names = ['input', 'weight']
+ input_names = ["input", "weight"]
do_export(model, (inp, weight), fname, use_fp8=use_fp8, input_names=input_names)
te_outputs = te_infer(model, (inp, weight), is_fp8=use_fp8)
serialize_inputs_outputs(fname, (inp, weight), te_outputs, input_names=input_names)
- if precision not in (torch.bfloat16, ):
+ if precision not in (torch.bfloat16,):
validate_result(
- fname, (inp, weight), model, atol=5e-2, is_fp8=use_fp8, allow_cnt_errors=2,
- input_names=input_names, te_outputs=te_outputs)
+ fname,
+ (inp, weight),
+ model,
+ atol=5e-2,
+ is_fp8=use_fp8,
+ allow_cnt_errors=2,
+ input_names=input_names,
+ te_outputs=te_outputs,
+ )
@skip_FP8
@@ -1357,32 +1497,61 @@ def test_export_gpt_generation(
output_layernorm=output_layernorm,
params_dtype=precision,
fuse_qkv_params=fuse_qkv_params,
- zero_centered_gamma=zero_centered_gamma).to(device='cuda')
+ zero_centered_gamma=zero_centered_gamma,
+ ).to(device="cuda")
# "Context phase": use full input sequence length
input_names = ["input"]
output_names = ["output"]
- input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
+ input_tensor = torch.rand(
+ sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
+ )
inp = (input_tensor,)
- do_export(model, inp, fname, use_fp8,
- input_names=input_names, output_names=output_names,
- dynamic_axes={"input": {0: "seq", 1:"bs"},
- "output": {0: "seq", 1:"bs"}, })
+ do_export(
+ model,
+ inp,
+ fname,
+ use_fp8,
+ input_names=input_names,
+ output_names=output_names,
+ dynamic_axes={
+ "input": {0: "seq", 1: "bs"},
+ "output": {0: "seq", 1: "bs"},
+ },
+ )
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
- serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names, output_names=output_names)
- if precision not in (torch.bfloat16, ):
- validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names,
- te_outputs=te_outputs)
+ serialize_inputs_outputs(
+ fname, inp, te_outputs, input_names=input_names, output_names=output_names
+ )
+ if precision not in (torch.bfloat16,):
+ validate_result(
+ fname,
+ inp,
+ model,
+ atol=6e-3,
+ is_fp8=use_fp8,
+ input_names=input_names,
+ te_outputs=te_outputs,
+ )
# "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8.
sequence_length = 1 if not use_fp8 else 8
- input_tensor = torch.rand(sequence_length, batch_size, hidden_size, dtype=precision, device="cuda")
+ input_tensor = torch.rand(
+ sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
+ )
inp = (input_tensor, attention_mask)
te_outputs = te_infer(model, inp, is_fp8=use_fp8)
serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
- if precision not in (torch.bfloat16, ):
- validate_result(fname, inp, model, atol=6e-3, is_fp8=use_fp8, input_names=input_names,
- te_outputs=te_outputs)
+ if precision not in (torch.bfloat16,):
+ validate_result(
+ fname,
+ inp,
+ model,
+ atol=6e-3,
+ is_fp8=use_fp8,
+ input_names=input_names,
+ te_outputs=te_outputs,
+ )
@pytest.mark.parametrize("enabled", [True, False])
diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py
index 5c16da982b87ac073b94a5ed59d70f7b7f9f5eda..8c854b65fbf3a708d1b18c2b47bed188259dd623 100644
--- a/tests/pytorch/test_recipe.py
+++ b/tests/pytorch/test_recipe.py
@@ -19,6 +19,7 @@ from transformer_engine.pytorch.fp8 import (
# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
+
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
class TestFP8Recipe:
@@ -95,8 +96,8 @@ class TestFP8Recipe:
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
- ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
- ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
+ ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
+ ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
update_weight_amax = is_first_microbatch is None or is_first_microbatch
if not update_weight_amax:
@@ -128,8 +129,8 @@ class TestFP8Recipe:
ref_amax_backward = amax_history_backward[-1]
else:
raise ValueError(f"{amax_compute_algo=} is not supported")
- ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2 ** margin)
- ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2 ** margin)
+ ref_scale_forward = (fp8_format.value.max_fwd / ref_amax_forward) / (2**margin)
+ ref_scale_backward = (fp8_format.value.max_bwd / ref_amax_backward) / (2**margin)
ref_scale_inv_forward = torch.reciprocal(ref_scale_forward)
ref_scale_inv_backward = torch.reciprocal(ref_scale_backward)
@@ -180,8 +181,9 @@ class TestFP8Recipe:
scaling_factor_compute_algo = None
if fused_update:
scaling_factor_compute_algo = (
- lambda amax, scale, fp8_max, recipe:
- te.fp8._default_sf_compute(amax, scale, fp8_max, recipe.margin)
+ lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute(
+ amax, scale, fp8_max, recipe.margin
+ )
)
recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=fp8_format, scaling_factor_compute_algo=scaling_factor_compute_algo
@@ -205,7 +207,9 @@ class TestFP8Recipe:
# test different scenarios
if amax_case == "zero":
- fp8_meta[forward_key].amax_history = torch.tensor([[0]], dtype=torch.float32, device="cuda")
+ fp8_meta[forward_key].amax_history = torch.tensor(
+ [[0]], dtype=torch.float32, device="cuda"
+ )
expected_scale = torch.tensor([1.0], dtype=torch.float32, device="cuda")
elif amax_case == "tiny":
# calculate the minimum amax value that results in a FP32 maximum scale
@@ -254,4 +258,6 @@ class TestFP8Recipe:
)
torch.testing.assert_close(fp8_meta[forward_key].scale, expected_scale)
- torch.testing.assert_close(fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale))
+ torch.testing.assert_close(
+ fp8_meta[forward_key].scale_inv, torch.reciprocal(expected_scale)
+ )
diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py
index 8f5913b70e4010f300498ba47b20daf2b85464db..4c9dea803f9590827049a2c28876ac237c6916dd 100644
--- a/tests/pytorch/test_sanity.py
+++ b/tests/pytorch/test_sanity.py
@@ -30,7 +30,13 @@ from transformer_engine.pytorch import (
)
from transformer_engine.common import recipe
import transformer_engine_torch as tex
-from transformer_engine.pytorch.cpp_extensions import gemm, fp8_gemm, gelu, cast_to_fp8, cast_from_fp8
+from transformer_engine.pytorch.cpp_extensions import (
+ gemm,
+ fp8_gemm,
+ gelu,
+ cast_to_fp8,
+ cast_from_fp8,
+)
from transformer_engine.pytorch.module.base import get_workspace
from test_onnx_export import create_meta
@@ -75,6 +81,7 @@ class ModelConfig:
return False
return True
+
model_configs = {
"126m": ModelConfig(12, 2048, 2, 768, 12),
"small": ModelConfig(2, 32, 2, 64, 2),
@@ -82,7 +89,7 @@ model_configs = {
}
fp8_recipes = [
- None, # Handles non-FP8 case
+ None, # Handles non-FP8 case
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3),
recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID),
recipe.DelayedScaling(
@@ -126,6 +133,7 @@ batch_sizes_with_zero = [0, 1, 2]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu"]
all_normalizations = ["LayerNorm", "RMSNorm"]
+
def _disable_wgrads(block):
for p in block.parameters():
p.requires_grad = False
@@ -143,8 +151,17 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad):
optimizer = torch.optim.SGD(block.parameters(), lr=0.1)
# Placeholders used for capture.
- static_input = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype, requires_grad=True)
- static_target = torch.randn(config.seq_len, config.batch_size, config.hidden_size, device='cuda', dtype=dtype)
+ static_input = torch.randn(
+ config.seq_len,
+ config.batch_size,
+ config.hidden_size,
+ device="cuda",
+ dtype=dtype,
+ requires_grad=True,
+ )
+ static_target = torch.randn(
+ config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype
+ )
real_input = torch.rand_like(static_input)
real_target = torch.rand_like(static_target)
@@ -403,11 +420,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
config = model_configs[model]
module = RMSNorm if normalization == "RMSNorm" else LayerNorm
- block = (
- module(config.hidden_size)
- .to(dtype=torch.float32)
- .cuda()
- )
+ block = module(config.hidden_size).to(dtype=torch.float32).cuda()
_test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad)
@@ -418,9 +431,9 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
-def test_sanity_layernorm_linear(dtype, fp8_recipe, model, skip_wgrad,
- zero_centered_gamma, skip_dgrad,
- normalization):
+def test_sanity_layernorm_linear(
+ dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, normalization
+):
config = model_configs[model]
if fp8_recipe is not None:
@@ -480,7 +493,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad):
def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_params, use_bias):
config = model_configs[model]
ffn_hidden_size = 4 * config.hidden_size
- num_tokens = bs*config.seq_len
+ num_tokens = bs * config.seq_len
if fp8_recipe is not None:
if not fp8_available:
@@ -490,15 +503,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
use_fp8 = fp8_recipe is not None
with fp8_model_init(enabled=use_fp8 and fp8_model_params):
- te_linear = (
- Linear(
- config.hidden_size,
- ffn_hidden_size,
- bias=use_bias,
- params_dtype=dtype
- )
- .cuda()
- )
+ te_linear = Linear(
+ config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
+ ).cuda()
inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
@@ -518,9 +525,9 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_
@pytest.mark.parametrize("skip_dgrad", all_boolean)
@pytest.mark.parametrize("activation", all_activations)
@pytest.mark.parametrize("normalization", all_normalizations)
-def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
- zero_centered_gamma, skip_dgrad, activation,
- normalization):
+def test_sanity_layernorm_mlp(
+ dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, skip_dgrad, activation, normalization
+):
config = model_configs[model]
if fp8_recipe is not None:
@@ -557,10 +564,18 @@ def test_sanity_layernorm_mlp(dtype, fp8_recipe, model, skip_wgrad,
@pytest.mark.parametrize("normalization", all_normalizations)
@pytest.mark.parametrize("parallel_attention_mlp", all_boolean)
@pytest.mark.parametrize("cpu_offload", all_boolean)
-def test_sanity_gpt(dtype, fp8_recipe, model, skip_wgrad,
- zero_centered_gamma, bias, activation,
- normalization, parallel_attention_mlp,
- cpu_offload):
+def test_sanity_gpt(
+ dtype,
+ fp8_recipe,
+ model,
+ skip_wgrad,
+ zero_centered_gamma,
+ bias,
+ activation,
+ normalization,
+ parallel_attention_mlp,
+ cpu_offload,
+):
config = model_configs[model]
if fp8_recipe is not None:
@@ -625,8 +640,7 @@ def test_sanity_gpt_126m():
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
-def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
- normalization):
+def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
@@ -683,8 +697,7 @@ def test_sanity_bert_126m():
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
-def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
- normalization):
+def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
@@ -845,7 +858,9 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad):
@pytest.mark.parametrize("model", ["small"])
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
-def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma):
+def test_sanity_gradient_accumulation_fusion(
+ dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma
+):
config = model_configs[model]
if fp8_recipe is not None:
@@ -885,8 +900,7 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra
@pytest.mark.parametrize("skip_wgrad", all_boolean)
@pytest.mark.parametrize("zero_centered_gamma", all_boolean)
@pytest.mark.parametrize("normalization", all_normalizations)
-def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma,
- normalization):
+def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization):
config = model_configs[model]
if fp8_recipe is not None:
@@ -919,9 +933,10 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm
_test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad)
+
def test_model_multiple_cast():
- a = torch.zeros((16,16), device="cuda")
- m = Linear(16,32)
+ a = torch.zeros((16, 16), device="cuda")
+ m = Linear(16, 32)
y = m(a)
assert y.dtype == torch.float32
@@ -937,15 +952,11 @@ def test_model_multiple_cast():
@pytest.mark.parametrize("offset", [1, 3, 5])
@pytest.mark.parametrize("datatype", param_types)
def test_sanity_gemm_with_unalignment(N, offset, datatype):
- scratchpad = torch.randn(N*N + 2*offset, device="cuda", dtype=datatype)
+ scratchpad = torch.randn(N * N + 2 * offset, device="cuda", dtype=datatype)
inp = torch.reshape(scratchpad[offset:-offset], (N, N))
- weight = torch.reshape(scratchpad[offset*2:], (N, N))
+ weight = torch.reshape(scratchpad[offset * 2 :], (N, N))
- _, _, _ = gemm(
- A=weight,
- B=inp,
- dtype=datatype,
- workspace=get_workspace())
+ _, _, _ = gemm(A=weight, B=inp, dtype=datatype, workspace=get_workspace())
torch.cuda.synchronize()
@@ -954,38 +965,35 @@ def test_sanity_gemm_with_unalignment(N, offset, datatype):
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
def test_sanity_fp8_gemm_with_unalignment(N, datatype):
offset = 16
- scratchpad = torch.randn(N*N + offset, device="cuda", dtype=datatype)
+ scratchpad = torch.randn(N * N + offset, device="cuda", dtype=datatype)
fp8_tensor_inp = tex.FP8FwdTensors.GEMM1_INPUT
fp8_tensor_weight = tex.FP8FwdTensors.GEMM1_WEIGHT
nb_inp_scales, nb_weight_scales = 1, N
- scale_factor = 1.
+ scale_factor = 1.0
meta_inp = create_meta(scale_factor, nb_inp_scales)
meta_weight = create_meta(scale_factor, nb_weight_scales)
inp_type = tex.DType.kFloat8E4M3
weights_type = tex.DType.kFloat8E4M3
outp_type = datatype
- scratchpad_fp8 = cast_to_fp8(
- scratchpad,
- meta_weight,
- fp8_tensor_inp,
- inp_type)
+ scratchpad_fp8 = cast_to_fp8(scratchpad, meta_weight, fp8_tensor_inp, inp_type)
inp_fp8 = torch.reshape(scratchpad_fp8[:-offset], (N, N))
weight_fp8 = torch.reshape(scratchpad_fp8[offset:], (N, N))
_, _ = fp8_gemm(
- weight_fp8,
- meta_weight.scale_inv,
- fp8_tensor_weight,
- inp_type,
- inp_fp8,
- meta_inp.scale_inv,
- fp8_tensor_inp,
- weights_type,
- outp_type,
- get_workspace(),
- bias=None,
- use_bias=False,
- use_split_accumulator=False)
+ weight_fp8,
+ meta_weight.scale_inv,
+ fp8_tensor_weight,
+ inp_type,
+ inp_fp8,
+ meta_inp.scale_inv,
+ fp8_tensor_inp,
+ weights_type,
+ outp_type,
+ get_workspace(),
+ bias=None,
+ use_bias=False,
+ use_split_accumulator=False,
+ )
torch.cuda.synchronize()
diff --git a/tests/pytorch/test_sanity_import.py b/tests/pytorch/test_sanity_import.py
index 4ddd6eb6e2f59755fe80dd85e69943f1bdf9fab4..954d807b7dac6aef81407201536bf59b34e71dca 100644
--- a/tests/pytorch/test_sanity_import.py
+++ b/tests/pytorch/test_sanity_import.py
@@ -3,4 +3,5 @@
# See LICENSE for license information.
import transformer_engine.pytorch
+
print("OK")
diff --git a/tests/pytorch/test_torch_save_load.py b/tests/pytorch/test_torch_save_load.py
index 74f76f7eb82e6218bb02f90cc8a3a6656f974c9a..6af7ede234095b5f382bba81d23cda2e893efabb 100644
--- a/tests/pytorch/test_torch_save_load.py
+++ b/tests/pytorch/test_torch_save_load.py
@@ -28,7 +28,7 @@ from transformer_engine.pytorch.module.base import TransformerEngineBaseModule
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
-def init_meta(size: int=1):
+def init_meta(size: int = 1):
meta = tex.FP8TensorMeta()
meta.scale = torch.ones(size, dtype=torch.float32, device="cuda")
meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda")
@@ -65,22 +65,18 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
self.inp_type = tex.DType.kFloat8E4M3
self.weights_type = tex.DType.kFloat8E4M3
self.outp_type = precision
-
+
def get_fp8_weights_scratchpad(self, is_first_microbatch):
- raise RuntimeError("Method get_fp8_weights_scratchpad is dummy and should not be invoked.")
+ raise RuntimeError(
+ "Method get_fp8_weights_scratchpad is dummy and should not be invoked."
+ )
def forward(self, inp, weight):
- inp_fp8 = cast_to_fp8(
- inp,
- self.meta_inp,
- self.fp8_tensor_inp,
- self.inp_type)
+ inp_fp8 = cast_to_fp8(inp, self.meta_inp, self.fp8_tensor_inp, self.inp_type)
weight_fp8 = cast_to_fp8(
- weight,
- self.meta_weight,
- self.fp8_tensor_weight,
- self.weights_type)
+ weight, self.meta_weight, self.fp8_tensor_weight, self.weights_type
+ )
ret = fp8_gemm(
weight_fp8,
@@ -95,20 +91,33 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
get_workspace(),
bias=self.bias,
use_bias=self.use_bias,
- use_split_accumulator=False)
+ use_split_accumulator=False,
+ )
return ret
model_in = Test_TE_Export(precision, True)
with te.fp8_autocast(enabled=True):
model_in.init_fp8_metadata()
# scaling fwd
- model_in.fp8_meta["scaling_fwd"].scale = torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd
- model_in.fp8_meta["scaling_fwd"].scale_inv = torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd
- model_in.fp8_meta["scaling_fwd"].amax_history = torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd
+ model_in.fp8_meta["scaling_fwd"].scale = (
+ torch.ones(3, dtype=torch.float32, device="cuda") * scale_fwd
+ )
+ model_in.fp8_meta["scaling_fwd"].scale_inv = (
+ torch.ones(3, dtype=torch.float32, device="cuda") / scale_fwd
+ )
+ model_in.fp8_meta["scaling_fwd"].amax_history = (
+ torch.ones(3, dtype=torch.float32, device="cuda") * history_fwd
+ )
# scaling bwd
- model_in.fp8_meta["scaling_bwd"].scale = torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd
- model_in.fp8_meta["scaling_bwd"].scale_inv = torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd
- model_in.fp8_meta["scaling_bwd"].amax_history = torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd
+ model_in.fp8_meta["scaling_bwd"].scale = (
+ torch.ones(2, dtype=torch.float32, device="cuda") * scale_bwd
+ )
+ model_in.fp8_meta["scaling_bwd"].scale_inv = (
+ torch.ones(2, dtype=torch.float32, device="cuda") / scale_bwd
+ )
+ model_in.fp8_meta["scaling_bwd"].amax_history = (
+ torch.ones(2, dtype=torch.float32, device="cuda") * history_bwd
+ )
torch.save(model_in.state_dict(), tmp_filename)
@@ -117,13 +126,27 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
model_out.eval()
# scaling fwd
- assert torch.allclose(model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale)
- assert torch.allclose(model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv)
- assert torch.allclose(model_in.fp8_meta["scaling_fwd"].amax_history, model_out.fp8_meta["scaling_fwd"].amax_history)
+ assert torch.allclose(
+ model_in.fp8_meta["scaling_fwd"].scale, model_out.fp8_meta["scaling_fwd"].scale
+ )
+ assert torch.allclose(
+ model_in.fp8_meta["scaling_fwd"].scale_inv, model_out.fp8_meta["scaling_fwd"].scale_inv
+ )
+ assert torch.allclose(
+ model_in.fp8_meta["scaling_fwd"].amax_history,
+ model_out.fp8_meta["scaling_fwd"].amax_history,
+ )
# scaling bwd
- assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale)
- assert torch.allclose(model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv)
- assert torch.allclose(model_in.fp8_meta["scaling_bwd"].amax_history, model_out.fp8_meta["scaling_bwd"].amax_history)
+ assert torch.allclose(
+ model_in.fp8_meta["scaling_bwd"].scale, model_out.fp8_meta["scaling_bwd"].scale
+ )
+ assert torch.allclose(
+ model_in.fp8_meta["scaling_bwd"].scale_inv, model_out.fp8_meta["scaling_bwd"].scale_inv
+ )
+ assert torch.allclose(
+ model_in.fp8_meta["scaling_bwd"].amax_history,
+ model_out.fp8_meta["scaling_bwd"].amax_history,
+ )
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
@@ -132,7 +155,7 @@ def test_export_loaded_checkpoint(scale_fwd, scale_bwd, history_fwd, history_bwd
def test_fp8_model_checkpoint(
save_fp8_model: bool,
load_fp8_model: bool,
- dims: Iterable[int] = [32,32],
+ dims: Iterable[int] = [32, 32],
dtype: torch.dtype = torch.float32,
device: Union[torch.device, str] = "cuda",
):
@@ -153,7 +176,7 @@ def test_fp8_model_checkpoint(
with te.fp8_autocast():
y_ref = model(x.detach().clone()).detach().clone()
- fp8_meta_ref = { "scaling_fwd": {}, "scaling_bwd": {} }
+ fp8_meta_ref = {"scaling_fwd": {}, "scaling_bwd": {}}
with te.fp8_autocast(), torch.no_grad():
fp8_meta_fwd = model.fp8_meta["scaling_fwd"]
fp8_meta_bwd = model.fp8_meta["scaling_bwd"]
@@ -168,7 +191,7 @@ def test_fp8_model_checkpoint(
fp8_meta_bwd.scale.copy_(fp8_meta_bwd_ref["scale"])
fp8_meta_bwd.scale_inv.copy_(fp8_meta_bwd_ref["scale_inv"])
del fp8_meta_fwd, fp8_meta_bwd
-
+
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
# This line copies the fp8 scale_inv from the model metadata to the weight fp8 tensor.
# The sole purpose of the following lines is to set the scale_inv of the weight tensor, which is the simplest method.
@@ -226,15 +249,14 @@ def test_fp8_model_checkpoint(
with pytest.raises(AssertionError):
torch.testing.assert_close(y, y_ref, **tols)
-
# [ This is part of logic that tests save_fp8_model=False and load_fp8_model=True ]
- # When save_fp8_model=True, we load a model with weights in high precision,
+ # When save_fp8_model=True, we load a model with weights in high precision,
# which does not include _scale_inv,
- # but has the fp8 scaling factor in the meta data. This scenario can occur
+ # but has the fp8 scaling factor in the meta data. This scenario can occur
# when using te.fp8_autocast(enabled=False, calibrating=True).
#
# In such cases, the default behavior of load_state_dict is incorrect - it loads tensors first,
- # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior
+ # followed by the fp8 metadata. This results in an incorrect _scale_inv for the tensor. This behavior
# is corrected by overriding the _load_state_dict method from PyTorch in TransformerEngineBaseModule,
# to load the fp8 metadata before loading tensors.
#
@@ -262,4 +284,6 @@ def test_fp8_model_checkpoint(
# We need to ensure that the tensor's scale_inv parameter matches its meta data.
# This is crucial to avoid confusion about which value is correct.
meta_index = model.weight._fp8_meta_index
- torch.testing.assert_close(model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item())
\ No newline at end of file
+ torch.testing.assert_close(
+ model.weight._scale_inv.item(), fp8_meta_fwd_ref["scale_inv"][meta_index].item()
+ )
diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h
index 12e1b37e8fc304c9fd971ad75e4f348a57ba8ce9..fc93705dff4866cb47c8f35ed006be6100563f24 100644
--- a/transformer_engine/common/activation/activation_template.h
+++ b/transformer_engine/common/activation/activation_template.h
@@ -4,74 +4,59 @@
* See LICENSE for license information.
************************************************************************/
-#include
#include
-#include "../util/vectorized_pointwise.h"
-#include "../common.h"
+#include
+#include "../common.h"
+#include "../util/vectorized_pointwise.h"
namespace transformer_engine {
-template
-void act_fn(const Tensor &input,
- Tensor *output,
- cudaStream_t stream) {
+template
+void act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "act_lu_input");
CheckOutputTensor(*output, "act_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);
- TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
- TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
- constexpr int nvec = 32 / sizeof(IType);
- VectorizedUnaryKernelLauncher(
- reinterpret_cast(input.data.dptr),
- reinterpret_cast(output->data.dptr),
- reinterpret_cast(output->scale.dptr),
- reinterpret_cast(output->amax.dptr),
- tot_elts,
- {},
- stream);
- ); // NOLINT(*)
- ); // NOLINT(*)
+ TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
+ input.data.dtype, IType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
+ output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
+ VectorizedUnaryKernelLauncher(
+ reinterpret_cast(input.data.dptr),
+ reinterpret_cast(output->data.dptr),
+ reinterpret_cast(output->scale.dptr),
+ reinterpret_cast(output->amax.dptr), tot_elts, {},
+ stream);); // NOLINT(*)
+ ); // NOLINT(*)
}
-template
-void dact_fn(const Tensor &grad,
- const Tensor &input,
- Tensor *output,
- cudaStream_t stream) {
+template
+void dact_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "dact_lu_input");
CheckInputTensor(grad, "dact_lu_input_grad");
CheckOutputTensor(*output, "dact_lu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
- NVTE_CHECK(input.data.dtype == grad.data.dtype,
- "Input and incoming gradient types must match.");
+ NVTE_CHECK(input.data.dtype == grad.data.dtype, "Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);
- TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
- TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
- constexpr int nvec = 32 / sizeof(IType);
- VectorizedUnaryGradKernelLauncher(
- reinterpret_cast(grad.data.dptr),
- reinterpret_cast(input.data.dptr),
- reinterpret_cast(output->data.dptr),
- reinterpret_cast(output->scale.dptr),
- reinterpret_cast(output->amax.dptr),
- tot_elts,
- {},
- stream);
- ); // NOLINT(*)
- ); // NOLINT(*)
+ TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
+ input.data.dtype, IType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
+ output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
+ VectorizedUnaryGradKernelLauncher(
+ reinterpret_cast(grad.data.dptr),
+ reinterpret_cast(input.data.dptr),
+ reinterpret_cast(output->data.dptr),
+ reinterpret_cast(output->scale.dptr),
+ reinterpret_cast(output->amax.dptr), tot_elts, {},
+ stream);); // NOLINT(*)
+ ); // NOLINT(*)
}
-template
-void gated_act_fn(const Tensor &input,
- Tensor *output,
- cudaStream_t stream) {
+template
+void gated_act_fn(const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions.");
@@ -81,29 +66,23 @@ void gated_act_fn(const Tensor &input,
NVTE_CHECK(input.data.shape[1] == output->data.shape[1] * 2,
"Input shape[1] must be 2x larger than output shape[1].");
- TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
- TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
- constexpr int nvec = 32 / sizeof(IType);
- GatedActivationKernelLauncher(
- reinterpret_cast(input.data.dptr),
- reinterpret_cast(output->data.dptr),
- reinterpret_cast(output->scale.dptr),
- reinterpret_cast(output->amax.dptr),
- output->data.shape[0],
- output->data.shape[1],
- {},
- stream);
- ); // NOLINT(*)
- ); // NOLINT(*)
+ TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
+ input.data.dtype, IType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
+ output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
+ GatedActivationKernelLauncher(
+ reinterpret_cast(input.data.dptr),
+ reinterpret_cast(output->data.dptr),
+ reinterpret_cast(output->scale.dptr),
+ reinterpret_cast(output->amax.dptr), output->data.shape[0],
+ output->data.shape[1], {},
+ stream);); // NOLINT(*)
+ ); // NOLINT(*)
}
-template
-void dgated_act_fn(const Tensor &grad,
- const Tensor &input,
- Tensor *output,
- cudaStream_t stream) {
+template
+void dgated_act_fn(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output");
@@ -114,23 +93,19 @@ void dgated_act_fn(const Tensor &grad,
"Output shape[0] must be equal to grad shape[0].");
NVTE_CHECK(output->data.shape[1] == grad.data.shape[1] * 2,
"Output shape[1] must be 2x larger than grad shape[1].");
- NVTE_CHECK(input.data.shape == output->data.shape,
- "Input and output shapes must match.");
+ NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
- TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
- TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
- constexpr int nvec = 32 / sizeof(IType);
- DGatedActivationKernelLauncher(
- reinterpret_cast(grad.data.dptr),
- reinterpret_cast(input.data.dptr),
- reinterpret_cast(output->data.dptr),
- grad.data.shape[0],
- grad.data.shape[1],
- {},
- stream);
- ); // NOLINT(*)
- ); // NOLINT(*)
+ TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
+ input.data.dtype, IType,
+ TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(
+ output->data.dtype, OType, constexpr int nvec = 32 / sizeof(IType);
+ DGatedActivationKernelLauncher(
+ reinterpret_cast(grad.data.dptr),
+ reinterpret_cast(input.data.dptr),
+ reinterpret_cast(output->data.dptr), grad.data.shape[0], grad.data.shape[1],
+ {},
+ stream);); // NOLINT(*)
+ ); // NOLINT(*)
}
} // namespace transformer_engine
-
diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu
index 9ed0a23cec4e7c68de3fbb79f9da055495cdf3b0..f9cd7b845a83ae90ad84950f03e6f8493824e8ab 100644
--- a/transformer_engine/common/activation/gelu.cu
+++ b/transformer_engine/common/activation/gelu.cu
@@ -3,96 +3,69 @@
*
* See LICENSE for license information.
************************************************************************/
-#include "./activation_template.h"
#include "../util/math.h"
+#include "./activation_template.h"
-
-void nvte_gelu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_gelu);
using namespace transformer_engine;
act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dgelu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
+void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgelu);
using namespace transformer_engine;
dact_fn>(*reinterpret_cast(grad),
*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_geglu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
gated_act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dgeglu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
+void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgated_act_fn, dgelu>(
- *reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(grad), *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
-void nvte_qgelu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine;
act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dqgelu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
+ cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dact_fn>(*reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
-void nvte_qgeglu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dqgeglu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
+ cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn, dqgelu>(
- *reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(grad), *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu
index 5d6b02d5edc46b30553a10104cb60bd5c76499be..c18d018a8e0843bcb48c1a83d452c8209c73313e 100644
--- a/transformer_engine/common/activation/relu.cu
+++ b/transformer_engine/common/activation/relu.cu
@@ -4,96 +4,69 @@
* See LICENSE for license information.
************************************************************************/
-#include "./activation_template.h"
#include "../util/math.h"
+#include "./activation_template.h"
-
-void nvte_relu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_relu);
using namespace transformer_engine;
act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_drelu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
+void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_drelu);
using namespace transformer_engine;
dact_fn>(*reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
-void nvte_reglu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
gated_act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dreglu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
+void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dgated_act_fn, drelu>(
- *reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(grad), *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
-void nvte_srelu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_srelu);
using namespace transformer_engine;
act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dsrelu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
+ cudaStream_t stream) {
NVTE_API_CALL(nvte_dsrelu);
using namespace transformer_engine;
dact_fn>(*reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
-void nvte_sreglu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dsreglu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
+ cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn, dsrelu>(
- *reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(grad), *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu
index 736fb9e4373c9ae9a2a20b70a0f055f823374b6b..c745ffeeb4c5b004260f3e29f4dec87ae69ef742 100644
--- a/transformer_engine/common/activation/swiglu.cu
+++ b/transformer_engine/common/activation/swiglu.cu
@@ -4,51 +4,37 @@
* See LICENSE for license information.
************************************************************************/
-#include "./activation_template.h"
#include "../util/math.h"
+#include "./activation_template.h"
-
-void nvte_silu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_silu);
using namespace transformer_engine;
act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dsilu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
+void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsilu);
using namespace transformer_engine;
dact_fn>(*reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
-void nvte_swiglu(const NVTETensor input,
- NVTETensor output,
- cudaStream_t stream) {
+void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
gated_act_fn>(*reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ reinterpret_cast(output), stream);
}
-void nvte_dswiglu(const NVTETensor grad,
- const NVTETensor input,
- NVTETensor output,
+void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dgated_act_fn, dsilu>(
- *reinterpret_cast(grad),
- *reinterpret_cast(input),
- reinterpret_cast(output),
- stream);
+ *reinterpret_cast(grad), *reinterpret_cast(input),
+ reinterpret_cast(output), stream);
}
diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h
index 92193c04ed80f723b11e3eba4e9a05d6114cae6d..42b529f3886aadc13dde8c615155a7a83e2e5505 100644
--- a/transformer_engine/common/common.h
+++ b/transformer_engine/common/common.h
@@ -7,6 +7,12 @@
#ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_
#define TRANSFORMER_ENGINE_COMMON_COMMON_H_
+#include
+#include
+#include
+#include
+#include
+
#include
#include
#include
@@ -15,12 +21,6 @@
#include
#include
-#include
-#include
-#include
-#include
-
-#include
#include "./nvtx.h"
#include "./util/logging.h"
@@ -31,8 +31,8 @@ struct SimpleTensor {
std::vector shape;
DType dtype;
- SimpleTensor(void *dptr, const std::vector &shape, DType dtype) :
- dptr(dptr), shape(shape), dtype(dtype) {}
+ SimpleTensor(void *dptr, const std::vector &shape, DType dtype)
+ : dptr(dptr), shape(shape), dtype(dtype) {}
SimpleTensor() : SimpleTensor(nullptr, {}, DType::kFloat32) {}
};
@@ -42,15 +42,16 @@ struct Tensor {
SimpleTensor scale;
SimpleTensor scale_inv;
- Tensor() : data(),
- amax(nullptr, {1}, DType::kFloat32),
- scale(nullptr, {1}, DType::kFloat32),
- scale_inv(nullptr, {1}, DType::kFloat32) {}
+ Tensor()
+ : data(),
+ amax(nullptr, {1}, DType::kFloat32),
+ scale(nullptr, {1}, DType::kFloat32),
+ scale_inv(nullptr, {1}, DType::kFloat32) {}
};
template
constexpr T DIVUP(const T &x, const T &y) {
- return (((x) + ((y)-1)) / (y));
+ return (((x) + ((y)-1)) / (y));
}
using byte = uint8_t;
@@ -65,8 +66,11 @@ namespace detail {
template
constexpr inline const char *type_name() noexcept;
-#define TRANSFORMER_ENGINE_TYPE_NAME(T) \
- template <> inline constexpr const char *type_name() noexcept { return #T; }
+#define TRANSFORMER_ENGINE_TYPE_NAME(T) \
+ template <> \
+ inline constexpr const char *type_name() noexcept { \
+ return #T; \
+ }
TRANSFORMER_ENGINE_TYPE_NAME(uint8_t)
TRANSFORMER_ENGINE_TYPE_NAME(int32_t)
TRANSFORMER_ENGINE_TYPE_NAME(float)
@@ -79,214 +83,167 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
} // namespace detail
template
-struct TypeInfo{
- using types = std::tuple;
-
- template
- struct Helper {
- constexpr static DType getType() {
- constexpr int i = static_cast(current);
- if (std::is_same::type>::value) {
- return current;
- } else {
- return Helper(i + 1)>::getType();
- }
- }
- };
-
- template
- struct Helper {
- constexpr static DType getType() {
- return DType::kNumTypes;
- }
- };
-
- template
+struct TypeInfo {
+ using types = std::tuple;
+
+ template
+ struct Helper {
constexpr static DType getType() {
- return Helper::getType();
+ constexpr int i = static_cast(current);
+ if (std::is_same::type>::value) {
+ return current;
+ } else {
+ return Helper(i + 1)>::getType();
+ }
}
+ };
+
+ template
+ struct Helper {
+ constexpr static DType getType() { return DType::kNumTypes; }
+ };
- constexpr static DType dtype = getType();
- constexpr static size_t size = sizeof(T);
- constexpr static const char *name = detail::type_name();
+ template
+ constexpr static DType getType() {
+ return Helper::getType();
+ }
+
+ constexpr static DType dtype = getType();
+ constexpr static size_t size = sizeof(T);
+ constexpr static const char *name = detail::type_name();
};
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
- switch (dtype) { \
- using namespace transformer_engine; \
- case DType::kByte: \
- { \
- using type = unsigned char; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kInt32: \
- { \
- using type = float; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat32: \
- { \
- using type = float; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat16: \
- { \
- using type = fp16; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kBFloat16: \
- { \
- using type = bf16; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat8E4M3: \
- { \
- using type = fp8e4m3; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat8E5M2: \
- { \
- using type = fp8e5m2; \
- {__VA_ARGS__} \
- } \
- break; \
- default: \
- NVTE_ERROR("Invalid type."); \
- }
+ switch (dtype) { \
+ using namespace transformer_engine; \
+ case DType::kByte: { \
+ using type = unsigned char; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kInt32: { \
+ using type = float; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat32: { \
+ using type = float; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat16: { \
+ using type = fp16; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kBFloat16: { \
+ using type = bf16; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat8E4M3: { \
+ using type = fp8e4m3; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat8E5M2: { \
+ using type = fp8e5m2; \
+ { __VA_ARGS__ } \
+ } break; \
+ default: \
+ NVTE_ERROR("Invalid type."); \
+ }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(dtype, type, ...) \
- switch (dtype) { \
- using namespace transformer_engine; \
- case DType::kFloat32: \
- { \
- using type = float; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat16: \
- { \
- using type = fp16; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kBFloat16: \
- { \
- using type = bf16; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat8E5M2: \
- { \
- using type = fp8e5m2; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat8E4M3: \
- { \
- using type = fp8e4m3; \
- {__VA_ARGS__} \
- } \
- break; \
- default: \
- NVTE_ERROR("Invalid type."); \
- }
+ switch (dtype) { \
+ using namespace transformer_engine; \
+ case DType::kFloat32: { \
+ using type = float; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat16: { \
+ using type = fp16; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kBFloat16: { \
+ using type = bf16; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat8E5M2: { \
+ using type = fp8e5m2; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat8E4M3: { \
+ using type = fp8e4m3; \
+ { __VA_ARGS__ } \
+ } break; \
+ default: \
+ NVTE_ERROR("Invalid type."); \
+ }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \
- switch (dtype) { \
- using namespace transformer_engine; \
- case DType::kFloat8E5M2: \
- { \
- using type = fp8e5m2; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat8E4M3: \
- { \
- using type = fp8e4m3; \
- {__VA_ARGS__} \
- } \
- break; \
- default: \
- NVTE_ERROR("Invalid type."); \
- }
+ switch (dtype) { \
+ using namespace transformer_engine; \
+ case DType::kFloat8E5M2: { \
+ using type = fp8e5m2; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat8E4M3: { \
+ using type = fp8e4m3; \
+ { __VA_ARGS__ } \
+ } break; \
+ default: \
+ NVTE_ERROR("Invalid type."); \
+ }
#define TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(dtype, type, ...) \
- switch (dtype) { \
- using namespace transformer_engine; \
- case DType::kFloat32: \
- { \
- using type = float; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat16: \
- { \
- using type = fp16; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kBFloat16: \
- { \
- using type = bf16; \
- {__VA_ARGS__} \
- } \
- break; \
- case DType::kFloat8E5M2: \
- case DType::kFloat8E4M3: \
- { \
- NVTE_ERROR("FP8 type not instantiated for input."); \
- } \
- break; \
- default: \
- NVTE_ERROR("Invalid type."); \
- }
-
-#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
- switch (dtype) \
- { \
- using namespace transformer_engine; \
- case DType::kFloat16: \
- { \
- using type = fp16; \
- __VA_ARGS__; \
- break; \
- } \
- case DType::kBFloat16: \
- { \
- using type = bf16; \
- __VA_ARGS__; \
- break; \
- } \
- default: \
- NVTE_ERROR("Invalid type for 16 bit."); \
- }
+ switch (dtype) { \
+ using namespace transformer_engine; \
+ case DType::kFloat32: { \
+ using type = float; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat16: { \
+ using type = fp16; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kBFloat16: { \
+ using type = bf16; \
+ { __VA_ARGS__ } \
+ } break; \
+ case DType::kFloat8E5M2: \
+ case DType::kFloat8E4M3: { \
+ NVTE_ERROR("FP8 type not instantiated for input."); \
+ } break; \
+ default: \
+ NVTE_ERROR("Invalid type."); \
+ }
+
+#define TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT(dtype, type, ...) \
+ switch (dtype) { \
+ using namespace transformer_engine; \
+ case DType::kFloat16: { \
+ using type = fp16; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ case DType::kBFloat16: { \
+ using type = bf16; \
+ __VA_ARGS__; \
+ break; \
+ } \
+ default: \
+ NVTE_ERROR("Invalid type for 16 bit."); \
+ }
////////////////////////////////////////////////////////////////////////////////////////////////////
inline size_t product(const std::vector &shape) {
- size_t ret = 1;
- for (const auto &elem : shape) {
- ret *= elem;
- }
- return ret;
+ size_t ret = 1;
+ for (const auto &elem : shape) {
+ ret *= elem;
+ }
+ return ret;
}
inline int log2_ceil(int value) {
- int log2_value = 0;
- while ((1 << log2_value) < value) ++log2_value;
- return log2_value;
+ int log2_value = 0;
+ while ((1 << log2_value) < value) ++log2_value;
+ return log2_value;
}
template
@@ -306,7 +263,7 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt
bool is_fp8_dtype(const DType t);
#define NVTE_API_CALL(api_name) \
- transformer_engine::nvtx::NVTXWrapper _ ## api_name ## _nvtx_wrapper(#api_name);
+ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name);
} // namespace transformer_engine
diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp
index a2eab9a70864195bbc2d1ec9f080048eae641233..e5a38793ba260941a0b39c33340696d712eed7ac 100644
--- a/transformer_engine/common/fused_attn/fused_attn.cpp
+++ b/transformer_engine/common/fused_attn/fused_attn.cpp
@@ -5,79 +5,74 @@
************************************************************************/
#include "transformer_engine/fused_attn.h"
+
#include "../common.h"
-#include "utils.h"
-#include "fused_attn_f16_max512_seqlen.h"
-#include "fused_attn_f16_arbitrary_seqlen.h"
-#include "fused_attn_fp8.h"
#include "../util/cuda_runtime.h"
#include "../util/system.h"
+#include "fused_attn_f16_arbitrary_seqlen.h"
+#include "fused_attn_f16_max512_seqlen.h"
+#include "fused_attn_fp8.h"
+#include "utils.h"
// map NVTE_QKV_Layout to NVTE_QKV_Layout_Group
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) {
- switch (qkv_layout) {
- case NVTE_QKV_Layout::NVTE_SB3HD:
- case NVTE_QKV_Layout::NVTE_BS3HD:
- case NVTE_QKV_Layout::NVTE_T3HD:
- return NVTE_QKV_Layout_Group::NVTE_3HD;
- case NVTE_QKV_Layout::NVTE_SBH3D:
- case NVTE_QKV_Layout::NVTE_BSH3D:
- case NVTE_QKV_Layout::NVTE_TH3D:
- return NVTE_QKV_Layout_Group::NVTE_H3D;
- case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
- case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
- case NVTE_QKV_Layout::NVTE_THD_T2HD:
- return NVTE_QKV_Layout_Group::NVTE_HD_2HD;
- case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
- case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
- case NVTE_QKV_Layout::NVTE_THD_TH2D:
- return NVTE_QKV_Layout_Group::NVTE_HD_H2D;
- case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
- case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
- case NVTE_QKV_Layout::NVTE_THD_THD_THD:
- return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
- default:
- NVTE_ERROR("qkv_layout not supported!");
- }
+ switch (qkv_layout) {
+ case NVTE_QKV_Layout::NVTE_SB3HD:
+ case NVTE_QKV_Layout::NVTE_BS3HD:
+ case NVTE_QKV_Layout::NVTE_T3HD:
+ return NVTE_QKV_Layout_Group::NVTE_3HD;
+ case NVTE_QKV_Layout::NVTE_SBH3D:
+ case NVTE_QKV_Layout::NVTE_BSH3D:
+ case NVTE_QKV_Layout::NVTE_TH3D:
+ return NVTE_QKV_Layout_Group::NVTE_H3D;
+ case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
+ case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
+ case NVTE_QKV_Layout::NVTE_THD_T2HD:
+ return NVTE_QKV_Layout_Group::NVTE_HD_2HD;
+ case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
+ case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
+ case NVTE_QKV_Layout::NVTE_THD_TH2D:
+ return NVTE_QKV_Layout_Group::NVTE_HD_H2D;
+ case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
+ case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
+ case NVTE_QKV_Layout::NVTE_THD_THD_THD:
+ return NVTE_QKV_Layout_Group::NVTE_HD_HD_HD;
+ default:
+ NVTE_ERROR("qkv_layout not supported!");
+ }
}
// map NVTE_QKV_Layout to NVTE_QKV_Format
NVTE_QKV_Format nvte_get_qkv_format(NVTE_QKV_Layout qkv_layout) {
- switch (qkv_layout) {
- case NVTE_QKV_Layout::NVTE_SB3HD:
- case NVTE_QKV_Layout::NVTE_SBH3D:
- case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
- case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
- case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
- return NVTE_QKV_Format::NVTE_SBHD;
- case NVTE_QKV_Layout::NVTE_BS3HD:
- case NVTE_QKV_Layout::NVTE_BSH3D:
- case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
- case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
- case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
- return NVTE_QKV_Format::NVTE_BSHD;
- case NVTE_QKV_Layout::NVTE_T3HD:
- case NVTE_QKV_Layout::NVTE_TH3D:
- case NVTE_QKV_Layout::NVTE_THD_T2HD:
- case NVTE_QKV_Layout::NVTE_THD_TH2D:
- case NVTE_QKV_Layout::NVTE_THD_THD_THD:
- return NVTE_QKV_Format::NVTE_THD;
- default:
- NVTE_ERROR("qkv_layout not supported!");
- }
+ switch (qkv_layout) {
+ case NVTE_QKV_Layout::NVTE_SB3HD:
+ case NVTE_QKV_Layout::NVTE_SBH3D:
+ case NVTE_QKV_Layout::NVTE_SBHD_SB2HD:
+ case NVTE_QKV_Layout::NVTE_SBHD_SBH2D:
+ case NVTE_QKV_Layout::NVTE_SBHD_SBHD_SBHD:
+ return NVTE_QKV_Format::NVTE_SBHD;
+ case NVTE_QKV_Layout::NVTE_BS3HD:
+ case NVTE_QKV_Layout::NVTE_BSH3D:
+ case NVTE_QKV_Layout::NVTE_BSHD_BS2HD:
+ case NVTE_QKV_Layout::NVTE_BSHD_BSH2D:
+ case NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD:
+ return NVTE_QKV_Format::NVTE_BSHD;
+ case NVTE_QKV_Layout::NVTE_T3HD:
+ case NVTE_QKV_Layout::NVTE_TH3D:
+ case NVTE_QKV_Layout::NVTE_THD_T2HD:
+ case NVTE_QKV_Layout::NVTE_THD_TH2D:
+ case NVTE_QKV_Layout::NVTE_THD_THD_THD:
+ return NVTE_QKV_Format::NVTE_THD;
+ default:
+ NVTE_ERROR("qkv_layout not supported!");
+ }
}
// select a backend for fused attention
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
- NVTEDType q_dtype,
- NVTEDType kv_dtype,
- NVTE_QKV_Layout qkv_layout,
- NVTE_Bias_Type bias_type,
- NVTE_Mask_Type attn_mask_type,
- float dropout,
- size_t num_attn_heads, size_t num_gqa_groups,
- size_t max_seqlen_q, size_t max_seqlen_kv,
- size_t head_dim) {
+ NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
+ NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups,
+ size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
@@ -85,96 +80,82 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
auto cudnn_runtime_version = cudnnGetVersion();
- if (((q_dtype == NVTEDType::kNVTEFloat8E4M3)
- || (q_dtype == NVTEDType::kNVTEFloat8E5M2))
- && (sm_arch_ >= 90)
- && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
- && (
- ((cudnn_runtime_version >= 8900)
- && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD)
- && (max_seqlen_q == max_seqlen_kv)
- && (max_seqlen_q <= 512)
- && (head_dim == 64)
- && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK))
- || ((cudnn_runtime_version >= 90100)
- && (max_seqlen_q % 128 == 0)
- && (max_seqlen_kv % 128 == 0)
- && (head_dim == 128)
- && ((qkv_format == NVTE_QKV_Format::NVTE_BSHD)
- || (qkv_format == NVTE_QKV_Format::NVTE_SBHD))
- && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
- || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) {
+ if (((q_dtype == NVTEDType::kNVTEFloat8E4M3) || (q_dtype == NVTEDType::kNVTEFloat8E5M2)) &&
+ (sm_arch_ >= 90) && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) &&
+ (((cudnn_runtime_version >= 8900) && (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) &&
+ (max_seqlen_q == max_seqlen_kv) && (max_seqlen_q <= 512) && (head_dim == 64) &&
+ (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)) ||
+ ((cudnn_runtime_version >= 90100) && (max_seqlen_q % 128 == 0) &&
+ (max_seqlen_kv % 128 == 0) && (head_dim == 128) &&
+ ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) ||
+ (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) &&
+ ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
+ (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))))) {
if (cudnn_runtime_version >= 8900) {
backend = NVTE_Fused_Attn_Backend::NVTE_FP8;
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+."
- " Please upgrade your cuDNN version if possible." << std::endl;
+ " Please upgrade your cuDNN version if possible."
+ << std::endl;
}
} else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) {
bool flag_m512 = false;
bool flag_arb = false;
- if ((sm_arch_ == 80 || sm_arch_ == 90)
- && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0)
- && (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0)
- && (head_dim == 64)
- && (num_attn_heads == num_gqa_groups)
- && ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
- || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
- && ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
- || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK)
- || (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK
- && max_seqlen_q == max_seqlen_kv)
- || (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))
- && ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD)
- || (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD)
- || (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD)
- || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD)
- || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
+ if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) &&
+ (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim == 64) &&
+ (num_attn_heads == num_gqa_groups) &&
+ ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
+ (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
+ ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
+ (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) ||
+ (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK &&
+ max_seqlen_q == max_seqlen_kv) ||
+ (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) &&
+ ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) ||
+ (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) ||
+ (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) ||
+ (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) ||
+ (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD))) {
flag_m512 = true;
}
- if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80)
- || (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90)))
- && ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0)
- || (cudnn_runtime_version >= 90000))
- && ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups)
- || (cudnn_runtime_version >= 8907))
- && ((head_dim <= 128 && head_dim % 8 == 0)
- // TODO (cyang): add is_training to nvte_get_fused_attn_backend
- // d=256 only supported for forward
- || (sm_arch_ >= 90 && cudnn_runtime_version >= 90000
- && head_dim <= 256 && head_dim % 8 == 0))
- && ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)
- || ((cudnn_runtime_version >= 8906)
- && (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS
- || (bias_type == NVTE_Bias_Type::NVTE_ALIBI
- && attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK
- && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK
- && sm_arch_ >= 90)
- || (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
- && sm_arch_ >= 90)))
- || ((cudnn_runtime_version >= 90000)
- && (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS
- && sm_arch_ >= 80)))
- && ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK)
- || ((cudnn_runtime_version >= 8906)
- && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK
- || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK
- || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK
- || attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)))
- && (!(cudnn_runtime_version >= 8906
- && (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK
- || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)
- && bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS))
- && ((qkv_format == NVTE_QKV_Format::NVTE_SBHD)
- || (sm_arch_ >= 90 && cudnn_runtime_version >= 90100
- && num_attn_heads == num_gqa_groups
- && qkv_format == NVTE_QKV_Format::NVTE_THD)
- || (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
+ if (((cudnn_runtime_version >= 8903 && sm_arch_ >= 80) ||
+ (cudnn_runtime_version < 8903 && (sm_arch_ == 80 || sm_arch_ == 90))) &&
+ ((cudnn_runtime_version < 90000 && max_seqlen_q % 64 == 0 && max_seqlen_kv % 64 == 0) ||
+ (cudnn_runtime_version >= 90000)) &&
+ ((cudnn_runtime_version < 8907 && num_attn_heads == num_gqa_groups) ||
+ (cudnn_runtime_version >= 8907)) &&
+ ((head_dim <= 128 && head_dim % 8 == 0)
+ // TODO (cyang): add is_training to nvte_get_fused_attn_backend
+ // d=256 only supported for forward
+ || (sm_arch_ >= 90 && cudnn_runtime_version >= 90000 && head_dim <= 256 &&
+ head_dim % 8 == 0)) &&
+ ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) ||
+ ((cudnn_runtime_version >= 8906) &&
+ (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS ||
+ (bias_type == NVTE_Bias_Type::NVTE_ALIBI &&
+ attn_mask_type != NVTE_Mask_Type::NVTE_NO_MASK &&
+ attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && sm_arch_ >= 90) ||
+ (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 90))) ||
+ ((cudnn_runtime_version >= 90000) &&
+ (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS && sm_arch_ >= 80))) &&
+ ((cudnn_runtime_version < 8906 && attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) ||
+ ((cudnn_runtime_version >= 8906) &&
+ (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
+ attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
+ attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
+ attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK))) &&
+ (!(cudnn_runtime_version >= 8906 &&
+ (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
+ attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) &&
+ bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) &&
+ ((qkv_format == NVTE_QKV_Format::NVTE_SBHD) ||
+ (sm_arch_ >= 90 && cudnn_runtime_version >= 90100 && num_attn_heads == num_gqa_groups &&
+ qkv_format == NVTE_QKV_Format::NVTE_THD) ||
+ (qkv_format == NVTE_QKV_Format::NVTE_BSHD))) {
flag_arb = true;
}
- if (((max_seqlen_q > 512) || (max_seqlen_kv > 512))
- && (flag_arb == true)) {
+ if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen;
}
if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) {
@@ -185,24 +166,26 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
int env_backend = static_cast(backend);
env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend);
- if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen))
- && flag_m512)
- || ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen))
- && flag_arb)) {
- backend = static_cast(env_backend);
+ if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) &&
+ flag_m512) ||
+ ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) &&
+ flag_arb)) {
+ backend = static_cast(env_backend);
}
}
- if (cudnn_runtime_version < 8901
- && backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
+ if (cudnn_runtime_version < 8901 &&
+ backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+."
- " Please upgrade your cuDNN version if possible." << std::endl;
+ " Please upgrade your cuDNN version if possible."
+ << std::endl;
}
- if (cudnn_runtime_version < 8900
- && backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
+ if (cudnn_runtime_version < 8900 &&
+ backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.0+."
- " Please upgrade your cuDNN version if possible." << std::endl;
+ " Please upgrade your cuDNN version if possible."
+ << std::endl;
}
} else {
backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
@@ -211,49 +194,40 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
}
// NVTE fused attention FWD with packed QKV
-void nvte_fused_attn_fwd_qkvpacked(
- const NVTETensor QKV,
- const NVTETensor Bias,
- NVTETensor S,
- NVTETensor O,
- NVTETensorPack* Aux_CTX_Tensors,
- const NVTETensor cu_seqlens,
- const NVTETensor seq_offsets_q,
- const NVTETensor seq_offsets_k,
- const NVTETensor seq_offsets_v,
- const NVTETensor seq_offsets_o,
- const NVTETensor rng_state,
- size_t max_seqlen,
- bool is_training, float attn_scale, float dropout,
- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type attn_mask_type,
- NVTETensor workspace,
- cudaStream_t stream) {
+void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S,
+ NVTETensor O, NVTETensorPack *Aux_CTX_Tensors,
+ const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q,
+ const NVTETensor seq_offsets_k, const NVTETensor seq_offsets_v,
+ const NVTETensor seq_offsets_o, const NVTETensor rng_state,
+ size_t max_seqlen, bool is_training, float attn_scale,
+ float dropout, NVTE_QKV_Layout qkv_layout,
+ NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
+ NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
- const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens);
- const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q);
- const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k);
- const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v);
- const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o);
- const Tensor *input_rng_state = reinterpret_cast(rng_state);
- const Tensor *input_QKV = reinterpret_cast(QKV);
- const Tensor *input_Bias = reinterpret_cast(Bias);
- Tensor *input_output_S = reinterpret_cast(S);
- Tensor *output_O = reinterpret_cast(O);
- Tensor *wkspace = reinterpret_cast(workspace);
+ const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens);
+ const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q);
+ const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k);
+ const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v);
+ const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o);
+ const Tensor *input_rng_state = reinterpret_cast(rng_state);
+ const Tensor *input_QKV = reinterpret_cast(QKV);
+ const Tensor *input_Bias = reinterpret_cast(Bias);
+ Tensor *input_output_S = reinterpret_cast(S);
+ Tensor *output_O = reinterpret_cast(O);
+ Tensor *wkspace = reinterpret_cast(workspace);
auto ndim = input_QKV->data.shape.size();
size_t b = input_cu_seqlens->data.shape[0] - 1;
size_t h = 0;
NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout);
if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) {
- h = input_QKV->data.shape[ndim - 2];
+ h = input_QKV->data.shape[ndim - 2];
} else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) {
- h = input_QKV->data.shape[ndim - 3];
+ h = input_QKV->data.shape[ndim - 3];
} else {
- NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
+ NVTE_ERROR("nvte_fused_attn_fwd_qkvpacked only supports H3D and 3HD layouts!");
}
size_t d = input_QKV->data.shape[ndim - 1];
@@ -261,49 +235,35 @@ void nvte_fused_attn_fwd_qkvpacked(
const NVTEDType QKV_type = static_cast(input_QKV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
- nvte_get_fused_attn_backend(
- QKV_type, QKV_type,
- qkv_layout, bias_type, attn_mask_type,
- dropout, h, h, max_seqlen, max_seqlen, d);
+ nvte_get_fused_attn_backend(QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type,
+ dropout, h, h, max_seqlen, max_seqlen, d);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
- fused_attn_max_512_fwd_qkvpacked(
- b, h, max_seqlen, d,
- is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
- input_QKV, input_Bias, output_O,
- Aux_CTX_Tensors,
- input_cu_seqlens,
- input_rng_state,
- wkspace, stream, handle);
+ fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout,
+ qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias,
+ output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state,
+ wkspace, stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) {
#if (CUDNN_VERSION >= 8900)
- fused_attn_arbitrary_seqlen_fwd_qkvpacked(
- b, h, max_seqlen, d,
- is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
- input_QKV, input_Bias, output_O,
- Aux_CTX_Tensors,
- input_cu_seqlens,
- input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
- input_rng_state,
- wkspace, stream, handle);
+ fused_attn_arbitrary_seqlen_fwd_qkvpacked(
+ b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, bias_type,
+ attn_mask_type, input_QKV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens,
+ input_seq_offsets_q, input_seq_offsets_k, input_seq_offsets_v, input_seq_offsets_o,
+ input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
- "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
+ "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
- fused_attn_fp8_fwd_qkvpacked(
- b, h, max_seqlen, d,
- is_training, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type,
- input_QKV, input_output_S, output_O,
- Aux_CTX_Tensors,
- input_cu_seqlens,
- input_rng_state,
- wkspace, stream, handle);
+ fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout,
+ bias_type, attn_mask_type, input_QKV, input_output_S, output_O,
+ Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace,
+ stream, handle);
#else
NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n");
#endif
@@ -313,52 +273,39 @@ void nvte_fused_attn_fwd_qkvpacked(
}
// NVTE fused attention BWD with packed QKV
void nvte_fused_attn_bwd_qkvpacked(
- const NVTETensor QKV,
- const NVTETensor O,
- const NVTETensor dO,
- const NVTETensor S,
- NVTETensor dP,
- const NVTETensorPack* Aux_CTX_Tensors,
- NVTETensor dQKV,
- NVTETensor dBias,
- const NVTETensor cu_seqlens,
- const NVTETensor seq_offsets_q,
- const NVTETensor seq_offsets_k,
- const NVTETensor seq_offsets_v,
- const NVTETensor seq_offsets_o,
- size_t max_seqlen,
- float attn_scale, float dropout,
- NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
- NVTE_Mask_Type attn_mask_type,
- NVTETensor workspace,
- cudaStream_t stream) {
+ const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S,
+ NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias,
+ const NVTETensor cu_seqlens, const NVTETensor seq_offsets_q, const NVTETensor seq_offsets_k,
+ const NVTETensor seq_offsets_v, const NVTETensor seq_offsets_o, size_t max_seqlen,
+ float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
+ NVTE_Mask_Type attn_mask_type, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
- const Tensor *input_cu_seqlens = reinterpret_cast(cu_seqlens);
- const Tensor *input_seq_offsets_q = reinterpret_cast(seq_offsets_q);
- const Tensor *input_seq_offsets_k = reinterpret_cast(seq_offsets_k);
- const Tensor *input_seq_offsets_v = reinterpret_cast(seq_offsets_v);
- const Tensor *input_seq_offsets_o = reinterpret_cast(seq_offsets_o);
- const Tensor *input_QKV = reinterpret_cast(QKV);
- const Tensor *input_O = reinterpret_cast(O);
- const Tensor *input_dO = reinterpret_cast(dO);
- const Tensor *input_S = reinterpret_cast(S);
- Tensor *input_output_dP = reinterpret_cast(dP);
- Tensor *output_dQKV = reinterpret_cast(dQKV);
- Tensor *output_dBias = reinterpret_cast(dBias);
- Tensor *wkspace = reinterpret_cast(workspace);
+ const Tensor *input_cu_seqlens = reinterpret_cast